Files
new-api/relay/image_handler.go
T
admin ced375076c
Docker Build / Build and Push Docker Image (push) Successful in 3m57s
feat: save generated images to local server storage
2026-06-22 02:15:38 +08:00

297 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
imageReq, ok := info.Request.(*dto.ImageRequest)
if !ok {
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
request, err := common.DeepCopy(imageReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
storage, err := common.GetBodyStorage(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = common.ReaderOnly(storage)
} else {
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
switch convertedRequest.(type) {
case *bytes.Buffer:
requestBody = convertedRequest.(io.Reader)
default:
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}
}
logger.LogDebug(c, "image request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
}
statusCodeMappingStr := c.GetString("status_code_mapping")
// 当客户端请求流式响应时,立即发送 SSE 头并启动周期性 keepalive 协程。
// 图片生成可能耗时 60 秒以上,若期间无数据传输,反向代理(如 Nginx 的
// proxy_read_timeout,默认 60 秒)或浏览器会关闭连接。
//
// 【重要】keepalive 必须覆盖 DoRequest 和 DoResponse 两个阶段。
// 对于 SSE 上游,DoRequest 在收到响应头后即返回(SSE 先发 200 + headers
// 事件数据稍后到达),真正的等待发生在 DoResponse 中——它逐行读取上游
// SSE 事件。若在 DoRequest 后就停止 keepaliveDoResponse 期间连接将
// 处于空闲状态,导致反向代理超时关闭连接(client_gone / context canceled)。
var keepaliveDone chan struct{}
if info.IsStream {
helper.SetEventStreamHeaders(c)
c.Status(http.StatusOK)
fmt.Fprint(c.Writer, ": keepalive\n\n")
_ = helper.FlushWriter(c)
keepaliveDone = make(chan struct{})
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fmt.Fprint(c.Writer, ": keepalive\n\n")
_ = helper.FlushWriter(c)
case <-keepaliveDone:
return
}
}
}()
}
resp, err := adaptor.DoRequest(c, info, requestBody)
// 注意:此处不能关闭 keepaliveDoneSSE 上游的 DoRequest 会快速返回,
// 真正的等待在 DoResponse 中,keepalive 必须持续到 DoResponse 完成。
if err != nil {
if keepaliveDone != nil {
close(keepaliveDone)
}
logImageError(c, info, request, fmt.Sprintf("请求失败: %s", err.Error()))
if info.IsStream {
writeImageStreamError(c, err.Error())
}
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
if httpResp.StatusCode == http.StatusCreated && info.ApiType == constant.APITypeReplicate {
// replicate channel returns 201 Created when using Prefer: wait, treat it as success.
httpResp.StatusCode = http.StatusOK
} else {
if keepaliveDone != nil {
close(keepaliveDone)
}
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
logImageError(c, info, request, fmt.Sprintf("上游返回错误 (HTTP %d): %s", newAPIError.StatusCode, newAPIError.Error()))
if info.IsStream {
writeImageStreamError(c, newAPIError.Error())
}
return newAPIError
}
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
// DoResponse 已完成,现在可以安全地停止 keepalive 协程
if keepaliveDone != nil {
close(keepaliveDone)
}
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
logImageError(c, info, request, fmt.Sprintf("响应处理失败: %s", newAPIError.Error()))
return newAPIError
}
imageN := uint(1)
if request.N != nil {
imageN = *request.N
}
// n is handled via OtherRatio so it is applied exactly once in quota
// calculation (both price-based and ratio-based paths).
// Adaptors may have already set a more accurate count from the
// upstream response; only set the default when they haven't.
if info.PriceData.UsePrice { // only price model use N ratio
if _, hasN := info.PriceData.OtherRatios["n"]; !hasN {
info.PriceData.AddOtherRatio("n", float64(imageN))
}
}
if usage.(*dto.Usage).TotalTokens == 0 {
usage.(*dto.Usage).TotalTokens = 1
}
if usage.(*dto.Usage).PromptTokens == 0 {
usage.(*dto.Usage).PromptTokens = 1
}
quality := request.Quality
if quality == "" {
quality = "standard"
}
var logContent []string
if len(request.Size) > 0 {
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
}
if len(quality) > 0 {
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
}
if imageN > 0 {
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
}
// If the model is configured for per-size billing, inject the size tier
// and image count into the gin context so PostTextConsumeQuota can apply
// the flat per-image surcharge instead of token billing.
if operation_setting.IsImagePerSizeBilling(info.OriginModelName) {
sizeTier, ok := operation_setting.ClassifyImageSizeTier(request.Size)
if !ok {
sizeTier = operation_setting.ImageSizeTier2K // default to 2K when unknown
}
c.Set("image_per_size_billing", true)
c.Set("image_size_tier", sizeTier)
c.Set("image_per_size_count", int(imageN))
logContent = append(logContent, fmt.Sprintf("分辨率档位 %s", sizeTier))
}
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), logContent)
return nil
}
// logImageError records a consume log for failed image generation requests
// so that users can see the failure reason in their usage logs.
func logImageError(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ImageRequest, errMsg string) {
if info == nil {
return
}
var logContent []string
if len(request.Size) > 0 {
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
}
quality := request.Quality
if quality == "" {
quality = "standard"
}
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
imageN := 1
if request.N != nil {
imageN = int(*request.N)
}
if imageN > 0 {
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
}
logContent = append(logContent, fmt.Sprintf("失败: %s", errMsg))
other := make(map[string]interface{})
if c.Request != nil && c.Request.URL != nil {
other["request_path"] = c.Request.URL.Path
}
other["is_image_error"] = true
other["error_message"] = errMsg
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: info.OriginModelName,
TokenName: c.GetString("token_name"),
TokenId: info.TokenId,
Quota: 0,
Content: strings.Join(logContent, ", "),
IsStream: false,
Group: info.UsingGroup,
Other: other,
})
}
// writeImageStreamError sends an error event followed by [DONE] on the SSE
// stream. This is used when SSE headers have already been flushed to the
// client (status 200) and a subsequent error must be delivered inside the
// stream rather than as a new HTTP status code.
func writeImageStreamError(c *gin.Context, errMsg string) {
payload := map[string]any{
"type": "error",
"message": errMsg,
}
data, err := common.Marshal(payload)
if err == nil {
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", data)
} else {
fmt.Fprintf(c.Writer, "event: error\ndata: {\"type\":\"error\",\"message\":\"%s\"}\n\n", errMsg)
}
fmt.Fprint(c.Writer, "data: [DONE]\n\n")
_ = helper.FlushWriter(c)
}