修复文件上传接口

This commit is contained in:
537yaha
2026-02-27 12:49:18 +08:00
parent 9c981dd20b
commit 48cfccf422
5 changed files with 249 additions and 29 deletions
+50 -5
View File
@@ -4,7 +4,10 @@ import (
"context"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/2930134478/AI-CS/backend/service"
"github.com/gin-gonic/gin"
@@ -12,14 +15,14 @@ import (
// ImportController 导入控制器
type ImportController struct {
importService *service.ImportService
importService *service.ImportService
embeddingConfigService *service.EmbeddingConfigService
}
// NewImportController 创建导入控制器实例
func NewImportController(importService *service.ImportService, embeddingConfigService *service.EmbeddingConfigService) *ImportController {
return &ImportController{
importService: importService,
importService: importService,
embeddingConfigService: embeddingConfigService,
}
}
@@ -27,7 +30,9 @@ func NewImportController(importService *service.ImportService, embeddingConfigSe
func (c *ImportController) checkKBAccess(ctx *gin.Context) bool {
userID := getUserIDFromHeader(ctx)
if userID == 0 {
return true
// ⚠️ 修复:改为拒绝访问,而不是允许
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头"})
return false
}
if err := c.embeddingConfigService.CheckKnowledgeBaseAccess(userID); err != nil {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
@@ -67,11 +72,37 @@ func (c *ImportController) ImportDocuments(ctx *gin.Context) {
return
}
// ⚠️ 添加:文件类型验证
allowedExts := map[string]bool{
".md": true,
".txt": true,
".pdf": true,
".doc": true,
".docx": true,
}
// 保存文件到临时目录
filePaths := make([]string, 0, len(files))
for _, file := range files {
// ⚠️ 添加:验证文件类型
ext := strings.ToLower(filepath.Ext(file.Filename))
if !allowedExts[ext] {
log.Printf("不支持的文件类型: %s (扩展名: %s)", file.Filename, ext)
continue
}
// ⚠️ 添加:清理文件名,防止路径遍历攻击
safeFilename := filepath.Base(file.Filename)
safeFilename = strings.ReplaceAll(safeFilename, "..", "")
safeFilename = strings.ReplaceAll(safeFilename, "/", "")
safeFilename = strings.ReplaceAll(safeFilename, "\\", "")
// 限制文件名长度
if len(safeFilename) > 255 {
safeFilename = safeFilename[:255]
}
// 保存文件
filePath := "/tmp/" + file.Filename
filePath := "/tmp/" + safeFilename
if err := ctx.SaveUploadedFile(file, filePath); err != nil {
log.Printf("保存文件失败: %v", err)
continue
@@ -79,6 +110,20 @@ func (c *ImportController) ImportDocuments(ctx *gin.Context) {
filePaths = append(filePaths, filePath)
}
if len(filePaths) == 0 {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "没有有效的文件(所有文件都被拒绝或保存失败)"})
return
}
// ⚠️ 添加:导入后清理临时文件
defer func() {
for _, path := range filePaths {
if err := os.Remove(path); err != nil {
log.Printf("清理临时文件失败: %v", err)
}
}
}()
// 导入文件
result, err := c.importService.ImportFiles(context.Background(), uint(kbID), filePaths)
if err != nil {
@@ -98,7 +143,7 @@ func (c *ImportController) ImportFromURLs(ctx *gin.Context) {
}
var req struct {
KnowledgeBaseID uint `json:"knowledge_base_id" binding:"required"`
URLs []string `json:"urls" binding:"required"`
URLs []string `json:"urls" binding:"required"`
}
if err := ctx.ShouldBindJSON(&req); err != nil {
+172 -22
View File
@@ -1,6 +1,8 @@
package controller
import (
"bytes"
"io"
"log"
"net/http"
"path/filepath"
@@ -14,15 +16,17 @@ import (
// MessageController 负责处理消息相关的 HTTP 请求。
type MessageController struct {
messageService *service.MessageService
storageService infra.StorageService
messageService *service.MessageService
conversationService *service.ConversationService
storageService infra.StorageService
}
// NewMessageController 创建 MessageController 实例。
func NewMessageController(messageService *service.MessageService, storageService infra.StorageService) *MessageController {
func NewMessageController(messageService *service.MessageService, conversationService *service.ConversationService, storageService infra.StorageService) *MessageController {
return &MessageController{
messageService: messageService,
storageService: storageService,
messageService: messageService,
conversationService: conversationService,
storageService: storageService,
}
}
@@ -141,7 +145,41 @@ func (mc *MessageController) MarkMessagesRead(c *gin.Context) {
// 请求格式:multipart/form-data
// - file: 文件内容(必需)
// - conversation_id: 对话ID(可选,用于组织目录)
// 认证方式:
// - 方式1:提供 X-User-Id 请求头(客服上传)
// - 方式2:提供 conversation_id 参数(访客上传,会验证对话是否存在且未关闭)
func (mc *MessageController) UploadFile(c *gin.Context) {
// ⚠️ 认证检查:必须满足以下条件之一
// 1. 提供 X-User-Id 请求头(客服)
// 2. 提供 conversation_id 参数(访客)
userID := getUserIDFromHeader(c)
conversationIDStr := c.PostForm("conversation_id")
// 如果既没有用户ID,也没有对话ID,拒绝访问
if userID == 0 && conversationIDStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头或 conversation_id 参数"})
return
}
// 如果是访客上传(没有用户ID,但有对话ID),验证对话是否存在且未关闭
if userID == 0 && conversationIDStr != "" {
convID, err := strconv.ParseUint(conversationIDStr, 10, 64)
if err != nil || convID == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "对话ID不合法"})
return
}
// 验证对话是否存在且未关闭
conv, err := mc.conversationService.GetConversationDetail(uint(convID), 0)
if err != nil {
c.JSON(http.StatusForbidden, gin.H{"error": "对话不存在或已关闭"})
return
}
if conv.Status == "closed" {
c.JSON(http.StatusForbidden, gin.H{"error": "对话已关闭"})
return
}
}
// 解析文件
file, err := c.FormFile("file")
if err != nil {
@@ -156,7 +194,7 @@ func (mc *MessageController) UploadFile(c *gin.Context) {
return
}
// 验证文件类型
// ⚠️ 加强:验证文件类型(扩展名)
ext := strings.ToLower(filepath.Ext(file.Filename))
allowedExts := map[string]bool{
".jpg": true,
@@ -174,25 +212,85 @@ func (mc *MessageController) UploadFile(c *gin.Context) {
return
}
// 获取对话ID(可选
// ⚠️ 加强:验证 MIME 类型(防止伪造扩展名
mimeType := file.Header.Get("Content-Type")
allowedMimeTypes := map[string]bool{
"image/jpeg": true,
"image/jpg": true,
"image/png": true,
"image/gif": true,
"image/webp": true,
"application/pdf": true,
"application/msword": true,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true, // .docx
"text/plain": true,
}
if !allowedMimeTypes[mimeType] {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的文件 MIME 类型: " + mimeType})
return
}
// ⚠️ 加强:清理文件名,防止路径遍历攻击
safeFilename := filepath.Base(file.Filename)
safeFilename = strings.ReplaceAll(safeFilename, "..", "")
safeFilename = strings.ReplaceAll(safeFilename, "/", "")
safeFilename = strings.ReplaceAll(safeFilename, "\\", "")
// 移除所有非字母数字、点、下划线、连字符的字符
var cleaned strings.Builder
for _, r := range safeFilename {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
cleaned.WriteRune(r)
}
}
safeFilename = cleaned.String()
// 限制文件名长度
if len(safeFilename) > 100 {
// 保留扩展名
ext := filepath.Ext(safeFilename)
nameWithoutExt := strings.TrimSuffix(safeFilename, ext)
if len(nameWithoutExt) > 100-len(ext) {
safeFilename = nameWithoutExt[:100-len(ext)] + ext
}
}
// ⚠️ 加强:验证文件内容(magic number 检查,防止伪造扩展名)
fileContent, err := file.Open()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件"})
return
}
defer fileContent.Close()
// 读取文件前几个字节(magic number
magicBytes := make([]byte, 12)
n, err := fileContent.Read(magicBytes)
if err != nil && err != io.EOF {
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容"})
return
}
// 验证文件内容是否匹配扩展名
if !isValidFileContent(ext, magicBytes[:n]) {
c.JSON(http.StatusBadRequest, gin.H{"error": "文件内容与扩展名不匹配,可能是伪造的文件类型"})
return
}
// 重置文件指针,以便后续保存
if _, err := fileContent.Seek(0, io.SeekStart); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "无法重置文件指针"})
return
}
// 获取对话ID(如果之前已经解析过,直接使用;否则从表单获取)
var conversationID uint
if conversationIDStr := c.PostForm("conversation_id"); conversationIDStr != "" {
if conversationIDStr != "" {
if id, err := strconv.ParseUint(conversationIDStr, 10, 64); err == nil {
conversationID = uint(id)
}
}
// 打开文件
src, err := file.Open()
if err != nil {
log.Printf("❌ 打开文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "打开文件失败"})
return
}
defer src.Close()
// 保存文件
fileURL, err := mc.storageService.SaveMessageFile(conversationID, src, file.Filename)
// 保存文件(使用清理后的文件名,fileContent 已经在上面打开并验证过)
fileURL, err := mc.storageService.SaveMessageFile(conversationID, fileContent, safeFilename)
if err != nil {
log.Printf("❌ 保存文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存文件失败"})
@@ -201,20 +299,72 @@ func (mc *MessageController) UploadFile(c *gin.Context) {
// 判断文件类型
fileType := "document"
mimeType := file.Header.Get("Content-Type")
if strings.HasPrefix(mimeType, "image/") {
fileType = "image"
}
// 返回文件信息
// 返回文件信息(使用清理后的文件名)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"file_url": fileURL,
"file_type": fileType,
"file_name": file.Filename,
"file_name": safeFilename,
"file_size": file.Size,
"mime_type": mimeType,
},
})
}
// isValidFileContent 验证文件内容是否与扩展名匹配(通过 magic number 检查)
func isValidFileContent(ext string, magicBytes []byte) bool {
if len(magicBytes) < 4 {
return false
}
ext = strings.ToLower(ext)
// 检查各种文件类型的 magic number
switch ext {
case ".jpg", ".jpeg":
// JPEG: FF D8 FF
return len(magicBytes) >= 3 && magicBytes[0] == 0xFF && magicBytes[1] == 0xD8 && magicBytes[2] == 0xFF
case ".png":
// PNG: 89 50 4E 47
return len(magicBytes) >= 4 && magicBytes[0] == 0x89 && magicBytes[1] == 0x50 && magicBytes[2] == 0x4E && magicBytes[3] == 0x47
case ".gif":
// GIF: 47 49 46 38 (GIF8)
return len(magicBytes) >= 4 && magicBytes[0] == 0x47 && magicBytes[1] == 0x49 && magicBytes[2] == 0x46 && magicBytes[3] == 0x38
case ".webp":
// WebP: RIFF ... WEBP
if len(magicBytes) >= 12 {
return bytes.Equal(magicBytes[0:4], []byte("RIFF")) && bytes.Equal(magicBytes[8:12], []byte("WEBP"))
}
return false
case ".pdf":
// PDF: 25 50 44 46 (%PDF)
return len(magicBytes) >= 4 && magicBytes[0] == 0x25 && magicBytes[1] == 0x50 && magicBytes[2] == 0x44 && magicBytes[3] == 0x46
case ".txt":
// 文本文件:检查是否为可打印字符(ASCII 32-126)或 UTF-8 BOM
// UTF-8 BOM: EF BB BF
if len(magicBytes) >= 3 && magicBytes[0] == 0xEF && magicBytes[1] == 0xBB && magicBytes[2] == 0xBF {
return true
}
// 检查前几个字节是否都是可打印字符
for i := 0; i < len(magicBytes) && i < 10; i++ {
if magicBytes[i] < 0x20 && magicBytes[i] != 0x09 && magicBytes[i] != 0x0A && magicBytes[i] != 0x0D {
// 不是可打印字符、制表符、换行符或回车符
return false
}
}
return true
case ".doc":
// DOC (OLE2): D0 CF 11 E0 A1 B1 1A E1
return len(magicBytes) >= 8 && magicBytes[0] == 0xD0 && magicBytes[1] == 0xCF && magicBytes[2] == 0x11 && magicBytes[3] == 0xE0
case ".docx":
// DOCX (ZIP): 50 4B 03 04 (PK..)
return len(magicBytes) >= 4 && magicBytes[0] == 0x50 && magicBytes[1] == 0x4B && magicBytes[2] == 0x03 && magicBytes[3] == 0x04
default:
return false
}
}
+1 -1
View File
@@ -307,7 +307,7 @@ func main() {
// 初始化控制器
authController := controller.NewAuthController(authService)
conversationController := controller.NewConversationController(conversationService, aiConfigService)
messageController := controller.NewMessageController(messageService, storageService)
messageController := controller.NewMessageController(messageService, conversationService, storageService)
adminController := controller.NewAdminController(authService, userService)
profileController := controller.NewProfileController(profileService)
aiConfigController := controller.NewAIConfigController(aiConfigService)
+25
View File
@@ -2,6 +2,8 @@ package middleware
import (
"log"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/cors"
@@ -26,3 +28,26 @@ func CORS() gin.HandlerFunc {
AllowCredentials: false,
})
}
// RequireAuth 认证中间件:要求请求头中包含有效的 X-User-Id
func RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
userIDStr := c.GetHeader("X-User-Id")
if userIDStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权访问,请提供 X-User-Id 请求头"})
c.Abort()
return
}
userID, err := strconv.ParseUint(userIDStr, 10, 64)
if err != nil || userID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户ID不合法"})
c.Abort()
return
}
// 将用户ID存储到上下文中,供后续使用
c.Set("user_id", uint(userID))
c.Next()
}
}
+1 -1
View File
@@ -39,7 +39,7 @@ func RegisterRoutes(r *gin.Engine, controllers ControllerSet, wsHandler gin.Hand
// Message
r.POST("/messages", controllers.Message.CreateMessage)
r.POST("/messages/upload", controllers.Message.UploadFile) // 文件上传接口
r.POST("/messages/upload", controllers.Message.UploadFile) // 文件上传接口(支持客服和访客上传)
r.GET("/messages", controllers.Message.ListMessages)
r.PUT("/messages/read", controllers.Message.MarkMessagesRead)