Files
AI-CS/backend/service/ai_service.go
T
2026-02-02 21:41:47 +08:00

237 lines
7.5 KiB
Go

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"github.com/2930134478/AI-CS/backend/models"
"github.com/2930134478/AI-CS/backend/repository"
"github.com/2930134478/AI-CS/backend/service/rag"
"github.com/2930134478/AI-CS/backend/utils"
"gorm.io/gorm"
)
// AIService AI 服务(负责调用 AI 生成回复)
type AIService struct {
aiConfigRepo *repository.AIConfigRepository
messageRepo *repository.MessageRepository
conversationRepo *repository.ConversationRepository
retrievalService *rag.RetrievalService // RAG 检索服务
providerFactory *AIProviderFactory
}
// NewAIService 创建 AI 服务实例。
func NewAIService(
aiConfigRepo *repository.AIConfigRepository,
messageRepo *repository.MessageRepository,
conversationRepo *repository.ConversationRepository,
retrievalService *rag.RetrievalService, // 添加 RAG 检索服务
) *AIService {
return &AIService{
aiConfigRepo: aiConfigRepo,
messageRepo: messageRepo,
conversationRepo: conversationRepo,
retrievalService: retrievalService,
providerFactory: NewAIProviderFactory(),
}
}
// GenerateAIResponse 为对话生成 AI 回复。
// conversationID: 对话ID
// userMessage: 用户消息
// userID: 用户ID(用于回退查找 AI 配置)
// 返回: AI 回复内容,如果失败返回错误
func (s *AIService) GenerateAIResponse(conversationID uint, userMessage string, userID uint) (string, error) {
// 1. 获取对话信息,优先使用对话绑定的 AI 配置
conversation, err := s.conversationRepo.GetByID(conversationID)
if err != nil {
return "", fmt.Errorf("获取对话失败: %v", err)
}
var config *models.AIConfig
if conversation.AIConfigID != nil {
// 使用对话绑定的配置(多厂商支持)
config, err = s.aiConfigRepo.GetByID(*conversation.AIConfigID)
if err != nil {
return "", fmt.Errorf("获取 AI 配置失败: %v", err)
}
// 验证配置是否启用
if !config.IsActive {
return "", errors.New("该模型配置已禁用")
}
} else {
// 回退:使用用户默认配置(向后兼容)
config, err = s.aiConfigRepo.GetActiveByUserID(userID, "text")
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", errors.New("未找到 AI 配置,请先在设置中配置 AI 服务")
}
return "", fmt.Errorf("获取 AI 配置失败: %v", err)
}
}
// 2. 解密 API Key
apiKey, err := utils.DecryptAPIKey(config.APIKey)
if err != nil {
return "", fmt.Errorf("解密 API Key 失败: %v", err)
}
// 3. 获取对话历史(用于上下文)
history, err := s.buildConversationHistory(conversationID)
if err != nil {
log.Printf("⚠️ 获取对话历史失败: %v", err)
// 即使获取历史失败,也继续处理(使用空历史)
history = []MessageHistory{}
}
// 4. RAG 检索:从知识库中检索相关文档
ragContext := ""
if s.retrievalService != nil {
ragContext, err = s.retrieveRAGContext(context.Background(), userMessage, conversation)
if err != nil {
log.Printf("⚠️ RAG 检索失败: %v,继续使用无知识库上下文", err)
// RAG 检索失败不影响主流程,继续处理
}
}
// 5. 构建增强的用户消息(包含 RAG 上下文)
enhancedUserMessage := userMessage
if ragContext != "" {
enhancedUserMessage = s.buildRAGPrompt(userMessage, ragContext)
}
// 6. 解析适配器配置(如果有)
var adapterConfig *AdapterConfig
if config.AdapterConfig != "" {
if err := json.Unmarshal([]byte(config.AdapterConfig), &adapterConfig); err != nil {
log.Printf("⚠️ 解析适配器配置失败: %v,使用默认配置", err)
}
}
// 7. 创建 AI 提供商
aiConfig := AIConfig{
APIURL: config.APIURL,
APIKey: apiKey,
Model: config.Model,
ModelType: config.ModelType,
Provider: config.Provider,
AdapterConfig: adapterConfig,
}
provider, err := s.providerFactory.CreateProvider(aiConfig)
if err != nil {
return "", fmt.Errorf("创建 AI 提供商失败: %v", err)
}
// 8. 调用 AI 生成回复(使用增强的消息)
response, err := provider.GenerateResponse(history, enhancedUserMessage)
if err != nil {
// AI 调用失败,返回友好的错误消息
log.Printf("❌ AI 调用失败: %v", err)
return "AI客服好像出了点差错,请联系人工客服解决", nil
}
return response, nil
}
// buildConversationHistory 构建对话历史(用于 AI 上下文)。
func (s *AIService) buildConversationHistory(conversationID uint) ([]MessageHistory, error) {
// 获取最近的对话消息(最多 10 条,避免上下文过长)
messages, err := s.messageRepo.ListByConversationID(conversationID)
if err != nil {
return nil, err
}
// 只取最近 10 条消息
startIdx := 0
if len(messages) > 10 {
startIdx = len(messages) - 10
}
history := make([]MessageHistory, 0)
for i := startIdx; i < len(messages); i++ {
msg := messages[i]
// 跳过系统消息
if msg.MessageType == "system_message" {
continue
}
role := "user"
if msg.SenderIsAgent {
role = "assistant"
}
history = append(history, MessageHistory{
Role: role,
Content: msg.Content,
})
}
return history, nil
}
// retrieveRAGContext 从知识库中检索相关文档内容
// query: 用户查询文本
// conversation: 对话信息(可能包含知识库 ID)
// 返回: 检索到的文档内容(格式化后的字符串)
func (s *AIService) retrieveRAGContext(ctx context.Context, query string, conversation *models.Conversation) (string, error) {
// 确定知识库 ID(可以从对话中获取,或为空表示搜索所有知识库)
// TODO: 后续在 Conversation 模型增加 KnowledgeBaseID 字段
var knowledgeBaseID *uint
// knowledgeBaseID = conversation.KnowledgeBaseID // 暂时注释,等模型字段添加后启用
// 执行 RAG 检索(Top-K = 5,返回最相关的 5 个文档片段)
// 使用重排序优化检索结果
topK := 5
results, err := s.retrievalService.RetrieveWithRerank(ctx, query, topK, knowledgeBaseID)
if err != nil {
return "", fmt.Errorf("RAG 检索失败: %w", err)
}
if len(results) == 0 {
// 没有检索到相关文档
return "", nil
}
// 格式化检索结果
var contextParts []string
for i, result := range results {
// 只使用相似度较高的结果(Score 越小表示相似度越高)
// 如果使用余弦相似度,通常阈值在 0.7-0.9 之间
// 这里我们暂时不过滤,让所有结果都参与
contextParts = append(contextParts, fmt.Sprintf("文档片段 %d:\n%s", i+1, result.Content))
}
return strings.Join(contextParts, "\n\n"), nil
}
// buildRAGPrompt 构建包含 RAG 上下文的 Prompt
// userMessage: 用户原始消息
// ragContext: RAG 检索到的文档内容
// 返回: 增强后的用户消息(包含知识库上下文)
func (s *AIService) buildRAGPrompt(userMessage string, ragContext string) string {
// 构建 RAG Prompt 模板
// 参考 PandaWiki 的 Prompt 格式
prompt := fmt.Sprintf(`你是一个智能客服助手,请基于以下知识库内容回答用户的问题。
知识库内容:
%s
用户问题:%s
请根据知识库内容回答用户的问题。如果知识库中没有相关信息,请礼貌地告知用户,并建议联系人工客服。
回答要求:
1. 基于知识库内容,提供准确、有用的回答
2. 如果知识库中有相关信息,请直接引用并解释
3. 如果知识库中没有相关信息,请诚实告知
4. 保持友好、专业的语气
5. 回答要简洁明了,避免冗长`, ragContext, userMessage)
return prompt
}