Files
new-api/relay/image_handler.go
2026-06-23 05:05:25 +08:00

357 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// determineActualSizeTier determines the resolution tier based on the actual
// dimensions of the downloaded images. It falls back to the request size
// parameter when no local images can be found (e.g. b64_json responses).
func determineActualSizeTier(c *gin.Context, requestSize string) string {
// Try to get actual image dimensions from the local files that were
// saved during DoResponse. The relay_image.go handlers store the
// local filenames in the gin context.
localImages, _ := c.Get("local_image_files")
if filenames, ok := localImages.([]string); ok && len(filenames) > 0 {
maxEdge := 0
for _, filename := range filenames {
w, h, err := service.GetImageDimensions(filename)
if err != nil {
logger.LogDebug(c, "failed to get image dimensions for %s: %v", filename, err)
continue
}
if w > maxEdge {
maxEdge = w
}
if h > maxEdge {
maxEdge = h
}
}
if maxEdge > 0 {
switch {
case maxEdge <= 1024:
return operation_setting.ImageSizeTier1K
case maxEdge <= 2048:
return operation_setting.ImageSizeTier2K
default:
return operation_setting.ImageSizeTier4K
}
}
}
// Fallback to request size parameter
tier, ok := operation_setting.ClassifyImageSizeTier(requestSize)
if !ok {
return operation_setting.ImageSizeTier2K // default to 2K when unknown
}
return tier
}
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
imageReq, ok := info.Request.(*dto.ImageRequest)
if !ok {
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
request, err := common.DeepCopy(imageReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
storage, err := common.GetBodyStorage(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = common.ReaderOnly(storage)
} else {
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
switch convertedRequest.(type) {
case *bytes.Buffer:
requestBody = convertedRequest.(io.Reader)
default:
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}
}
logger.LogDebug(c, "image request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
}
statusCodeMappingStr := c.GetString("status_code_mapping")
// 当客户端请求流式响应时,立即发送 SSE 头并启动周期性 keepalive 协程。
// 图片生成可能耗时 60 秒以上,若期间无数据传输,反向代理(如 Nginx 的
// proxy_read_timeout,默认 60 秒)或浏览器会关闭连接。
//
// 【重要】keepalive 必须覆盖 DoRequest 和 DoResponse 两个阶段。
// 对于 SSE 上游,DoRequest 在收到响应头后即返回(SSE 先发 200 + headers
// 事件数据稍后到达),真正的等待发生在 DoResponse 中——它逐行读取上游
// SSE 事件。若在 DoRequest 后就停止 keepaliveDoResponse 期间连接将
// 处于空闲状态,导致反向代理超时关闭连接(client_gone / context canceled)。
var keepaliveDone chan struct{}
if info.IsStream {
helper.SetEventStreamHeaders(c)
c.Status(http.StatusOK)
fmt.Fprint(c.Writer, ": keepalive\n\n")
_ = helper.FlushWriter(c)
keepaliveDone = make(chan struct{})
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fmt.Fprint(c.Writer, ": keepalive\n\n")
_ = helper.FlushWriter(c)
case <-keepaliveDone:
return
}
}
}()
}
resp, err := adaptor.DoRequest(c, info, requestBody)
// 注意:此处不能关闭 keepaliveDoneSSE 上游的 DoRequest 会快速返回,
// 真正的等待在 DoResponse 中,keepalive 必须持续到 DoResponse 完成。
if err != nil {
if keepaliveDone != nil {
close(keepaliveDone)
}
logImageError(c, info, request, fmt.Sprintf("请求失败: %s", err.Error()))
if info.IsStream {
writeImageStreamError(c, err.Error())
}
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
if httpResp.StatusCode == http.StatusCreated && info.ApiType == constant.APITypeReplicate {
// replicate channel returns 201 Created when using Prefer: wait, treat it as success.
httpResp.StatusCode = http.StatusOK
} else {
if keepaliveDone != nil {
close(keepaliveDone)
}
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
logImageError(c, info, request, fmt.Sprintf("上游返回错误 (HTTP %d): %s", newAPIError.StatusCode, newAPIError.Error()))
if info.IsStream {
// 优先使用 OpenAI 风格的错误消息(更简洁),回退到 Error()
errMsg := newAPIError.ToOpenAIError().Message
if errMsg == "" {
errMsg = newAPIError.Error()
}
writeImageStreamError(c, errMsg)
}
return newAPIError
}
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
// DoResponse 已完成,现在可以安全地停止 keepalive 协程
if keepaliveDone != nil {
close(keepaliveDone)
}
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
logImageError(c, info, request, fmt.Sprintf("响应处理失败: %s", newAPIError.Error()))
return newAPIError
}
imageN := uint(1)
if request.N != nil {
imageN = *request.N
}
// n is handled via OtherRatio so it is applied exactly once in quota
// calculation (both price-based and ratio-based paths).
// Adaptors may have already set a more accurate count from the
// upstream response; only set the default when they haven't.
if info.PriceData.UsePrice { // only price model use N ratio
if _, hasN := info.PriceData.OtherRatios["n"]; !hasN {
info.PriceData.AddOtherRatio("n", float64(imageN))
}
}
quality := request.Quality
if quality == "" {
quality = "standard"
}
isPerSizeBilling := operation_setting.IsImagePerSizeBilling(info.OriginModelName)
var logContent []string
if isPerSizeBilling {
// For per-size billing, calculate actual count and tier from the
// response rather than the request parameters.
actualCount := info.ReceivedResponseCount
if actualCount <= 0 {
actualCount = int(imageN)
}
actualSizeTier := determineActualSizeTier(c, request.Size)
c.Set("image_per_size_billing", true)
c.Set("image_size_tier", actualSizeTier)
c.Set("image_per_size_count", actualCount)
if actualSizeTier != "" && request.Size != "" {
requestTier, _ := operation_setting.ClassifyImageSizeTier(request.Size)
if requestTier != actualSizeTier {
logContent = append(logContent, fmt.Sprintf("请求大小 %s (档位 %s), 实际档位 %s", request.Size, requestTier, actualSizeTier))
}
}
if actualCount != int(imageN) {
logContent = append(logContent, fmt.Sprintf("请求数量 %d, 实际返回 %d", imageN, actualCount))
}
} else {
// For non-per-size billing, set token counts for quota calculation.
if usage.(*dto.Usage).TotalTokens == 0 {
usage.(*dto.Usage).TotalTokens = 1
}
if usage.(*dto.Usage).PromptTokens == 0 {
usage.(*dto.Usage).PromptTokens = 1
}
if len(request.Size) > 0 {
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
}
if len(quality) > 0 {
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
}
if imageN > 0 {
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
}
}
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), logContent)
return nil
}
// logImageError records a consume log for failed image generation requests
// so that users can see the failure reason in their usage logs.
func logImageError(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ImageRequest, errMsg string) {
if info == nil {
return
}
var logContent []string
if len(request.Size) > 0 {
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
}
quality := request.Quality
if quality == "" {
quality = "standard"
}
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
imageN := 1
if request.N != nil {
imageN = int(*request.N)
}
if imageN > 0 {
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
}
logContent = append(logContent, fmt.Sprintf("失败: %s", errMsg))
other := make(map[string]interface{})
if c.Request != nil && c.Request.URL != nil {
other["request_path"] = c.Request.URL.Path
}
other["is_image_error"] = true
other["error_message"] = errMsg
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: info.OriginModelName,
TokenName: c.GetString("token_name"),
TokenId: info.TokenId,
Quota: 0,
Content: strings.Join(logContent, ", "),
IsStream: false,
Group: info.UsingGroup,
Other: other,
})
}
// writeImageStreamError sends an error event followed by [DONE] on the SSE
// stream. This is used when SSE headers have already been flushed to the
// client (status 200) and a subsequent error must be delivered inside the
// stream rather than as a new HTTP status code.
func writeImageStreamError(c *gin.Context, errMsg string) {
payload := map[string]any{
"type": "error",
"message": errMsg,
}
data, err := common.Marshal(payload)
if err == nil {
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", data)
} else {
fmt.Fprintf(c.Writer, "event: error\ndata: {\"type\":\"error\",\"message\":\"%s\"}\n\n", errMsg)
}
fmt.Fprint(c.Writer, "data: [DONE]\n\n")
_ = helper.FlushWriter(c)
}