package controller import ( "bytes" "io" "log" "net/http" "path/filepath" "strconv" "strings" "github.com/2930134478/AI-CS/backend/infra" "github.com/2930134478/AI-CS/backend/service" "github.com/gin-gonic/gin" ) // MessageController 负责处理消息相关的 HTTP 请求。 type MessageController struct { messageService *service.MessageService conversationService *service.ConversationService userService *service.UserService storageService infra.StorageService } // NewMessageController 创建 MessageController 实例。 func NewMessageController( messageService *service.MessageService, conversationService *service.ConversationService, userService *service.UserService, storageService infra.StorageService, ) *MessageController { return &MessageController{ messageService: messageService, conversationService: conversationService, userService: userService, storageService: storageService, } } type createMessageRequest struct { ConversationID uint `json:"conversation_id"` Content string `json:"content"` SenderIsAgent bool `json:"sender_is_agent"` SenderID uint `json:"sender_id"` FileURL *string `json:"file_url"` FileType *string `json:"file_type"` FileName *string `json:"file_name"` FileSize *int64 `json:"file_size"` MimeType *string `json:"mime_type"` // 回复数据源开关(仅 AI 模式有效),不传则默认:知识库+大模型开,联网关 UseKnowledgeBase *bool `json:"use_knowledge_base"` UseLLM *bool `json:"use_llm"` UseWebSearch *bool `json:"use_web_search"` NeedWebSearch bool `json:"need_web_search"` } // CreateMessage 处理发送消息的请求。 func (mc *MessageController) CreateMessage(c *gin.Context) { var req createMessageRequest if err := c.ShouldBindJSON(&req); err != nil || req.ConversationID == 0 { c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数错误"}) return } userID := getUserIDFromHeader(c) // 兼容 demo 自测场景:已登录客服也允许按访客身份发送消息(sender_is_agent=false)。 // 访客消息 sender_id 仍由服务端强制置 0,避免前端注入身份。 // 客服消息必须绑定当前登录用户(X-User-Id),并以服务端用户 ID 为准,避免伪造 sender_id。 if req.SenderIsAgent { if userID == 0 { c.JSON(http.StatusForbidden, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头"}) return } req.SenderID = userID if mc.userService != nil { // 按会话类型进行权限校验: // - visitor 会话:需要 chat 权限 // - internal 会话:需要 kb_test 权限,且仅会话创建者可发送 detail, err := mc.conversationService.GetConversationDetail(req.ConversationID, userID) if err != nil { c.JSON(http.StatusForbidden, gin.H{"error": "无权限访问该会话"}) return } if detail.ConversationType == "internal" { if detail.AgentID != userID { c.JSON(http.StatusForbidden, gin.H{"error": "仅内部会话创建者可发送消息"}) return } if err := mc.userService.CheckPermission(userID, string(service.PermKBTest)); err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } } else { if err := mc.userService.CheckPermission(userID, string(service.PermChat)); err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } } } } else { // 访客消息的 sender_id 统一由服务端置 0,避免前端注入。 req.SenderID = 0 } // 验证:必须有内容或文件 if req.Content == "" && req.FileURL == nil { c.JSON(http.StatusBadRequest, gin.H{"error": "消息内容或文件不能同时为空"}) return } msg, err := mc.messageService.CreateMessage(service.CreateMessageInput{ ConversationID: req.ConversationID, Content: req.Content, SenderID: req.SenderID, SenderIsAgent: req.SenderIsAgent, FileURL: req.FileURL, FileType: req.FileType, FileName: req.FileName, FileSize: req.FileSize, MimeType: req.MimeType, UseKnowledgeBase: req.UseKnowledgeBase, UseLLM: req.UseLLM, UseWebSearch: req.UseWebSearch, NeedWebSearch: req.NeedWebSearch, }) if err != nil { log.Printf("❌ 创建消息失败: 对话ID=%d, 错误=%v", req.ConversationID, err) switch err { case service.ErrConversationClosed: c.JSON(http.StatusBadRequest, gin.H{"error": "会话已关闭"}) case service.ErrConversationNotFound: c.JSON(http.StatusBadRequest, gin.H{"error": "会话不存在"}) default: c.JSON(http.StatusInternalServerError, gin.H{"error": "创建消息失败"}) } return } // 返回持久化后的完整消息:客服端/访客端可在发送成功后立即更新 UI,避免仅依赖 WebSocket 时出现「空了要等刷新」 c.JSON(http.StatusOK, msg) } // ListMessages 返回指定会话的消息列表。 // 查询参数: // - conversation_id: 会话ID(必需) // - include_ai_messages: 是否包含 AI 消息(可选,默认 false) func (mc *MessageController) ListMessages(c *gin.Context) { conversationIDStr := c.Query("conversation_id") if conversationIDStr == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "会话ID不能为空"}) return } conversationID, err := strconv.ParseUint(conversationIDStr, 10, 64) if err != nil || conversationID == 0 { c.JSON(http.StatusBadRequest, gin.H{"error": "会话ID不合法"}) return } if mc.userService != nil { userID := getUserIDFromHeader(c) detail, detailErr := mc.conversationService.GetConversationDetail(uint(conversationID), userID) if detailErr != nil && userID > 0 { c.JSON(http.StatusForbidden, gin.H{"error": "无权限访问该会话"}) return } if detail != nil { if detail.ConversationType == "internal" { if userID == 0 || detail.AgentID != userID { c.JSON(http.StatusForbidden, gin.H{"error": "无权限访问内部会话"}) return } if err := mc.userService.CheckPermission(userID, string(service.PermKBTest)); err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } } else if userID > 0 { if err := mc.userService.CheckPermission(userID, string(service.PermChat)); err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } } } } // 解析 include_ai_messages 参数(默认 false) includeAIMessages := c.DefaultQuery("include_ai_messages", "false") == "true" messages, err := mc.messageService.ListMessages(uint(conversationID), includeAIMessages) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败"}) return } c.JSON(http.StatusOK, messages) } type markMessagesReadRequest struct { ConversationID uint `json:"conversation_id"` ReaderIsAgent bool `json:"reader_is_agent"` } // MarkMessagesRead 将指定会话的消息标记为已读。 func (mc *MessageController) MarkMessagesRead(c *gin.Context) { var req markMessagesReadRequest if err := c.ShouldBindJSON(&req); err != nil || req.ConversationID == 0 { c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数错误"}) return } if mc.userService != nil { userID := getUserIDFromHeader(c) detail, detailErr := mc.conversationService.GetConversationDetail(req.ConversationID, userID) if detailErr != nil && userID > 0 { c.JSON(http.StatusForbidden, gin.H{"error": "无权限访问该会话"}) return } if detail != nil { if detail.ConversationType == "internal" { if userID == 0 || detail.AgentID != userID { c.JSON(http.StatusForbidden, gin.H{"error": "无权限访问内部会话"}) return } } if req.ReaderIsAgent { if userID == 0 { c.JSON(http.StatusForbidden, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头"}) return } if detail.ConversationType == "internal" { if err := mc.userService.CheckPermission(userID, string(service.PermKBTest)); err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } } else { if err := mc.userService.CheckPermission(userID, string(service.PermChat)); err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } } } } } result, err := mc.messageService.MarkMessagesRead(req.ConversationID, req.ReaderIsAgent) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "更新消息状态失败"}) return } c.JSON(http.StatusOK, gin.H{ "updated": len(result.MessageIDs), "message_ids": result.MessageIDs, "conversation_id": result.ConversationID, "unread_count": result.UnreadCount, "read_at": formatTimeValue(result.ReadAt), }) } // UploadFile 处理文件上传请求。 // 请求格式:multipart/form-data // - file: 文件内容(必需) // - conversation_id: 对话ID(可选,用于组织目录) // // 认证方式: // - 方式1:提供 X-User-Id 请求头(客服上传) // - 方式2:提供 conversation_id 参数(访客上传,会验证对话是否存在且未关闭) func (mc *MessageController) UploadFile(c *gin.Context) { // ⚠️ 认证检查:必须满足以下条件之一 // 1. 提供 X-User-Id 请求头(客服) // 2. 提供 conversation_id 参数(访客) userID := getUserIDFromHeader(c) conversationIDStr := c.PostForm("conversation_id") // 如果既没有用户ID,也没有对话ID,拒绝访问 if userID == 0 && conversationIDStr == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头或 conversation_id 参数"}) return } // 如果是访客上传(没有用户ID,但有对话ID),验证对话是否存在且未关闭 if userID == 0 && conversationIDStr != "" { convID, err := strconv.ParseUint(conversationIDStr, 10, 64) if err != nil || convID == 0 { c.JSON(http.StatusBadRequest, gin.H{"error": "对话ID不合法"}) return } // 验证对话是否存在且未关闭 conv, err := mc.conversationService.GetConversationDetail(uint(convID), 0) if err != nil { c.JSON(http.StatusForbidden, gin.H{"error": "对话不存在或已关闭"}) return } if conv.Status == "closed" { c.JSON(http.StatusForbidden, gin.H{"error": "对话已关闭"}) return } } // 解析文件 file, err := c.FormFile("file") if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "文件不能为空"}) return } // 验证文件大小(10MB) const maxFileSize = 10 * 1024 * 1024 // 10MB if file.Size > maxFileSize { c.JSON(http.StatusBadRequest, gin.H{"error": "文件大小超过限制(最大10MB)"}) return } // ⚠️ 加强:验证文件类型(扩展名) ext := strings.ToLower(filepath.Ext(file.Filename)) allowedExts := map[string]bool{ ".jpg": true, ".jpeg": true, ".png": true, ".gif": true, ".webp": true, ".pdf": true, ".doc": true, ".docx": true, ".txt": true, } if !allowedExts[ext] { c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的文件类型"}) return } // ⚠️ 加强:验证 MIME 类型(防止伪造扩展名) mimeType := file.Header.Get("Content-Type") allowedMimeTypes := map[string]bool{ "image/jpeg": true, "image/jpg": true, "image/png": true, "image/gif": true, "image/webp": true, "application/pdf": true, "application/msword": true, "application/vnd.openxmlformats-officedocument.wordprocessingml.document": true, // .docx "text/plain": true, } if !allowedMimeTypes[mimeType] { c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的文件 MIME 类型: " + mimeType}) return } // ⚠️ 加强:清理文件名,防止路径遍历攻击 safeFilename := filepath.Base(file.Filename) safeFilename = strings.ReplaceAll(safeFilename, "..", "") safeFilename = strings.ReplaceAll(safeFilename, "/", "") safeFilename = strings.ReplaceAll(safeFilename, "\\", "") // 移除所有非字母数字、点、下划线、连字符的字符 var cleaned strings.Builder for _, r := range safeFilename { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' { cleaned.WriteRune(r) } } safeFilename = cleaned.String() // 限制文件名长度 if len(safeFilename) > 100 { // 保留扩展名 ext := filepath.Ext(safeFilename) nameWithoutExt := strings.TrimSuffix(safeFilename, ext) if len(nameWithoutExt) > 100-len(ext) { safeFilename = nameWithoutExt[:100-len(ext)] + ext } } // ⚠️ 加强:验证文件内容(magic number 检查,防止伪造扩展名) fileContent, err := file.Open() if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件"}) return } defer fileContent.Close() // 读取文件前几个字节(magic number) magicBytes := make([]byte, 12) n, err := fileContent.Read(magicBytes) if err != nil && err != io.EOF { c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容"}) return } // 验证文件内容是否匹配扩展名 if !isValidFileContent(ext, magicBytes[:n]) { c.JSON(http.StatusBadRequest, gin.H{"error": "文件内容与扩展名不匹配,可能是伪造的文件类型"}) return } // 重置文件指针,以便后续保存 if _, err := fileContent.Seek(0, io.SeekStart); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "无法重置文件指针"}) return } // 获取对话ID(如果之前已经解析过,直接使用;否则从表单获取) var conversationID uint if conversationIDStr != "" { if id, err := strconv.ParseUint(conversationIDStr, 10, 64); err == nil { conversationID = uint(id) } } // 保存文件(使用清理后的文件名,fileContent 已经在上面打开并验证过) fileURL, err := mc.storageService.SaveMessageFile(conversationID, fileContent, safeFilename) if err != nil { log.Printf("❌ 保存文件失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "保存文件失败"}) return } // 判断文件类型 fileType := "document" if strings.HasPrefix(mimeType, "image/") { fileType = "image" } // 返回文件信息(使用清理后的文件名) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "file_url": fileURL, "file_type": fileType, "file_name": safeFilename, "file_size": file.Size, "mime_type": mimeType, }, }) } // isValidFileContent 验证文件内容是否与扩展名匹配(通过 magic number 检查) func isValidFileContent(ext string, magicBytes []byte) bool { if len(magicBytes) < 4 { return false } ext = strings.ToLower(ext) // 检查各种文件类型的 magic number switch ext { case ".jpg", ".jpeg": // JPEG: FF D8 FF return len(magicBytes) >= 3 && magicBytes[0] == 0xFF && magicBytes[1] == 0xD8 && magicBytes[2] == 0xFF case ".png": // PNG: 89 50 4E 47 return len(magicBytes) >= 4 && magicBytes[0] == 0x89 && magicBytes[1] == 0x50 && magicBytes[2] == 0x4E && magicBytes[3] == 0x47 case ".gif": // GIF: 47 49 46 38 (GIF8) return len(magicBytes) >= 4 && magicBytes[0] == 0x47 && magicBytes[1] == 0x49 && magicBytes[2] == 0x46 && magicBytes[3] == 0x38 case ".webp": // WebP: RIFF ... WEBP if len(magicBytes) >= 12 { return bytes.Equal(magicBytes[0:4], []byte("RIFF")) && bytes.Equal(magicBytes[8:12], []byte("WEBP")) } return false case ".pdf": // PDF: 25 50 44 46 (%PDF) return len(magicBytes) >= 4 && magicBytes[0] == 0x25 && magicBytes[1] == 0x50 && magicBytes[2] == 0x44 && magicBytes[3] == 0x46 case ".txt": // 文本文件:检查是否为可打印字符(ASCII 32-126)或 UTF-8 BOM // UTF-8 BOM: EF BB BF if len(magicBytes) >= 3 && magicBytes[0] == 0xEF && magicBytes[1] == 0xBB && magicBytes[2] == 0xBF { return true } // 检查前几个字节是否都是可打印字符 for i := 0; i < len(magicBytes) && i < 10; i++ { if magicBytes[i] < 0x20 && magicBytes[i] != 0x09 && magicBytes[i] != 0x0A && magicBytes[i] != 0x0D { // 不是可打印字符、制表符、换行符或回车符 return false } } return true case ".doc": // DOC (OLE2): D0 CF 11 E0 A1 B1 1A E1 return len(magicBytes) >= 8 && magicBytes[0] == 0xD0 && magicBytes[1] == 0xCF && magicBytes[2] == 0x11 && magicBytes[3] == 0xE0 case ".docx": // DOCX (ZIP): 50 4B 03 04 (PK..) return len(magicBytes) >= 4 && magicBytes[0] == 0x50 && magicBytes[1] == 0x4B && magicBytes[2] == 0x03 && magicBytes[3] == 0x04 default: return false } }