357 lines
12 KiB
Go
357 lines
12 KiB
Go
package relay
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/logger"
|
||
"github.com/QuantumNous/new-api/model"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/relay/helper"
|
||
"github.com/QuantumNous/new-api/service"
|
||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||
"github.com/QuantumNous/new-api/types"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// determineActualSizeTier determines the resolution tier based on the actual
|
||
// dimensions of the downloaded images. It falls back to the request size
|
||
// parameter when no local images can be found (e.g. b64_json responses).
|
||
func determineActualSizeTier(c *gin.Context, requestSize string) string {
|
||
// Try to get actual image dimensions from the local files that were
|
||
// saved during DoResponse. The relay_image.go handlers store the
|
||
// local filenames in the gin context.
|
||
localImages, _ := c.Get("local_image_files")
|
||
if filenames, ok := localImages.([]string); ok && len(filenames) > 0 {
|
||
maxEdge := 0
|
||
for _, filename := range filenames {
|
||
w, h, err := service.GetImageDimensions(filename)
|
||
if err != nil {
|
||
logger.LogDebug(c, "failed to get image dimensions for %s: %v", filename, err)
|
||
continue
|
||
}
|
||
if w > maxEdge {
|
||
maxEdge = w
|
||
}
|
||
if h > maxEdge {
|
||
maxEdge = h
|
||
}
|
||
}
|
||
if maxEdge > 0 {
|
||
switch {
|
||
case maxEdge <= 1024:
|
||
return operation_setting.ImageSizeTier1K
|
||
case maxEdge <= 2048:
|
||
return operation_setting.ImageSizeTier2K
|
||
default:
|
||
return operation_setting.ImageSizeTier4K
|
||
}
|
||
}
|
||
}
|
||
|
||
// Fallback to request size parameter
|
||
tier, ok := operation_setting.ClassifyImageSizeTier(requestSize)
|
||
if !ok {
|
||
return operation_setting.ImageSizeTier2K // default to 2K when unknown
|
||
}
|
||
return tier
|
||
}
|
||
|
||
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||
info.InitChannelMeta(c)
|
||
|
||
imageReq, ok := info.Request.(*dto.ImageRequest)
|
||
if !ok {
|
||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||
}
|
||
|
||
request, err := common.DeepCopy(imageReq)
|
||
if err != nil {
|
||
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||
}
|
||
|
||
err = helper.ModelMappedHelper(c, info, request)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||
}
|
||
|
||
adaptor := GetAdaptor(info.ApiType)
|
||
if adaptor == nil {
|
||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||
}
|
||
adaptor.Init(info)
|
||
|
||
var requestBody io.Reader
|
||
|
||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||
storage, err := common.GetBodyStorage(c)
|
||
if err != nil {
|
||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||
}
|
||
requestBody = common.ReaderOnly(storage)
|
||
} else {
|
||
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||
}
|
||
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
|
||
|
||
switch convertedRequest.(type) {
|
||
case *bytes.Buffer:
|
||
requestBody = convertedRequest.(io.Reader)
|
||
default:
|
||
jsonData, err := common.Marshal(convertedRequest)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
|
||
// apply param override
|
||
if len(info.ParamOverride) > 0 {
|
||
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
|
||
if err != nil {
|
||
return newAPIErrorFromParamOverride(err)
|
||
}
|
||
}
|
||
|
||
logger.LogDebug(c, "image request body: %s", jsonData)
|
||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
defer closer.Close()
|
||
jsonData = nil
|
||
info.UpstreamRequestBodySize = size
|
||
requestBody = body
|
||
}
|
||
}
|
||
|
||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||
|
||
// 当客户端请求流式响应时,立即发送 SSE 头并启动周期性 keepalive 协程。
|
||
// 图片生成可能耗时 60 秒以上,若期间无数据传输,反向代理(如 Nginx 的
|
||
// proxy_read_timeout,默认 60 秒)或浏览器会关闭连接。
|
||
//
|
||
// 【重要】keepalive 必须覆盖 DoRequest 和 DoResponse 两个阶段。
|
||
// 对于 SSE 上游,DoRequest 在收到响应头后即返回(SSE 先发 200 + headers,
|
||
// 事件数据稍后到达),真正的等待发生在 DoResponse 中——它逐行读取上游
|
||
// SSE 事件。若在 DoRequest 后就停止 keepalive,DoResponse 期间连接将
|
||
// 处于空闲状态,导致反向代理超时关闭连接(client_gone / context canceled)。
|
||
var keepaliveDone chan struct{}
|
||
if info.IsStream {
|
||
helper.SetEventStreamHeaders(c)
|
||
c.Status(http.StatusOK)
|
||
fmt.Fprint(c.Writer, ": keepalive\n\n")
|
||
_ = helper.FlushWriter(c)
|
||
|
||
keepaliveDone = make(chan struct{})
|
||
go func() {
|
||
ticker := time.NewTicker(15 * time.Second)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
fmt.Fprint(c.Writer, ": keepalive\n\n")
|
||
_ = helper.FlushWriter(c)
|
||
case <-keepaliveDone:
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||
// 注意:此处不能关闭 keepaliveDone!SSE 上游的 DoRequest 会快速返回,
|
||
// 真正的等待在 DoResponse 中,keepalive 必须持续到 DoResponse 完成。
|
||
if err != nil {
|
||
if keepaliveDone != nil {
|
||
close(keepaliveDone)
|
||
}
|
||
logImageError(c, info, request, fmt.Sprintf("请求失败: %s", err.Error()))
|
||
if info.IsStream {
|
||
writeImageStreamError(c, err.Error())
|
||
}
|
||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||
}
|
||
var httpResp *http.Response
|
||
if resp != nil {
|
||
httpResp = resp.(*http.Response)
|
||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
if httpResp.StatusCode == http.StatusCreated && info.ApiType == constant.APITypeReplicate {
|
||
// replicate channel returns 201 Created when using Prefer: wait, treat it as success.
|
||
httpResp.StatusCode = http.StatusOK
|
||
} else {
|
||
if keepaliveDone != nil {
|
||
close(keepaliveDone)
|
||
}
|
||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||
// reset status code 重置状态码
|
||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||
logImageError(c, info, request, fmt.Sprintf("上游返回错误 (HTTP %d): %s", newAPIError.StatusCode, newAPIError.Error()))
|
||
if info.IsStream {
|
||
// 优先使用 OpenAI 风格的错误消息(更简洁),回退到 Error()
|
||
errMsg := newAPIError.ToOpenAIError().Message
|
||
if errMsg == "" {
|
||
errMsg = newAPIError.Error()
|
||
}
|
||
writeImageStreamError(c, errMsg)
|
||
}
|
||
return newAPIError
|
||
}
|
||
}
|
||
}
|
||
|
||
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
|
||
// DoResponse 已完成,现在可以安全地停止 keepalive 协程
|
||
if keepaliveDone != nil {
|
||
close(keepaliveDone)
|
||
}
|
||
if newAPIError != nil {
|
||
// reset status code 重置状态码
|
||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||
logImageError(c, info, request, fmt.Sprintf("响应处理失败: %s", newAPIError.Error()))
|
||
return newAPIError
|
||
}
|
||
|
||
imageN := uint(1)
|
||
if request.N != nil {
|
||
imageN = *request.N
|
||
}
|
||
|
||
// n is handled via OtherRatio so it is applied exactly once in quota
|
||
// calculation (both price-based and ratio-based paths).
|
||
// Adaptors may have already set a more accurate count from the
|
||
// upstream response; only set the default when they haven't.
|
||
if info.PriceData.UsePrice { // only price model use N ratio
|
||
if _, hasN := info.PriceData.OtherRatios["n"]; !hasN {
|
||
info.PriceData.AddOtherRatio("n", float64(imageN))
|
||
}
|
||
}
|
||
|
||
quality := request.Quality
|
||
if quality == "" {
|
||
quality = "standard"
|
||
}
|
||
|
||
isPerSizeBilling := operation_setting.IsImagePerSizeBilling(info.OriginModelName)
|
||
|
||
var logContent []string
|
||
|
||
if isPerSizeBilling {
|
||
// For per-size billing, calculate actual count and tier from the
|
||
// response rather than the request parameters.
|
||
actualCount := info.ReceivedResponseCount
|
||
if actualCount <= 0 {
|
||
actualCount = int(imageN)
|
||
}
|
||
actualSizeTier := determineActualSizeTier(c, request.Size)
|
||
c.Set("image_per_size_billing", true)
|
||
c.Set("image_size_tier", actualSizeTier)
|
||
c.Set("image_per_size_count", actualCount)
|
||
|
||
if actualSizeTier != "" && request.Size != "" {
|
||
requestTier, _ := operation_setting.ClassifyImageSizeTier(request.Size)
|
||
if requestTier != actualSizeTier {
|
||
logContent = append(logContent, fmt.Sprintf("请求大小 %s (档位 %s), 实际档位 %s", request.Size, requestTier, actualSizeTier))
|
||
}
|
||
}
|
||
if actualCount != int(imageN) {
|
||
logContent = append(logContent, fmt.Sprintf("请求数量 %d, 实际返回 %d", imageN, actualCount))
|
||
}
|
||
} else {
|
||
// For non-per-size billing, set token counts for quota calculation.
|
||
if usage.(*dto.Usage).TotalTokens == 0 {
|
||
usage.(*dto.Usage).TotalTokens = 1
|
||
}
|
||
if usage.(*dto.Usage).PromptTokens == 0 {
|
||
usage.(*dto.Usage).PromptTokens = 1
|
||
}
|
||
|
||
if len(request.Size) > 0 {
|
||
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
|
||
}
|
||
if len(quality) > 0 {
|
||
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
|
||
}
|
||
if imageN > 0 {
|
||
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
|
||
}
|
||
}
|
||
|
||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||
return nil
|
||
}
|
||
|
||
// logImageError records a consume log for failed image generation requests
|
||
// so that users can see the failure reason in their usage logs.
|
||
func logImageError(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ImageRequest, errMsg string) {
|
||
if info == nil {
|
||
return
|
||
}
|
||
|
||
var logContent []string
|
||
if len(request.Size) > 0 {
|
||
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
|
||
}
|
||
quality := request.Quality
|
||
if quality == "" {
|
||
quality = "standard"
|
||
}
|
||
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
|
||
imageN := 1
|
||
if request.N != nil {
|
||
imageN = int(*request.N)
|
||
}
|
||
if imageN > 0 {
|
||
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
|
||
}
|
||
logContent = append(logContent, fmt.Sprintf("失败: %s", errMsg))
|
||
|
||
other := make(map[string]interface{})
|
||
if c.Request != nil && c.Request.URL != nil {
|
||
other["request_path"] = c.Request.URL.Path
|
||
}
|
||
other["is_image_error"] = true
|
||
other["error_message"] = errMsg
|
||
|
||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||
ChannelId: info.ChannelId,
|
||
ModelName: info.OriginModelName,
|
||
TokenName: c.GetString("token_name"),
|
||
TokenId: info.TokenId,
|
||
Quota: 0,
|
||
Content: strings.Join(logContent, ", "),
|
||
IsStream: false,
|
||
Group: info.UsingGroup,
|
||
Other: other,
|
||
})
|
||
}
|
||
|
||
// writeImageStreamError sends an error event followed by [DONE] on the SSE
|
||
// stream. This is used when SSE headers have already been flushed to the
|
||
// client (status 200) and a subsequent error must be delivered inside the
|
||
// stream rather than as a new HTTP status code.
|
||
func writeImageStreamError(c *gin.Context, errMsg string) {
|
||
payload := map[string]any{
|
||
"type": "error",
|
||
"message": errMsg,
|
||
}
|
||
data, err := common.Marshal(payload)
|
||
if err == nil {
|
||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", data)
|
||
} else {
|
||
fmt.Fprintf(c.Writer, "event: error\ndata: {\"type\":\"error\",\"message\":\"%s\"}\n\n", errMsg)
|
||
}
|
||
fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||
_ = helper.FlushWriter(c)
|
||
}
|