实用示例

常见AI应用场景的实用代码模式和示例。

实用配方

1. 聊天机器人循环与回退链

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content};
use std::collections::VecDeque;

pub struct Chatbot {
    clients: Vec<AiClient>,
    conversation_history: VecDeque<Message>,
    max_history: usize,
}

impl Chatbot {
    pub fn new(providers: Vec<Provider>) -> Result<Self, Box<dyn std::error::Error>> {
        let clients = providers
            .into_iter()
            .map(|provider| AiClient::new(provider))
            .collect::<Result<Vec<_>, _>>()?;
        
        Ok(Self {
            clients,
            conversation_history: VecDeque::new(),
            max_history: 10,
        })
    }
    
    pub async fn chat(&mut self, user_input: &str) -> Result<String, Box<dyn std::error::Error>> {
        // 添加用户消息到历史
        self.conversation_history.push_back(Message::user(Content::new_text(user_input.to_string())));
        
        // 保持历史长度
        if self.conversation_history.len() > self.max_history {
            self.conversation_history.pop_front();
        }
        
        // 尝试每个客户端,直到成功
        for client in &self.clients {
            let request = ChatCompletionRequest::new(
                client.default_chat_model(),
                self.conversation_history.iter().cloned().collect()
            );
            
            match client.chat_completion(request).await {
                Ok(response) => {
                    let assistant_message = response.first_text()?;
                    
                    // 添加助手回复到历史
                    self.conversation_history.push_back(Message {
                        role: Role::Assistant,
                        content: Content::new_text(assistant_message.clone()),
                        function_call: None,
                    });
                    
                    return Ok(assistant_message);
                }
                Err(e) => {
                    eprintln!("Provider failed: {}", e);
                    continue;
                }
            }
        }
        
        Err("All providers failed".into())
    }
}

2. 工具调用执行循环

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content, Tool, FunctionCallPolicy};
use serde_json::Value;

pub struct ToolExecutor {
    client: AiClient,
    tools: Vec<Tool>,
}

impl ToolExecutor {
    pub fn new(provider: Provider) -> Result<Self, Box<dyn std::error::Error>> {
        let client = AiClient::new(provider)?;
        Ok(Self {
            client,
            tools: Vec::new(),
        })
    }
    
    pub fn add_tool(&mut self, tool: Tool) {
        self.tools.push(tool);
    }
    
    pub async fn execute_with_tools(&self, user_input: &str) -> Result<String, Box<dyn std::error::Error>> {
        let mut conversation = vec![Message::user(Content::new_text(user_input.to_string()))];
        
        loop {
            let request = ChatCompletionRequest::new(
                self.client.default_chat_model(),
                conversation.clone()
            )
            .with_functions(self.tools.clone())
            .with_function_call(FunctionCallPolicy::Auto);
            
            let response = self.client.chat_completion(request).await?;
            
            if let Some(choice) = response.choices.first() {
                if let Some(function_call) = &choice.message.function_call {
                    // 执行工具调用
                    let tool_result = self.execute_tool(function_call).await?;
                    
                    // 添加工具调用和结果到对话
                    conversation.push(choice.message.clone());
                    conversation.push(Message {
                        role: Role::Assistant,
                        content: Content::new_json(tool_result),
                        function_call: None,
                    });
                } else {
                    // 没有工具调用,返回最终答案
                    return Ok(choice.message.content.as_text());
                }
            } else {
                return Err("No response from model".into());
            }
        }
    }
    
