Files
AI-CS/backend/service/ai_service.go
T
2026-03-25 18:50:58 +08:00

790 lines
28 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/2930134478/AI-CS/backend/infra"
"github.com/2930134478/AI-CS/backend/infra/search"
"github.com/2930134478/AI-CS/backend/models"
"github.com/2930134478/AI-CS/backend/repository"
"github.com/2930134478/AI-CS/backend/service/rag"
"github.com/2930134478/AI-CS/backend/utils"
"gorm.io/gorm"
)
// AIService AI 服务(负责调用 AI 生成回复)
type AIService struct {
aiConfigRepo *repository.AIConfigRepository
messageRepo *repository.MessageRepository
conversationRepo *repository.ConversationRepository
retrievalService *rag.RetrievalService
providerFactory *AIProviderFactory
webSearchProvider search.WebSearchProvider // 可选,自建联网时用
embeddingConfigSvc *EmbeddingConfigService // 读取联网方式:厂商内置 / 自建
promptConfigSvc *PromptConfigService // 可选,提示词配置(为空则用代码内默认)
storageService infra.StorageService // 可选,用于多模态识图时读取消息附件
systemLogSvc *SystemLogService // 可选,结构化日志服务
}
// NewAIService 创建 AI 服务实例。webSearchProvider、storageService 可为 nil。
func NewAIService(
aiConfigRepo *repository.AIConfigRepository,
messageRepo *repository.MessageRepository,
conversationRepo *repository.ConversationRepository,
retrievalService *rag.RetrievalService,
webSearchProvider search.WebSearchProvider,
embeddingConfigSvc *EmbeddingConfigService,
promptConfigSvc *PromptConfigService,
storageService infra.StorageService,
systemLogSvc *SystemLogService,
) *AIService {
return &AIService{
aiConfigRepo: aiConfigRepo,
messageRepo: messageRepo,
conversationRepo: conversationRepo,
retrievalService: retrievalService,
providerFactory: NewAIProviderFactory(),
webSearchProvider: webSearchProvider,
embeddingConfigSvc: embeddingConfigSvc,
promptConfigSvc: promptConfigSvc,
storageService: storageService,
systemLogSvc: systemLogSvc,
}
}
// GenerateAIResponse 为对话生成 AI 回复(兼容旧调用,使用默认数据源选项)。
// 返回: AI 回复内容,若失败返回错误。
func (s *AIService) GenerateAIResponse(conversationID uint, userMessage string, userID uint) (string, error) {
res, err := s.GenerateAIResponseWithOptions(conversationID, userMessage, userID, nil)
if err != nil {
return "", err
}
return res.Content, nil
}
// GenerateAIResponseWithOptions 根据数据源开关生成一条合成回复,并返回使用的来源标记。
// opts 为 nil 时使用默认:知识库+大模型开,联网关。
func (s *AIService) GenerateAIResponseWithOptions(conversationID uint, userMessage string, userID uint, opts *GenerateAIResponseInput) (*GenerateAIResponseResult, error) {
useKB := true
useLLM := true
useWeb := false
needWeb := false
if opts != nil {
if opts.UseKnowledgeBase != nil {
useKB = *opts.UseKnowledgeBase
}
if opts.UseLLM != nil {
useLLM = *opts.UseLLM
}
if opts.UseWebSearch != nil {
useWeb = *opts.UseWebSearch
}
needWeb = opts.NeedWebSearch
}
conversation, err := s.conversationRepo.GetByID(conversationID)
if err != nil {
return nil, fmt.Errorf("获取对话失败: %v", err)
}
// 以下 config 为「AI 配置」:对话/联网均使用此接口;与「知识库向量配置」(embedding,如 nekoai)无关。
var config *models.AIConfig
if conversation.AIConfigID != nil {
config, err = s.aiConfigRepo.GetByID(*conversation.AIConfigID)
if err != nil {
return nil, fmt.Errorf("获取 AI 配置失败: %v", err)
}
if !config.IsActive {
return nil, errors.New("该模型配置已禁用")
}
} else {
config, err = s.aiConfigRepo.GetActiveByUserID(userID, "text")
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("未找到 AI 配置,请先在设置中配置 AI 服务")
}
return nil, fmt.Errorf("获取 AI 配置失败: %v", err)
}
}
apiKey, err := utils.DecryptAPIKey(config.APIKey)
if err != nil {
return nil, fmt.Errorf("解密 API Key 失败: %v", err)
}
// 若当前 AI 配置为生图模型(model_type=image),则直接走生图逻辑,
// 不参与 RAG/联网与文本对话流程。前端仍显示在「AI 客服」渠道下。
if config.ModelType == "image" {
log.Printf("[生图] 对话ID=%d 使用 model_type=image 配置 id=%d,走 GenerateImageReply", conversationID, config.ID)
return s.GenerateImageReply(conversationID, userMessage, userID)
}
// 调试:确认本条对话实际使用的 AI 配置(便于排查联网/厂商内置是否走对接口)
if needWeb || useWeb {
convAIConfigID := "nil"
if conversation.AIConfigID != nil {
convAIConfigID = fmt.Sprintf("%d", *conversation.AIConfigID)
}
apiURLMask := config.APIURL
if len(apiURLMask) > 50 {
apiURLMask = apiURLMask[:50] + "..."
}
log.Printf("[联网] 对话ID=%d 使用的AI配置: conversation.ai_config_id=%s, config.id=%d, provider=%s, api_url=%s",
conversationID, convAIConfigID, config.ID, config.Provider, apiURLMask)
}
history, err := s.buildConversationHistory(conversationID)
if err != nil {
log.Printf("⚠️ 获取对话历史失败: %v", err)
history = []MessageHistory{}
}
// 多模态识图:当前条带图时读取文件并转 base64 供 provider 使用
var imageBase64, imageMimeType string
if opts != nil && opts.Attachment != nil && opts.Attachment.FileType == "image" && opts.Attachment.FileURL != "" && s.storageService != nil {
data, err := s.storageService.ReadMessageFile(opts.Attachment.FileURL)
if err != nil {
log.Printf("⚠️ 读取消息图片失败: %v", err)
} else {
imageBase64 = base64.StdEncoding.EncodeToString(data)
imageMimeType = opts.Attachment.MimeType
if imageMimeType == "" {
imageMimeType = "image/jpeg"
}
}
}
var ragContext string
ragStartedAt := time.Now()
if useKB && s.retrievalService != nil {
ragContext, err = s.retrieveRAGContext(context.Background(), userMessage, conversation)
if err != nil {
log.Printf("⚠️ RAG 检索失败: %v", err)
}
if s.systemLogSvc != nil {
hit := strings.TrimSpace(ragContext) != ""
convID := conversationID
uID := userID
_ = s.systemLogSvc.Create(CreateSystemLogInput{
Level: "info",
Category: "rag",
Event: "rag_context_result",
Source: "backend",
ConversationID: &convID,
UserID: &uID,
Message: "RAG 检索完成",
Meta: map[string]interface{}{
"hit": hit,
"context_len": len(ragContext),
"elapsed_ms": time.Since(ragStartedAt).Milliseconds(),
"use_kb": useKB,
"need_web": needWeb,
"use_web": useWeb,
},
})
}
}
var adapterConfig *AdapterConfig
if config.AdapterConfig != "" {
_ = json.Unmarshal([]byte(config.AdapterConfig), &adapterConfig)
}
aiConfig := AIConfig{
APIURL: config.APIURL,
APIKey: apiKey,
Model: config.Model,
ModelType: config.ModelType,
Provider: config.Provider,
AdapterConfig: adapterConfig,
}
provider, err := s.providerFactory.CreateProvider(aiConfig)
if err != nil {
return nil, fmt.Errorf("创建 AI 提供商失败: %v", err)
}
var sources []string
enhancedMessage := userMessage
// 1) 有知识库匹配:以知识库为主生成;若本回合允许联网,则用增强 prompt + 联网工具,由模型在无关/不足时用自身知识或联网
if ragContext != "" {
sources = append(sources, "knowledge_base")
if needWeb && useWeb {
webSource := "custom"
if s.embeddingConfigSvc != nil {
webSource, _ = s.embeddingConfigSvc.GetWebSearchSource()
}
enhancedMessage = s.buildRAGPromptWithWebOptional(userMessage, ragContext)
content, usedWeb, err := s.generateWithWebTools(context.Background(), provider, history, enhancedMessage, webSource, imageBase64, imageMimeType)
if err != nil {
log.Printf("⚠️ RAG+联网(function calling)失败: %v,回退到仅 RAG", err)
if s.systemLogSvc != nil {
_ = s.systemLogSvc.Create(CreateSystemLogInput{
Level: "warn",
Category: "ai",
Event: "rag_web_fallback",
Source: "backend",
ConversationID: &conversationID,
UserID: &userID,
Message: "RAG+联网失败,回退到仅RAG",
Meta: map[string]interface{}{
"error": err.Error(),
"web_source": webSource,
"ai_config": config.ID,
},
})
}
if webSource == "vendor" && (strings.Contains(err.Error(), "web_search") || strings.Contains(err.Error(), "Supported values")) {
log.Printf("💡 提示:当前对话使用的 AI 配置接口不支持 type \"web_search\"。若需联网,请改用支持该能力的模型(如 Poixe),或在设置中将联网方式改为「自建」并配置 SERPER_API_KEY。")
}
enhancedMessage = s.buildRAGPrompt(userMessage, ragContext)
} else if content != "" {
sources = append(sources, "llm")
if usedWeb {
sources = append(sources, "web")
}
if s.systemLogSvc != nil {
convID := conversationID
uID := userID
_ = s.systemLogSvc.Create(CreateSystemLogInput{
Level: "info",
Category: "ai",
Event: "ai_web_success",
Source: "backend",
ConversationID: &convID,
UserID: &uID,
Message: "RAG+联网生成成功",
Meta: map[string]interface{}{
"sources": strings.Join(sources, ","),
},
})
}
return &GenerateAIResponseResult{
Content: content,
SourcesUsed: strings.Join(sources, ","),
}, nil
} else {
enhancedMessage = s.buildRAGPrompt(userMessage, ragContext)
}
} else {
enhancedMessage = s.buildRAGPrompt(userMessage, ragContext)
}
} else {
// 2) 无知识库匹配:本回合允许联网时走「模型决定搜」function calling;否则仅用大模型知识
if needWeb && useWeb {
webSource := "custom"
if s.embeddingConfigSvc != nil {
webSource, _ = s.embeddingConfigSvc.GetWebSearchSource()
}
content, usedWeb, err := s.generateWithWebTools(context.Background(), provider, history, userMessage, webSource, imageBase64, imageMimeType)
if err != nil {
log.Printf("⚠️ 联网(function calling)失败: %v,回退到仅大模型", err)
if s.systemLogSvc != nil {
_ = s.systemLogSvc.Create(CreateSystemLogInput{
Level: "warn",
Category: "ai",
Event: "web_fallback_to_llm",
Source: "backend",
ConversationID: &conversationID,
UserID: &userID,
Message: "联网失败,回退到仅大模型",
Meta: map[string]interface{}{
"error": err.Error(),
"web_source": webSource,
"ai_config": config.ID,
},
})
}
if webSource == "vendor" && (strings.Contains(err.Error(), "web_search") || strings.Contains(err.Error(), "Supported values")) {
log.Printf("💡 提示:当前对话使用的 AI 配置接口不支持 type \"web_search\"。若需联网,请改用支持该能力的模型(如 Poixe),或在设置中将联网方式改为「自建」并配置 SERPER_API_KEY。")
}
} else if content != "" {
sources = append(sources, "llm")
if usedWeb {
sources = append(sources, "web")
}
if s.systemLogSvc != nil {
convID := conversationID
uID := userID
_ = s.systemLogSvc.Create(CreateSystemLogInput{
Level: "info",
Category: "ai",
Event: "ai_web_success",
Source: "backend",
ConversationID: &convID,
UserID: &uID,
Message: "联网生成成功",
Meta: map[string]interface{}{
"sources": strings.Join(sources, ","),
},
})
}
return &GenerateAIResponseResult{
Content: content,
SourcesUsed: strings.Join(sources, ","),
}, nil
}
}
if useLLM && len(sources) == 0 {
enhancedMessage = s.buildNoKBPrompt(userMessage)
sources = append(sources, "llm")
} else if useLLM && len(sources) > 0 {
sources = append(sources, "llm")
}
}
// 无任何来源时(例如 useKB 且无匹配,useLLM 关):使用可配置回复语
if len(sources) == 0 {
reply := s.getNoSourceReply()
return &GenerateAIResponseResult{
Content: reply,
SourcesUsed: "",
}, nil
}
response, err := provider.GenerateResponse(history, enhancedMessage, imageBase64, imageMimeType)
if err != nil {
log.Printf("❌ AI 调用失败: %v", err)
if s.systemLogSvc != nil {
_ = s.systemLogSvc.Create(CreateSystemLogInput{
Level: "error",
Category: "ai",
Event: "ai_generate_failed",
Source: "backend",
ConversationID: &conversationID,
UserID: &userID,
Message: "AI 调用失败,返回兜底回复",
Meta: map[string]interface{}{
"error": err.Error(),
"ai_config": config.ID,
},
})
}
return &GenerateAIResponseResult{
Content: s.getAIFailReply(),
SourcesUsed: strings.Join(sources, ","),
GenerationFailed: true,
}, nil
}
if s.systemLogSvc != nil {
convID := conversationID
uID := userID
event := "ai_llm_success"
if strings.Contains(strings.Join(sources, ","), "knowledge_base") {
event = "ai_rag_success"
}
_ = s.systemLogSvc.Create(CreateSystemLogInput{
Level: "info",
Category: "ai",
Event: event,
Source: "backend",
ConversationID: &convID,
UserID: &uID,
Message: "AI 生成成功",
Meta: map[string]interface{}{
"sources": strings.Join(sources, ","),
},
})
}
return &GenerateAIResponseResult{
Content: response,
SourcesUsed: strings.Join(sources, ","),
}, nil
}
// GenerateImageReply 生图渠道专用:根据用户描述生成图片并保存到存储,返回说明文案与图片 URL。
func (s *AIService) GenerateImageReply(conversationID uint, prompt string, userID uint) (*GenerateAIResponseResult, error) {
conversation, err := s.conversationRepo.GetByID(conversationID)
if err != nil {
return nil, fmt.Errorf("获取对话失败: %v", err)
}
if conversation.AIConfigID == nil {
return nil, errors.New("生图渠道需要选择生图模型,请先在渠道中选择「生图绘画」并选择模型")
}
config, err := s.aiConfigRepo.GetByID(*conversation.AIConfigID)
if err != nil {
return nil, fmt.Errorf("获取 AI 配置失败: %v", err)
}
if !config.IsActive {
return nil, errors.New("该生图模型已禁用")
}
if config.ModelType != "image" {
return nil, fmt.Errorf("当前选择的不是生图模型,model_type=%s", config.ModelType)
}
apiKey, err := utils.DecryptAPIKey(config.APIKey)
if err != nil {
return nil, fmt.Errorf("解密 API Key 失败: %v", err)
}
var adapterConfig *AdapterConfig
if config.AdapterConfig != "" {
_ = json.Unmarshal([]byte(config.AdapterConfig), &adapterConfig)
}
aiConfig := AIConfig{
APIURL: config.APIURL,
APIKey: apiKey,
Model: config.Model,
ModelType: config.ModelType,
Provider: config.Provider,
AdapterConfig: adapterConfig,
}
provider, err := s.providerFactory.CreateProvider(aiConfig)
if err != nil {
return nil, err
}
imgProvider, ok := provider.(ImageGenerationProvider)
if !ok {
return nil, errors.New("当前提供商不支持生图")
}
imageData, mimeType, err := imgProvider.GenerateImage(prompt)
if err != nil {
return nil, err
}
if s.storageService == nil {
return nil, errors.New("存储服务未配置,无法保存生成图片")
}
ext := ".png"
if strings.Contains(mimeType, "jpeg") || strings.Contains(mimeType, "jpg") {
ext = ".jpg"
}
fileURL, err := s.storageService.SaveMessageFile(conversationID, bytes.NewReader(imageData), "generated"+ext)
if err != nil {
return nil, fmt.Errorf("保存生成图片失败: %v", err)
}
content := "已根据您的描述生成图片。"
return &GenerateAIResponseResult{
Content: content,
SourcesUsed: "",
GeneratedFileURL: &fileURL,
}, nil
}
func (s *AIService) buildNoKBPrompt(userMessage string) string {
if s.promptConfigSvc != nil {
tpl, err := s.promptConfigSvc.GetNoKBPromptTemplate()
if err == nil && tpl != "" {
return replaceUserMessageOnly(tpl, userMessage)
}
}
return fmt.Sprintf(`你是一个智能客服助手。当前未使用知识库,请仅基于你的知识回答用户问题。
用户问题:%s
请简洁、友好地回答。若无法回答,可建议用户联系人工客服。`, userMessage)
}
func (s *AIService) buildWebSearchPrompt(userMessage string, webContext string) string {
if s.promptConfigSvc != nil {
tpl, err := s.promptConfigSvc.GetWebSearchResultPromptTemplate()
if err == nil && tpl != "" {
return replaceWebSearchPlaceholders(tpl, webContext, userMessage)
}
}
return fmt.Sprintf(`你是一个智能客服助手。请结合以下联网搜索结果回答用户问题。
联网搜索结果:
%s
用户问题:%s
请基于以上内容给出简洁、准确的回答。`, webContext, userMessage)
}
// replaceUserMessageOnly 仅替换 {{user_message}}
func replaceUserMessageOnly(template, userMessage string) string {
return strings.ReplaceAll(template, "{{user_message}}", userMessage)
}
// replaceWebSearchPlaceholders 替换 {{web_context}}、{{user_message}}
func replaceWebSearchPlaceholders(template, webContext, userMessage string) string {
template = strings.ReplaceAll(template, "{{web_context}}", webContext)
template = strings.ReplaceAll(template, "{{user_message}}", userMessage)
return template
}
// getNoSourceReply 无任何来源时返回给用户的一句话(可配置)
func (s *AIService) getNoSourceReply() string {
if s.promptConfigSvc != nil {
reply, err := s.promptConfigSvc.GetNoSourceReply()
if err == nil && strings.TrimSpace(reply) != "" {
return strings.TrimSpace(reply)
}
}
return "当前知识库暂无与此问题相关的内容,您可以尝试联系人工客服。"
}
// getAIFailReply AI 调用失败时返回给用户的一句话(可配置)
func (s *AIService) getAIFailReply() string {
if s.promptConfigSvc != nil {
reply, err := s.promptConfigSvc.GetAIFailReply()
if err == nil && strings.TrimSpace(reply) != "" {
return strings.TrimSpace(reply)
}
}
return "AI客服好像出了点差错,请联系人工客服解决"
}
// buildConversationHistory 构建对话历史(用于 AI 上下文)。
func (s *AIService) buildConversationHistory(conversationID uint) ([]MessageHistory, error) {
// 获取最近的对话消息(最多 10 条,避免上下文过长)
messages, err := s.messageRepo.ListByConversationID(conversationID)
if err != nil {
return nil, err
}
// 只取最近 10 条消息
startIdx := 0
if len(messages) > 10 {
startIdx = len(messages) - 10
}
history := make([]MessageHistory, 0)
for i := startIdx; i < len(messages); i++ {
msg := messages[i]
// 跳过系统消息
if msg.MessageType == "system_message" {
continue
}
role := "user"
if msg.SenderIsAgent {
role = "assistant"
}
history = append(history, MessageHistory{
Role: role,
Content: msg.Content,
})
}
return history, nil
}
// retrieveRAGContext 从知识库中检索相关文档内容
// query: 用户查询文本
// conversation: 对话信息(可能包含知识库 ID)
// 返回: 检索到的文档内容(格式化后的字符串)
func (s *AIService) retrieveRAGContext(ctx context.Context, query string, conversation *models.Conversation) (string, error) {
// 确定知识库 ID(可以从对话中获取,或为空表示搜索所有知识库)
// TODO: 后续在 Conversation 模型增加 KnowledgeBaseID 字段
var knowledgeBaseID *uint
// knowledgeBaseID = conversation.KnowledgeBaseID // 暂时注释,等模型字段添加后启用
// 执行 RAG 检索(Top-K = 5,返回最相关的 5 个文档片段)
// 使用重排序优化检索结果
topK := 5
results, err := s.retrievalService.RetrieveWithRerank(ctx, query, topK, knowledgeBaseID)
if err != nil {
return "", fmt.Errorf("RAG 检索失败: %w", err)
}
if len(results) == 0 {
// 没有检索到相关文档
return "", nil
}
// 格式化检索结果
var contextParts []string
for i, result := range results {
// 只使用相似度较高的结果(Score 越小表示相似度越高)
// 如果使用余弦相似度,通常阈值在 0.7-0.9 之间
// 这里我们暂时不过滤,让所有结果都参与
contextParts = append(contextParts, fmt.Sprintf("文档片段 %d:\n%s", i+1, result.Content))
}
return strings.Join(contextParts, "\n\n"), nil
}
// buildRAGPrompt 构建包含 RAG 上下文的 Prompt
// userMessage: 用户原始消息
// ragContext: RAG 检索到的文档内容
// 返回: 增强后的用户消息(包含知识库上下文)。若已配置提示词服务则使用可配置模板(占位符 {{rag_context}}、{{user_message}}),否则使用代码内默认。
func (s *AIService) buildRAGPrompt(userMessage string, ragContext string) string {
if s.promptConfigSvc != nil {
tpl, err := s.promptConfigSvc.GetRAGPromptTemplate()
if err == nil && tpl != "" {
return replacePromptPlaceholders(tpl, ragContext, userMessage)
}
}
return s.buildRAGPromptFallback(userMessage, ragContext)
}
// buildRAGPromptFallback 代码内默认 RAG 提示词(与 prompt_config_service 默认一致,用于 promptConfigSvc 为空或出错时)
func (s *AIService) buildRAGPromptFallback(userMessage string, ragContext string) string {
return fmt.Sprintf(`你是一个智能客服助手,请基于以下知识库内容回答用户的问题。
知识库内容:
%s
用户问题:%s
请根据知识库内容回答用户的问题。如果知识库中没有相关信息,请礼貌地告知用户,并建议联系人工客服。
回答要求:
1. 基于知识库内容,提供准确、有用的回答
2. 如果知识库中有相关信息,请直接引用并解释
3. 如果知识库中没有相关信息,请诚实告知
4. 保持友好、专业的语气
5. 回答要简洁明了,避免冗长`, ragContext, userMessage)
}
// replacePromptPlaceholders 将模板中的 {{rag_context}}、{{user_message}} 替换为实际值
func replacePromptPlaceholders(template, ragContext, userMessage string) string {
template = strings.ReplaceAll(template, "{{rag_context}}", ragContext)
template = strings.ReplaceAll(template, "{{user_message}}", userMessage)
return template
}
// buildRAGPromptWithWebOptional 构建 RAG prompt,并允许在知识库无关或不足时用自身知识或联网。
// 与 buildRAGPrompt 区别:明确说明可先基于知识库,若无关/弱相关可基于自身知识,若仍不足可由模型决定是否联网(需配合传入 web_search 工具使用)。
func (s *AIService) buildRAGPromptWithWebOptional(userMessage string, ragContext string) string {
if s.promptConfigSvc != nil {
tpl, err := s.promptConfigSvc.GetRAGPromptWithWebOptionalTemplate()
if err == nil && tpl != "" {
return replacePromptPlaceholders(tpl, ragContext, userMessage)
}
}
return s.buildRAGPromptWithWebOptionalFallback(userMessage, ragContext)
}
// buildRAGPromptWithWebOptionalFallback 代码内默认(RAG+联网可选)
func (s *AIService) buildRAGPromptWithWebOptionalFallback(userMessage string, ragContext string) string {
return fmt.Sprintf(`你是一个智能客服助手。请优先基于以下知识库内容回答用户的问题。
知识库内容:
%s
用户问题:%s
回答要求:
1. 若知识库内容与问题明确相关,请基于知识库给出准确、简洁的回答。
2. 若知识库内容与问题无关或仅弱相关,可先基于你自身的知识回答,不必拘泥于知识库。
3. 若你自身知识仍不足以回答(例如需要最新资讯、实时数据),你可决定是否使用联网搜索获取信息后再回答。
4. 保持友好、专业,回答简洁明了。`, ragContext, userMessage)
}
const maxWebToolRounds = 5
// webSearchToolDefinition 返回 type: "function" 的 web_search 工具定义,仅用于「自建」联网(Serper 执行)。
func (s *AIService) webSearchToolDefinition() []map[string]interface{} {
return []map[string]interface{}{
{
"type": "function",
"function": map[string]interface{}{
"name": "web_search",
"description": "Search the web for current information. Use when you need up-to-date or external information to answer the user.",
"parameters": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]string{"type": "string", "description": "Search query"},
},
"required": []string{"query"},
},
},
},
}
}
// generateWithWebTools 使用 function calling 做联网(模型决定是否搜)。webSource: vendor / custom。
// 联网请求始终发往当前对话的「AI 配置」对话接口(与知识库向量配置/embedding 无关)。
// - vendor(模式一:厂商内置):在 tools 里传 type "web_search",由厂商在自家 API 内封装并执行搜索,无需自建。
// - custom(模式二:自建):在 tools 里传 type "function" 的自定义函数(如 web_search),由本服务调用 Serper 等执行并回填。
func (s *AIService) generateWithWebTools(ctx context.Context, provider AIProvider, history []MessageHistory, userMessage string, webSource string, imageBase64 string, imageMimeType string) (content string, usedWeb bool, err error) {
messages := s.historyToOpenAIMessages(history, userMessage, imageBase64, imageMimeType)
var tools []map[string]interface{}
useFunctionFormat := false
switch webSource {
case "vendor":
// 模式一:厂商内置,仅传 web_search,由厂商执行
tools = []map[string]interface{}{
{"type": "web_search"},
}
case "custom":
if s.webSearchProvider == nil {
return "", false, nil
}
useFunctionFormat = true
tools = s.webSearchToolDefinition()
default:
tools = nil
}
if len(tools) == 0 {
return "", false, nil
}
rounds := 0
for rounds < maxWebToolRounds {
rounds++
respContent, toolCalls, callErr := provider.GenerateResponseWithTools(messages, tools)
if callErr != nil {
return "", usedWeb, callErr
}
if len(toolCalls) == 0 {
return respContent, usedWeb, nil
}
if useFunctionFormat {
usedWeb = true
}
// 追加 assistant 消息(含 tool_calls
assistantMsg := map[string]interface{}{"role": "assistant", "content": respContent}
tcList := make([]map[string]interface{}, 0, len(toolCalls))
for _, tc := range toolCalls {
tcList = append(tcList, map[string]interface{}{
"id": tc.ID,
"type": "function",
"function": map[string]interface{}{"name": tc.Name, "arguments": tc.Arguments},
})
}
assistantMsg["tool_calls"] = tcList
messages = append(messages, assistantMsg)
for _, tc := range toolCalls {
toolResult := ""
if useFunctionFormat && tc.Name == "web_search" && s.webSearchProvider != nil {
var args struct {
Query string `json:"query"`
}
_ = json.Unmarshal([]byte(tc.Arguments), &args)
query := args.Query
if query == "" {
query = userMessage
}
toolResult, _ = s.webSearchProvider.Search(ctx, query)
}
messages = append(messages, map[string]interface{}{
"role": "tool",
"tool_call_id": tc.ID,
"content": toolResult,
})
}
}
return "", usedWeb, fmt.Errorf("联网工具调用超过 %d 轮", maxWebToolRounds)
}
func (s *AIService) historyToOpenAIMessages(history []MessageHistory, userMessage string, imageBase64 string, imageMimeType string) []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(history)+1)
for _, h := range history {
out = append(out, map[string]interface{}{"role": h.Role, "content": h.Content})
}
var lastContent interface{} = userMessage
if imageBase64 != "" {
dataURL := "data:" + imageMimeType + ";base64," + imageBase64
if imageMimeType == "" {
dataURL = "data:image/jpeg;base64," + imageBase64
}
lastContent = []map[string]interface{}{
{"type": "text", "text": userMessage},
{"type": "image_url", "image_url": map[string]string{"url": dataURL}},
}
}
out = append(out, map[string]interface{}{"role": "user", "content": lastContent})
return out
}