mirror of
https://github.com/2930134478/AI-CS.git
synced 2026-06-15 00:44:30 +08:00
修复文件上传接口
This commit is contained in:
@@ -4,7 +4,10 @@ import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/2930134478/AI-CS/backend/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -12,14 +15,14 @@ import (
|
||||
|
||||
// ImportController 导入控制器
|
||||
type ImportController struct {
|
||||
importService *service.ImportService
|
||||
importService *service.ImportService
|
||||
embeddingConfigService *service.EmbeddingConfigService
|
||||
}
|
||||
|
||||
// NewImportController 创建导入控制器实例
|
||||
func NewImportController(importService *service.ImportService, embeddingConfigService *service.EmbeddingConfigService) *ImportController {
|
||||
return &ImportController{
|
||||
importService: importService,
|
||||
importService: importService,
|
||||
embeddingConfigService: embeddingConfigService,
|
||||
}
|
||||
}
|
||||
@@ -27,7 +30,9 @@ func NewImportController(importService *service.ImportService, embeddingConfigSe
|
||||
func (c *ImportController) checkKBAccess(ctx *gin.Context) bool {
|
||||
userID := getUserIDFromHeader(ctx)
|
||||
if userID == 0 {
|
||||
return true
|
||||
// ⚠️ 修复:改为拒绝访问,而不是允许
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头"})
|
||||
return false
|
||||
}
|
||||
if err := c.embeddingConfigService.CheckKnowledgeBaseAccess(userID); err != nil {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
@@ -67,11 +72,37 @@ func (c *ImportController) ImportDocuments(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// ⚠️ 添加:文件类型验证
|
||||
allowedExts := map[string]bool{
|
||||
".md": true,
|
||||
".txt": true,
|
||||
".pdf": true,
|
||||
".doc": true,
|
||||
".docx": true,
|
||||
}
|
||||
|
||||
// 保存文件到临时目录
|
||||
filePaths := make([]string, 0, len(files))
|
||||
for _, file := range files {
|
||||
// ⚠️ 添加:验证文件类型
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
if !allowedExts[ext] {
|
||||
log.Printf("不支持的文件类型: %s (扩展名: %s)", file.Filename, ext)
|
||||
continue
|
||||
}
|
||||
|
||||
// ⚠️ 添加:清理文件名,防止路径遍历攻击
|
||||
safeFilename := filepath.Base(file.Filename)
|
||||
safeFilename = strings.ReplaceAll(safeFilename, "..", "")
|
||||
safeFilename = strings.ReplaceAll(safeFilename, "/", "")
|
||||
safeFilename = strings.ReplaceAll(safeFilename, "\\", "")
|
||||
// 限制文件名长度
|
||||
if len(safeFilename) > 255 {
|
||||
safeFilename = safeFilename[:255]
|
||||
}
|
||||
|
||||
// 保存文件
|
||||
filePath := "/tmp/" + file.Filename
|
||||
filePath := "/tmp/" + safeFilename
|
||||
if err := ctx.SaveUploadedFile(file, filePath); err != nil {
|
||||
log.Printf("保存文件失败: %v", err)
|
||||
continue
|
||||
@@ -79,6 +110,20 @@ func (c *ImportController) ImportDocuments(ctx *gin.Context) {
|
||||
filePaths = append(filePaths, filePath)
|
||||
}
|
||||
|
||||
if len(filePaths) == 0 {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "没有有效的文件(所有文件都被拒绝或保存失败)"})
|
||||
return
|
||||
}
|
||||
|
||||
// ⚠️ 添加:导入后清理临时文件
|
||||
defer func() {
|
||||
for _, path := range filePaths {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Printf("清理临时文件失败: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 导入文件
|
||||
result, err := c.importService.ImportFiles(context.Background(), uint(kbID), filePaths)
|
||||
if err != nil {
|
||||
@@ -98,7 +143,7 @@ func (c *ImportController) ImportFromURLs(ctx *gin.Context) {
|
||||
}
|
||||
var req struct {
|
||||
KnowledgeBaseID uint `json:"knowledge_base_id" binding:"required"`
|
||||
URLs []string `json:"urls" binding:"required"`
|
||||
URLs []string `json:"urls" binding:"required"`
|
||||
}
|
||||
|
||||
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
@@ -14,15 +16,17 @@ import (
|
||||
|
||||
// MessageController 负责处理消息相关的 HTTP 请求。
|
||||
type MessageController struct {
|
||||
messageService *service.MessageService
|
||||
storageService infra.StorageService
|
||||
messageService *service.MessageService
|
||||
conversationService *service.ConversationService
|
||||
storageService infra.StorageService
|
||||
}
|
||||
|
||||
// NewMessageController 创建 MessageController 实例。
|
||||
func NewMessageController(messageService *service.MessageService, storageService infra.StorageService) *MessageController {
|
||||
func NewMessageController(messageService *service.MessageService, conversationService *service.ConversationService, storageService infra.StorageService) *MessageController {
|
||||
return &MessageController{
|
||||
messageService: messageService,
|
||||
storageService: storageService,
|
||||
messageService: messageService,
|
||||
conversationService: conversationService,
|
||||
storageService: storageService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +145,41 @@ func (mc *MessageController) MarkMessagesRead(c *gin.Context) {
|
||||
// 请求格式:multipart/form-data
|
||||
// - file: 文件内容(必需)
|
||||
// - conversation_id: 对话ID(可选,用于组织目录)
|
||||
// 认证方式:
|
||||
// - 方式1:提供 X-User-Id 请求头(客服上传)
|
||||
// - 方式2:提供 conversation_id 参数(访客上传,会验证对话是否存在且未关闭)
|
||||
func (mc *MessageController) UploadFile(c *gin.Context) {
|
||||
// ⚠️ 认证检查:必须满足以下条件之一
|
||||
// 1. 提供 X-User-Id 请求头(客服)
|
||||
// 2. 提供 conversation_id 参数(访客)
|
||||
userID := getUserIDFromHeader(c)
|
||||
conversationIDStr := c.PostForm("conversation_id")
|
||||
|
||||
// 如果既没有用户ID,也没有对话ID,拒绝访问
|
||||
if userID == 0 && conversationIDStr == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头或 conversation_id 参数"})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是访客上传(没有用户ID,但有对话ID),验证对话是否存在且未关闭
|
||||
if userID == 0 && conversationIDStr != "" {
|
||||
convID, err := strconv.ParseUint(conversationIDStr, 10, 64)
|
||||
if err != nil || convID == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "对话ID不合法"})
|
||||
return
|
||||
}
|
||||
// 验证对话是否存在且未关闭
|
||||
conv, err := mc.conversationService.GetConversationDetail(uint(convID), 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "对话不存在或已关闭"})
|
||||
return
|
||||
}
|
||||
if conv.Status == "closed" {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "对话已关闭"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 解析文件
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
@@ -156,7 +194,7 @@ func (mc *MessageController) UploadFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 验证文件类型
|
||||
// ⚠️ 加强:验证文件类型(扩展名)
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
allowedExts := map[string]bool{
|
||||
".jpg": true,
|
||||
@@ -174,25 +212,85 @@ func (mc *MessageController) UploadFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取对话ID(可选)
|
||||
// ⚠️ 加强:验证 MIME 类型(防止伪造扩展名)
|
||||
mimeType := file.Header.Get("Content-Type")
|
||||
allowedMimeTypes := map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"application/pdf": true,
|
||||
"application/msword": true,
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true, // .docx
|
||||
"text/plain": true,
|
||||
}
|
||||
if !allowedMimeTypes[mimeType] {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的文件 MIME 类型: " + mimeType})
|
||||
return
|
||||
}
|
||||
|
||||
// ⚠️ 加强:清理文件名,防止路径遍历攻击
|
||||
safeFilename := filepath.Base(file.Filename)
|
||||
safeFilename = strings.ReplaceAll(safeFilename, "..", "")
|
||||
safeFilename = strings.ReplaceAll(safeFilename, "/", "")
|
||||
safeFilename = strings.ReplaceAll(safeFilename, "\\", "")
|
||||
// 移除所有非字母数字、点、下划线、连字符的字符
|
||||
var cleaned strings.Builder
|
||||
for _, r := range safeFilename {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
|
||||
cleaned.WriteRune(r)
|
||||
}
|
||||
}
|
||||
safeFilename = cleaned.String()
|
||||
// 限制文件名长度
|
||||
if len(safeFilename) > 100 {
|
||||
// 保留扩展名
|
||||
ext := filepath.Ext(safeFilename)
|
||||
nameWithoutExt := strings.TrimSuffix(safeFilename, ext)
|
||||
if len(nameWithoutExt) > 100-len(ext) {
|
||||
safeFilename = nameWithoutExt[:100-len(ext)] + ext
|
||||
}
|
||||
}
|
||||
|
||||
// ⚠️ 加强:验证文件内容(magic number 检查,防止伪造扩展名)
|
||||
fileContent, err := file.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件"})
|
||||
return
|
||||
}
|
||||
defer fileContent.Close()
|
||||
|
||||
// 读取文件前几个字节(magic number)
|
||||
magicBytes := make([]byte, 12)
|
||||
n, err := fileContent.Read(magicBytes)
|
||||
if err != nil && err != io.EOF {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证文件内容是否匹配扩展名
|
||||
if !isValidFileContent(ext, magicBytes[:n]) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "文件内容与扩展名不匹配,可能是伪造的文件类型"})
|
||||
return
|
||||
}
|
||||
|
||||
// 重置文件指针,以便后续保存
|
||||
if _, err := fileContent.Seek(0, io.SeekStart); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "无法重置文件指针"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取对话ID(如果之前已经解析过,直接使用;否则从表单获取)
|
||||
var conversationID uint
|
||||
if conversationIDStr := c.PostForm("conversation_id"); conversationIDStr != "" {
|
||||
if conversationIDStr != "" {
|
||||
if id, err := strconv.ParseUint(conversationIDStr, 10, 64); err == nil {
|
||||
conversationID = uint(id)
|
||||
}
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
log.Printf("❌ 打开文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "打开文件失败"})
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 保存文件
|
||||
fileURL, err := mc.storageService.SaveMessageFile(conversationID, src, file.Filename)
|
||||
// 保存文件(使用清理后的文件名,fileContent 已经在上面打开并验证过)
|
||||
fileURL, err := mc.storageService.SaveMessageFile(conversationID, fileContent, safeFilename)
|
||||
if err != nil {
|
||||
log.Printf("❌ 保存文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存文件失败"})
|
||||
@@ -201,20 +299,72 @@ func (mc *MessageController) UploadFile(c *gin.Context) {
|
||||
|
||||
// 判断文件类型
|
||||
fileType := "document"
|
||||
mimeType := file.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(mimeType, "image/") {
|
||||
fileType = "image"
|
||||
}
|
||||
|
||||
// 返回文件信息
|
||||
// 返回文件信息(使用清理后的文件名)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"file_url": fileURL,
|
||||
"file_type": fileType,
|
||||
"file_name": file.Filename,
|
||||
"file_name": safeFilename,
|
||||
"file_size": file.Size,
|
||||
"mime_type": mimeType,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// isValidFileContent 验证文件内容是否与扩展名匹配(通过 magic number 检查)
|
||||
func isValidFileContent(ext string, magicBytes []byte) bool {
|
||||
if len(magicBytes) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
ext = strings.ToLower(ext)
|
||||
|
||||
// 检查各种文件类型的 magic number
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
// JPEG: FF D8 FF
|
||||
return len(magicBytes) >= 3 && magicBytes[0] == 0xFF && magicBytes[1] == 0xD8 && magicBytes[2] == 0xFF
|
||||
case ".png":
|
||||
// PNG: 89 50 4E 47
|
||||
return len(magicBytes) >= 4 && magicBytes[0] == 0x89 && magicBytes[1] == 0x50 && magicBytes[2] == 0x4E && magicBytes[3] == 0x47
|
||||
case ".gif":
|
||||
// GIF: 47 49 46 38 (GIF8)
|
||||
return len(magicBytes) >= 4 && magicBytes[0] == 0x47 && magicBytes[1] == 0x49 && magicBytes[2] == 0x46 && magicBytes[3] == 0x38
|
||||
case ".webp":
|
||||
// WebP: RIFF ... WEBP
|
||||
if len(magicBytes) >= 12 {
|
||||
return bytes.Equal(magicBytes[0:4], []byte("RIFF")) && bytes.Equal(magicBytes[8:12], []byte("WEBP"))
|
||||
}
|
||||
return false
|
||||
case ".pdf":
|
||||
// PDF: 25 50 44 46 (%PDF)
|
||||
return len(magicBytes) >= 4 && magicBytes[0] == 0x25 && magicBytes[1] == 0x50 && magicBytes[2] == 0x44 && magicBytes[3] == 0x46
|
||||
case ".txt":
|
||||
// 文本文件:检查是否为可打印字符(ASCII 32-126)或 UTF-8 BOM
|
||||
// UTF-8 BOM: EF BB BF
|
||||
if len(magicBytes) >= 3 && magicBytes[0] == 0xEF && magicBytes[1] == 0xBB && magicBytes[2] == 0xBF {
|
||||
return true
|
||||
}
|
||||
// 检查前几个字节是否都是可打印字符
|
||||
for i := 0; i < len(magicBytes) && i < 10; i++ {
|
||||
if magicBytes[i] < 0x20 && magicBytes[i] != 0x09 && magicBytes[i] != 0x0A && magicBytes[i] != 0x0D {
|
||||
// 不是可打印字符、制表符、换行符或回车符
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case ".doc":
|
||||
// DOC (OLE2): D0 CF 11 E0 A1 B1 1A E1
|
||||
return len(magicBytes) >= 8 && magicBytes[0] == 0xD0 && magicBytes[1] == 0xCF && magicBytes[2] == 0x11 && magicBytes[3] == 0xE0
|
||||
case ".docx":
|
||||
// DOCX (ZIP): 50 4B 03 04 (PK..)
|
||||
return len(magicBytes) >= 4 && magicBytes[0] == 0x50 && magicBytes[1] == 0x4B && magicBytes[2] == 0x03 && magicBytes[3] == 0x04
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -307,7 +307,7 @@ func main() {
|
||||
// 初始化控制器
|
||||
authController := controller.NewAuthController(authService)
|
||||
conversationController := controller.NewConversationController(conversationService, aiConfigService)
|
||||
messageController := controller.NewMessageController(messageService, storageService)
|
||||
messageController := controller.NewMessageController(messageService, conversationService, storageService)
|
||||
adminController := controller.NewAdminController(authService, userService)
|
||||
profileController := controller.NewProfileController(profileService)
|
||||
aiConfigController := controller.NewAIConfigController(aiConfigService)
|
||||
|
||||
@@ -2,6 +2,8 @@ package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
@@ -26,3 +28,26 @@ func CORS() gin.HandlerFunc {
|
||||
AllowCredentials: false,
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAuth 认证中间件:要求请求头中包含有效的 X-User-Id
|
||||
func RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userIDStr := c.GetHeader("X-User-Id")
|
||||
if userIDStr == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseUint(userIDStr, 10, 64)
|
||||
if err != nil || userID == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户ID不合法"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户ID存储到上下文中,供后续使用
|
||||
c.Set("user_id", uint(userID))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func RegisterRoutes(r *gin.Engine, controllers ControllerSet, wsHandler gin.Hand
|
||||
|
||||
// Message
|
||||
r.POST("/messages", controllers.Message.CreateMessage)
|
||||
r.POST("/messages/upload", controllers.Message.UploadFile) // 文件上传接口
|
||||
r.POST("/messages/upload", controllers.Message.UploadFile) // 文件上传接口(支持客服和访客上传)
|
||||
r.GET("/messages", controllers.Message.ListMessages)
|
||||
r.PUT("/messages/read", controllers.Message.MarkMessagesRead)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user