    async fn execute_tool(&self, function_call: &FunctionCall) -> Result<Value, Box<dyn std::error::Error>> {
        match function_call.name.as_str() {
            "get_weather" => {
                let location = function_call.arguments
                    .get("location")
                    .and_then(|v| v.as_str())
                    .unwrap_or("unknown");
                
                // 模拟天气API调用
                Ok(serde_json::json!({
                    "temperature": 22.5,
                    "condition": "sunny",
                    "location": location
                }))
            }
            "get_news" => {
                let topic = function_call.arguments
                    .get("topic")
                    .and_then(|v| v.as_str())
                    .unwrap_or("general");
                
                Ok(serde_json::json!({
                    "headlines": [
                        format!("Breaking: {} news update 1", topic),
                        format!("Latest: {} news update 2", topic),
                    ]
                }))
            }
            _ => Err(format!("Unknown tool: {}", function_call.name).into()),
        }
    }
}

3. 多模态图像+文本请求

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content};

pub struct MultimodalAnalyzer {
    client: AiClient,
}

impl MultimodalAnalyzer {
    pub fn new(provider: Provider) -> Result<Self, Box<dyn std::error::Error>> {
        let client = AiClient::new(provider)?;
        Ok(Self { client })
    }
    
    pub async fn analyze_image(&self, image_url: &str, prompt: &str) -> Result<String, Box<dyn std::error::Error>> {
        let message = Message::user(Content::new_image(
            Some(image_url.to_string()),
            Some("image/jpeg".to_string()),
            Some("image.jpg".to_string())
        ));
        
        let text_message = Message::user(Content::new_text(prompt.to_string()));
        
        let request = ChatCompletionRequest::new(
            self.client.default_chat_model(),
            vec![message, text_message]
        );
        
        let response = self.client.chat_completion(request).await?;
        Ok(response.first_text()?)
    }
    
    pub async fn compare_images(&self, image1_url: &str, image2_url: &str) -> Result<String, Box<dyn std::error::Error>> {
        let message1 = Message::user(Content::new_image(
            Some(image1_url.to_string()),
            Some("image/jpeg".to_string()),
            Some("image1.jpg".to_string())
        ));
        
        let message2 = Message::user(Content::new_image(
            Some(image2_url.to_string()),
            Some("image/jpeg".to_string()),
            Some("image2.jpg".to_string())
        ));
        
        let text_message = Message::user(Content::new_text("请比较这两张图片的相似性和差异"));
        
        let request = ChatCompletionRequest::new(
            self.client.default_chat_model(),
            vec![message1, message2, text_message]
        );
        
        let response = self.client.chat_completion(request).await?;
        Ok(response.first_text()?)
    }
}

4. 延迟竞争(对冲)示例

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content};
use tokio::time::{timeout, Duration};
use futures::future::select_all;

pub struct HedgingClient {
    primary_client: AiClient,
    fallback_clients: Vec<AiClient>,
    hedge_delay: Duration,
}

impl HedgingClient {
    pub fn new(primary: Provider, fallbacks: Vec<Provider>) -> Result<Self, Box<dyn std::error::Error>> {
        let primary_client = AiClient::new(primary)?;
        let fallback_clients = fallbacks
            .into_iter()
            .map(|provider| AiClient::new(provider))
            .collect::<Result<Vec<_>, _>>()?;
        
        Ok(Self {
            primary_client,
            fallback_clients,
            hedge_delay: Duration::from_millis(100),
        })
    }
    
    pub async fn chat_with_hedging(&self, user_input: &str) -> Result<String, Box<dyn std::error::Error>> {
        let request = ChatCompletionRequest::new(
            self.primary_client.default_chat_model(),
            vec![Message::user(Content::new_text(user_input.to_string()))]
        );
        
        // 立即启动主请求
        let primary_future = self.primary_client.chat_completion(request.clone());
        
        // 延迟启动回退请求
        let hedge_future = async {
            tokio::time::sleep(self.hedge_delay).await;
            
            let mut futures = Vec::new();
            for client in &self.fallback_clients {
                let future = client.chat_completion(request.clone());
                futures.push(future);
            }
            
            select_all(futures).await
        };
        
        // 选择最快完成的请求
        tokio::select! {
            result = primary_future => {
                match result {
                    Ok(response) => Ok(response.first_text()?),
                    Err(e) => {
                        eprintln!("Primary request failed: {}", e);
                        Err(e.into())
                    }
                }
            }
            result = hedge_future => {
                match result {
                    Ok((response, _, _)) => {
                        println!("Hedged request succeeded");
                        Ok(response.first_text()?)
                    }
                    Err(e) => {
                        eprintln!("All hedged requests failed: {}", e);
                        Err(e.into())
                    }
                }
            }
        }
    }
}

