perf: reduce heap residency for large base64 relay requests

Three layered optimizations targeting Gemini-style 5MB base64 payloads where
RSS could balloon to tens of GB under concurrent load:

1. Byte-based param override (relay/common/override.go)
   - Switch legacy/operations hot paths from common.Marshal round-trips and
     map[string]any conversions to gjson/sjson on []byte directly.
   - Avoids cloning 5MB strings during each Set/Delete operation.

2. strings.Builder for Gemini response markdown (relay/channel/gemini/relay-gemini.go)
   - Replace string concatenation + strings.Join when assembling
     "![image](data:...;base64,DATA)" content for inline image responses.
   - Pre-allocates capacity from inline_data byte sizes.

3. Outbound BodyStorage + streaming Decoder (this commit's core)
   - New relay/common/outbound_body.go helper wraps marshaled upstream bodies
     in common.BodyStorage, allowing disk-cache mode to offload jsonData to
     a temp file while waiting for upstream TTFB. The original []byte can
     then be GC'd, removing ~5MB/req of heap residency during the longest
     window of a request.
   - All 7 relay handlers (gemini/claude/responses/embedding/image/compatible/
     rerank) plus chat_completions_via_responses adopt the helper with
     defer closer.Close() and explicit jsonData = nil.
   - relay/common/relay_info.go: new UpstreamRequestBodySize so
     relay/channel/api_request.go can populate req.ContentLength (lost when
     body becomes a type-erased io.Reader).
   - common/gin.go UnmarshalBodyReusable: when storage is disk-backed and
     content-type is JSON, decode via DecodeJson(storage) instead of
     storage.Bytes()+Unmarshal, removing one transient 5MB copy per request.
     memory mode and form/multipart paths unchanged.
This commit is contained in:
CaIon
2026-05-22 19:08:38 +08:00
parent b9bc6f0e21
commit fddf54ccc5
15 changed files with 407 additions and 169 deletions
+2 -2
View File
@@ -37,7 +37,7 @@ func checkWriter(writer io.Writer) stringWriter {
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
var contentType = []string{"text/event-stream"}
var writeContentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
@@ -79,7 +79,7 @@ func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
header := w.Header()
header["Content-Type"] = contentType
header["Content-Type"] = writeContentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
+19 -1
View File
@@ -110,11 +110,29 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
// disk-backed JSON: stream-decode directly from the file to avoid
// materializing the entire payload back into a transient []byte
// (diskStorage.Bytes() would ReadFull the whole file into the heap).
if storage.IsDisk() && strings.HasPrefix(contentType, "application/json") {
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
if err := DecodeJson(storage, v); err != nil {
return err
}
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
c.Request.Body = io.NopCloser(storage)
return nil
}
requestBody, err := storage.Bytes()
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
+20
View File
@@ -25,6 +25,23 @@ import (
"github.com/gorilla/websocket"
)
// applyUpstreamContentLength populates req.ContentLength when the upstream
// body is wrapped in a BodyStorage (see relay/common/outbound_body.go).
//
// net/http.NewRequest only auto-detects ContentLength for *bytes.Reader,
// *bytes.Buffer and *strings.Reader. When the body is a type-erased io.Reader
// (which is the case for ReaderOnly(BodyStorage)), the Content-Length header
// would otherwise be omitted, forcing chunked transfer encoding and breaking
// some upstreams that require an explicit Content-Length.
func applyUpstreamContentLength(req *http.Request, info *common.RelayInfo) {
if info == nil {
return
}
if info.UpstreamRequestBodySize > 0 && req.ContentLength <= 0 {
req.ContentLength = info.UpstreamRequestBodySize
}
}
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data
@@ -297,6 +314,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
headers := req.Header
err = a.SetupRequestHeader(c, &headers, info)
if err != nil {
@@ -326,6 +344,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
headers := req.Header
@@ -522,6 +541,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
+88 -17
View File
@@ -1079,17 +1079,47 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
FinishReason: constant.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
var texts []string
// 使用 strings.Builder 直接累积最终 content,避免:
// 1) 每张 inline image 生成一次中间 "![image](...)" 字符串
// 2) 末尾 strings.Join 再分配一份等大缓冲
// Gemini 图片返回时 InlineData.Data 可能是数 MB 的 base64
// 上述两份临时分配在高并发下会显著放大堆驻留。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
// 媒体内容
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
writeSep()
content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
} else {
// 其他媒体类型,直接显示链接
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
writeSep()
content.WriteString("[media](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
}
} else if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
@@ -1100,13 +1130,22 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.ReasoningContent = &part.Text
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
texts = append(texts, part.Text)
writeSep()
content.WriteString(part.Text)
}
}
}
@@ -1115,7 +1154,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
choice.Message.SetStringContent(content.String())
}
if candidate.FinishReason != nil {
@@ -1169,7 +1208,25 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
//Role: "assistant",
},
}
var texts []string
// 使用 strings.Builder 直接累积 delta content,避免每张 image / 每个
// 文本片段都先 `+` 拼出一份临时 string,再 strings.Join 再拷贝一遍。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
isTools := false
isThought := false
if candidate.FinishReason != nil {
@@ -1207,8 +1264,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
writeSep()
content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
}
} else if part.FunctionCall != nil {
isTools = true
@@ -1219,23 +1280,33 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
} else if part.Thought {
isThought = true
texts = append(texts, part.Text)
writeSep()
content.WriteString(part.Text)
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```\n")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```\n")
} else {
if part.Text != "\n" {
texts = append(texts, part.Text)
writeSep()
content.WriteString(part.Text)
}
}
}
}
if isThought {
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
choice.Delta.SetReasoningContent(content.String())
} else {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
choice.Delta.SetContentString(content.String())
}
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
+8 -2
View File
@@ -1,7 +1,6 @@
package relay
import (
"bytes"
"io"
"net/http"
"strings"
@@ -125,7 +124,14 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
var requestBody io.Reader = bytes.NewBuffer(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
var requestBody io.Reader = body
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, info, requestBody)
+8 -2
View File
@@ -1,7 +1,6 @@
package relay
import (
"bytes"
"encoding/json"
"fmt"
"io"
@@ -179,7 +178,14 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
logger.LogDebug(c, "requestBody: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
statusCodeMappingStr := c.GetString("status_code_mapping")
+31
View File
@@ -0,0 +1,31 @@
package common
import (
"io"
"github.com/QuantumNous/new-api/common"
)
// NewOutboundJSONBody wraps the already-marshaled upstream request body into a
// BodyStorage. When disk cache is enabled and the payload exceeds the configured
// threshold, the data is written to a temp file and the original []byte can be
// GC'd, significantly reducing the heap residency while waiting for the
// upstream provider to respond (the dominant cost for large base64 payloads).
//
// In memory mode the underlying memoryStorage reuses the same backing array,
// so this is equivalent to bytes.NewReader(data) in terms of memory usage.
//
// The caller MUST invoke closer.Close() once the upstream call has finished
// (typically via defer) to release the disk file / memory accounting.
//
// The returned reader is wrapped with common.ReaderOnly to prevent the HTTP
// transport from prematurely closing the underlying BodyStorage. The returned
// size is meant to be propagated to http.Request.ContentLength because the
// type-erased io.Reader prevents net/http from auto-detecting it.
func NewOutboundJSONBody(data []byte) (body io.Reader, size int64, closer io.Closer, err error) {
storage, err := common.CreateBodyStorage(data)
if err != nil {
return nil, 0, nil, err
}
return common.ReaderOnly(storage), storage.Size(), storage, nil
}
+168 -133
View File
@@ -153,9 +153,8 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
}
}
// 使用新方法
result, err := applyOperations(string(workingJSON), operations, conditionContext)
return []byte(result), err
// 使用新方法(基于 []byte,避免整包 string 拷贝)
return applyOperations(workingJSON, operations, conditionContext)
}
// 直接使用旧方法
@@ -510,13 +509,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
return operations, true
}
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
func checkConditions(data []byte, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
if len(conditions) == 0 {
return true, nil // 没有条件,直接通过
}
results := make([]bool, len(conditions))
for i, condition := range conditions {
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
result, err := checkSingleCondition(data, contextJSON, condition)
if err != nil {
return false, err
}
@@ -529,10 +528,10 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio
return lo.SomeBy(results, func(item bool) bool { return item }), nil
}
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
func checkSingleCondition(data []byte, contextJSON string, condition ConditionOperation) (bool, error) {
// 处理负数索引
path := processNegativeIndex(jsonStr, condition.Path)
value := gjson.Get(jsonStr, path)
path := processNegativeIndex(data, condition.Path)
value := gjson.GetBytes(data, path)
if !value.Exists() && contextJSON != "" {
value = gjson.Get(contextJSON, condition.Path)
}
@@ -561,7 +560,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat
return result, nil
}
func processNegativeIndex(jsonStr string, path string) string {
func processNegativeIndex(data []byte, path string) string {
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
if len(matches) == 0 {
@@ -578,7 +577,7 @@ func processNegativeIndex(jsonStr string, path string) string {
arrayPath = arrayPath[:len(arrayPath)-1]
}
array := gjson.Get(jsonStr, arrayPath)
array := gjson.GetBytes(data, arrayPath)
if array.IsArray() {
length := len(array.Array())
actualIndex := length + index
@@ -667,36 +666,76 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
}
}
// applyOperationsLegacy 原参数覆盖方法
// applyOperationsLegacy 原参数覆盖方法
//
// 旧实现把整个 jsonData unmarshal 成 map[string]interface{} 再 marshal 回来,
// 对包含大 base64 字段(如 Gemini inlineData.data)的请求会放大数倍内存
// interface 装箱、map bucket、再次 marshal)。
// 这里改成在 []byte 上直接调用 sjson.SetBytes,按顶层 key 逐个写入,
// 不再把 payload 解码到 map[string]interface{}。
//
// 语义保持:每个 paramOverride 顶层 key 视为字面 key(不解析点号路径),
// 与旧的 reqMap[key] = value 一致。包含 `.` `*` `?` `\` 的 key 会被转义,
// 防止被 sjson 当作嵌套路径或通配符。
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) {
reqMap := make(map[string]interface{})
err := common.Unmarshal(jsonData, &reqMap)
if err != nil {
return nil, err
if len(paramOverride) == 0 {
return jsonData, nil
}
result := jsonData
for key, value := range paramOverride {
reqMap[key] = value
escaped := escapeSjsonLiteralKey(key)
next, err := sjson.SetBytes(result, escaped, value)
if err != nil {
return nil, err
}
result = next
auditRecorder.recordOperation("set", key, "", "", value)
}
return common.Marshal(reqMap)
return result, nil
}
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
// escapeSjsonLiteralKey 把可能被 sjson 误判为路径或通配符的字符转义,
// 用于把字面 key 安全地传给 sjson.SetBytes / sjson.DeleteBytes。
func escapeSjsonLiteralKey(key string) string {
if !strings.ContainsAny(key, ".*?\\") {
return key
}
var sb strings.Builder
sb.Grow(len(key) + 4)
for i := 0; i < len(key); i++ {
c := key[i]
switch c {
case '.', '*', '?', '\\':
sb.WriteByte('\\')
}
sb.WriteByte(c)
}
return sb.String()
}
// applyOperations 在 []byte 上原地应用所有 param override 操作。
//
// 旧实现走 string-based gjson/sjson,在 ApplyParamOverride 入口会做
// string(jsonData) 与最终 []byte(result) 各一次整包拷贝,对大 base64
// payload 来说每次重试都额外多花 2 倍 body 体积的临时内存。
// 这里改成全程在 []byte 上工作,sjson.SetBytes / gjson.GetBytes 都是
// 直接读写 []byte,每个操作只会产生一份新 buffer。
func applyOperations(jsonData []byte, operations []ParamOperation, conditionContext map[string]interface{}) ([]byte, error) {
context := ensureContextMap(conditionContext)
auditRecorder := getParamOverrideAuditRecorder(context)
contextJSON, err := marshalContextJSON(context)
if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err)
return nil, fmt.Errorf("failed to marshal condition context: %v", err)
}
result := jsonStr
result := jsonData
for _, op := range operations {
// 检查条件是否满足
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
if err != nil {
return "", err
return nil, err
}
if !ok {
continue // 条件不满足,跳过当前操作
@@ -707,7 +746,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if isPathBasedOperation(op.Mode) {
opPaths, err = resolveOperationPaths(result, opPath)
if err != nil {
return "", err
return nil, err
}
if len(opPaths) == 0 {
continue
@@ -725,10 +764,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
}
case "set":
for _, path := range opPaths {
if op.KeepOrigin && gjson.Get(result, path).Exists() {
if op.KeepOrigin && gjson.GetBytes(result, path).Exists() {
continue
}
result, err = sjson.Set(result, path, op.Value)
result, err = sjson.SetBytes(result, path, op.Value)
if err != nil {
break
}
@@ -743,7 +782,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
}
case "copy":
if op.From == "" || op.To == "" {
return "", fmt.Errorf("copy from/to is required")
return nil, fmt.Errorf("copy from/to is required")
}
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
@@ -843,9 +882,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value)
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
if parseErr != nil {
return "", parseErr
return nil, parseErr
}
return "", returnErr
return nil, returnErr
case "prune_objects":
for _, path := range opPaths {
result, err = pruneObjects(result, path, contextJSON, op.Value)
@@ -902,7 +941,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
case "pass_headers":
headerNames, parseErr := parseHeaderPassThroughNames(op.Value)
if parseErr != nil {
return "", parseErr
return nil, parseErr
}
for _, headerName := range headerNames {
if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil {
@@ -924,10 +963,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
contextJSON, err = marshalContextJSON(context)
}
default:
return "", fmt.Errorf("unknown operation: %s", op.Mode)
return nil, fmt.Errorf("unknown operation: %s", op.Mode)
}
if err != nil {
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
return nil, fmt.Errorf("operation %s failed: %w", op.Mode, err)
}
}
return result, nil
@@ -1361,11 +1400,11 @@ func parseSyncTarget(spec string) (syncTarget, error) {
}
}
func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
func readSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
switch target.kind {
case "json":
path := processNegativeIndex(jsonStr, target.key)
value := gjson.Get(jsonStr, path)
path := processNegativeIndex(data, target.key)
value := gjson.GetBytes(data, path)
if !value.Exists() || value.Type == gjson.Null {
return nil, false, nil
}
@@ -1384,52 +1423,52 @@ func readSyncTargetValue(jsonStr string, context map[string]interface{}, target
}
}
func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) {
func writeSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget, value interface{}) ([]byte, error) {
switch target.kind {
case "json":
path := processNegativeIndex(jsonStr, target.key)
nextJSON, err := sjson.Set(jsonStr, path, value)
path := processNegativeIndex(data, target.key)
nextJSON, err := sjson.SetBytes(data, path, value)
if err != nil {
return "", err
return nil, err
}
return nextJSON, nil
case "header":
if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil {
return "", err
return nil, err
}
return jsonStr, nil
return data, nil
default:
return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
return nil, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
}
}
func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) {
func syncFieldsBetweenTargets(data []byte, context map[string]interface{}, fromSpec string, toSpec string) ([]byte, error) {
fromTarget, err := parseSyncTarget(fromSpec)
if err != nil {
return "", err
return nil, err
}
toTarget, err := parseSyncTarget(toSpec)
if err != nil {
return "", err
return nil, err
}
fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget)
fromValue, fromExists, err := readSyncTargetValue(data, context, fromTarget)
if err != nil {
return "", err
return nil, err
}
toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget)
toValue, toExists, err := readSyncTargetValue(data, context, toTarget)
if err != nil {
return "", err
return nil, err
}
// If one side exists and the other side is missing, sync the missing side.
if fromExists && !toExists {
return writeSyncTargetValue(jsonStr, context, toTarget, fromValue)
return writeSyncTargetValue(data, context, toTarget, fromValue)
}
if toExists && !fromExists {
return writeSyncTargetValue(jsonStr, context, fromTarget, toValue)
return writeSyncTargetValue(data, context, fromTarget, toValue)
}
return jsonStr, nil
return data, nil
}
func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} {
@@ -1503,24 +1542,24 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
info.UseRuntimeHeadersOverride = true
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
func moveValue(data []byte, fromPath, toPath string) ([]byte, error) {
sourceValue := gjson.GetBytes(data, fromPath)
if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
return data, fmt.Errorf("source path does not exist: %s", fromPath)
}
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
result, err := sjson.SetBytes(data, toPath, sourceValue.Value())
if err != nil {
return "", err
return nil, err
}
return sjson.Delete(result, fromPath)
return sjson.DeleteBytes(result, fromPath)
}
func copyValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
func copyValue(data []byte, fromPath, toPath string) ([]byte, error) {
sourceValue := gjson.GetBytes(data, fromPath)
if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
return data, fmt.Errorf("source path does not exist: %s", fromPath)
}
return sjson.Set(jsonStr, toPath, sourceValue.Value())
return sjson.SetBytes(data, toPath, sourceValue.Value())
}
func isPathBasedOperation(mode string) bool {
@@ -1532,16 +1571,16 @@ func isPathBasedOperation(mode string) bool {
}
}
func resolveOperationPaths(jsonStr, path string) ([]string, error) {
func resolveOperationPaths(data []byte, path string) ([]string, error) {
if !strings.Contains(path, "*") {
return []string{path}, nil
}
return expandWildcardPaths(jsonStr, path)
return expandWildcardPaths(data, path)
}
func expandWildcardPaths(jsonStr, path string) ([]string, error) {
func expandWildcardPaths(data []byte, path string) ([]string, error) {
var root interface{}
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
if err := common.Unmarshal(data, &root); err != nil {
return nil, err
}
@@ -1602,28 +1641,28 @@ func collectWildcardPaths(node interface{}, segments []string, prefix []string)
}
}
func deleteValue(jsonStr, path string) (string, error) {
func deleteValue(data []byte, path string) ([]byte, error) {
if strings.TrimSpace(path) == "" {
return jsonStr, nil
return data, nil
}
return sjson.Delete(jsonStr, path)
return sjson.DeleteBytes(data, path)
}
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
func modifyValue(data []byte, path string, value interface{}, keepOrigin, isPrepend bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
switch {
case current.IsArray():
return modifyArray(jsonStr, path, value, isPrepend)
return modifyArray(data, path, value, isPrepend)
case current.Type == gjson.String:
return modifyString(jsonStr, path, value, isPrepend)
return modifyString(data, path, value, isPrepend)
case current.Type == gjson.JSON:
return mergeObjects(jsonStr, path, value, keepOrigin)
return mergeObjects(data, path, value, keepOrigin)
}
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
}
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
func modifyArray(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
var newArray []interface{}
// 添加新值
addValue := func() {
@@ -1647,11 +1686,11 @@ func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (strin
addOriginal()
addValue()
}
return sjson.Set(jsonStr, path, newArray)
return sjson.SetBytes(data, path, newArray)
}
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
func modifyString(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
valueStr := fmt.Sprintf("%v", value)
var newStr string
if isPrepend {
@@ -1659,17 +1698,17 @@ func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (stri
} else {
newStr = current.String() + valueStr
}
return sjson.Set(jsonStr, path, newStr)
return sjson.SetBytes(data, path, newStr)
}
func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
func trimStringValue(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return jsonStr, fmt.Errorf("trim value is required")
return data, fmt.Errorf("trim value is required")
}
valueStr := fmt.Sprintf("%v", value)
@@ -1679,69 +1718,69 @@ func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (st
} else {
newStr = strings.TrimSuffix(current.String(), valueStr)
}
return sjson.Set(jsonStr, path, newStr)
return sjson.SetBytes(data, path, newStr)
}
func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
func ensureStringAffix(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return jsonStr, fmt.Errorf("ensure value is required")
return data, fmt.Errorf("ensure value is required")
}
valueStr := fmt.Sprintf("%v", value)
if valueStr == "" {
return jsonStr, fmt.Errorf("ensure value is required")
return data, fmt.Errorf("ensure value is required")
}
currentStr := current.String()
if isPrefix {
if strings.HasPrefix(currentStr, valueStr) {
return jsonStr, nil
return data, nil
}
return sjson.Set(jsonStr, path, valueStr+currentStr)
return sjson.SetBytes(data, path, valueStr+currentStr)
}
if strings.HasSuffix(currentStr, valueStr) {
return jsonStr, nil
return data, nil
}
return sjson.Set(jsonStr, path, currentStr+valueStr)
return sjson.SetBytes(data, path, currentStr+valueStr)
}
func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
current := gjson.Get(jsonStr, path)
func transformStringValue(data []byte, path string, transform func(string) string) ([]byte, error) {
current := gjson.GetBytes(data, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
}
return sjson.Set(jsonStr, path, transform(current.String()))
return sjson.SetBytes(data, path, transform(current.String()))
}
func replaceStringValue(jsonStr, path, from, to string) (string, error) {
current := gjson.Get(jsonStr, path)
func replaceStringValue(data []byte, path, from, to string) ([]byte, error) {
current := gjson.GetBytes(data, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if from == "" {
return jsonStr, fmt.Errorf("replace from is required")
return data, fmt.Errorf("replace from is required")
}
return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
return sjson.SetBytes(data, path, strings.ReplaceAll(current.String(), from, to))
}
func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
current := gjson.Get(jsonStr, path)
func regexReplaceStringValue(data []byte, path, pattern, replacement string) ([]byte, error) {
current := gjson.GetBytes(data, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if pattern == "" {
return jsonStr, fmt.Errorf("regex pattern is required")
return data, fmt.Errorf("regex pattern is required")
}
re, err := regexp.Compile(pattern)
if err != nil {
return jsonStr, err
return data, err
}
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
return sjson.SetBytes(data, path, re.ReplaceAllString(current.String(), replacement))
}
type pruneObjectsOptions struct {
@@ -1750,37 +1789,33 @@ type pruneObjectsOptions struct {
recursive bool
}
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
func pruneObjects(data []byte, path, contextJSON string, value interface{}) ([]byte, error) {
options, err := parsePruneObjectsOptions(value)
if err != nil {
return "", err
return nil, err
}
if path == "" {
var root interface{}
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
return "", err
if err := common.Unmarshal(data, &root); err != nil {
return nil, err
}
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
if err != nil {
return "", err
return nil, err
}
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return "", err
}
return string(cleanedBytes), nil
return common.Marshal(cleaned)
}
target := gjson.Get(jsonStr, path)
target := gjson.GetBytes(data, path)
if !target.Exists() {
return jsonStr, nil
return data, nil
}
var targetNode interface{}
if target.Type == gjson.JSON {
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
return "", err
if err := common.UnmarshalJsonStr(target.Raw, &targetNode); err != nil {
return nil, err
}
} else {
targetNode = target.Value()
@@ -1788,13 +1823,13 @@ func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string,
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
if err != nil {
return "", err
return nil, err
}
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return "", err
return nil, err
}
return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
return sjson.SetRawBytes(data, path, cleanedBytes)
}
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
@@ -1970,16 +2005,16 @@ func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions,
if err != nil {
return false, err
}
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
return checkConditions(nodeBytes, contextJSON, options.conditions, options.logic)
}
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
current := gjson.Get(jsonStr, path)
func mergeObjects(data []byte, path string, value interface{}, keepOrigin bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
var currentMap, newMap map[string]interface{}
// 解析当前值
if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
return "", err
// 解析当前值current.Raw 是 data 的子串,避免再分配一份)
if err := common.UnmarshalJsonStr(current.Raw, &currentMap); err != nil {
return nil, err
}
// 解析新值
switch v := value.(type) {
@@ -1988,7 +2023,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
default:
jsonBytes, _ := common.Marshal(v)
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
return "", err
return nil, err
}
}
// 合并
@@ -2001,7 +2036,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
result[k] = v
}
}
return sjson.Set(jsonStr, path, result)
return sjson.SetBytes(data, path, result)
}
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
+7
View File
@@ -154,6 +154,13 @@ type RelayInfo struct {
UseRuntimeHeadersOverride bool
ParamOverrideAudit []string
// UpstreamRequestBodySize is the byte size of the marshaled upstream request
// body. It is set when the body is wrapped in a BodyStorage (see
// relay/common/outbound_body.go), so that DoApiRequest can populate
// http.Request.ContentLength manually (net/http only auto-detects it for
// *bytes.Reader/Buffer/strings.Reader). 0 means "let net/http decide".
UpstreamRequestBodySize int64
PriceData types.PriceData
// TieredBillingSnapshot is a frozen snapshot of tiered billing rules
+8 -2
View File
@@ -1,7 +1,6 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -176,7 +175,14 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
logger.LogDebug(c, "text request body: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
var httpResp *http.Response
+8 -2
View File
@@ -1,7 +1,6 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -59,7 +58,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
logger.LogDebug(c, "converted embedding request body: %s", jsonData)
var requestBody io.Reader = bytes.NewBuffer(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
var requestBody io.Reader = body
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
+16 -3
View File
@@ -1,7 +1,6 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -165,7 +164,14 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
logger.LogDebug(c, "Gemini request body: %s", jsonData)
requestBody = bytes.NewReader(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
resp, err := adaptor.DoRequest(c, info, requestBody)
@@ -263,7 +269,14 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
}
}
logger.LogDebug(c, "Gemini embedding request body: %s", jsonData)
requestBody = bytes.NewReader(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
+8 -1
View File
@@ -77,7 +77,14 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
}
logger.LogDebug(c, "image request body: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
}
+8 -2
View File
@@ -1,7 +1,6 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -69,7 +68,14 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
logger.LogDebug(c, "Rerank request body: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
resp, err := adaptor.DoRequest(c, info, requestBody)
+8 -2
View File
@@ -1,7 +1,6 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -104,7 +103,14 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
logger.LogDebug(c, "requestBody: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
var httpResp *http.Response