diff --git a/controller/image.go b/controller/image.go index d6e8806ad..32577c2ee 100644 --- a/controller/image.go +++ b/controller/image.go @@ -1,9 +1,12 @@ package controller import ( + "github.com/QuantumNous/new-api/service" + "github.com/gin-gonic/gin" ) -func GetImage(c *gin.Context) { - +// ServeImage 提供本地存储的图片文件服务 +func ServeImage(c *gin.Context) { + service.ServeImage(c) } diff --git a/relay/channel/openai/relay_image.go b/relay/channel/openai/relay_image.go index 3fb615f2d..39134482b 100644 --- a/relay/channel/openai/relay_image.go +++ b/relay/channel/openai/relay_image.go @@ -40,6 +40,25 @@ func OpenaiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } + // 下载远程图片到本地并替换 URL + var imageResp dto.ImageResponse + if common.Unmarshal(responseBody, &imageResp) == nil { + modified := false + for i := range imageResp.Data { + if imageResp.Data[i].Url != "" && !strings.HasPrefix(imageResp.Data[i].Url, "/api/images/") { + if filename, err := service.SaveRemoteImage(c, imageResp.Data[i].Url); err == nil && filename != "" { + imageResp.Data[i].Url = service.GetLocalImageURL(c, filename) + modified = true + } + } + } + if modified { + if newBody, err := common.Marshal(imageResp); err == nil { + responseBody = newBody + } + } + } + // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) @@ -182,12 +201,38 @@ func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp // writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk: // it emits an "event:" line derived from the JSON "type" field (when present) -// followed by the verbatim "data:" payload, mirroring helper.ResponseChunkData. +// followed by the "data:" payload. +// If the payload contains a remote image URL, it downloads the image to local +// storage and replaces the URL with a local path so the image persists after +// the upstream temporary URL expires. func writeOpenaiImageStreamChunk(c *gin.Context, data []byte) { var payload struct { - Type string `json:"type"` + Type string `json:"type"` + URL string `json:"url"` + B64Json string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` } _ = common.Unmarshal(data, &payload) + + // 如果包含远程 URL,下载图片到本地并替换 URL + if payload.URL != "" && !strings.HasPrefix(payload.URL, "/api/images/") { + filename, err := service.SaveRemoteImage(c, payload.URL) + if err == nil && filename != "" { + payload.URL = service.GetLocalImageURL(c, filename) + // 重新序列化 payload + // 保留原始 JSON 中的其他字段 + var rawMap map[string]json.RawMessage + if common.Unmarshal(data, &rawMap) == nil { + if urlBytes, err := common.Marshal(payload.URL); err == nil { + rawMap["url"] = urlBytes + } + if newData, err := common.Marshal(rawMap); err == nil { + data = newData + } + } + } + } + if eventName := strings.TrimSpace(payload.Type); eventName != "" { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", eventName)}) } @@ -282,6 +327,12 @@ func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, "created_at": created, } if image.Url != "" { + // 下载远程图片到本地并替换 URL + if !strings.HasPrefix(image.Url, "/api/images/") { + if filename, err := service.SaveRemoteImage(c, image.Url); err == nil && filename != "" { + image.Url = service.GetLocalImageURL(c, filename) + } + } payload["url"] = image.Url } if image.B64Json != "" { diff --git a/relay/image_handler.go b/relay/image_handler.go index 5427956cd..980ee3208 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -93,16 +93,15 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type statusCodeMappingStr := c.GetString("status_code_mapping") - // When the client requests streaming, send SSE headers immediately and - // start a periodic keepalive goroutine. Image generation can take 60+ - // seconds; without periodic data the connection will be closed by - // reverse-proxies (Nginx proxy_read_timeout) or the browser. + // 当客户端请求流式响应时,立即发送 SSE 头并启动周期性 keepalive 协程。 + // 图片生成可能耗时 60 秒以上,若期间无数据传输,反向代理(如 Nginx 的 + // proxy_read_timeout,默认 60 秒)或浏览器会关闭连接。 // - // IMPORTANT: The keepalive must span both DoRequest AND DoResponse. - // For SSE upstreams, DoRequest returns quickly (after receiving response - // headers), while DoResponse is where the actual waiting happens as it - // reads streaming events. Stopping the keepalive after DoRequest would - // leave the connection idle during DoResponse, causing proxy timeouts. + // 【重要】keepalive 必须覆盖 DoRequest 和 DoResponse 两个阶段。 + // 对于 SSE 上游,DoRequest 在收到响应头后即返回(SSE 先发 200 + headers, + // 事件数据稍后到达),真正的等待发生在 DoResponse 中——它逐行读取上游 + // SSE 事件。若在 DoRequest 后就停止 keepalive,DoResponse 期间连接将 + // 处于空闲状态,导致反向代理超时关闭连接(client_gone / context canceled)。 var keepaliveDone chan struct{} if info.IsStream { helper.SetEventStreamHeaders(c) @@ -127,8 +126,8 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type } resp, err := adaptor.DoRequest(c, info, requestBody) - // NOTE: do NOT close keepaliveDone here — DoRequest may return quickly - // for SSE upstreams while the real waiting happens in DoResponse. + // 注意:此处不能关闭 keepaliveDone!SSE 上游的 DoRequest 会快速返回, + // 真正的等待在 DoResponse 中,keepalive 必须持续到 DoResponse 完成。 if err != nil { if keepaliveDone != nil { close(keepaliveDone) @@ -164,7 +163,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type } usage, newAPIError := adaptor.DoResponse(c, httpResp, info) - // DoResponse has completed — now it is safe to stop the keepalive. + // DoResponse 已完成,现在可以安全地停止 keepalive 协程 if keepaliveDone != nil { close(keepaliveDone) } diff --git a/router/api-router.go b/router/api-router.go index e3231a15a..fc66f5818 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -64,6 +64,9 @@ func SetApiRouter(router *gin.Engine) { // Universal secure verification routes apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify) + // 图片文件服务(UUID 文件名保证安全性,无需认证,支持 标签直接访问) + apiRouter.GET("/images/:filename", controller.ServeImage) + userRoute := apiRouter.Group("/user") { userRoute.POST("/register", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, middleware.TurnstileCheck(), controller.Register) diff --git a/service/image_storage.go b/service/image_storage.go new file mode 100644 index 000000000..5e24b084a --- /dev/null +++ b/service/image_storage.go @@ -0,0 +1,224 @@ +package service + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/logger" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +const ( + imageDirName = "images" +) + +var ( + imageDirOnce sync.Once + imageDirPath string + imageDirInitErr error +) + +// getImageDir 获取图片存储目录路径 +// 优先级:DISK_CACHE_PATH/images > /data/images > 临时目录/images +func getImageDir() (string, error) { + imageDirOnce.Do(func() { + // 1. 优先使用磁盘缓存路径 + cachePath := common.GetDiskCachePath() + if cachePath == "" { + // 2. 尝试 /data 目录(Docker 部署时的持久化目录) + if info, err := os.Stat("/data"); err == nil && info.IsDir() { + cachePath = "/data" + } else { + // 3. 回退到临时目录 + cachePath = os.TempDir() + } + } + imageDirPath = filepath.Join(cachePath, imageDirName) + imageDirInitErr = os.MkdirAll(imageDirPath, 0755) + }) + return imageDirPath, imageDirInitErr +} + +// SaveRemoteImage 下载远程图片并保存到服务器本地存储。 +// 返回本地文件名(不含路径)和可能的错误。 +// 如果 remoteURL 为空或已经是本地路径,则直接返回不做处理。 +func SaveRemoteImage(c *gin.Context, remoteURL string) (string, error) { + if remoteURL == "" { + return "", nil + } + // 已经是本地路径则跳过 + if strings.HasPrefix(remoteURL, "/api/images/") { + return filepath.Base(remoteURL), nil + } + + // 下载远程图片 + resp, err := DoDownloadRequest(remoteURL, "image_storage") + if err != nil { + return "", fmt.Errorf("下载图片失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("下载图片失败, HTTP %d", resp.StatusCode) + } + + // 读取图片数据(限制 20MB) + imageData, err := io.ReadAll(io.LimitReader(resp.Body, 20*1024*1024+1)) + if err != nil { + return "", fmt.Errorf("读取图片数据失败: %w", err) + } + if len(imageData) > 20*1024*1024 { + return "", fmt.Errorf("图片大小超过 20MB 限制") + } + + // 检测图片格式并确定扩展名 + ext := detectImageExtension(resp, imageData) + + // 生成文件名 + filename := uuid.New().String() + ext + + // 保存到磁盘 + dir, err := getImageDir() + if err != nil { + return "", fmt.Errorf("获取图片目录失败: %w", err) + } + + filePath := filepath.Join(dir, filename) + if err := os.WriteFile(filePath, imageData, 0644); err != nil { + return "", fmt.Errorf("保存图片文件失败: %w", err) + } + + logger.LogInfo(c.Request.Context(), fmt.Sprintf("图片已保存到本地: %s (原始URL: %s)", filename, common.MaskSensitiveInfo(remoteURL))) + return filename, nil +} + +// detectImageExtension 从 HTTP 响应和图片数据中检测扩展名 +func detectImageExtension(resp *http.Response, data []byte) string { + // 1. 从 Content-Type 检测 + contentType := resp.Header.Get("Content-Type") + switch { + case strings.HasPrefix(contentType, "image/png"): + return ".png" + case strings.HasPrefix(contentType, "image/jpeg"), strings.HasPrefix(contentType, "image/jpg"): + return ".jpg" + case strings.HasPrefix(contentType, "image/gif"): + return ".gif" + case strings.HasPrefix(contentType, "image/webp"): + return ".webp" + } + + // 2. 从内容嗅探 + if len(data) > 0 { + sniffed := http.DetectContentType(data) + switch { + case strings.HasPrefix(sniffed, "image/png"): + return ".png" + case strings.HasPrefix(sniffed, "image/jpeg"): + return ".jpg" + case strings.HasPrefix(sniffed, "image/gif"): + return ".gif" + case strings.HasPrefix(sniffed, "image/webp"): + return ".webp" + } + } + + // 默认 png + return ".png" +} + +// GetLocalImageURL 根据文件名生成本地访问 URL +func GetLocalImageURL(c *gin.Context, filename string) string { + if filename == "" { + return "" + } + // 使用相对路径,避免硬编码域名和端口 + return "/api/images/" + filename +} + +// ServeImage 提供图片文件服务 +func ServeImage(c *gin.Context) { + filename := c.Param("filename") + if filename == "" { + c.JSON(http.StatusNotFound, gin.H{"error": "图片不存在"}) + return + } + + // 防止路径遍历攻击 + filename = filepath.Base(filename) + + dir, err := getImageDir() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "图片服务不可用"}) + return + } + + filePath := filepath.Join(dir, filename) + + // 检查文件是否存在 + info, err := os.Stat(filePath) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "图片不存在"}) + return + } + + // 设置缓存头:图片是静态资源,缓存 30 天 + c.Header("Cache-Control", "public, max-age=2592000") + c.Header("ETag", fmt.Sprintf(`"%x-%x"`, info.ModTime().Unix(), info.Size())) + + // 设置 Content-Type + ext := strings.ToLower(filepath.Ext(filename)) + switch ext { + case ".png": + c.Header("Content-Type", "image/png") + case ".jpg", ".jpeg": + c.Header("Content-Type", "image/jpeg") + case ".gif": + c.Header("Content-Type", "image/gif") + case ".webp": + c.Header("Content-Type", "image/webp") + default: + c.Header("Content-Type", "application/octet-stream") + } + + c.File(filePath) +} + +// CleanupOldImages 清理超过 maxAge 的图片文件 +func CleanupOldImages(maxAge time.Duration) error { + dir, err := getImageDir() + if err != nil { + return err + } + + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + now := time.Now() + for _, entry := range entries { + if entry.IsDir() { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + if now.Sub(info.ModTime()) > maxAge { + os.Remove(filepath.Join(dir, entry.Name())) + } + } + return nil +}