5. 成本感知路由策略

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content, ModelArray, LoadBalancingStrategy};
use std::collections::HashMap;

pub struct CostAwareRouter {
    model_array: ModelArray,
    cost_per_token: HashMap<String, f64>,
}

impl CostAwareRouter {
    pub fn new() -> Self {
        let mut model_array = ModelArray::new("cost-aware")
            .with_strategy(LoadBalancingStrategy::Weighted);
        
        // 添加不同成本的模型
        model_array.add_endpoint(ModelEndpoint {
            name: "gpt-4".into(),
            url: "https://api.openai.com".into(),
            weight: 0.3, // 高成本,低权重
            healthy: true,
        });
        
        model_array.add_endpoint(ModelEndpoint {
            name: "gpt-3.5-turbo".into(),
            url: "https://api.openai.com".into(),
            weight: 0.7, // 低成本,高权重
            healthy: true,
        });
        
        let mut cost_per_token = HashMap::new();
        cost_per_token.insert("gpt-4".to_string(), 0.00003);
        cost_per_token.insert("gpt-3.5-turbo".to_string(), 0.000002);
        
        Self {
            model_array,
            cost_per_token,
        }
    }
    
    pub async fn route_request(&self, request: &ChatCompletionRequest) -> Result<String, Box<dyn std::error::Error>> {
        // 根据请求复杂度选择模型
        let complexity = self.estimate_complexity(request);
        let model = self.select_model_by_complexity(complexity);
        
        let client = AiClient::with_model_array(self.model_array.clone())?;
        let response = client.chat_completion(request.clone()).await?;
        
        // 记录成本
        let cost = self.calculate_cost(&model, &response);
        println!("Request cost: ${:.4}", cost);
        
        Ok(response.first_text()?)
    }
    
    fn estimate_complexity(&self, request: &ChatCompletionRequest) -> f64 {
        let total_tokens = request.messages
            .iter()
            .map(|msg| msg.content.as_text().len())
            .sum::<usize>() as f64;
        
        // 简单的复杂度估算:基于令牌数量
        total_tokens / 1000.0
    }
    
    fn select_model_by_complexity(&self, complexity: f64) -> String {
        if complexity > 0.5 {
            "gpt-4".to_string() // 复杂请求使用GPT-4
        } else {
            "gpt-3.5-turbo".to_string() // 简单请求使用GPT-3.5
        }
    }
    
    fn calculate_cost(&self, model: &str, response: &ChatCompletionResponse) -> f64 {
        let cost_per_token = self.cost_per_token.get(model).unwrap_or(&0.0);
        let total_tokens = response.usage.total_tokens as f64;
        cost_per_token * total_tokens
    }
}

6. 批量处理优化

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content};
use tokio::sync::Semaphore;
use std::sync::Arc;

pub struct BatchProcessor {
    client: AiClient,
    semaphore: Arc<Semaphore>,
    batch_size: usize,
}

impl BatchProcessor {
    pub fn new(provider: Provider, max_concurrency: usize, batch_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
        let client = AiClient::new(provider)?;
        let semaphore = Arc::new(Semaphore::new(max_concurrency));
        
        Ok(Self {
            client,
            semaphore,
            batch_size,
        })
    }
    
    pub async fn process_batch(&self, requests: Vec<ChatCompletionRequest>) -> Result<Vec<String>, Box<dyn std::error::Error>> {
        let mut results = Vec::new();
        let mut chunks = requests.chunks(self.batch_size);
        
        for chunk in chunks {
            let chunk_results = self.process_chunk(chunk.to_vec()).await?;
            results.extend(chunk_results);
        }
        
        Ok(results)
    }
    
    async fn process_chunk(&self, requests: Vec<ChatCompletionRequest>) -> Result<Vec<String>, Box<dyn std::error::Error>> {
        let mut handles = Vec::new();
        
        for request in requests {
            let permit = self.semaphore.clone().acquire_owned().await?;
            let client = self.client.clone();
            
            let handle = tokio::spawn(async move {
                let _permit = permit;
                client.chat_completion(request).await
            });
            
            handles.push(handle);
        }
        
        let mut results = Vec::new();
        for handle in handles {
            let response = handle.await??;
            results.push(response.first_text()?);
        }
        
        Ok(results)
    }
}

7. 流式处理优化

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content};
use futures::stream::{self, StreamExt};
use tokio::io::{AsyncWrite, AsyncWriteExt};

pub struct StreamingProcessor {
    client: AiClient,
}

impl StreamingProcessor {
    pub fn new(provider: Provider) -> Result<Self, Box<dyn std::error::Error>> {
        let client = AiClient::new(provider)?;
        Ok(Self { client })
    }
    
    pub async fn stream_to_writer<W: AsyncWrite + Unpin>(
        &self,
        request: ChatCompletionRequest,
        writer: &mut W,
    ) -> Result<(), Box<dyn std::error::Error>> {
        let mut stream = self.client.chat_completion_stream(request).await?;
        
        while let Some(chunk) = stream.next().await {
            let chunk = chunk?;
            if let Some(delta) = chunk.choices[0].delta.content.clone() {
                writer.write_all(delta.as_bytes()).await?;
                writer.flush().await?;
            }
        }
        
        Ok(())
    }
    
    pub async fn stream_with_callback<F>(
        &self,
        request: ChatCompletionRequest,
        mut callback: F,
    ) -> Result<(), Box<dyn std::error::Error>>
    where
        F: FnMut(&str) -> Result<(), Box<dyn std::error::Error>>,
    {
        let mut stream = self.client.chat_completion_stream(request).await?;
        
        while let Some(chunk) = stream.next().await {
            let chunk = chunk?;
            if let Some(delta) = chunk.choices[0].delta.content.clone() {
                callback(&delta)?;
            }
        }
        
        Ok(())
    }
}

最佳实践

错误处理

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content};
use std::time::Duration;

pub async fn robust_chat_completion(
    client: &AiClient,
    request: ChatCompletionRequest,
    max_retries: usize,
) -> Result<String, Box<dyn std::error::Error>> {
    let mut last_error = None;
    
    for attempt in 0..max_retries {
        match timeout(Duration::from_secs(30), client.chat_completion(request.clone())).await {
            Ok(Ok(response)) => return Ok(response.first_text()?),
            Ok(Err(e)) => {
                last_error = Some(e);
                if attempt < max_retries - 1 {
                    tokio::time::sleep(Duration::from_millis(1000 * (attempt + 1) as u64)).await;
                }
            }
            Err(_) => {
                last_error = Some("Request timeout".into());
                if attempt < max_retries - 1 {
                    tokio::time::sleep(Duration::from_millis(1000 * (attempt + 1) as u64)).await;
                }
            }
        }
    }
    
    Err(last_error.unwrap_or("All retries failed".into()))
}

性能监控

use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Content};
use std::time::Instant;

pub async fn monitored_chat_completion(
    client: &AiClient,
    request: ChatCompletionRequest,
) -> Result<(String, Duration), Box<dyn std::error::Error>> {
    let start = Instant::now();
    
    let response = client.chat_completion(request).await?;
    let duration = start.elapsed();
    
    println!("Request completed in {:?}", duration);
    
    Ok((response.first_text()?, duration))
}

下一步

Build: 3de64ed · 2025-09-09T12:50:59.766Z · v0.21