Compare commits

..

2 Commits

Author SHA1 Message Date
CaIon c29d80f015 chore: remove channel cache test 2026-05-23 13:23:32 +08:00
t0ng7u 3ba01a7dcd 🐛 fix(channel): evict auto-disabled multi-key channels from cache
Ensure multi-key channels are removed from the in-memory routing cache when all keys become auto-disabled, preventing subsequent requests from repeatedly selecting channels with no available keys.

Also make multi-key status updates more robust by handling missing key matches, checking actual enabled key availability, and restoring the channel status when a key is re-enabled. Add regression coverage for disabled cached channels and multi-key cache eviction.
2026-05-20 12:47:46 +08:00
521 changed files with 26434 additions and 33544 deletions
-2
View File
@@ -56,8 +56,6 @@
# 对话超时设置
# 所有请求超时时间,单位秒,默认为0,表示不限制
# RELAY_TIMEOUT=0
# Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制
# RELAY_IDLE_CONN_TIMEOUT=90
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
# STREAMING_TIMEOUT=300
+4 -11
View File
@@ -11,8 +11,6 @@ assignees: ''
- 文档:https://docs.newapi.ai/
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
- 不接受 coding plan、逆向渠道等技术支持类 issue。
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
**您当前的 newapi 版本**
@@ -22,18 +20,13 @@ assignees: ''
**提交确认**
[//]: # (方框内删除已有的空格,填 x 号)
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue)确认目前没有类似 issue
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题。
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已完整查看文档 https://docs.newapi.ai/项目 README,尤其是常见问题部分
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
**问题描述**
请尽可能说明问题现象、影响范围,以及你判断它是程序问题而不是上游行为或使用问题的依据。
- 转发问题请尽可能说明渠道类型、转换格式、上游原生支持依据和服务端日志。
- 计费问题请尽可能附请求返回的 `usage` 示例。
**复现步骤**
**预期结果**
+4 -11
View File
@@ -11,8 +11,6 @@ assignees: ''
- Docs: https://docs.newapi.ai/
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
**Your current newapi version**
@@ -22,18 +20,13 @@ Please fill this in, for example: `v1.0.0`
**Submission Checks**
[//]: # (Remove the space in the box and fill with an x)
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question.
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
+ [ ] I have confirmed there are no similar issues
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, especially the FAQ section
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
**Issue Description**
Describe the symptom, impact scope, and why you believe this is an application issue rather than upstream behavior or a usage question with as much detail as possible.
- For forwarding issues, include the channel type, conversion format, upstream native-support evidence, and server logs when possible.
- For billing issues, include an example of the returned `usage` when possible.
**Steps to Reproduce**
**Expected Result**
+4 -6
View File
@@ -11,8 +11,6 @@ assignees: ''
- 文档:https://docs.newapi.ai/
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
- 不接受 coding plan、逆向渠道等技术支持类 issue。
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
**您当前的 newapi 版本**
@@ -22,10 +20,10 @@ assignees: ''
**提交确认**
[//]: # (方框内删除已有的空格,填 x 号)
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue)确认目前没有类似 issue
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题,且现有版本无法满足需求
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已完整查看文档 https://docs.newapi.ai/项目 README,已确定现有版本无法满足需求
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
**功能描述**
+4 -6
View File
@@ -11,8 +11,6 @@ assignees: ''
- Docs: https://docs.newapi.ai/
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
**Your current newapi version**
@@ -22,10 +20,10 @@ Please fill this in, for example: `v1.0.0`
**Submission Checks**
[//]: # (Remove the space in the box and fill with an x)
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question, and that the current version cannot meet my needs.
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
+ [ ] I have confirmed there are no similar issues
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, and confirmed the current version cannot meet my needs
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
**Feature Description**
+12 -18
View File
@@ -33,18 +33,16 @@ jobs:
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd default
cd web/default
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Build Frontend (classic)
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd classic
cd web/classic
bun install
VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Set up Go
@@ -93,18 +91,16 @@ jobs:
CI: ""
NODE_OPTIONS: "--max-old-space-size=4096"
run: |
cd web
bun install --frozen-lockfile
cd default
cd web/default
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Build Frontend (classic)
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd classic
cd web/classic
bun install
VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Set up Go
@@ -150,18 +146,16 @@ jobs:
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd default
cd web/default
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Build Frontend (classic)
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd classic
cd web/classic
bun install
VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Set up Go
-1
View File
@@ -35,4 +35,3 @@ data/
.test
token_estimator_test.go
skills-lock.json
.playwright-mcp
+16 -18
View File
@@ -1,24 +1,22 @@
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
WORKDIR /build/web
COPY web/package.json web/bun.lock ./
COPY web/default/package.json ./default/package.json
COPY web/classic/package.json ./classic/package.json
RUN bun install --frozen-lockfile
COPY ./web/default ./default
COPY ./VERSION /build/VERSION
RUN cd default && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
WORKDIR /build
COPY web/default/package.json .
COPY web/default/bun.lock .
RUN bun install
COPY ./web/default .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder-classic
WORKDIR /build/web
COPY web/package.json web/bun.lock ./
COPY web/default/package.json ./default/package.json
COPY web/classic/package.json ./classic/package.json
RUN bun install --frozen-lockfile
COPY ./web/classic ./classic
COPY ./VERSION /build/VERSION
RUN cd classic && VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
WORKDIR /build
COPY web/classic/package.json .
COPY web/classic/bun.lock .
RUN bun install
COPY ./web/classic .
COPY ./VERSION .
RUN VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
ENV GO111MODULE=on CGO_ENABLED=0
@@ -34,8 +32,8 @@ ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/web/default/dist ./web/default/dist
COPY --from=builder-classic /build/web/classic/dist ./web/classic/dist
COPY --from=builder /build/dist ./web/default/dist
COPY --from=builder-classic /build/dist ./web/classic/dist
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
-1
View File
@@ -316,7 +316,6 @@ docker run --name new-api -d --restart always \
| `CRYPTO_SECRET` | Encryption secret (required for Redis) | - |
| `SQL_DSN` | Database connection string | - |
| `REDIS_CONN_STRING` | Redis connection string | - |
| `RELAY_IDLE_CONN_TIMEOUT` | Idle keep-alive timeout for relay HTTP clients, seconds. Defaults to Go standard library behavior; set `0` to disable | `90` |
| `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` |
| `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` |
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
-1
View File
@@ -170,7 +170,6 @@ var BatchUpdateInterval int
var RelayTimeout int // unit is second
var RelayIdleConnTimeout int // unit is second
var RelayMaxIdleConns int
var RelayMaxIdleConnsPerHost int
+2 -2
View File
@@ -37,7 +37,7 @@ func checkWriter(writer io.Writer) stringWriter {
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
var writeContentType = []string{"text/event-stream"}
var contentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
@@ -79,7 +79,7 @@ func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
header := w.Header()
header["Content-Type"] = writeContentType
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
+1 -19
View File
@@ -110,29 +110,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
// disk-backed JSON: stream-decode directly from the file to avoid
// materializing the entire payload back into a transient []byte
// (diskStorage.Bytes() would ReadFull the whole file into the heap).
if storage.IsDisk() && strings.HasPrefix(contentType, "application/json") {
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
if err := DecodeJson(storage, v); err != nil {
return err
}
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
c.Request.Body = io.NopCloser(storage)
return nil
}
requestBody, err := storage.Bytes()
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
+2 -4
View File
@@ -102,7 +102,6 @@ func InitEnv() {
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
RelayIdleConnTimeout = GetEnvOrDefault("RELAY_IDLE_CONN_TIMEOUT", 90)
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
@@ -112,11 +111,11 @@ func InitEnv() {
// Initialize rate limit variables
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 360)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 120)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
@@ -136,7 +135,6 @@ func initConstantEnv() {
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 128)
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
constant.AnonymousRequestBodyLimitKB = GetEnvOrDefault("ANONYMOUS_REQUEST_BODY_LIMIT_KB", 512)
// ForceStreamOption 覆盖请求参数,强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
-13
View File
@@ -1,13 +0,0 @@
package common
import "github.com/QuantumNous/new-api/constant"
const defaultAnonymousRequestBodyLimitKB = 512
func GetAnonymousRequestBodyLimitBytes() int64 {
limitKB := constant.AnonymousRequestBodyLimitKB
if limitKB < 0 {
limitKB = defaultAnonymousRequestBodyLimitKB
}
return int64(limitKB) << 10
}
-11
View File
@@ -3,7 +3,6 @@ package common
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
@@ -21,16 +20,6 @@ var (
maskApiKeyPattern = regexp.MustCompile(`(['"]?)api_key:([^\s'"]+)(['"]?)`)
)
const LocalLogContentLimit = 2048
// LocalLogPreview limits log-only content unless debug logging is enabled.
func LocalLogPreview(content string) string {
if DebugEnabled || len(content) <= LocalLogContentLimit {
return content
}
return fmt.Sprintf("%s... [truncated, original_length=%d, limit=%d]", content[:LocalLogContentLimit], len(content), LocalLogContentLimit)
}
func GetStringIfEmpty(str string, defaultValue string) string {
if str == "" {
return defaultValue
-1
View File
@@ -10,7 +10,6 @@ var GetMediaToken bool
var GetMediaTokenNotStream bool
var UpdateTask bool
var MaxRequestBodyMB int
var AnonymousRequestBodyLimitKB int
var AzureDefaultAPIVersion string
var NotifyLimitCount int
var NotificationLimitDurationMinute int
+8 -34
View File
@@ -57,24 +57,7 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
return normalized
}
func resolveChannelTestUserID(c *gin.Context) (int, error) {
if c != nil {
if userID := c.GetInt("id"); userID > 0 {
return userID, nil
}
}
var rootUser model.User
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
}
if rootUser.Id == 0 {
return 0, errors.New("failed to resolve channel test user")
}
return rootUser.Id, nil
}
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney,
@@ -160,7 +143,7 @@ func testChannel(channel *model.Channel, testUserID int, testModel string, endpo
Header: make(http.Header),
}
cache, err := model.GetUserCache(testUserID)
cache, err := model.GetUserCache(1)
if err != nil {
return testResult{
localErr: err,
@@ -168,13 +151,13 @@ func testChannel(channel *model.Channel, testUserID int, testModel string, endpo
}
}
cache.WriteContext(c)
c.Set("id", testUserID)
c.Set("id", 1)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(testUserID, false)
group, _ := model.GetUserGroup(1, false)
c.Set("group", group)
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
@@ -501,7 +484,7 @@ func testChannel(channel *model.Channel, testUserID int, testModel string, endpo
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
@@ -814,7 +797,7 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel,
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
if dto.IsOpenAIReasoningOModel(model) {
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = lo.ToPtr(uint(16))
} else if strings.Contains(model, "thinking") {
if !strings.Contains(model, "claude") {
@@ -851,13 +834,8 @@ func TestChannel(c *gin.Context) {
testModel := c.Query("model")
endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream"))
testUserID, err := resolveChannelTestUserID(c)
if err != nil {
common.ApiError(c, err)
return
}
tik := time.Now()
result := testChannel(channel, testUserID, testModel, endpointType, isStream)
result := testChannel(channel, testModel, endpointType, isStream)
if result.localErr != nil {
resp := gin.H{
"success": false,
@@ -894,10 +872,6 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
testUserID, err := resolveChannelTestUserID(nil)
if err != nil {
return err
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
@@ -928,7 +902,7 @@ func testAllChannels(notify bool) error {
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
+1 -1
View File
@@ -1218,7 +1218,7 @@ func CopyChannel(c *gin.Context) {
}
// insert
if err := clone.Insert(); err != nil {
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
common.SysError("failed to clone channel: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
return
-11
View File
@@ -69,14 +69,3 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
require.Equal(t, "base", other["matched_tier"])
require.NotEmpty(t, other["expr_b64"])
}
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Set("id", 2)
userID, err := resolveChannelTestUserID(ctx)
require.NoError(t, err)
require.Equal(t, 2, userID)
}
-1
View File
@@ -88,7 +88,6 @@ func GetStatus(c *gin.Context) {
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"register_enabled": common.RegisterEnabled,
"password_login_enabled": common.PasswordLoginEnabled,
"password_register_enabled": common.PasswordRegisterEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup,
+43 -120
View File
@@ -3,7 +3,6 @@ package controller
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
@@ -110,102 +109,9 @@ func init() {
})
}
func channelOwnerName(channelType int) string {
apiType, success := common.ChannelType2APIType(channelType)
if !success {
return strings.ToLower(constant.GetChannelTypeName(channelType))
}
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return strings.ToLower(constant.GetChannelTypeName(channelType))
}
adaptor.Init(&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
ChannelType: channelType,
}})
if name := strings.TrimSpace(adaptor.GetChannelName()); name != "" {
return name
}
return strings.ToLower(constant.GetChannelTypeName(channelType))
}
func getPreferredModelOwners(modelNames []string, groups []string) map[string]string {
channelTypes, err := model.GetPreferredModelOwnerChannelTypes(modelNames, groups)
if err != nil {
common.SysLog(fmt.Sprintf("GetPreferredModelOwnerChannelTypes error: %v", err))
return map[string]string{}
}
ownerByChannelType := make(map[int]string)
owners := make(map[string]string, len(channelTypes))
for modelName, channelType := range channelTypes {
owner, ok := ownerByChannelType[channelType]
if !ok {
owner = channelOwnerName(channelType)
ownerByChannelType[channelType] = owner
}
if owner != "" {
owners[modelName] = owner
}
}
return owners
}
func buildOpenAIModel(modelName string, ownerByModel map[string]string) dto.OpenAIModels {
var oaiModel dto.OpenAIModels
if staticModel, ok := openAIModelsMap[modelName]; ok {
oaiModel = staticModel
} else {
oaiModel = dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
}
}
if owner, ok := ownerByModel[modelName]; ok && owner != "" {
oaiModel.OwnedBy = owner
}
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
return oaiModel
}
type modelListGroups struct {
userGroup string
tokenGroup string
ownerGroups []string
}
func getModelListGroups(c *gin.Context) (modelListGroups, error) {
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
if userGroup == "" && (tokenGroup == "" || tokenGroup == "auto") {
var err error
userGroup, err = model.GetUserGroup(c.GetInt("id"), false)
if err != nil {
return modelListGroups{}, err
}
}
if tokenGroup == "auto" {
return modelListGroups{
userGroup: userGroup,
tokenGroup: tokenGroup,
ownerGroups: service.GetUserAutoGroup(userGroup),
}, nil
}
group := userGroup
if tokenGroup != "" {
group = tokenGroup
}
return modelListGroups{
userGroup: userGroup,
tokenGroup: tokenGroup,
ownerGroups: []string{group},
}, nil
}
func ListModels(c *gin.Context, modelType int) {
userOpenAiModels := make([]dto.OpenAIModels, 0)
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
if !acceptUnsetRatioModel {
userId := c.GetInt("id")
@@ -217,16 +123,6 @@ func ListModels(c *gin.Context, modelType int) {
}
}
userModelNames := make([]string, 0)
groups, err := getModelListGroups(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
})
return
}
ownerGroups := groups.ownerGroups
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
@@ -242,12 +138,37 @@ func ListModels(c *gin.Context, modelType int) {
continue
}
}
userModelNames = append(userModelNames, allowModel)
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
})
return
}
group := userGroup
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
group = tokenGroup
}
var models []string
if groups.tokenGroup == "auto" {
for _, autoGroup := range ownerGroups {
if tokenGroup == "auto" {
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
groupModels := model.GetGroupEnabledModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {
@@ -256,7 +177,7 @@ func ListModels(c *gin.Context, modelType int) {
}
}
} else {
models = model.GetGroupEnabledModels(ownerGroups[0])
models = model.GetGroupEnabledModels(group)
}
for _, modelName := range models {
if !acceptUnsetRatioModel {
@@ -264,19 +185,21 @@ func ListModels(c *gin.Context, modelType int) {
continue
}
}
userModelNames = append(userModelNames, modelName)
if oaiModel, ok := openAIModelsMap[modelName]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
})
}
}
}
ownerByModel := map[string]string{}
if len(ownerGroups) > 0 {
ownerByModel = getPreferredModelOwners(userModelNames, ownerGroups)
}
userOpenAiModels := make([]dto.OpenAIModels, 0, len(userModelNames))
for _, modelName := range userModelNames {
userOpenAiModels = append(userOpenAiModels, buildOpenAIModel(modelName, ownerByModel))
}
switch modelType {
case constant.ChannelTypeAnthropic:
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
-85
View File
@@ -1,85 +0,0 @@
package controller
import (
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestChannelOwnerNameUsesAdaptorChannelName(t *testing.T) {
tests := []struct {
name string
channelType int
expected string
}{
{
name: "openai",
channelType: constant.ChannelTypeOpenAI,
expected: "openai",
},
{
name: "codex",
channelType: constant.ChannelTypeCodex,
expected: "codex",
},
{
name: "openrouter",
channelType: constant.ChannelTypeOpenRouter,
expected: "openrouter",
},
{
name: "azure fallback",
channelType: constant.ChannelTypeAzure,
expected: "azure",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.expected, channelOwnerName(tt.channelType))
})
}
}
func TestBuildOpenAIModelOverridesOwnedBy(t *testing.T) {
modelItem := buildOpenAIModel("gpt-5.4", map[string]string{"gpt-5.4": "openai"})
require.Equal(t, "gpt-5.4", modelItem.Id)
require.Equal(t, "openai", modelItem.OwnedBy)
}
func TestBuildOpenAIModelFallsBackToCustomForUnknownModels(t *testing.T) {
modelItem := buildOpenAIModel("custom-test-model", nil)
require.Equal(t, "custom-test-model", modelItem.Id)
require.Equal(t, "custom", modelItem.OwnedBy)
}
func TestGetModelListGroupsUsesUserGroupWhenTokenGroupIsEmpty(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
groups, err := getModelListGroups(ctx)
require.NoError(t, err)
require.Equal(t, "default", groups.userGroup)
require.Empty(t, groups.tokenGroup)
require.Equal(t, []string{"default"}, groups.ownerGroups)
}
func TestGetModelListGroupsUsesExplicitTokenGroup(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
common.SetContextKey(ctx, constant.ContextKeyTokenGroup, "vip")
groups, err := getModelListGroups(ctx)
require.NoError(t, err)
require.Equal(t, "default", groups.userGroup)
require.Equal(t, "vip", groups.tokenGroup)
require.Equal(t, []string{"vip"}, groups.ownerGroups)
}
+10 -1
View File
@@ -42,6 +42,15 @@ func isPositiveOptionValue(value string) bool {
return err == nil && floatValue > 0
}
func isVisiblePublicKeyOption(key string) bool {
switch key {
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
return true
default:
return false
}
}
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
if strings.TrimSpace(raw) == "" {
return
@@ -86,7 +95,7 @@ func GetOptions(c *gin.Context) {
strings.HasSuffix(k, "Key") ||
strings.HasSuffix(k, "secret") ||
strings.HasSuffix(k, "api_key")
if isSensitiveKey {
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
continue
}
options = append(options, &model.Option{
+13 -4
View File
@@ -77,15 +77,24 @@ func isWaffoPancakeTopUpEnabled() bool {
if !isPaymentComplianceConfirmed() {
return false
}
// Presence-of-credentials = enabled. Webhook public keys ship inside
// the SDK; mode (test/prod) is read from each event.
return strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
if !setting.WaffoPancakeEnabled {
return false
}
return isWaffoPancakeWebhookConfigured() &&
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
}
func isWaffoPancakeWebhookConfigured() bool {
return isWaffoPancakeTopUpEnabled()
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
if setting.WaffoPancakeSandbox {
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
}
return currentWebhookKey != ""
}
func isWaffoPancakeWebhookEnabled() bool {
@@ -114,32 +114,47 @@ func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
confirmPaymentComplianceForTest(t)
originalEnabled := setting.WaffoPancakeEnabled
originalSandbox := setting.WaffoPancakeSandbox
originalMerchantID := setting.WaffoPancakeMerchantID
originalPrivateKey := setting.WaffoPancakePrivateKey
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
originalStoreID := setting.WaffoPancakeStoreID
originalProductID := setting.WaffoPancakeProductID
t.Cleanup(func() {
setting.WaffoPancakeEnabled = originalEnabled
setting.WaffoPancakeSandbox = originalSandbox
setting.WaffoPancakeMerchantID = originalMerchantID
setting.WaffoPancakePrivateKey = originalPrivateKey
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
setting.WaffoPancakeStoreID = originalStoreID
setting.WaffoPancakeProductID = originalProductID
})
// Presence of all three credentials enables the gateway. Webhook public
// keys are bundled in the SDK and there is no separate Enabled toggle —
// clear any of the three fields to disable.
setting.WaffoPancakeMerchantID = ""
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = false
setting.WaffoPancakeMerchantID = "merchant"
setting.WaffoPancakePrivateKey = "private"
setting.WaffoPancakeStoreID = "store"
setting.WaffoPancakeProductID = "product"
setting.WaffoPancakeWebhookPublicKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeMerchantID = "merchant"
setting.WaffoPancakeWebhookPublicKey = "public"
require.True(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeProductID = ""
setting.WaffoPancakeEnabled = false
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeProductID = "product"
setting.WaffoPancakePrivateKey = ""
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = true
setting.WaffoPancakeWebhookTestKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookTestKey = "test_public"
require.True(t, isWaffoPancakeWebhookEnabled())
}
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
+2 -2
View File
@@ -88,7 +88,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
defer func() {
if newAPIError != nil {
logger.LogError(c, fmt.Sprintf("relay error: %s", common.LocalLogPreview(newAPIError.Error())))
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
@@ -354,7 +354,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
}
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, common.LocalLogPreview(err.Error())))
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(err) && channelError.AutoBan {
-32
View File
@@ -22,10 +22,6 @@ type BillingPreferenceRequest struct {
BillingPreference string `json:"billing_preference"`
}
type SubscriptionBalancePayRequest struct {
PlanId int `json:"plan_id"`
}
// ---- User APIs ----
func GetSubscriptionPlans(c *gin.Context) {
@@ -41,7 +37,6 @@ func GetSubscriptionPlans(c *gin.Context) {
}
result := make([]SubscriptionPlanDTO, 0, len(plans))
for _, p := range plans {
p.NormalizeDefaults()
result = append(result, SubscriptionPlanDTO{
Plan: p,
})
@@ -97,25 +92,6 @@ func UpdateSubscriptionPreference(c *gin.Context) {
common.ApiSuccess(c, gin.H{"billing_preference": pref})
}
func SubscriptionRequestBalancePay(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
userId := c.GetInt("id")
var req SubscriptionBalancePayRequest
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
common.ApiErrorMsg(c, "参数错误")
return
}
if err := model.PurchaseSubscriptionWithBalance(userId, req.PlanId); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, nil)
}
// ---- Admin APIs ----
func AdminListSubscriptionPlans(c *gin.Context) {
@@ -126,7 +102,6 @@ func AdminListSubscriptionPlans(c *gin.Context) {
}
result := make([]SubscriptionPlanDTO, 0, len(plans))
for _, p := range plans {
p.NormalizeDefaults()
result = append(result, SubscriptionPlanDTO{
Plan: p,
})
@@ -165,9 +140,6 @@ func AdminCreateSubscriptionPlan(c *gin.Context) {
req.Plan.Currency = "USD"
}
req.Plan.Currency = "USD"
if req.Plan.AllowBalancePay == nil {
req.Plan.AllowBalancePay = common.GetPointer(true)
}
if req.Plan.DurationUnit == "" {
req.Plan.DurationUnit = model.SubscriptionDurationMonth
}
@@ -276,7 +248,6 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
"sort_order": req.Plan.SortOrder,
"stripe_price_id": req.Plan.StripePriceId,
"creem_product_id": req.Plan.CreemProductId,
"waffo_pancake_product_id": req.Plan.WaffoPancakeProductId,
"max_purchase_per_user": req.Plan.MaxPurchasePerUser,
"total_amount": req.Plan.TotalAmount,
"upgrade_group": req.Plan.UpgradeGroup,
@@ -284,9 +255,6 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
"quota_reset_custom_seconds": req.Plan.QuotaResetCustomSeconds,
"updated_at": common.GetTimestamp(),
}
if req.Plan.AllowBalancePay != nil {
updateMap["allow_balance_pay"] = *req.Plan.AllowBalancePay
}
if err := tx.Model(&model.SubscriptionPlan{}).Where("id = ?", id).Updates(updateMap).Error; err != nil {
return err
}
@@ -1,130 +0,0 @@
package controller
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"github.com/thanhpk/randstr"
)
type SubscriptionWaffoPancakePayRequest struct {
PlanId int `json:"plan_id"`
}
func SubscriptionRequestWaffoPancakePay(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
var req SubscriptionWaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
common.ApiErrorMsg(c, "参数错误")
return
}
plan, err := model.GetSubscriptionPlanById(req.PlanId)
if err != nil {
common.ApiError(c, err)
return
}
if !plan.Enabled {
common.ApiErrorMsg(c, "套餐未启用")
return
}
if strings.TrimSpace(plan.WaffoPancakeProductId) == "" {
common.ApiErrorMsg(c, "该套餐未配置 WaffoPancakeProductId")
return
}
// Plan targets its own Pancake product, so we only require credentials
// here — not the gateway-level WaffoPancakeProductID.
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" {
common.ApiErrorMsg(c, "Waffo Pancake 未配置或密钥无效")
return
}
userId := c.GetInt("id")
user, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
if user == nil {
common.ApiErrorMsg(c, "用户不存在")
return
}
if plan.MaxPurchasePerUser > 0 {
count, err := model.CountUserSubscriptionsByPlan(userId, plan.Id)
if err != nil {
common.ApiError(c, err)
return
}
if count >= int64(plan.MaxPurchasePerUser) {
common.ApiErrorMsg(c, "已达到该套餐购买上限")
return
}
}
// WAFFO_PANCAKE_SUB- prefix (vs. wallet's WAFFO_PANCAKE-) drives webhook
// dispatch in WaffoPancakeWebhook.
tradeNo := fmt.Sprintf("WAFFO_PANCAKE_SUB-%d-%d-%s", userId, time.Now().UnixMilli(), randstr.String(6))
order := &model.SubscriptionOrder{
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: tradeNo,
PaymentMethod: model.PaymentMethodWaffoPancake,
PaymentProvider: model.PaymentProviderWaffoPancake,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := order.Insert(); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
expiresInSeconds := 45 * 60
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
ProductID: plan.WaffoPancakeProductId,
BuyerIdentity: service.WaffoPancakeBuyerIdentityFromUserID(user.Id),
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
Amount: decimal.NewFromFloat(plan.PriceAmount).StringFixed(2),
TaxCategory: "saas",
},
BuyerEmail: getWaffoPancakeBuyerEmail(user),
ExpiresInSeconds: &expiresInSeconds,
OrderMerchantExternalID: tradeNo,
})
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅结账会话创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
order.Status = common.TopUpStatusFailed
_ = order.Update()
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建成功 user_id=%d plan_id=%d trade_no=%s session_id=%s money=%.2f", userId, plan.Id, tradeNo, session.SessionID, plan.PriceAmount))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
"token": session.Token,
"token_expires_at": session.TokenExpiresAt,
},
})
}
+20 -21
View File
@@ -52,27 +52,6 @@ func GetTopUpInfo(c *gin.Context) {
}
}
// Waffo Pancake displayed above the legacy Waffo gateway.
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
if enableWaffoPancake {
hasWaffoPancake := false
for _, method := range payMethods {
if method["type"] == model.PaymentMethodWaffoPancake {
hasWaffoPancake = true
break
}
}
if !hasWaffoPancake {
payMethods = append(payMethods, map[string]string{
"name": "Waffo Pancake",
"type": model.PaymentMethodWaffoPancake,
"color": "rgba(var(--semi-orange-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
})
}
}
// 如果启用了 Waffo 支付,添加到支付方法列表
enableWaffo := isWaffoTopUpEnabled()
if enableWaffo {
@@ -95,6 +74,26 @@ func GetTopUpInfo(c *gin.Context) {
}
}
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
if enableWaffoPancake {
hasWaffoPancake := false
for _, method := range payMethods {
if method["type"] == model.PaymentMethodWaffoPancake {
hasWaffoPancake = true
break
}
}
if !hasWaffoPancake {
payMethods = append(payMethods, map[string]string{
"name": "Waffo Pancake",
"type": model.PaymentMethodWaffoPancake,
"color": "rgba(var(--semi-orange-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
})
}
}
data := gin.H{
"enable_online_topup": isEpayTopUpEnabled(),
"enable_stripe_topup": isStripeTopUpEnabled(),
+33 -311
View File
@@ -96,257 +96,33 @@ func getWaffoPancakeBuyerEmail(user *model.User) string {
if user != nil && strings.TrimSpace(user.Email) != "" {
return user.Email
}
if user != nil {
return fmt.Sprintf("%d@new-api.local", user.Id)
}
return ""
}
// The admin config endpoints below accept typed-but-not-yet-saved creds in
// the body and fall back to persisted creds when the body is blank (see
// resolveWaffoPancakeAdminCreds). Only SaveWaffoPancake writes to OptionMap.
type waffoPancakeCredsRequest struct {
MerchantID string `json:"merchant_id"`
PrivateKey string `json:"private_key"`
}
type saveWaffoPancakeRequest struct {
MerchantID string `json:"merchant_id"`
PrivateKey string `json:"private_key"`
ReturnURL string `json:"return_url"`
StoreID string `json:"store_id"`
ProductID string `json:"product_id"`
}
type createWaffoPancakePairRequest struct {
MerchantID string `json:"merchant_id"`
PrivateKey string `json:"private_key"`
ReturnURL string `json:"return_url"`
}
// SaveWaffoPancake atomically persists all five operator-controlled fields.
// Catalog / pair endpoints are transient — only this one writes the OptionMap.
func SaveWaffoPancake(c *gin.Context) {
var req saveWaffoPancakeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
func getWaffoPancakeReturnURL() string {
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
return setting.WaffoPancakeReturnURL
}
if err := service.SaveWaffoPancakeConfig(
c.Request.Context(),
req.MerchantID,
req.PrivateKey,
req.ReturnURL,
req.StoreID,
req.ProductID,
); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 保存配置失败 store_id=%q product_id=%q error=%q",
req.StoreID, req.ProductID, err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "保存配置失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"product_id": setting.WaffoPancakeProductID,
"store_id": setting.WaffoPancakeStoreID,
},
})
}
// resolveWaffoPancakeAdminCreds prefers body creds (typed-but-not-yet-saved
// values, for verification) and falls back to persisted creds when the body
// is blank (so returning admins don't have to re-paste the private key,
// which is stripped from GET /api/option/).
func resolveWaffoPancakeAdminCreds(bodyMerchantID, bodyPrivateKey string) (string, string) {
m := strings.TrimSpace(bodyMerchantID)
k := strings.TrimSpace(bodyPrivateKey)
if m == "" && k == "" {
return setting.WaffoPancakeMerchantID, setting.WaffoPancakePrivateKey
}
return m, k
}
// CreateWaffoPancakePair mints a Store + OnetimeProduct pair in one round-
// trip. Surfaces an orphan-store flag when the product half fails so the
// frontend can preselect / retry without losing context.
func CreateWaffoPancakePair(c *gin.Context) {
var req createWaffoPancakePairRequest
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
}
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
if merchantID == "" || privateKey == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
return
}
result, err := service.CreateWaffoPancakePrimaryPair(
c.Request.Context(), merchantID, privateKey, req.ReturnURL,
)
if err != nil {
orphan := result != nil && result.OrphanStore
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 创建店铺与产品失败 orphan_store=%t store_id=%q error=%q",
orphan, func() string {
if result == nil {
return ""
}
return result.StoreID
}(), err.Error(),
))
data := gin.H{"error": err.Error()}
if orphan {
data["store_id"] = result.StoreID
data["store_name"] = result.StoreName
data["orphan_store"] = true
}
c.JSON(http.StatusOK, gin.H{"message": "error", "data": data})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"store_id": result.StoreID,
"store_name": result.StoreName,
"product_id": result.ProductID,
"product_name": result.ProductName,
},
})
}
// ListWaffoPancakeCatalog returns the merchant's Stores + OnetimeProducts.
// Doubles as a credential probe (a successful 200 proves the resolved creds
// authenticate). See resolveWaffoPancakeAdminCreds for credential resolution.
func ListWaffoPancakeCatalog(c *gin.Context) {
var req waffoPancakeCredsRequest
// An empty body means "use persisted creds"; only fail on malformed JSON.
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
}
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
if merchantID == "" || privateKey == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
return
}
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 拉取店铺与产品目录失败 error=%q", err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取目录失败"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": catalog})
}
type createWaffoPancakeSubscriptionProductRequest struct {
Name string `json:"name"`
Amount string `json:"amount"`
}
// CreateWaffoPancakeSubscriptionProduct mints an OnetimeProduct (not
// SubscriptionProduct — see service.CreateWaffoPancakeProductForPlan)
// sized to a plan's `name` + `amount`, using persisted Pancake credentials
// + StoreID. Reads from the form, not the plan row, so newly-typed unsaved
// plans can mint a product too.
func CreateWaffoPancakeSubscriptionProduct(c *gin.Context) {
var req createWaffoPancakeSubscriptionProductRequest
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
}
if strings.TrimSpace(req.Name) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐名称不能为空"})
return
}
if strings.TrimSpace(req.Amount) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐价格不能为空"})
return
}
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
if merchantID == "" || privateKey == "" || storeID == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
return
}
productID, err := service.CreateWaffoPancakeProductForPlan(
c.Request.Context(),
merchantID,
privateKey,
storeID,
req.Name,
req.Amount,
setting.WaffoPancakeReturnURL,
)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 创建套餐产品失败 store_id=%q name=%q amount=%q error=%q",
storeID, req.Name, req.Amount, err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建套餐产品失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"product_id": productID,
"product_name": req.Name,
"store_id": storeID,
},
})
}
// ListWaffoPancakeSubscriptionProductOptions returns the OnetimeProducts
// in the saved Pancake store, for the subscription-plan dropdown. The name
// reflects new-api's plan concept; under the hood it's still OnetimeProducts.
func ListWaffoPancakeSubscriptionProductOptions(c *gin.Context) {
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
if merchantID == "" || privateKey == "" || storeID == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
return
}
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 拉取订阅产品列表失败 store_id=%q error=%q", storeID, err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取产品列表失败"})
return
}
products := []service.WaffoPancakeCatalogProduct{}
for _, store := range catalog.Stores {
if store.ID == storeID {
products = store.OnetimeProducts
break
}
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"store_id": storeID,
"products": products,
},
})
}
func getWaffoPancakeBuyerIdentity(user *model.User) string {
if user == nil {
return ""
}
return service.WaffoPancakeBuyerIdentityFromUserID(user.Id)
return paymentReturnPath("/console/topup?show_history=true")
}
func RequestWaffoPancakePay(c *gin.Context) {
if !isWaffoPancakeTopUpEnabled() {
if !setting.WaffoPancakeEnabled {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
return
}
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
if setting.WaffoPancakeSandbox {
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
}
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
strings.TrimSpace(currentWebhookKey) == "" ||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
return
}
@@ -399,15 +175,18 @@ func RequestWaffoPancakePay(c *gin.Context) {
expiresInSeconds := 45 * 60
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
ProductID: setting.WaffoPancakeProductID,
BuyerIdentity: getWaffoPancakeBuyerIdentity(user),
StoreID: setting.WaffoPancakeStoreID,
ProductID: setting.WaffoPancakeProductID,
ProductType: "onetime",
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
Amount: formatWaffoPancakeAmount(payMoney),
TaxIncluded: false,
TaxCategory: "saas",
},
BuyerEmail: getWaffoPancakeBuyerEmail(user),
ExpiresInSeconds: &expiresInSeconds,
OrderMerchantExternalID: tradeNo,
BuyerEmail: getWaffoPancakeBuyerEmail(user),
SuccessURL: getWaffoPancakeReturnURL(),
ExpiresInSeconds: &expiresInSeconds,
})
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
@@ -421,12 +200,10 @@ func RequestWaffoPancakePay(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
"token": session.Token,
"token_expires_at": session.TokenExpiresAt,
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
},
})
}
@@ -438,19 +215,6 @@ func WaffoPancakeWebhook(c *gin.Context) {
return
}
// :env splits test vs prod traffic at the routing layer — operator
// registers each URL in the matching webhook slot in Pancake's dashboard.
// We then enforce event.mode == expectedEnv to catch mis-registrations.
expectedEnv := strings.TrimSpace(c.Param("env"))
if expectedEnv != "test" && expectedEnv != "prod" {
logger.LogWarn(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 路径环境段无效 env=%q path=%q client_ip=%s",
expectedEnv, c.Request.RequestURI, c.ClientIP(),
))
c.String(http.StatusNotFound, "unknown env")
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
@@ -468,57 +232,15 @@ func WaffoPancakeWebhook(c *gin.Context) {
return
}
if !strings.EqualFold(strings.TrimSpace(event.Mode), expectedEnv) {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 环境不匹配 expected=%q actual_mode=%q event_id=%s order_id=%s client_ip=%s",
expectedEnv, event.Mode, event.ID, event.Data.OrderID, c.ClientIP(),
))
c.String(http.StatusOK, "OK")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
if event.NormalizedEventType() != "order.completed" {
c.String(http.StatusOK, "OK")
return
}
// Dispatch by trade_no prefix. OrderMerchantExternalID = our trade_no;
// OrderID is Pancake's internal ORD_* (logs only).
rawTradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
isSubscription := strings.HasPrefix(rawTradeNo, "WAFFO_PANCAKE_SUB-")
if isSubscription {
tradeNo, err := service.ResolveWaffoPancakeSubscriptionTradeNo(event)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 订阅订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
))
c.String(http.StatusOK, "OK")
return
}
LockOrder(tradeNo)
defer UnlockOrder(tradeNo)
if err := model.CompleteSubscriptionOrder(tradeNo, string(bodyBytes), model.PaymentProviderWaffoPancake, ""); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
c.String(http.StatusInternalServerError, "retry")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
c.String(http.StatusOK, "OK")
return
}
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
if err != nil {
// LogError (not LogWarn): covers order-not-found and buyer-identity
// mismatch — both warrant human attention. 200 OK so Waffo doesn't
// retry a permanently-unresolvable webhook.
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
))
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
c.String(http.StatusOK, "OK")
return
}
+1 -13
View File
@@ -251,20 +251,8 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
var role *int
if roleStr := c.Query("role"); roleStr != "" {
if parsed, err := strconv.Atoi(roleStr); err == nil {
role = &parsed
}
}
var status *int
if statusStr := c.Query("status"); statusStr != "" {
if parsed, err := strconv.Atoi(statusStr); err == nil {
status = &parsed
}
}
pageInfo := common.GetPageQuery(c)
users, total, err := model.SearchUsers(keyword, group, role, status, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
-1
View File
@@ -34,7 +34,6 @@ services:
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
- NODE_NAME=new-api-node-1 # 节点名称,用于审计日志中标识节点身份;多节点/容器部署时建议设置 (Node name used in audit logs; recommended when running multiple instances or in containers)
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions
# - RELAY_IDLE_CONN_TIMEOUT=90 # Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制 (Relay HTTP client idle keep-alive timeout in seconds, defaults to Go standard library; set 0 to disable)
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!!
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
# - GOOGLE_ANALYTICS_ID=G-XXXXXXXXXX # Google Analytics 的测量 ID (Google Analytics Measurement ID)
+6 -6
View File
@@ -26,11 +26,11 @@ type ImageRequest struct {
OutputFormat json.RawMessage `json:"output_format,omitempty"`
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
Stream *bool `json:"stream,omitempty"`
Images json.RawMessage `json:"images,omitempty"`
Mask json.RawMessage `json:"mask,omitempty"`
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
// Stream bool `json:"stream,omitempty"`
Images json.RawMessage `json:"images,omitempty"`
Mask json.RawMessage `json:"mask,omitempty"`
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
// zhipu 4v
WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
UserId json.RawMessage `json:"user_id,omitempty"`
@@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
}
func (i *ImageRequest) IsStream(c *gin.Context) bool {
return i.Stream != nil && *i.Stream
return false
}
func (i *ImageRequest) SetModelName(modelName string) {
+2 -12
View File
@@ -213,22 +213,12 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
return result
}
func IsOpenAIReasoningOModel(modelName string) bool {
return strings.HasPrefix(modelName, "o1") ||
strings.HasPrefix(modelName, "o3") ||
strings.HasPrefix(modelName, "o4")
}
func IsOpenAIGPT5Model(modelName string) bool {
return strings.HasPrefix(modelName, "gpt-5")
}
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
if IsOpenAIReasoningOModel(r.Model) {
if strings.HasPrefix(r.Model, "o") {
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
return "developer"
}
} else if IsOpenAIGPT5Model(r.Model) {
} else if strings.HasPrefix(r.Model, "gpt-5") {
return "developer"
}
return "system"
-24
View File
@@ -71,27 +71,3 @@ func TestOpenAIResponsesRequestPreserveExplicitZeroValues(t *testing.T) {
require.True(t, gjson.GetBytes(encoded, "stream").Exists())
require.True(t, gjson.GetBytes(encoded, "top_p").Exists())
}
func TestGeneralOpenAIRequestGetSystemRoleName(t *testing.T) {
tests := []struct {
name string
model string
want string
}{
{name: "o1 uses developer", model: "o1", want: "developer"},
{name: "o3 family uses developer", model: "o3-mini-high", want: "developer"},
{name: "o4 family uses developer", model: "o4-mini", want: "developer"},
{name: "o1 mini stays system", model: "o1-mini", want: "system"},
{name: "o1 preview stays system", model: "o1-preview", want: "system"},
{name: "gpt 5 uses developer", model: "gpt-5", want: "developer"},
{name: "omni is not o series", model: "omni-moderation-latest", want: "system"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := GeneralOpenAIRequest{Model: tt.model}
require.Equal(t, tt.want, req.GetSystemRoleName())
})
}
}
-2
View File
@@ -60,8 +60,6 @@ require (
gorm.io/gorm v1.25.2
)
require github.com/waffo-com/waffo-pancake-sdk-go v0.3.1
require (
github.com/DmitriyVTitov/size v1.5.0 // indirect
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
-6
View File
@@ -308,12 +308,6 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/waffo-com/waffo-go v1.3.1 h1:NCYD3oQ59DTJj1bwS5T/659LI4h8PuAIW4Qj/w7fKPw=
github.com/waffo-com/waffo-go v1.3.1/go.mod h1:IaXVYq6mmYtrLFFsLxPslNwuIZx0mIadWWjhe+eWb0g=
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1 h1:YOI7+3zTBlTB7Ou6+ZXnJV2JvW/ag9d7CwE/TxH3Hls=
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0 h1:cCSgccM66p7feTtgRqUUGT50tYQOhahsoPXavd+ib1U=
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1 h1:ngQSN/oVB35xTwFPLfg++bxPC+SptcF145Mb6c62YCc=
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1/go.mod h1:OB2MyFIQaefoPO0FV3J+yu9sDP8RVFQ+sbFsXqGuObc=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
+5 -32
View File
@@ -1,8 +1,6 @@
FRONTEND_DIR = ./web/default
FRONTEND_CLASSIC_DIR = ./web/classic
BACKEND_DIR = .
DEV_FRONTEND_DEFAULT_PORT ?= 5173
DEV_FRONTEND_CLASSIC_PORT ?= 5174
DEV_COMPOSE_FILE = docker-compose.dev.yml
DEV_POSTGRES_SERVICE = postgres
DEV_BACKEND_SERVICE = new-api
@@ -16,13 +14,11 @@ all: build-all-frontends start-backend
build-frontend:
@echo "Building default frontend..."
@cd ./web && bun install --frozen-lockfile
@cd $(FRONTEND_DIR) && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
build-frontend-classic:
@echo "Building classic frontend..."
@cd ./web && bun install --frozen-lockfile
@cd $(FRONTEND_CLASSIC_DIR) && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
@cd $(FRONTEND_CLASSIC_DIR) && bun install && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
build-all-frontends: build-frontend build-frontend-classic
@@ -39,35 +35,12 @@ dev-api-rebuild:
@docker compose -f $(DEV_COMPOSE_FILE) up -d --build $(DEV_BACKEND_SERVICE)
dev-web:
@echo "Starting both frontend dev servers..."
@echo "Default frontend: http://localhost:$(DEV_FRONTEND_DEFAULT_PORT)"
@echo "Classic frontend: http://localhost:$(DEV_FRONTEND_CLASSIC_PORT)"
@cd ./web && bun install
@(cd $(FRONTEND_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_DEFAULT_PORT)) & \
default_pid=$$!; \
(cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)) & \
classic_pid=$$!; \
trap 'kill $$default_pid $$classic_pid 2>/dev/null; wait $$default_pid $$classic_pid 2>/dev/null; exit 130' INT TERM; \
while kill -0 $$default_pid 2>/dev/null && kill -0 $$classic_pid 2>/dev/null; do \
sleep 1; \
done; \
if ! kill -0 $$default_pid 2>/dev/null; then \
wait $$default_pid; \
status=$$?; \
kill $$classic_pid 2>/dev/null; \
wait $$classic_pid 2>/dev/null; \
exit $$status; \
fi; \
wait $$classic_pid; \
status=$$?; \
kill $$default_pid 2>/dev/null; \
wait $$default_pid 2>/dev/null; \
exit $$status
@echo "Starting frontend dev server..."
@cd $(FRONTEND_DIR) && bun install && bun run dev
dev-web-classic:
@echo "Starting classic frontend dev server..."
@cd ./web && bun install
@cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)
@cd $(FRONTEND_CLASSIC_DIR) && bun install && bun run dev
dev: dev-api dev-web
+7 -89
View File
@@ -3,7 +3,6 @@ package middleware
import (
"errors"
"fmt"
"io"
"net/http"
"slices"
"strconv"
@@ -21,7 +20,6 @@ import (
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type ModelRequest struct {
@@ -102,10 +100,14 @@ func Distribute() func(c *gin.Context) {
}
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
affinityUsable := false
preferred, err := model.CacheGetChannel(preferredChannelID)
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
if usingGroup == "auto" {
if err == nil && preferred != nil {
if preferred.Status != common.ChannelStatusEnabled {
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorAffinityChannelDisabled))
return
}
} else if usingGroup == "auto" {
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
autoGroups := service.GetUserAutoGroup(userGroup)
for _, g := range autoGroups {
@@ -113,7 +115,6 @@ func Distribute() func(c *gin.Context) {
selectGroup = g
common.SetContextKey(c, constant.ContextKeyAutoGroup, g)
channel = preferred
affinityUsable = true
service.MarkChannelAffinityUsed(c, g, preferred.Id)
break
}
@@ -121,13 +122,9 @@ func Distribute() func(c *gin.Context) {
} else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) {
channel = preferred
selectGroup = usingGroup
affinityUsable = true
service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id)
}
}
if !affinityUsable && !service.ShouldKeepChannelAffinityOnChannelDisabled() {
service.ClearCurrentChannelAffinityCache(c)
}
}
if channel == nil {
@@ -173,14 +170,6 @@ func Distribute() func(c *gin.Context) {
// - application/x-www-form-urlencoded
// - multipart/form-data
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") {
modelRequest, err := getModelFromJSONBody(c)
if err != nil {
return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()}))
}
return modelRequest, nil
}
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
@@ -189,50 +178,6 @@ func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
return &modelRequest, nil
}
func getModelFromJSONBody(c *gin.Context) (*ModelRequest, error) {
storage, err := common.GetBodyStorage(c)
if err != nil {
return nil, err
}
requestBody, err := storage.Bytes()
if err != nil {
return nil, err
}
if !gjson.ValidBytes(requestBody) {
return nil, errors.New("invalid JSON request body")
}
values := gjson.GetManyBytes(requestBody, "model", "group")
model, err := getJSONStringValue(values[0], "model")
if err != nil {
return nil, err
}
group, err := getJSONStringValue(values[1], "group")
if err != nil {
return nil, err
}
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return nil, seekErr
}
c.Request.Body = io.NopCloser(storage)
return &ModelRequest{
Model: model,
Group: group,
}, nil
}
func getJSONStringValue(result gjson.Result, field string) (string, error) {
if !result.Exists() || result.Type == gjson.Null {
return "", nil
}
if result.Type != gjson.String {
return "", fmt.Errorf("field %s must be a string", field)
}
return result.String(), nil
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
@@ -299,7 +244,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
shouldSelectChannel = false
modelRequest.Model = getTaskOriginModelName(c)
}
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
@@ -314,7 +258,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
shouldSelectChannel = false
modelRequest.Model = getTaskOriginModelName(c)
}
if _, ok := c.Get("relay_mode"); !ok {
c.Set("relay_mode", relayMode)
@@ -399,31 +342,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
return &modelRequest, shouldSelectChannel, nil
}
// 修复 #4834: GET /v1/video/generations/:task_id && /v1/video/:task_id 此前不解析 model
// 当 token 启用「可用模型限制」时,下游 modelLimitEnable 校验会因
// modelRequest.Model 为空而误报 "This token has no access to model"。
// 从已存储的任务记录中回填 OriginModelName 即可让校验走在正确的模型上。
func getTaskOriginModelName(c *gin.Context) string {
if !common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) {
return ""
}
taskId := c.Param("task_id")
if taskId == "" {
// jimeng adapter
taskId = c.GetString("task_id")
}
if taskId == "" {
return ""
}
userId := c.GetInt("id")
if task, exist, err := model.GetByTaskId(userId, taskId); err == nil && exist && task != nil {
return task.Properties.OriginModelName
}
return ""
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
c.Set("original_model", modelName) // for retry
if channel == nil {
-47
View File
@@ -1,47 +0,0 @@
package middleware
import (
"bytes"
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/gin-gonic/gin"
)
func AnonymousRequestBodyLimit() gin.HandlerFunc {
return func(c *gin.Context) {
maxBytes := common.GetAnonymousRequestBodyLimitBytes()
if maxBytes <= 0 || c.Request.Body == nil {
c.Next()
return
}
originalBody := c.Request.Body
limitedBody, err := readAnonymousRequestBody(originalBody, maxBytes)
_ = originalBody.Close()
if err != nil {
if common.IsRequestBodyTooLargeError(err) {
c.AbortWithStatus(http.StatusRequestEntityTooLarge)
return
}
c.AbortWithStatus(http.StatusBadRequest)
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(limitedBody))
c.Request.ContentLength = int64(len(limitedBody))
c.Next()
}
}
func readAnonymousRequestBody(body io.Reader, maxBytes int64) ([]byte, error) {
data, err := io.ReadAll(io.LimitReader(body, maxBytes+1))
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, common.ErrRequestBodyTooLarge
}
return data, nil
}
+49 -65
View File
@@ -17,39 +17,25 @@ import (
"gorm.io/gorm"
)
func applyExplicitLogTextFilter(tx *gorm.DB, column string, value string) (*gorm.DB, error) {
if value == "" {
return tx, nil
}
if strings.Contains(value, "%") {
pattern, err := sanitizeLikePattern(value)
if err != nil {
return nil, err
}
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern), nil
}
return tx.Where(column+" = ?", value), nil
}
type Log struct {
Id int `json:"id" gorm:"index:idx_created_at_id,priority:2;index:idx_user_id_id,priority:2"`
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:1;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
IsStream bool `json:"is_stream"`
ChannelId int `json:"channel" gorm:"index"`
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Ip string `json:"ip" gorm:"index;default:''"`
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
IsStream bool `json:"is_stream"`
ChannelId int `json:"channel" gorm:"index"`
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Ip string `json:"ip" gorm:"index;default:''"`
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
UpstreamRequestId string `json:"upstream_request_id,omitempty" gorm:"type:varchar(128);index:idx_logs_upstream_request_id;default:''"`
Other string `json:"other"`
@@ -160,7 +146,7 @@ func RecordTopupLog(userId int, content string, callerIp string, paymentMethod s
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, common.LocalLogPreview(content)))
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
requestId := c.GetString(common.RequestIdKey)
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
@@ -323,15 +309,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = LOG_DB.Where("logs.type = ?", logType)
}
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
return nil, 0, err
}
if tx, err = applyExplicitLogTextFilter(tx, "logs.username", username); err != nil {
return nil, 0, err
}
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
}
tx = applyLogContainsFilter(tx, "logs.model_name", modelName)
tx = applyLogContainsFilter(tx, "logs.username", username)
tx = applyLogContainsFilter(tx, "logs.token_name", tokenName)
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
@@ -354,7 +334,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if err != nil {
return nil, 0, err
}
err = tx.Order("logs.created_at desc, logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
@@ -412,12 +392,8 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
}
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
return nil, 0, err
}
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
}
tx = applyLogContainsFilter(tx, "logs.model_name", modelName)
tx = applyLogContainsFilter(tx, "logs.token_name", tokenName)
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
@@ -454,34 +430,42 @@ type Stat struct {
Tpm int `json:"tpm"`
}
func logContainsPattern(input string) (string, bool) {
input = strings.TrimSpace(input)
if input == "" {
return "", false
}
replacer := strings.NewReplacer("!", "!!", "%", "!%", "_", "!_")
return "%" + replacer.Replace(input) + "%", true
}
func applyLogContainsFilter(tx *gorm.DB, column string, value string) *gorm.DB {
pattern, ok := logContainsPattern(value)
if !ok {
return tx
}
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern)
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if tx, err = applyExplicitLogTextFilter(tx, "username", username); err != nil {
return stat, err
}
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "username", username); err != nil {
return stat, err
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
}
tx = applyLogContainsFilter(tx, "username", username)
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "username", username)
tx = applyLogContainsFilter(tx, "token_name", tokenName)
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "token_name", tokenName)
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if tx, err = applyExplicitLogTextFilter(tx, "model_name", modelName); err != nil {
return stat, err
}
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "model_name", modelName); err != nil {
return stat, err
}
tx = applyLogContainsFilter(tx, "model_name", modelName)
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "model_name", modelName)
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
-4
View File
@@ -397,10 +397,8 @@ func ensureSubscriptionPlanTableSQLite() error {
` + "`custom_seconds`" + ` bigint NOT NULL DEFAULT 0,
` + "`enabled`" + ` numeric DEFAULT 1,
` + "`sort_order`" + ` integer DEFAULT 0,
` + "`allow_balance_pay`" + ` numeric DEFAULT 1,
` + "`stripe_price_id`" + ` varchar(128) DEFAULT '',
` + "`creem_product_id`" + ` varchar(128) DEFAULT '',
` + "`waffo_pancake_product_id`" + ` varchar(128) DEFAULT '',
` + "`max_purchase_per_user`" + ` integer DEFAULT 0,
` + "`upgrade_group`" + ` varchar(64) DEFAULT '',
` + "`total_amount`" + ` bigint NOT NULL DEFAULT 0,
@@ -432,10 +430,8 @@ PRIMARY KEY (` + "`id`" + `)
{Name: "custom_seconds", DDL: "`custom_seconds` bigint NOT NULL DEFAULT 0"},
{Name: "enabled", DDL: "`enabled` numeric DEFAULT 1"},
{Name: "sort_order", DDL: "`sort_order` integer DEFAULT 0"},
{Name: "allow_balance_pay", DDL: "`allow_balance_pay` numeric DEFAULT 1"},
{Name: "stripe_price_id", DDL: "`stripe_price_id` varchar(128) DEFAULT ''"},
{Name: "creem_product_id", DDL: "`creem_product_id` varchar(128) DEFAULT ''"},
{Name: "waffo_pancake_product_id", DDL: "`waffo_pancake_product_id` varchar(128) DEFAULT ''"},
{Name: "max_purchase_per_user", DDL: "`max_purchase_per_user` integer DEFAULT 0"},
{Name: "upgrade_group", DDL: "`upgrade_group` varchar(64) DEFAULT ''"},
{Name: "total_amount", DDL: "`total_amount` bigint NOT NULL DEFAULT 0"},
-57
View File
@@ -2,7 +2,6 @@ package model
import (
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -136,62 +135,6 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
return result, nil
}
func normalizeLookupValues(values []string) []string {
seen := make(map[string]struct{}, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
normalized = append(normalized, value)
}
return normalized
}
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
result := make(map[string]int)
modelNames = normalizeLookupValues(modelNames)
if len(modelNames) == 0 {
return result, nil
}
type row struct {
Model string
ChannelType int
}
var rows []row
query := DB.Table("abilities").
Select("abilities.model as model, channels.type as channel_type").
Joins("JOIN channels ON abilities.channel_id = channels.id").
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
Order("COALESCE(abilities.priority, 0) DESC").
Order("abilities.weight DESC").
Order("abilities.channel_id ASC")
groups = normalizeLookupValues(groups)
if len(groups) > 0 {
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
}
if err := query.Scan(&rows).Error; err != nil {
return nil, err
}
for _, r := range rows {
if _, ok := result[r.Model]; ok {
continue
}
result[r.Model] = r.ChannelType
}
return result, nil
}
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var models []*Model
db := DB.Model(&Model{})
-141
View File
@@ -1,141 +0,0 @@
package model
import (
"fmt"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/stretchr/testify/require"
)
func clearPreferredOwnerTables(t *testing.T) {
t.Helper()
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
}
func insertPreferredOwnerCandidate(
t *testing.T,
channelID int,
modelName string,
group string,
channelType int,
priority int64,
weight uint,
channelStatus int,
abilityEnabled bool,
) {
t.Helper()
require.NoError(t, DB.Create(&Channel{
Id: channelID,
Type: channelType,
Key: fmt.Sprintf("key-%d", channelID),
Status: channelStatus,
Name: fmt.Sprintf("channel-%d", channelID),
}).Error)
require.NoError(t, DB.Create(&Ability{
Group: group,
Model: modelName,
ChannelId: channelID,
Enabled: abilityEnabled,
Priority: &priority,
Weight: weight,
}).Error)
}
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
const modelName = "gpt-5.4"
tests := []struct {
name string
setup func(t *testing.T)
groups []string
expected int
found bool
}{
{
name: "openai only",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "codex only",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "priority wins",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "weight wins when priority is equal",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "channel id stabilizes exact ties",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "group filter excludes other groups",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "disabled candidates are ignored",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
},
groups: []string{"default"},
found: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clearPreferredOwnerTables(t)
tt.setup(t)
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
require.NoError(t, err)
got, ok := owners[modelName]
require.Equal(t, tt.found, ok)
if tt.found {
require.Equal(t, tt.expected, got)
}
})
}
}
+20 -39
View File
@@ -12,7 +12,6 @@ import (
"github.com/QuantumNous/new-api/setting/performance_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"gorm.io/gorm"
)
type Option struct {
@@ -107,13 +106,18 @@ func InitOptionMap() {
common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp)
common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString()
common.OptionMap["WaffoPancakeEnabled"] = strconv.FormatBool(setting.WaffoPancakeEnabled)
common.OptionMap["WaffoPancakeSandbox"] = strconv.FormatBool(setting.WaffoPancakeSandbox)
common.OptionMap["WaffoPancakeMerchantID"] = setting.WaffoPancakeMerchantID
common.OptionMap["WaffoPancakePrivateKey"] = setting.WaffoPancakePrivateKey
common.OptionMap["WaffoPancakeReturnURL"] = setting.WaffoPancakeReturnURL
common.OptionMap["WaffoPancakeUnitPrice"] = strconv.FormatFloat(setting.WaffoPancakeUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoPancakeMinTopUp"] = strconv.Itoa(setting.WaffoPancakeMinTopUp)
common.OptionMap["WaffoPancakeWebhookPublicKey"] = setting.WaffoPancakeWebhookPublicKey
common.OptionMap["WaffoPancakeWebhookTestKey"] = setting.WaffoPancakeWebhookTestKey
common.OptionMap["WaffoPancakeStoreID"] = setting.WaffoPancakeStoreID
common.OptionMap["WaffoPancakeProductID"] = setting.WaffoPancakeProductID
common.OptionMap["WaffoPancakeReturnURL"] = setting.WaffoPancakeReturnURL
common.OptionMap["WaffoPancakeCurrency"] = setting.WaffoPancakeCurrency
common.OptionMap["WaffoPancakeUnitPrice"] = strconv.FormatFloat(setting.WaffoPancakeUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoPancakeMinTopUp"] = strconv.Itoa(setting.WaffoPancakeMinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
@@ -218,39 +222,6 @@ func UpdateOption(key string, value string) error {
return updateOptionMap(key, value)
}
// UpdateOptionsBulk persists multiple key/value pairs in a single database
// transaction, then dispatches them through updateOptionMap in one pass. If
// any DB write fails the whole transaction rolls back and no in-memory state
// is touched — safe for callers that must commit a set of related options
// atomically (e.g. payment gateway binding).
func UpdateOptionsBulk(values map[string]string) error {
if len(values) == 0 {
return nil
}
err := DB.Transaction(func(tx *gorm.DB) error {
for k, v := range values {
option := Option{Key: k}
if err := tx.FirstOrCreate(&option, Option{Key: k}).Error; err != nil {
return err
}
option.Value = v
if err := tx.Save(&option).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
for k, v := range values {
if err := updateOptionMap(k, v); err != nil {
return err
}
}
return nil
}
func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
@@ -448,16 +419,26 @@ func updateOptionMap(key string, value string) (err error) {
setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64)
case "WaffoMinTopUp":
setting.WaffoMinTopUp, _ = strconv.Atoi(value)
case "WaffoPancakeEnabled":
setting.WaffoPancakeEnabled = value == "true"
case "WaffoPancakeSandbox":
setting.WaffoPancakeSandbox = value == "true"
case "WaffoPancakeMerchantID":
setting.WaffoPancakeMerchantID = value
case "WaffoPancakePrivateKey":
setting.WaffoPancakePrivateKey = value
case "WaffoPancakeReturnURL":
setting.WaffoPancakeReturnURL = value
case "WaffoPancakeWebhookPublicKey":
setting.WaffoPancakeWebhookPublicKey = value
case "WaffoPancakeWebhookTestKey":
setting.WaffoPancakeWebhookTestKey = value
case "WaffoPancakeStoreID":
setting.WaffoPancakeStoreID = value
case "WaffoPancakeProductID":
setting.WaffoPancakeProductID = value
case "WaffoPancakeReturnURL":
setting.WaffoPancakeReturnURL = value
case "WaffoPancakeCurrency":
setting.WaffoPancakeCurrency = value
case "WaffoPancakeUnitPrice":
setting.WaffoPancakeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "WaffoPancakeMinTopUp":
+2 -117
View File
@@ -11,7 +11,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/samber/hot"
"github.com/shopspring/decimal"
"gorm.io/gorm"
)
@@ -160,11 +159,8 @@ type SubscriptionPlan struct {
Enabled bool `json:"enabled" gorm:"default:true"`
SortOrder int `json:"sort_order" gorm:"type:int;default:0"`
AllowBalancePay *bool `json:"allow_balance_pay" gorm:"default:true"`
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
WaffoPancakeProductId string `json:"waffo_pancake_product_id" gorm:"type:varchar(128);default:''"`
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
// Max purchases per user (0 = unlimited)
MaxPurchasePerUser int `json:"max_purchase_per_user" gorm:"type:int;default:0"`
@@ -195,12 +191,6 @@ func (p *SubscriptionPlan) BeforeUpdate(tx *gorm.DB) error {
return nil
}
func (p *SubscriptionPlan) NormalizeDefaults() {
if p.AllowBalancePay == nil {
p.AllowBalancePay = common.GetPointer(true)
}
}
// Subscription order (payment -> webhook -> create UserSubscription)
type SubscriptionOrder struct {
Id int `json:"id"`
@@ -368,7 +358,6 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
key := subscriptionPlanCacheKey(id)
if key != "" {
if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
cached.NormalizeDefaults()
return &cached, nil
}
}
@@ -380,7 +369,6 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
return nil, err
}
plan.NormalizeDefaults()
_ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
return &plan, nil
}
@@ -676,109 +664,6 @@ func AdminBindSubscription(userId int, planId int, sourceNote string) (string, e
return "", nil
}
func calcSubscriptionBalanceQuota(priceAmount float64) (int, error) {
if priceAmount <= 0 {
return 0, nil
}
if common.QuotaPerUnit <= 0 {
return 0, errors.New("额度单位配置错误")
}
quota := decimal.NewFromFloat(priceAmount).
Mul(decimal.NewFromFloat(common.QuotaPerUnit)).
Ceil().
IntPart()
return int(quota), nil
}
// PurchaseSubscriptionWithBalance creates a subscription by deducting the user's wallet quota.
func PurchaseSubscriptionWithBalance(userId int, planId int) error {
if userId <= 0 || planId <= 0 {
return errors.New("invalid userId or planId")
}
var logPlanTitle string
var logMoney float64
var chargedQuota int
var upgradeGroup string
err := DB.Transaction(func(tx *gorm.DB) error {
plan, err := getSubscriptionPlanByIdTx(tx, planId)
if err != nil {
return err
}
if !plan.Enabled {
return errors.New("套餐未启用")
}
if plan.PriceAmount < 0 {
return errors.New("套餐价格不能为负数")
}
if plan.AllowBalancePay != nil && !*plan.AllowBalancePay {
return errors.New("该套餐不允许使用余额兑换")
}
requiredQuota, err := calcSubscriptionBalanceQuota(plan.PriceAmount)
if err != nil {
return err
}
var user User
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where("id = ?", userId).First(&user).Error; err != nil {
return err
}
if requiredQuota > 0 && user.Quota < requiredQuota {
return errors.New("余额不足")
}
if requiredQuota > 0 {
if err := tx.Model(&User{}).Where("id = ?", userId).
Update("quota", gorm.Expr("quota - ?", requiredQuota)).Error; err != nil {
return err
}
}
if _, err := CreateUserSubscriptionFromPlanTx(tx, userId, plan, PaymentMethodBalance); err != nil {
return err
}
now := common.GetTimestamp()
tradeNo := fmt.Sprintf("SUBBALUSR%dNO%s%d", userId, common.GetRandomString(6), time.Now().UnixNano())
order := &SubscriptionOrder{
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: tradeNo,
PaymentMethod: PaymentMethodBalance,
PaymentProvider: PaymentProviderBalance,
Status: common.TopUpStatusSuccess,
CreateTime: now,
CompleteTime: now,
ProviderPayload: fmt.Sprintf("charged_quota=%d", requiredQuota),
}
if err := tx.Create(order).Error; err != nil {
return err
}
logPlanTitle = plan.Title
logMoney = plan.PriceAmount
chargedQuota = requiredQuota
upgradeGroup = strings.TrimSpace(plan.UpgradeGroup)
return nil
})
if err != nil {
return err
}
if chargedQuota > 0 {
if err := cacheDecrUserQuota(userId, int64(chargedQuota)); err != nil {
common.SysLog("failed to decrease user quota cache after subscription balance purchase: " + err.Error())
}
}
if upgradeGroup != "" {
_ = UpdateUserGroupCache(userId, upgradeGroup)
}
msg := fmt.Sprintf("使用余额购买订阅成功,套餐: %s,支付金额: %.2f,扣除额度: %d", logPlanTitle, logMoney, chargedQuota)
RecordLog(userId, LogTypeTopup, msg)
return nil
}
// GetAllActiveUserSubscriptions returns all active subscriptions for a user.
func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
if userId <= 0 {
-2
View File
@@ -40,7 +40,6 @@ func TestMain(m *testing.M) {
&Token{},
&Log{},
&Channel{},
&Ability{},
&TopUp{},
&SubscriptionPlan{},
&SubscriptionOrder{},
@@ -61,7 +60,6 @@ func truncateTables(t *testing.T) {
DB.Exec("DELETE FROM tokens")
DB.Exec("DELETE FROM logs")
DB.Exec("DELETE FROM channels")
DB.Exec("DELETE FROM abilities")
DB.Exec("DELETE FROM top_ups")
DB.Exec("DELETE FROM subscription_orders")
DB.Exec("DELETE FROM subscription_plans")
-2
View File
@@ -29,7 +29,6 @@ const (
PaymentMethodCreem = "creem"
PaymentMethodWaffo = "waffo"
PaymentMethodWaffoPancake = "waffo_pancake"
PaymentMethodBalance = "balance"
)
const (
@@ -38,7 +37,6 @@ const (
PaymentProviderCreem = "creem"
PaymentProviderWaffo = "waffo"
PaymentProviderWaffoPancake = "waffo_pancake"
PaymentProviderBalance = "balance"
)
var (
+17 -31
View File
@@ -225,7 +225,7 @@ func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err err
return users, total, nil
}
func SearchUsers(keyword string, group string, role *int, status *int, startIdx int, num int) ([]*User, int64, error) {
func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
var users []*User
var total int64
var err error
@@ -246,25 +246,28 @@ func SearchUsers(keyword string, group string, role *int, status *int, startIdx
// 构建搜索条件
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
likeArgs := []interface{}{"%" + keyword + "%", "%" + keyword + "%", "%" + keyword + "%"}
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果是数字,同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
likeArgs = append([]interface{}{keywordInt}, likeArgs...)
}
query = query.Where("("+likeCondition+")", likeArgs...)
if group != "" {
query = query.Where(commonGroupCol+" = ?", group)
}
if role != nil {
query = query.Where("role = ?", *role)
}
if status != nil {
query = query.Where("status = ?", *status)
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
} else {
// 非数字关键字,只搜索字符串字段
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
}
// 获取总数
@@ -984,23 +987,6 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
//}
}
func updateUserQuotaUsedQuotaAndRequestCount(id int, quota int, usedQuota int, requestCount int) {
if quota == 0 && usedQuota == 0 && requestCount == 0 {
return
}
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"quota": gorm.Expr("quota + ?", quota),
"used_quota": gorm.Expr("used_quota + ?", usedQuota),
"request_count": gorm.Expr("request_count + ?", requestCount),
},
).Error
if err != nil {
common.SysLog("failed to batch update user quota, used quota and request count: " + err.Error())
}
}
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
+11 -26
View File
@@ -67,48 +67,33 @@ func batchUpdate() {
}
common.SysLog("batch update started")
stores := make([]map[int]int, BatchUpdateTypeCount)
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
stores[i] = batchUpdateStores[i]
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
}
for i, store := range stores {
if i == BatchUpdateTypeUserQuota || i == BatchUpdateTypeUsedQuota || i == BatchUpdateTypeRequestCount {
continue
}
// TODO: maybe we can combine updates with same key?
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysLog("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysLog("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
userQuotaStore := stores[BatchUpdateTypeUserQuota]
usedQuotaStore := stores[BatchUpdateTypeUsedQuota]
requestCountStore := stores[BatchUpdateTypeRequestCount]
userIDs := make(map[int]struct{}, len(userQuotaStore)+len(usedQuotaStore)+len(requestCountStore))
for key := range userQuotaStore {
userIDs[key] = struct{}{}
}
for key := range usedQuotaStore {
userIDs[key] = struct{}{}
}
for key := range requestCountStore {
userIDs[key] = struct{}{}
}
for key := range userIDs {
updateUserQuotaUsedQuotaAndRequestCount(key, userQuotaStore[key], usedQuotaStore[key], requestCountStore[key])
}
common.SysLog("batch update finished")
}
-20
View File
@@ -25,23 +25,6 @@ import (
"github.com/gorilla/websocket"
)
// applyUpstreamContentLength populates req.ContentLength when the upstream
// body is wrapped in a BodyStorage (see relay/common/outbound_body.go).
//
// net/http.NewRequest only auto-detects ContentLength for *bytes.Reader,
// *bytes.Buffer and *strings.Reader. When the body is a type-erased io.Reader
// (which is the case for ReaderOnly(BodyStorage)), the Content-Length header
// would otherwise be omitted, forcing chunked transfer encoding and breaking
// some upstreams that require an explicit Content-Length.
func applyUpstreamContentLength(req *http.Request, info *common.RelayInfo) {
if info == nil {
return
}
if info.UpstreamRequestBodySize > 0 && req.ContentLength <= 0 {
req.ContentLength = info.UpstreamRequestBodySize
}
}
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data
@@ -314,7 +297,6 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
headers := req.Header
err = a.SetupRequestHeader(c, &headers, info)
if err != nil {
@@ -344,7 +326,6 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
headers := req.Header
@@ -541,7 +522,6 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
-6
View File
@@ -19,7 +19,6 @@ var awsModelIDMap = map[string]string{
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
"claude-opus-4-7": "anthropic.claude-opus-4-7",
"claude-opus-4-8": "anthropic.claude-opus-4-8",
// Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
@@ -98,11 +97,6 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"ap": true,
"eu": true,
},
"anthropic.claude-opus-4-8": {
"us": true,
"ap": true,
"eu": true,
},
"anthropic.claude-haiku-4-5-20251001-v1:0": {
"us": true,
"ap": true,
-7
View File
@@ -33,13 +33,6 @@ var ModelList = []string{
"claude-opus-4-7-medium",
"claude-opus-4-7-low",
"claude-opus-4-7-thinking",
"claude-opus-4-8",
"claude-opus-4-8-max",
"claude-opus-4-8-xhigh",
"claude-opus-4-8-high",
"claude-opus-4-8-medium",
"claude-opus-4-8-low",
"claude-opus-4-8-thinking",
}
var ChannelName = "claude"
+9 -10
View File
@@ -154,17 +154,14 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
}
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") ||
strings.HasPrefix(textRequest.Model, "claude-opus-4-7") ||
strings.HasPrefix(textRequest.Model, "claude-opus-4-8")) {
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") || strings.HasPrefix(textRequest.Model, "claude-opus-4-7")) {
claudeRequest.Model = baseModel
claudeRequest.Thinking = &dto.Thinking{
Type: "adaptive",
}
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
if strings.HasPrefix(baseModel, "claude-opus-4-7") ||
strings.HasPrefix(baseModel, "claude-opus-4-8") {
// Opus 4.7/4.8 reject non-default temperature/top_p/top_k with 400
if strings.HasPrefix(baseModel, "claude-opus-4-7") {
// Opus 4.7 rejects non-default temperature/top_p/top_k with 400
// and defaults display to "omitted"; restore the 4.6 visible summary.
claudeRequest.Thinking.Display = "summarized"
claudeRequest.Temperature = nil
@@ -178,9 +175,8 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
strings.HasSuffix(textRequest.Model, "-thinking") {
trimmedModel := strings.TrimSuffix(textRequest.Model, "-thinking")
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") ||
strings.HasPrefix(trimmedModel, "claude-opus-4-8") {
// Opus 4.7/4.8 reject thinking.type="enabled"; use adaptive at high effort.
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") {
// Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
claudeRequest.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
claudeRequest.OutputConfig = json.RawMessage(`{"effort":"high"}`)
claudeRequest.Temperature = nil
@@ -446,7 +442,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
tools := make([]dto.ToolCallResponse, 0)
fcIdx := 0
if claudeResponse.Index != nil {
fcIdx = *claudeResponse.Index
fcIdx = *claudeResponse.Index - 1
if fcIdx < 0 {
fcIdx = 0
}
}
var choice dto.ChatCompletionsStreamResponseChoice
if claudeResponse.Type == "message_start" {
-56
View File
@@ -9,10 +9,6 @@ import (
"github.com/stretchr/testify/require"
)
func commonPointer[T any](value T) *T {
return &value
}
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
claudeInfo := &ClaudeResponseInfo{
Usage: &dto.Usage{},
@@ -314,58 +310,6 @@ func TestRequestOpenAI2ClaudeMessage_IgnoresUnsupportedFileContent(t *testing.T)
require.Equal(t, "see attachment", *content[0].Text)
}
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48HighUsesAdaptiveThinking(t *testing.T) {
request := dto.GeneralOpenAIRequest{
Model: "claude-opus-4-8-high",
Temperature: commonPointer(0.7),
TopP: commonPointer(0.9),
TopK: commonPointer(40),
Messages: []dto.Message{
{
Role: "user",
Content: "hello",
},
},
}
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
require.NoError(t, err)
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
require.NotNil(t, claudeRequest.Thinking)
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
require.Nil(t, claudeRequest.Temperature)
require.Nil(t, claudeRequest.TopP)
require.Nil(t, claudeRequest.TopK)
}
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48ThinkingUsesAdaptiveHighEffort(t *testing.T) {
request := dto.GeneralOpenAIRequest{
Model: "claude-opus-4-8-thinking",
Temperature: commonPointer(0.7),
TopP: commonPointer(0.9),
TopK: commonPointer(40),
Messages: []dto.Message{
{
Role: "user",
Content: "hello",
},
},
}
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
require.NoError(t, err)
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
require.NotNil(t, claudeRequest.Thinking)
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
require.Nil(t, claudeRequest.Temperature)
require.Nil(t, claudeRequest.TopP)
require.Nil(t, claudeRequest.TopK)
}
func TestRequestOpenAI2ClaudeMessage_SupportsPDFFileContent(t *testing.T) {
request := dto.GeneralOpenAIRequest{
Model: "claude-3-5-sonnet",
+1 -1
View File
@@ -30,7 +30,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
}
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
+2 -4
View File
@@ -1,6 +1,7 @@
package cohere
import (
"bufio"
"encoding/json"
"io"
"net/http"
@@ -85,7 +86,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
@@ -105,9 +106,6 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
data := scanner.Text()
dataChan <- data
}
if err := scanner.Err(); err != nil {
common.SysLog("error reading stream: " + err.Error())
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
+1 -1
View File
@@ -98,7 +98,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
}
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
+3 -8
View File
@@ -159,14 +159,9 @@ func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto
media := mediaContent.GetImageMedia()
var file *DifyFile
if media.IsRemoteImage() {
// 修复 #2083: 远程图片分支此前未初始化 file,
// 导致 file.Type = ... 触发 nil pointer dereference
// 而 panic500: "invalid memory address or nil pointer dereference")。
file = &DifyFile{
Type: media.MimeType,
TransferMode: "remote_url",
URL: media.Url,
}
file.Type = media.MimeType
file.TransferMode = "remote_url"
file.URL = media.Url
} else {
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
}
+18 -103
View File
@@ -1079,47 +1079,17 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
FinishReason: constant.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
// 使用 strings.Builder 直接累积最终 content,避免:
// 1) 每张 inline image 生成一次中间 "![image](...)" 字符串
// 2) 末尾 strings.Join 再分配一份等大缓冲
// Gemini 图片返回时 InlineData.Data 可能是数 MB 的 base64
// 上述两份临时分配在高并发下会显著放大堆驻留。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
var texts []string
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
// 媒体内容
if strings.HasPrefix(part.InlineData.MimeType, "image") {
writeSep()
content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
} else {
// 其他媒体类型,直接显示链接
writeSep()
content.WriteString("[media](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
}
} else if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
@@ -1130,22 +1100,13 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.ReasoningContent = &part.Text
} else {
if part.ExecutableCode != nil {
writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```")
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
} else if part.CodeExecutionResult != nil {
writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```")
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
writeSep()
content.WriteString(part.Text)
texts = append(texts, part.Text)
}
}
}
@@ -1154,7 +1115,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(content.String())
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
if candidate.FinishReason != nil {
@@ -1208,25 +1169,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
//Role: "assistant",
},
}
// 使用 strings.Builder 直接累积 delta content,避免每张 image / 每个
// 文本片段都先 `+` 拼出一份临时 string,再 strings.Join 再拷贝一遍。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
var texts []string
isTools := false
isThought := false
if candidate.FinishReason != nil {
@@ -1264,12 +1207,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
writeSep()
content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
}
} else if part.FunctionCall != nil {
isTools = true
@@ -1280,33 +1219,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
} else if part.Thought {
isThought = true
writeSep()
content.WriteString(part.Text)
texts = append(texts, part.Text)
} else {
if part.ExecutableCode != nil {
writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```\n")
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
} else if part.CodeExecutionResult != nil {
writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```\n")
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
} else {
if part.Text != "\n" {
writeSep()
content.WriteString(part.Text)
texts = append(texts, part.Text)
}
}
}
}
if isThought {
choice.Delta.SetReasoningContent(content.String())
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
} else {
choice.Delta.SetContentString(content.String())
choice.Delta.SetContentString(strings.Join(texts, "\n"))
}
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
@@ -1410,14 +1339,6 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
if response.IsToolCall() {
finishReason = constant.FinishReasonToolCalls
if info.RelayFormat == types.RelayFormatClaude {
for choiceIdx := range response.Choices {
response.Choices[choiceIdx].FinishReason = nil
}
}
}
for choiceIdx := range response.Choices {
choiceKey := response.Choices[choiceIdx].Index
for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
@@ -1478,9 +1399,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
logger.LogError(c, err.Error())
}
if isStop {
if info.RelayFormat != types.RelayFormatClaude {
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
}
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
}
return true
})
@@ -1490,10 +1409,6 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
}
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil && !info.ClaudeConvertInfo.Done {
response = helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)
response.Usage = usage
}
handleErr := handleFinalStream(c, info, response)
if handleErr != nil {
common.SysLog("send final response failed: " + handleErr.Error())
-16
View File
@@ -5,9 +5,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
channelconstant "github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
@@ -81,23 +79,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request.Temperature != nil && isTemperatureOneOnlyModel(getUpstreamModelName(info, request.Model)) && *request.Temperature != 1.0 {
request.Temperature = common.GetPointer[float64](1.0)
}
return request, nil
}
func getUpstreamModelName(info *relaycommon.RelayInfo, fallback string) string {
if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" {
return info.UpstreamModelName
}
return fallback
}
func isTemperatureOneOnlyModel(model string) bool {
return strings.EqualFold(model, "kimi-k2.6")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
-68
View File
@@ -1,68 +0,0 @@
package moonshot
import (
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/stretchr/testify/require"
)
func TestConvertOpenAIRequestKimiK26UsesOnlyAllowedTemperature(t *testing.T) {
request := &dto.GeneralOpenAIRequest{
Model: "kimi-k2.6",
Temperature: common.GetPointer[float64](0.7),
}
info := &relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "kimi-k2.6",
},
}
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
require.NoError(t, err)
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
require.True(t, ok)
require.NotNil(t, convertedRequest.Temperature)
require.Equal(t, 1.0, *convertedRequest.Temperature)
}
func TestConvertOpenAIRequestKimiK26KeepsOmittedTemperatureOmitted(t *testing.T) {
request := &dto.GeneralOpenAIRequest{
Model: "kimi-k2.6",
}
info := &relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "kimi-k2.6",
},
}
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
require.NoError(t, err)
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
require.True(t, ok)
require.Nil(t, convertedRequest.Temperature)
}
func TestConvertOpenAIRequestOtherMoonshotModelKeepsTemperature(t *testing.T) {
request := &dto.GeneralOpenAIRequest{
Model: "kimi-k2.5",
Temperature: common.GetPointer[float64](0.7),
}
info := &relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{
UpstreamModelName: "kimi-k2.5",
},
}
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
require.NoError(t, err)
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
require.True(t, ok)
require.NotNil(t, convertedRequest.Temperature)
require.Equal(t, 0.7, *convertedRequest.Temperature)
}
+2 -2
View File
@@ -1,6 +1,7 @@
package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -11,7 +12,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
@@ -397,7 +397,7 @@ func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback f
}
// 读取流式响应
scanner := helper.NewStreamScanner(response.Body)
scanner := bufio.NewScanner(response.Body)
successful := false
for scanner.Scan() {
line := scanner.Text()
+2 -1
View File
@@ -1,6 +1,7 @@
package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -69,7 +70,7 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
defer service.CloseResponseBodyGracefully(resp)
helper.SetEventStreamHeaders(c)
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
usage := &dto.Usage{}
var model = info.UpstreamModelName
var responseId = common.GetUUID()
+7 -17
View File
@@ -9,7 +9,6 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"path/filepath"
"strings"
@@ -311,20 +310,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
}
}
isOModel := dto.IsOpenAIReasoningOModel(info.UpstreamModelName)
isGPT5Model := dto.IsOpenAIGPT5Model(info.UpstreamModelName)
if isOModel || isGPT5Model {
if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
if lo.FromPtrOr(request.MaxCompletionTokens, uint(0)) == 0 && lo.FromPtrOr(request.MaxTokens, uint(0)) != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = nil
}
if isOModel {
if strings.HasPrefix(info.UpstreamModelName, "o") {
request.Temperature = nil
}
// gpt-5系列模型适配 归零不再支持的参数
if isGPT5Model {
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
request.Temperature = nil
request.TopP = nil
request.LogProbs = nil
@@ -440,13 +437,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
// 使用已解析的 multipart 表单,避免重复解析
mf := c.Request.MultipartForm
if mf == nil {
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return nil, fmt.Errorf("failed to parse multipart form: %w", err)
if _, err := c.MultipartForm(); err != nil {
return nil, errors.New("failed to parse multipart form")
}
c.Request.MultipartForm = form
c.Request.PostForm = url.Values(form.Value)
mf = form
mf = c.Request.MultipartForm
}
// 写入所有非文件字段
@@ -629,11 +623,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case relayconstant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
if info.IsStream {
usage, err = OpenaiImageStreamHandler(c, info, resp)
} else {
usage, err = OpenaiImageHandler(c, info, resp)
}
usage, err = OpenaiHandlerWithUsage(c, info, resp)
case relayconstant.RelayModeRerank:
usage, err = common_handler.RerankHandler(c, info, resp)
case relayconstant.RelayModeResponses:
+65 -14
View File
@@ -1,6 +1,7 @@
package openai
import (
"encoding/json"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -91,28 +92,78 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res
return nil
}
func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error {
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
return err
}
return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount)
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
var streamResponse dto.CompletionsStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
return err
}
processCompletionsStreamResponse(streamResponse, responseTextBuilder)
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysLog("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
-98
View File
@@ -1,98 +0,0 @@
package openai
import (
"bytes"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// TestConvertImageEditRequestMultipart verifies that ConvertImageRequest
// re-serializes multipart image edit requests with all fields (including
// stream) and the file intact, both when the form was already parsed and when
// it must be re-parsed from the reusable body.
func TestConvertImageEditRequestMultipart(t *testing.T) {
gin.SetMode(gin.TestMode)
newMultipartContext := func(t *testing.T, prompt string) *gin.Context {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
require.NoError(t, writer.WriteField("prompt", prompt))
require.NoError(t, writer.WriteField("stream", "true"))
require.NoError(t, writer.WriteField("partial_images", "3"))
part, err := writer.CreateFormFile("image", "input.png")
require.NoError(t, err)
_, err = part.Write([]byte("fake image"))
require.NoError(t, err)
require.NoError(t, writer.Close())
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return c
}
convertAndReplay := func(t *testing.T, c *gin.Context, prompt string) {
info := &relaycommon.RelayInfo{
RelayMode: relayconstant.RelayModeImagesEdits,
}
request := dto.ImageRequest{
Model: "gpt-image-1",
Prompt: prompt,
Stream: common.GetPointer(true),
}
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
require.NoError(t, err)
convertedBody, ok := converted.(*bytes.Buffer)
require.True(t, ok)
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
require.Equal(t, prompt, replayedRequest.PostForm.Get("prompt"))
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
file, err := replayedRequest.MultipartForm.File["image"][0].Open()
require.NoError(t, err)
defer file.Close()
fileBytes, err := io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, []byte("fake image"), fileBytes)
}
t.Run("with pre-parsed form", func(t *testing.T) {
prompt := "edit this image"
c := newMultipartContext(t, prompt)
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
convertAndReplay(t, c, prompt)
})
t.Run("re-parses reusable body when form is missing", func(t *testing.T) {
prompt := "edit without pre-parsed form"
c := newMultipartContext(t, prompt)
storage, err := common.GetBodyStorage(c)
require.NoError(t, err)
c.Request.Body = io.NopCloser(storage)
c.Request.MultipartForm = nil
c.Request.PostForm = nil
convertAndReplay(t, c, prompt)
})
}
-173
View File
@@ -1,173 +0,0 @@
package openai
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/QuantumNous/new-api/constant"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func newImageTestContext(t *testing.T, body, contentType string, isStream bool) (*gin.Context, *httptest.ResponseRecorder, *http.Response, *relaycommon.RelayInfo) {
t.Helper()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: http.Header{"Content-Type": []string{contentType}},
}
info := &relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{},
IsStream: isStream,
}
return c, recorder, resp, info
}
// TestOpenaiImageStreamHandlerForwardsSSEAndUsage covers the core SSE path:
// chunks are forwarded with rebuilt event lines, usage is extracted and
// normalized (input_tokens -> prompt_tokens with details), and [DONE] is
// re-emitted to the client.
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
oldMode := gin.Mode()
gin.SetMode(gin.TestMode)
t.Cleanup(func() { gin.SetMode(oldMode) })
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
body := strings.Join([]string{
`event: image_generation.partial_image`,
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
``,
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
``,
`data: [DONE]`,
``,
}, "\n")
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
usage, err := OpenaiImageStreamHandler(c, info, resp)
require.Nil(t, err)
require.Equal(t, 3, usage.PromptTokens)
require.Equal(t, 4, usage.CompletionTokens)
require.Equal(t, 7, usage.TotalTokens)
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
}
// TestOpenaiImageStreamHandlerWrapsJSONResponse covers the non-SSE fallback:
// a JSON upstream response is wrapped into pseudo-SSE completed events.
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
oldMode := gin.Mode()
gin.SetMode(gin.TestMode)
t.Cleanup(func() { gin.SetMode(oldMode) })
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
usage, err := OpenaiImageStreamHandler(c, info, resp)
require.Nil(t, err)
require.Equal(t, 3, usage.PromptTokens)
require.Equal(t, 4, usage.CompletionTokens)
require.Equal(t, 7, usage.TotalTokens)
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
require.Empty(t, recorder.Header().Get("Content-Length"))
require.Contains(t, recorder.Body.String(), `event: image_generation.completed`)
require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`)
require.Contains(t, recorder.Body.String(), `"b64_json":"final"`)
require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`)
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
}
// TestOpenaiImageHandlersReturnJSONError covers JSON error responses for both
// entry points: the non-streaming handler and the stream handler's non-SSE
// fallback. Neither must leak the error body to the client.
func TestOpenaiImageHandlersReturnJSONError(t *testing.T) {
oldMode := gin.Mode()
gin.SetMode(gin.TestMode)
t.Cleanup(func() { gin.SetMode(oldMode) })
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
t.Run("non-streaming handler", func(t *testing.T) {
c, recorder, resp, info := newImageTestContext(t, body, "application/json", false)
usage, err := OpenaiImageHandler(c, info, resp)
require.Nil(t, usage)
require.NotNil(t, err)
require.Equal(t, http.StatusOK, err.StatusCode)
oaiError := err.ToOpenAIError()
require.Equal(t, "content moderation failed", oaiError.Message)
require.Equal(t, "upstream_error", oaiError.Type)
require.Equal(t, "content_moderation_failed", oaiError.Code)
require.Empty(t, recorder.Body.String())
})
t.Run("stream handler JSON fallback", func(t *testing.T) {
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
usage, err := OpenaiImageStreamHandler(c, info, resp)
require.Nil(t, usage)
require.NotNil(t, err)
require.Equal(t, http.StatusOK, err.StatusCode)
require.Equal(t, "content moderation failed", err.ToOpenAIError().Message)
require.Empty(t, recorder.Body.String())
})
}
// TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent verifies that an error
// event inside the SSE stream is recorded as a soft error while the payload is
// still forwarded to the client.
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
oldMode := gin.Mode()
gin.SetMode(gin.TestMode)
t.Cleanup(func() { gin.SetMode(oldMode) })
oldTimeout := constant.StreamingTimeout
constant.StreamingTimeout = 30
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
body := strings.Join([]string{
`event: image_generation.partial_image`,
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
``,
`event: error`,
`data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`,
``,
}, "\n")
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
usage, err := OpenaiImageStreamHandler(c, info, resp)
require.Nil(t, err)
require.NotNil(t, usage)
require.NotNil(t, info.StreamStatus)
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
require.True(t, info.StreamStatus.HasErrors())
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
// The scanner strips the upstream "event: error" line; the event name is
// rebuilt from the JSON "type" field (upstream_error). The error message
// is still forwarded in the data: payload (stream ID 77).
require.Contains(t, recorder.Body.String(), `event: upstream_error`)
require.Contains(t, recorder.Body.String(), `stream ID 77`)
}
+428 -4
View File
@@ -14,9 +14,12 @@ import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -116,6 +119,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
var lastStreamData string
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
@@ -136,10 +140,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
lastStreamData = data
if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil {
logger.LogError(c, "error processing stream token data: "+err.Error())
sr.Error(err)
}
streamItems = append(streamItems, data)
}
})
@@ -174,6 +175,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
logger.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
usage.CompletionTokens += toolCount * 7
@@ -290,3 +296,421 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
return &simpleResponse.Usage, nil
}
func streamTTSResponse(c *gin.Context, resp *http.Response) {
c.Writer.WriteHeaderNow()
flusher, ok := c.Writer.(http.Flusher)
if !ok {
logger.LogWarn(c, "streaming not supported")
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogWarn(c, err.Error())
}
return
}
buffer := make([]byte, 4096)
for {
n, err := resp.Body.Read(buffer)
//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
if n > 0 {
if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
logger.LogError(c, writeErr.Error())
break
}
flusher.Flush()
}
if err != nil {
if err != io.EOF {
logger.LogError(c, err.Error())
}
break
}
}
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
}
info.IsStream = true
clientConn := info.ClientWs
targetConn := info.TargetWs
clientClosed := make(chan struct{})
targetClosed := make(chan struct{})
sendChan := make(chan []byte, 100)
receiveChan := make(chan []byte, 100)
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
sumUsage := &dto.RealtimeUsage{}
gopool.Go(func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("panic in client reader: %v", r)
}
}()
for {
select {
case <-c.Done():
return
default:
_, message, err := clientConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from client: %v", err)
}
close(clientClosed)
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = helper.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
}
select {
case sendChan <- message:
default:
}
}
}
})
gopool.Go(func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("panic in target reader: %v", r)
}
}()
for {
select {
case <-c.Done():
return
default:
_, message, err := targetConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from target: %v", err)
}
close(targetClosed)
return
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
realtimeUsage := realtimeEvent.Response.Usage
if realtimeUsage != nil {
usage.TotalTokens += realtimeUsage.TotalTokens
usage.InputTokens += realtimeUsage.InputTokens
usage.OutputTokens += realtimeUsage.OutputTokens
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
err := preConsumeUsage(c, info, usage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
usage = &dto.RealtimeUsage{}
localUsage = &dto.RealtimeUsage{}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = preConsumeUsage(c, info, localUsage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
localUsage = &dto.RealtimeUsage{}
// print now usage
}
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = helper.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return
}
select {
case receiveChan <- message:
default:
}
}
}
})
select {
case <-clientClosed:
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
logger.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
if usage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, usage, sumUsage)
}
if localUsage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, localUsage, sumUsage)
}
// check usage total tokens, if 0, use local usage
return nil, sumUsage
}
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
if usage == nil || totalUsage == nil {
return fmt.Errorf("invalid usage pointer")
}
totalUsage.TotalTokens += usage.TotalTokens
totalUsage.InputTokens += usage.InputTokens
totalUsage.OutputTokens += usage.OutputTokens
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
// clear usage
err := service.PreWssConsumeQuota(ctx, info, usage)
return err
}
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
var usageResp dto.SimpleResponse
err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
// Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content
// We should still perform billing even if parsing fails
// format
if usageResp.InputTokens > 0 {
usageResp.PromptTokens += usageResp.InputTokens
}
if usageResp.OutputTokens > 0 {
usageResp.CompletionTokens += usageResp.OutputTokens
}
if usageResp.InputTokensDetails != nil {
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
}
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
return &usageResp.Usage, nil
}
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
if info == nil || usage == nil {
return
}
switch info.ChannelType {
case constant.ChannelTypeDeepSeek:
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
case constant.ChannelTypeZhipu_v4:
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if usage.PromptCacheHitTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
case constant.ChannelTypeMoonshot:
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if usage.PromptCacheHitTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
case constant.ChannelTypeOpenAI:
if usage.PromptTokensDetails.CachedTokens == 0 {
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
}
}
}
}
func extractCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Usage struct {
PromptTokensDetails struct {
CachedTokens *int `json:"cached_tokens"`
} `json:"prompt_tokens_details"`
CachedTokens *int `json:"cached_tokens"`
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
} `json:"usage"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
return *payload.Usage.PromptTokensDetails.CachedTokens, true
}
if payload.Usage.CachedTokens != nil {
return *payload.Usage.CachedTokens, true
}
if payload.Usage.PromptCacheHitTokens != nil {
return *payload.Usage.PromptCacheHitTokens, true
}
return 0, false
}
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Choices []struct {
Usage struct {
CachedTokens *int `json:"cached_tokens"`
} `json:"usage"`
} `json:"choices"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
// 遍历choices查找cached_tokens
for _, choice := range payload.Choices {
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
return *choice.Usage.CachedTokens, true
}
}
return 0, false
}
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Timings struct {
CachedTokens *int `json:"cache_n"`
} `json:"timings"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
if payload.Timings.CachedTokens == nil {
return 0, false
}
return *payload.Timings.CachedTokens, true
}
-287
View File
@@ -1,287 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// OpenaiImageHandler handles non-streaming OpenAI image responses
// (generations/edits), returning the parsed usage for billing.
func OpenaiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
var usageResp dto.SimpleResponse
err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
normalizeOpenAIUsage(&usageResp.Usage)
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
return &usageResp.Usage, nil
}
// normalizeOpenAIUsage maps the OpenAI Images usage shape (input_tokens /
// output_tokens / input_tokens_details) onto the canonical prompt/completion
// fields. It is used only on the OpenAI image relay paths (generations/edits,
// streaming and non-streaming): the image API never returns prompt_tokens /
// completion_tokens, so the overwrite (=) semantics here are equivalent to the
// previous additive (+=) behavior while avoiding any future double-counting if
// both field sets are ever populated. Do not reuse this on chat/embedding paths
// without revisiting the overwrite semantics.
func normalizeOpenAIUsage(usage *dto.Usage) {
if usage == nil {
return
}
if usage.InputTokens != 0 {
usage.PromptTokens = usage.InputTokens
}
if usage.OutputTokens != 0 {
usage.CompletionTokens = usage.OutputTokens
}
if usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens
usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens
usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens
usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens
}
if usage.TotalTokens == 0 {
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
}
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
logger.LogError(c, "invalid image stream response")
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return OpenaiImageHandler(c, info, resp)
}
if !strings.Contains(contentType, "text/event-stream") {
return OpenaiImageJSONAsStreamHandler(c, info, resp)
}
// Reuse the shared streaming engine (helper.StreamScannerHandler) so the
// image streaming path gets the same ping keepalive, streaming-timeout
// watchdog, client-disconnect detection, panic recovery and goroutine
// cleanup as every other relay stream. The scanner delivers only the
// "data:" payload, so the SSE "event:" line is rebuilt from the JSON "type"
// field (real OpenAI image events keep event == type).
usage := &dto.Usage{}
var lastStreamData []byte
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
raw := common.StringToByteSlice(data)
lastStreamData = raw
if isOpenAIImageStreamErrorEvent(raw) {
// Record the error as a soft error; the scanner drives the final
// EndReason. HasErrors() flags the failure for logging/handling.
sr.Error(fmt.Errorf("%s", extractOpenAIImageStreamErrorMessage(raw)))
}
var usageResp dto.SimpleResponse
if err := common.Unmarshal(raw, &usageResp); err == nil {
normalizeOpenAIUsage(&usageResp.Usage)
if service.ValidUsage(&usageResp.Usage) {
usage = &usageResp.Usage
}
}
writeOpenaiImageStreamChunk(c, raw)
})
// StreamScannerHandler consumes the upstream [DONE]; re-emit it so the
// client still receives a terminal data: [DONE].
if info != nil && info.StreamStatus != nil && info.StreamStatus.EndReason == relaycommon.StreamEndReasonDone {
helper.Done(c)
}
applyUsagePostProcessing(info, usage, lastStreamData)
return usage, nil
}
// writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk:
// it emits an "event:" line derived from the JSON "type" field (when present)
// followed by the verbatim "data:" payload, mirroring helper.ResponseChunkData.
func writeOpenaiImageStreamChunk(c *gin.Context, data []byte) {
var payload struct {
Type string `json:"type"`
}
_ = common.Unmarshal(data, &payload)
if eventName := strings.TrimSpace(payload.Type); eventName != "" {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", eventName)})
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(data)})
_ = helper.FlushWriter(c)
}
// isOpenAIImageStreamErrorEvent detects upstream error chunks by JSON content
// only ("type" of error/upstream_error, or a non-empty "error" field). The SSE
// "event:" line is not available here: StreamScannerHandler delivers only the
// "data:" payload. A payload carrying just a "message" key is deliberately NOT
// treated as an error to avoid false positives.
func isOpenAIImageStreamErrorEvent(data []byte) bool {
if !json.Valid(data) {
return false
}
var payload struct {
Type string `json:"type"`
Error json.RawMessage `json:"error"`
}
if err := common.Unmarshal(data, &payload); err != nil {
return false
}
payloadType := strings.ToLower(strings.TrimSpace(payload.Type))
return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0
}
func extractOpenAIImageStreamErrorMessage(data []byte) string {
if len(data) == 0 || !json.Valid(data) {
return "upstream image stream returned error event"
}
var payload struct {
Message string `json:"message"`
Error json.RawMessage `json:"error"`
}
if err := common.Unmarshal(data, &payload); err != nil {
return "upstream image stream returned error event"
}
if msg := strings.TrimSpace(payload.Message); msg != "" {
return msg
}
if len(payload.Error) > 0 {
var nested struct {
Message string `json:"message"`
}
if err := common.Unmarshal(payload.Error, &nested); err == nil {
if msg := strings.TrimSpace(nested.Message); msg != "" {
return msg
}
}
if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" {
return msg
}
}
return "upstream image stream returned error event"
}
func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
var imageResp dto.ImageResponse
if err := common.Unmarshal(responseBody, &imageResp); err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
var usageResp dto.SimpleResponse
_ = common.Unmarshal(responseBody, &usageResp)
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
normalizeOpenAIUsage(&usageResp.Usage)
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
helper.SetEventStreamHeaders(c)
c.Status(http.StatusOK)
created := imageResp.Created
if created == 0 {
created = time.Now().Unix()
}
if info != nil {
info.SetFirstResponseTime()
}
for _, image := range imageResp.Data {
payload := map[string]any{
"type": "image_generation.completed",
"created_at": created,
}
if image.Url != "" {
payload["url"] = image.Url
}
if image.B64Json != "" {
payload["b64_json"] = image.B64Json
}
if image.RevisedPrompt != "" {
payload["revised_prompt"] = image.RevisedPrompt
}
if service.ValidUsage(&usageResp.Usage) {
payload["usage"] = usageResp.Usage
}
if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil {
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
}
return &usageResp.Usage, nil
}
}
if err := writeOpenaiImageStreamDone(c); err != nil {
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
}
return &usageResp.Usage, nil
}
if info != nil {
info.ReceivedResponseCount += len(imageResp.Data)
if info.StreamStatus == nil {
info.StreamStatus = relaycommon.NewStreamStatus()
}
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
}
return &usageResp.Usage, nil
}
func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error {
data, err := common.Marshal(payload)
if err != nil {
return err
}
if eventName != "" {
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
return err
}
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil {
return err
}
return helper.FlushWriter(c)
}
func writeOpenaiImageStreamDone(c *gin.Context) error {
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
return err
}
return helper.FlushWriter(c)
}
-242
View File
@@ -1,242 +0,0 @@
package openai
import (
"fmt"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
}
info.IsStream = true
clientConn := info.ClientWs
targetConn := info.TargetWs
clientClosed := make(chan struct{})
targetClosed := make(chan struct{})
sendChan := make(chan []byte, 100)
receiveChan := make(chan []byte, 100)
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
sumUsage := &dto.RealtimeUsage{}
gopool.Go(func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("panic in client reader: %v", r)
}
}()
for {
select {
case <-c.Done():
return
default:
_, message, err := clientConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from client: %v", err)
}
close(clientClosed)
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = helper.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
}
select {
case sendChan <- message:
default:
}
}
}
})
gopool.Go(func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("panic in target reader: %v", r)
}
}()
for {
select {
case <-c.Done():
return
default:
_, message, err := targetConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from target: %v", err)
}
close(targetClosed)
return
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
realtimeUsage := realtimeEvent.Response.Usage
if realtimeUsage != nil {
usage.TotalTokens += realtimeUsage.TotalTokens
usage.InputTokens += realtimeUsage.InputTokens
usage.OutputTokens += realtimeUsage.OutputTokens
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
err := preConsumeUsage(c, info, usage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
usage = &dto.RealtimeUsage{}
localUsage = &dto.RealtimeUsage{}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = preConsumeUsage(c, info, localUsage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
localUsage = &dto.RealtimeUsage{}
// print now usage
}
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = helper.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return
}
select {
case receiveChan <- message:
default:
}
}
}
})
select {
case <-clientClosed:
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
logger.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
if usage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, usage, sumUsage)
}
if localUsage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, localUsage, sumUsage)
}
// check usage total tokens, if 0, use local usage
return nil, sumUsage
}
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
if usage == nil || totalUsage == nil {
return fmt.Errorf("invalid usage pointer")
}
totalUsage.TotalTokens += usage.TotalTokens
totalUsage.InputTokens += usage.InputTokens
totalUsage.OutputTokens += usage.OutputTokens
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
// clear usage
err := service.PreWssConsumeQuota(ctx, info, usage)
return err
}
-133
View File
@@ -1,133 +0,0 @@
package openai
import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
)
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
if info == nil || usage == nil {
return
}
switch info.ChannelType {
case constant.ChannelTypeDeepSeek:
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
case constant.ChannelTypeZhipu_v4:
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if usage.PromptCacheHitTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
case constant.ChannelTypeMoonshot:
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if usage.PromptCacheHitTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
case constant.ChannelTypeOpenAI:
if usage.PromptTokensDetails.CachedTokens == 0 {
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
}
}
}
}
func extractCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Usage struct {
PromptTokensDetails struct {
CachedTokens *int `json:"cached_tokens"`
} `json:"prompt_tokens_details"`
CachedTokens *int `json:"cached_tokens"`
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
} `json:"usage"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
return *payload.Usage.PromptTokensDetails.CachedTokens, true
}
if payload.Usage.CachedTokens != nil {
return *payload.Usage.CachedTokens, true
}
if payload.Usage.PromptCacheHitTokens != nil {
return *payload.Usage.PromptCacheHitTokens, true
}
return 0, false
}
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Choices []struct {
Usage struct {
CachedTokens *int `json:"cached_tokens"`
} `json:"usage"`
} `json:"choices"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
// 遍历choices查找cached_tokens
for _, choice := range payload.Choices {
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
return *choice.Usage.CachedTokens, true
}
}
return 0, false
}
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Timings struct {
CachedTokens *int `json:"cache_n"`
} `json:"timings"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
if payload.Timings.CachedTokens == nil {
return 0, false
}
return *payload.Timings.CachedTokens, true
}
+1 -1
View File
@@ -92,7 +92,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
-1
View File
@@ -45,7 +45,6 @@ var claudeModelMap = map[string]string{
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
"claude-opus-4-6": "claude-opus-4-6",
"claude-opus-4-7": "claude-opus-4-7",
"claude-opus-4-8": "claude-opus-4-8",
}
const anthropicVersion = "vertex-2023-10-16"
+1 -1
View File
@@ -114,7 +114,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
usage, err = openai.OpenaiImageHandler(c, info, resp)
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
case constant.RelayModeResponses:
if info.IsStream {
usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
+1 -4
View File
@@ -157,7 +157,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage *dto.Usage
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
dataChan := make(chan string)
metaChan := make(chan string)
@@ -180,9 +180,6 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
}
}
}
if err := scanner.Err(); err != nil {
common.SysLog("error reading stream: " + err.Error())
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
+2 -8
View File
@@ -1,6 +1,7 @@
package relay
import (
"bytes"
"io"
"net/http"
"strings"
@@ -124,14 +125,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
var requestBody io.Reader = body
var requestBody io.Reader = bytes.NewBuffer(jsonData)
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, info, requestBody)
+7 -18
View File
@@ -1,6 +1,7 @@
package relay
import (
"bytes"
"encoding/json"
"fmt"
"io"
@@ -53,17 +54,14 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(request.Model); ok && effortLevel != "" &&
(strings.HasPrefix(request.Model, "claude-opus-4-6") ||
strings.HasPrefix(request.Model, "claude-opus-4-7") ||
strings.HasPrefix(request.Model, "claude-opus-4-8")) {
(strings.HasPrefix(request.Model, "claude-opus-4-6") || strings.HasPrefix(request.Model, "claude-opus-4-7")) {
request.Model = baseModel
request.Thinking = &dto.Thinking{
Type: "adaptive",
}
request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
if strings.HasPrefix(request.Model, "claude-opus-4-7") ||
strings.HasPrefix(request.Model, "claude-opus-4-8") {
// Opus 4.7/4.8 reject non-default temperature/top_p/top_k with 400
if strings.HasPrefix(request.Model, "claude-opus-4-7") {
// Opus 4.7 rejects non-default temperature/top_p/top_k with 400
// and defaults display to "omitted"; restore the 4.6 visible summary.
request.Thinking.Display = "summarized"
request.Temperature = nil
@@ -77,9 +75,8 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
strings.HasSuffix(request.Model, "-thinking") {
if request.Thinking == nil {
baseModel := strings.TrimSuffix(request.Model, "-thinking")
if strings.HasPrefix(baseModel, "claude-opus-4-7") ||
strings.HasPrefix(baseModel, "claude-opus-4-8") {
// Opus 4.7/4.8 reject thinking.type="enabled"; use adaptive at high effort.
if strings.HasPrefix(baseModel, "claude-opus-4-7") {
// Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
request.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
request.OutputConfig = json.RawMessage(`{"effort":"high"}`)
request.Temperature = nil
@@ -155,7 +152,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
info.UpstreamRequestBodySize = storage.Size()
requestBody = common.ReaderOnly(storage)
} else {
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
@@ -183,14 +179,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
logger.LogDebug(c, "requestBody: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
requestBody = bytes.NewBuffer(jsonData)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
-31
View File
@@ -1,31 +0,0 @@
package common
import (
"io"
"github.com/QuantumNous/new-api/common"
)
// NewOutboundJSONBody wraps the already-marshaled upstream request body into a
// BodyStorage. When disk cache is enabled and the payload exceeds the configured
// threshold, the data is written to a temp file and the original []byte can be
// GC'd, significantly reducing the heap residency while waiting for the
// upstream provider to respond (the dominant cost for large base64 payloads).
//
// In memory mode the underlying memoryStorage reuses the same backing array,
// so this is equivalent to bytes.NewReader(data) in terms of memory usage.
//
// The caller MUST invoke closer.Close() once the upstream call has finished
// (typically via defer) to release the disk file / memory accounting.
//
// The returned reader is wrapped with common.ReaderOnly to prevent the HTTP
// transport from prematurely closing the underlying BodyStorage. The returned
// size is meant to be propagated to http.Request.ContentLength because the
// type-erased io.Reader prevents net/http from auto-detecting it.
func NewOutboundJSONBody(data []byte) (body io.Reader, size int64, closer io.Closer, err error) {
storage, err := common.CreateBodyStorage(data)
if err != nil {
return nil, 0, nil, err
}
return common.ReaderOnly(storage), storage.Size(), storage, nil
}
+133 -168
View File
@@ -153,8 +153,9 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
}
}
// 使用新方法(基于 []byte,避免整包 string 拷贝)
return applyOperations(workingJSON, operations, conditionContext)
// 使用新方法
result, err := applyOperations(string(workingJSON), operations, conditionContext)
return []byte(result), err
}
// 直接使用旧方法
@@ -509,13 +510,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
return operations, true
}
func checkConditions(data []byte, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
if len(conditions) == 0 {
return true, nil // 没有条件,直接通过
}
results := make([]bool, len(conditions))
for i, condition := range conditions {
result, err := checkSingleCondition(data, contextJSON, condition)
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
if err != nil {
return false, err
}
@@ -528,10 +529,10 @@ func checkConditions(data []byte, contextJSON string, conditions []ConditionOper
return lo.SomeBy(results, func(item bool) bool { return item }), nil
}
func checkSingleCondition(data []byte, contextJSON string, condition ConditionOperation) (bool, error) {
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
// 处理负数索引
path := processNegativeIndex(data, condition.Path)
value := gjson.GetBytes(data, path)
path := processNegativeIndex(jsonStr, condition.Path)
value := gjson.Get(jsonStr, path)
if !value.Exists() && contextJSON != "" {
value = gjson.Get(contextJSON, condition.Path)
}
@@ -560,7 +561,7 @@ func checkSingleCondition(data []byte, contextJSON string, condition ConditionOp
return result, nil
}
func processNegativeIndex(data []byte, path string) string {
func processNegativeIndex(jsonStr string, path string) string {
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
if len(matches) == 0 {
@@ -577,7 +578,7 @@ func processNegativeIndex(data []byte, path string) string {
arrayPath = arrayPath[:len(arrayPath)-1]
}
array := gjson.GetBytes(data, arrayPath)
array := gjson.Get(jsonStr, arrayPath)
if array.IsArray() {
length := len(array.Array())
actualIndex := length + index
@@ -666,76 +667,36 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
}
}
// applyOperationsLegacy 原参数覆盖方法
//
// 旧实现把整个 jsonData unmarshal 成 map[string]interface{} 再 marshal 回来,
// 对包含大 base64 字段(如 Gemini inlineData.data)的请求会放大数倍内存
// interface 装箱、map bucket、再次 marshal)。
// 这里改成在 []byte 上直接调用 sjson.SetBytes,按顶层 key 逐个写入,
// 不再把 payload 解码到 map[string]interface{}。
//
// 语义保持:每个 paramOverride 顶层 key 视为字面 key(不解析点号路径),
// 与旧的 reqMap[key] = value 一致。包含 `.` `*` `?` `\` 的 key 会被转义,
// 防止被 sjson 当作嵌套路径或通配符。
// applyOperationsLegacy 原参数覆盖方法
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) {
if len(paramOverride) == 0 {
return jsonData, nil
reqMap := make(map[string]interface{})
err := common.Unmarshal(jsonData, &reqMap)
if err != nil {
return nil, err
}
result := jsonData
for key, value := range paramOverride {
escaped := escapeSjsonLiteralKey(key)
next, err := sjson.SetBytes(result, escaped, value)
if err != nil {
return nil, err
}
result = next
reqMap[key] = value
auditRecorder.recordOperation("set", key, "", "", value)
}
return result, nil
return common.Marshal(reqMap)
}
// escapeSjsonLiteralKey 把可能被 sjson 误判为路径或通配符的字符转义,
// 用于把字面 key 安全地传给 sjson.SetBytes / sjson.DeleteBytes。
func escapeSjsonLiteralKey(key string) string {
if !strings.ContainsAny(key, ".*?\\") {
return key
}
var sb strings.Builder
sb.Grow(len(key) + 4)
for i := 0; i < len(key); i++ {
c := key[i]
switch c {
case '.', '*', '?', '\\':
sb.WriteByte('\\')
}
sb.WriteByte(c)
}
return sb.String()
}
// applyOperations 在 []byte 上原地应用所有 param override 操作。
//
// 旧实现走 string-based gjson/sjson,在 ApplyParamOverride 入口会做
// string(jsonData) 与最终 []byte(result) 各一次整包拷贝,对大 base64
// payload 来说每次重试都额外多花 2 倍 body 体积的临时内存。
// 这里改成全程在 []byte 上工作,sjson.SetBytes / gjson.GetBytes 都是
// 直接读写 []byte,每个操作只会产生一份新 buffer。
func applyOperations(jsonData []byte, operations []ParamOperation, conditionContext map[string]interface{}) ([]byte, error) {
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
context := ensureContextMap(conditionContext)
auditRecorder := getParamOverrideAuditRecorder(context)
contextJSON, err := marshalContextJSON(context)
if err != nil {
return nil, fmt.Errorf("failed to marshal condition context: %v", err)
return "", fmt.Errorf("failed to marshal condition context: %v", err)
}
result := jsonData
result := jsonStr
for _, op := range operations {
// 检查条件是否满足
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
if err != nil {
return nil, err
return "", err
}
if !ok {
continue // 条件不满足,跳过当前操作
@@ -746,7 +707,7 @@ func applyOperations(jsonData []byte, operations []ParamOperation, conditionCont
if isPathBasedOperation(op.Mode) {
opPaths, err = resolveOperationPaths(result, opPath)
if err != nil {
return nil, err
return "", err
}
if len(opPaths) == 0 {
continue
@@ -764,10 +725,10 @@ func applyOperations(jsonData []byte, operations []ParamOperation, conditionCont
}
case "set":
for _, path := range opPaths {
if op.KeepOrigin && gjson.GetBytes(result, path).Exists() {
if op.KeepOrigin && gjson.Get(result, path).Exists() {
continue
}
result, err = sjson.SetBytes(result, path, op.Value)
result, err = sjson.Set(result, path, op.Value)
if err != nil {
break
}
@@ -782,7 +743,7 @@ func applyOperations(jsonData []byte, operations []ParamOperation, conditionCont
}
case "copy":
if op.From == "" || op.To == "" {
return nil, fmt.Errorf("copy from/to is required")
return "", fmt.Errorf("copy from/to is required")
}
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
@@ -882,9 +843,9 @@ func applyOperations(jsonData []byte, operations []ParamOperation, conditionCont
auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value)
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
if parseErr != nil {
return nil, parseErr
return "", parseErr
}
return nil, returnErr
return "", returnErr
case "prune_objects":
for _, path := range opPaths {
result, err = pruneObjects(result, path, contextJSON, op.Value)
@@ -941,7 +902,7 @@ func applyOperations(jsonData []byte, operations []ParamOperation, conditionCont
case "pass_headers":
headerNames, parseErr := parseHeaderPassThroughNames(op.Value)
if parseErr != nil {
return nil, parseErr
return "", parseErr
}
for _, headerName := range headerNames {
if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil {
@@ -963,10 +924,10 @@ func applyOperations(jsonData []byte, operations []ParamOperation, conditionCont
contextJSON, err = marshalContextJSON(context)
}
default:
return nil, fmt.Errorf("unknown operation: %s", op.Mode)
return "", fmt.Errorf("unknown operation: %s", op.Mode)
}
if err != nil {
return nil, fmt.Errorf("operation %s failed: %w", op.Mode, err)
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
}
}
return result, nil
@@ -1400,11 +1361,11 @@ func parseSyncTarget(spec string) (syncTarget, error) {
}
}
func readSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
switch target.kind {
case "json":
path := processNegativeIndex(data, target.key)
value := gjson.GetBytes(data, path)
path := processNegativeIndex(jsonStr, target.key)
value := gjson.Get(jsonStr, path)
if !value.Exists() || value.Type == gjson.Null {
return nil, false, nil
}
@@ -1423,52 +1384,52 @@ func readSyncTargetValue(data []byte, context map[string]interface{}, target syn
}
}
func writeSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget, value interface{}) ([]byte, error) {
func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) {
switch target.kind {
case "json":
path := processNegativeIndex(data, target.key)
nextJSON, err := sjson.SetBytes(data, path, value)
path := processNegativeIndex(jsonStr, target.key)
nextJSON, err := sjson.Set(jsonStr, path, value)
if err != nil {
return nil, err
return "", err
}
return nextJSON, nil
case "header":
if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil {
return nil, err
return "", err
}
return data, nil
return jsonStr, nil
default:
return nil, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
}
}
func syncFieldsBetweenTargets(data []byte, context map[string]interface{}, fromSpec string, toSpec string) ([]byte, error) {
func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) {
fromTarget, err := parseSyncTarget(fromSpec)
if err != nil {
return nil, err
return "", err
}
toTarget, err := parseSyncTarget(toSpec)
if err != nil {
return nil, err
return "", err
}
fromValue, fromExists, err := readSyncTargetValue(data, context, fromTarget)
fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget)
if err != nil {
return nil, err
return "", err
}
toValue, toExists, err := readSyncTargetValue(data, context, toTarget)
toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget)
if err != nil {
return nil, err
return "", err
}
// If one side exists and the other side is missing, sync the missing side.
if fromExists && !toExists {
return writeSyncTargetValue(data, context, toTarget, fromValue)
return writeSyncTargetValue(jsonStr, context, toTarget, fromValue)
}
if toExists && !fromExists {
return writeSyncTargetValue(data, context, fromTarget, toValue)
return writeSyncTargetValue(jsonStr, context, fromTarget, toValue)
}
return data, nil
return jsonStr, nil
}
func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} {
@@ -1542,24 +1503,24 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
info.UseRuntimeHeadersOverride = true
}
func moveValue(data []byte, fromPath, toPath string) ([]byte, error) {
sourceValue := gjson.GetBytes(data, fromPath)
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
return data, fmt.Errorf("source path does not exist: %s", fromPath)
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
}
result, err := sjson.SetBytes(data, toPath, sourceValue.Value())
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
if err != nil {
return nil, err
return "", err
}
return sjson.DeleteBytes(result, fromPath)
return sjson.Delete(result, fromPath)
}
func copyValue(data []byte, fromPath, toPath string) ([]byte, error) {
sourceValue := gjson.GetBytes(data, fromPath)
func copyValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
return data, fmt.Errorf("source path does not exist: %s", fromPath)
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
}
return sjson.SetBytes(data, toPath, sourceValue.Value())
return sjson.Set(jsonStr, toPath, sourceValue.Value())
}
func isPathBasedOperation(mode string) bool {
@@ -1571,16 +1532,16 @@ func isPathBasedOperation(mode string) bool {
}
}
func resolveOperationPaths(data []byte, path string) ([]string, error) {
func resolveOperationPaths(jsonStr, path string) ([]string, error) {
if !strings.Contains(path, "*") {
return []string{path}, nil
}
return expandWildcardPaths(data, path)
return expandWildcardPaths(jsonStr, path)
}
func expandWildcardPaths(data []byte, path string) ([]string, error) {
func expandWildcardPaths(jsonStr, path string) ([]string, error) {
var root interface{}
if err := common.Unmarshal(data, &root); err != nil {
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
return nil, err
}
@@ -1641,28 +1602,28 @@ func collectWildcardPaths(node interface{}, segments []string, prefix []string)
}
}
func deleteValue(data []byte, path string) ([]byte, error) {
func deleteValue(jsonStr, path string) (string, error) {
if strings.TrimSpace(path) == "" {
return data, nil
return jsonStr, nil
}
return sjson.DeleteBytes(data, path)
return sjson.Delete(jsonStr, path)
}
func modifyValue(data []byte, path string, value interface{}, keepOrigin, isPrepend bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
switch {
case current.IsArray():
return modifyArray(data, path, value, isPrepend)
return modifyArray(jsonStr, path, value, isPrepend)
case current.Type == gjson.String:
return modifyString(data, path, value, isPrepend)
return modifyString(jsonStr, path, value, isPrepend)
case current.Type == gjson.JSON:
return mergeObjects(data, path, value, keepOrigin)
return mergeObjects(jsonStr, path, value, keepOrigin)
}
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
func modifyArray(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
var newArray []interface{}
// 添加新值
addValue := func() {
@@ -1686,11 +1647,11 @@ func modifyArray(data []byte, path string, value interface{}, isPrepend bool) ([
addOriginal()
addValue()
}
return sjson.SetBytes(data, path, newArray)
return sjson.Set(jsonStr, path, newArray)
}
func modifyString(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
valueStr := fmt.Sprintf("%v", value)
var newStr string
if isPrepend {
@@ -1698,17 +1659,17 @@ func modifyString(data []byte, path string, value interface{}, isPrepend bool) (
} else {
newStr = current.String() + valueStr
}
return sjson.SetBytes(data, path, newStr)
return sjson.Set(jsonStr, path, newStr)
}
func trimStringValue(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return data, fmt.Errorf("trim value is required")
return jsonStr, fmt.Errorf("trim value is required")
}
valueStr := fmt.Sprintf("%v", value)
@@ -1718,69 +1679,69 @@ func trimStringValue(data []byte, path string, value interface{}, isPrefix bool)
} else {
newStr = strings.TrimSuffix(current.String(), valueStr)
}
return sjson.SetBytes(data, path, newStr)
return sjson.Set(jsonStr, path, newStr)
}
func ensureStringAffix(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return data, fmt.Errorf("ensure value is required")
return jsonStr, fmt.Errorf("ensure value is required")
}
valueStr := fmt.Sprintf("%v", value)
if valueStr == "" {
return data, fmt.Errorf("ensure value is required")
return jsonStr, fmt.Errorf("ensure value is required")
}
currentStr := current.String()
if isPrefix {
if strings.HasPrefix(currentStr, valueStr) {
return data, nil
return jsonStr, nil
}
return sjson.SetBytes(data, path, valueStr+currentStr)
return sjson.Set(jsonStr, path, valueStr+currentStr)
}
if strings.HasSuffix(currentStr, valueStr) {
return data, nil
return jsonStr, nil
}
return sjson.SetBytes(data, path, currentStr+valueStr)
return sjson.Set(jsonStr, path, currentStr+valueStr)
}
func transformStringValue(data []byte, path string, transform func(string) string) ([]byte, error) {
current := gjson.GetBytes(data, path)
func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
return sjson.SetBytes(data, path, transform(current.String()))
return sjson.Set(jsonStr, path, transform(current.String()))
}
func replaceStringValue(data []byte, path, from, to string) ([]byte, error) {
current := gjson.GetBytes(data, path)
func replaceStringValue(jsonStr, path, from, to string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if from == "" {
return data, fmt.Errorf("replace from is required")
return jsonStr, fmt.Errorf("replace from is required")
}
return sjson.SetBytes(data, path, strings.ReplaceAll(current.String(), from, to))
return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
}
func regexReplaceStringValue(data []byte, path, pattern, replacement string) ([]byte, error) {
current := gjson.GetBytes(data, path)
func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if pattern == "" {
return data, fmt.Errorf("regex pattern is required")
return jsonStr, fmt.Errorf("regex pattern is required")
}
re, err := regexp.Compile(pattern)
if err != nil {
return data, err
return jsonStr, err
}
return sjson.SetBytes(data, path, re.ReplaceAllString(current.String(), replacement))
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
}
type pruneObjectsOptions struct {
@@ -1789,33 +1750,37 @@ type pruneObjectsOptions struct {
recursive bool
}
func pruneObjects(data []byte, path, contextJSON string, value interface{}) ([]byte, error) {
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
options, err := parsePruneObjectsOptions(value)
if err != nil {
return nil, err
return "", err
}
if path == "" {
var root interface{}
if err := common.Unmarshal(data, &root); err != nil {
return nil, err
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
return "", err
}
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
if err != nil {
return nil, err
return "", err
}
return common.Marshal(cleaned)
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return "", err
}
return string(cleanedBytes), nil
}
target := gjson.GetBytes(data, path)
target := gjson.Get(jsonStr, path)
if !target.Exists() {
return data, nil
return jsonStr, nil
}
var targetNode interface{}
if target.Type == gjson.JSON {
if err := common.UnmarshalJsonStr(target.Raw, &targetNode); err != nil {
return nil, err
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
return "", err
}
} else {
targetNode = target.Value()
@@ -1823,13 +1788,13 @@ func pruneObjects(data []byte, path, contextJSON string, value interface{}) ([]b
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
if err != nil {
return nil, err
return "", err
}
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return nil, err
return "", err
}
return sjson.SetRawBytes(data, path, cleanedBytes)
return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
}
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
@@ -2005,16 +1970,16 @@ func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions,
if err != nil {
return false, err
}
return checkConditions(nodeBytes, contextJSON, options.conditions, options.logic)
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
}
func mergeObjects(data []byte, path string, value interface{}, keepOrigin bool) ([]byte, error) {
current := gjson.GetBytes(data, path)
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
current := gjson.Get(jsonStr, path)
var currentMap, newMap map[string]interface{}
// 解析当前值current.Raw 是 data 的子串,避免再分配一份)
if err := common.UnmarshalJsonStr(current.Raw, &currentMap); err != nil {
return nil, err
// 解析当前值
if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
return "", err
}
// 解析新值
switch v := value.(type) {
@@ -2023,7 +1988,7 @@ func mergeObjects(data []byte, path string, value interface{}, keepOrigin bool)
default:
jsonBytes, _ := common.Marshal(v)
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
return nil, err
return "", err
}
}
// 合并
@@ -2036,7 +2001,7 @@ func mergeObjects(data []byte, path string, value interface{}, keepOrigin bool)
result[k] = v
}
}
return sjson.SetBytes(data, path, result)
return sjson.Set(jsonStr, path, result)
}
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
-11
View File
@@ -2054,17 +2054,6 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out))
}
func TestRemoveDisabledFieldsNoControlledFieldsKeepsBody(t *testing.T) {
input := `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
require.Equal(t, input, string(out))
}
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
input := `{
"inference_geo":"eu",
-30
View File
@@ -18,7 +18,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
)
type ThinkingContentInfo struct {
@@ -154,13 +153,6 @@ type RelayInfo struct {
UseRuntimeHeadersOverride bool
ParamOverrideAudit []string
// UpstreamRequestBodySize is the byte size of the marshaled upstream request
// body. It is set when the body is wrapped in a BodyStorage (see
// relay/common/outbound_body.go), so that DoApiRequest can populate
// http.Request.ContentLength manually (net/http only auto-detects it for
// *bytes.Reader/Buffer/strings.Reader). 0 means "let net/http decide".
UpstreamRequestBodySize int64
PriceData types.PriceData
// TieredBillingSnapshot is a frozen snapshot of tiered billing rules
@@ -793,9 +785,6 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
return jsonData, nil
}
if !hasRemovableDisabledField(jsonData, channelOtherSettings) {
return jsonData, nil
}
var data map[string]interface{}
if err := common.Unmarshal(jsonData, &data); err != nil {
@@ -862,25 +851,6 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
return jsonDataAfter, nil
}
func hasRemovableDisabledField(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) bool {
values := gjson.GetManyBytes(
jsonData,
"service_tier",
"inference_geo",
"speed",
"store",
"safety_identifier",
"stream_options.include_obfuscation",
)
return (!channelOtherSettings.AllowServiceTier && values[0].Exists()) ||
(!channelOtherSettings.AllowInferenceGeo && values[1].Exists()) ||
(!channelOtherSettings.AllowSpeed && values[2].Exists()) ||
(channelOtherSettings.DisableStore && values[3].Exists()) ||
(!channelOtherSettings.AllowSafetyIdentifier && values[4].Exists()) ||
(!channelOtherSettings.AllowIncludeObfuscation && values[5].Exists())
}
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
// Currently supports removing functionResponse.id field which Vertex AI does not support
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
+2 -8
View File
@@ -1,6 +1,7 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -175,14 +176,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
logger.LogDebug(c, "text request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
requestBody = bytes.NewBuffer(jsonData)
}
var httpResp *http.Response
+2 -8
View File
@@ -1,6 +1,7 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -58,14 +59,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
logger.LogDebug(c, "converted embedding request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
var requestBody io.Reader = body
var requestBody io.Reader = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
+3 -16
View File
@@ -1,6 +1,7 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -164,14 +165,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
logger.LogDebug(c, "Gemini request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
requestBody = bytes.NewReader(jsonData)
}
resp, err := adaptor.DoRequest(c, info, requestBody)
@@ -269,14 +263,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
}
}
logger.LogDebug(c, "Gemini embedding request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
-71
View File
@@ -1,71 +0,0 @@
package helper
import (
"bytes"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/QuantumNous/new-api/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// TestGetAndValidOpenAIImageRequestMultipartStream verifies multipart image
// edit parsing: the stream field is parsed and validated, and the request body
// stays replayable for the upstream request.
func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
gin.SetMode(gin.TestMode)
newContext := func(t *testing.T, streamValue string, withImage bool) (*gin.Context, string) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
require.NoError(t, writer.WriteField("prompt", "edit this image"))
require.NoError(t, writer.WriteField("stream", streamValue))
if withImage {
part, err := writer.CreateFormFile("image", "input.png")
require.NoError(t, err)
_, err = part.Write([]byte("fake image"))
require.NoError(t, err)
}
require.NoError(t, writer.Close())
originalBody := body.String()
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return c, originalBody
}
t.Run("valid stream value keeps body replayable", func(t *testing.T) {
c, originalBody := newContext(t, "true", true)
req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
require.NoError(t, err)
require.NotNil(t, req.Stream)
require.True(t, *req.Stream)
require.True(t, req.IsStream(c))
bodyAfterValidation, err := io.ReadAll(c.Request.Body)
require.NoError(t, err)
require.Equal(t, originalBody, string(bodyAfterValidation))
form, err := common.ParseMultipartFormReusable(c)
require.NoError(t, err)
require.Equal(t, "true", url.Values(form.Value).Get("stream"))
require.Len(t, form.File["image"], 1)
})
t.Run("invalid stream value is rejected", func(t *testing.T) {
c, _ := newContext(t, "notabool", false)
_, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid stream value")
})
}
+4 -9
View File
@@ -22,8 +22,8 @@ import (
)
const (
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
DefaultMaxScannerBufferSize = 128 << 20 // 64MB (64*1024*1024) default SSE buffer size
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
DefaultMaxScannerBufferSize = 64 << 20 // 64MB (64*1024*1024) default SSE buffer size
DefaultPingInterval = 10 * time.Second
)
@@ -34,12 +34,6 @@ func getScannerBufferSize() int {
return DefaultMaxScannerBufferSize
}
func NewStreamScanner(reader io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
return scanner
}
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
if resp == nil || dataHandler == nil {
@@ -60,7 +54,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
var (
stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
scanner = NewStreamScanner(resp.Body)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
@@ -110,6 +104,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
close(stopChan)
}()
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
+2 -19
View File
@@ -1,7 +1,6 @@
package helper
import (
"bufio"
"fmt"
"io"
"net/http"
@@ -82,22 +81,6 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
}
func TestNewStreamScanner_AllowsLargeStreamLine(t *testing.T) {
oldBufferMB := constant.StreamScannerMaxBufferMB
constant.StreamScannerMaxBufferMB = 1
t.Cleanup(func() {
constant.StreamScannerMaxBufferMB = oldBufferMB
})
payload := strings.Repeat("x", 128<<10)
scanner := NewStreamScanner(strings.NewReader("data: " + payload + "\n"))
scanner.Split(bufio.ScanLines)
require.True(t, scanner.Scan())
assert.Equal(t, "data: "+payload, scanner.Text())
require.NoError(t, scanner.Err())
}
func TestStreamScannerHandler_EmptyBody(t *testing.T) {
t.Parallel()
@@ -631,7 +614,7 @@ func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
assert.NotNil(t, info.StreamStatus)
}
func TestStreamScannerHandler_StreamStatus_ReplacesPreInitialized(t *testing.T) {
func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
t.Parallel()
body := buildSSEBody(5)
@@ -643,7 +626,7 @@ func TestStreamScannerHandler_StreamStatus_ReplacesPreInitialized(t *testing.T)
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
assert.Equal(t, 0, info.StreamStatus.TotalErrorCount())
assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
}
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
+2 -13
View File
@@ -4,8 +4,6 @@ import (
"errors"
"fmt"
"math"
"net/url"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -146,25 +144,16 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
switch relayMode {
case relayconstant.RelayModeImagesEdits:
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
form, err := common.ParseMultipartFormReusable(c)
_, err := c.MultipartForm()
if err != nil {
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
}
formData := url.Values(form.Value)
c.Request.MultipartForm = form
c.Request.PostForm = formData
formData := c.Request.PostForm
imageRequest.Prompt = formData.Get("prompt")
imageRequest.Model = formData.Get("model")
imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n"))))
imageRequest.Quality = formData.Get("quality")
imageRequest.Size = formData.Get("size")
if streamValue := strings.TrimSpace(formData.Get("stream")); streamValue != "" {
stream, err := strconv.ParseBool(streamValue)
if err != nil {
return nil, fmt.Errorf("invalid stream value: %w", err)
}
imageRequest.Stream = common.GetPointer(stream)
}
if imageValue := formData.Get("image"); imageValue != "" {
imageRequest.Image, _ = common.Marshal(imageValue)
}
+4 -11
View File
@@ -77,14 +77,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
}
logger.LogDebug(c, "image request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
requestBody = bytes.NewBuffer(jsonData)
}
}
@@ -140,9 +133,9 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
usage.(*dto.Usage).PromptTokens = 1
}
quality := request.Quality
if quality == "" {
quality = "standard"
quality := "standard"
if request.Quality == "hd" {
quality = "hd"
}
var logContent []string
+2 -8
View File
@@ -1,6 +1,7 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -68,14 +69,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
logger.LogDebug(c, "Rerank request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
requestBody = bytes.NewBuffer(jsonData)
}
resp, err := adaptor.DoRequest(c, info, requestBody)
+2 -8
View File
@@ -1,6 +1,7 @@
package relay
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -103,14 +104,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
logger.LogDebug(c, "requestBody: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
requestBody = bytes.NewBuffer(jsonData)
}
var httpResp *http.Response
+18 -28
View File
@@ -17,10 +17,9 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.BodyStorageCleanup()) // 清理请求体存储
apiRouter.Use(middleware.GlobalAPIRateLimit())
anonymousRequestBodyLimit := middleware.AnonymousRequestBodyLimit()
{
apiRouter.GET("/setup", controller.GetSetup)
apiRouter.POST("/setup", anonymousRequestBodyLimit, controller.PostSetup)
apiRouter.POST("/setup", controller.PostSetup)
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
@@ -41,39 +40,37 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/rankings", middleware.HeaderNavModuleAuth("rankings"), controller.GetRankings)
apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.ResetPassword)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
// OAuth routes - specific routes must come before :provider wildcard
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.POST("/oauth/email/bind", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.EmailBind)
apiRouter.POST("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
// Non-standard OAuth (WeChat, Telegram) - keep original routes
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.POST("/oauth/wechat/bind", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.WeChatBind)
apiRouter.POST("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
apiRouter.GET("/oauth/:provider", middleware.CriticalRateLimit(), controller.HandleOAuth)
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
apiRouter.POST("/stripe/webhook", anonymousRequestBodyLimit, controller.StripeWebhook)
apiRouter.POST("/creem/webhook", anonymousRequestBodyLimit, controller.CreemWebhook)
apiRouter.POST("/waffo/webhook", anonymousRequestBodyLimit, controller.WaffoWebhook)
// :env separates test vs prod URLs so the operator can register each
// in Pancake's matching webhook slot; handler enforces env match.
apiRouter.POST("/waffo-pancake/webhook/:env", anonymousRequestBodyLimit, controller.WaffoPancakeWebhook)
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
apiRouter.POST("/creem/webhook", controller.CreemWebhook)
apiRouter.POST("/waffo/webhook", controller.WaffoWebhook)
//apiRouter.POST("/waffo-pancake/webhook", controller.WaffoPancakeWebhook)
// Universal secure verification routes
apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, middleware.TurnstileCheck(), controller.Login)
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.Verify2FALogin)
userRoute.POST("/passkey/login/begin", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.PasskeyLoginBegin)
userRoute.POST("/passkey/login/finish", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.PasskeyLoginFinish)
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
userRoute.POST("/passkey/login/begin", middleware.CriticalRateLimit(), controller.PasskeyLoginBegin)
userRoute.POST("/passkey/login/finish", middleware.CriticalRateLimit(), controller.PasskeyLoginFinish)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.POST("/epay/notify", anonymousRequestBodyLimit, controller.EpayNotify)
userRoute.POST("/epay/notify", controller.EpayNotify)
userRoute.GET("/epay/notify", controller.EpayNotify)
userRoute.GET("/groups", controller.GetUserGroups)
@@ -103,8 +100,8 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.RequestCreemPay)
selfRoute.POST("/waffo/amount", controller.RequestWaffoAmount)
selfRoute.POST("/waffo/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPay)
selfRoute.POST("/waffo-pancake/amount", controller.RequestWaffoPancakeAmount)
selfRoute.POST("/waffo-pancake/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPancakePay)
//selfRoute.POST("/waffo-pancake/amount", controller.RequestWaffoPancakeAmount)
//selfRoute.POST("/waffo-pancake/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPancakePay)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
@@ -154,11 +151,9 @@ func SetApiRouter(router *gin.Engine) {
subscriptionRoute.GET("/plans", controller.GetSubscriptionPlans)
subscriptionRoute.GET("/self", controller.GetSubscriptionSelf)
subscriptionRoute.PUT("/self/preference", controller.UpdateSubscriptionPreference)
subscriptionRoute.POST("/balance/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestBalancePay)
subscriptionRoute.POST("/epay/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestEpay)
subscriptionRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestStripePay)
subscriptionRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestCreemPay)
subscriptionRoute.POST("/waffo-pancake/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestWaffoPancakePay)
}
subscriptionAdminRoute := apiRouter.Group("/subscription/admin")
subscriptionAdminRoute.Use(middleware.AdminAuth())
@@ -177,10 +172,10 @@ func SetApiRouter(router *gin.Engine) {
}
// Subscription payment callbacks (no auth)
apiRouter.POST("/subscription/epay/notify", anonymousRequestBodyLimit, controller.SubscriptionEpayNotify)
apiRouter.POST("/subscription/epay/notify", controller.SubscriptionEpayNotify)
apiRouter.GET("/subscription/epay/notify", controller.SubscriptionEpayNotify)
apiRouter.GET("/subscription/epay/return", controller.SubscriptionEpayReturn)
apiRouter.POST("/subscription/epay/return", anonymousRequestBodyLimit, controller.SubscriptionEpayReturn)
apiRouter.POST("/subscription/epay/return", controller.SubscriptionEpayReturn)
optionRoute := apiRouter.Group("/option")
optionRoute.Use(middleware.RootAuth())
{
@@ -191,11 +186,6 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.DELETE("/channel_affinity_cache", controller.ClearChannelAffinityCache)
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
optionRoute.POST("/waffo-pancake/catalog", controller.ListWaffoPancakeCatalog)
optionRoute.POST("/waffo-pancake/pair", controller.CreateWaffoPancakePair)
optionRoute.POST("/waffo-pancake/save", controller.SaveWaffoPancake)
optionRoute.POST("/waffo-pancake/subscription-product", controller.CreateWaffoPancakeSubscriptionProduct)
optionRoute.POST("/waffo-pancake/subscription-product-options", controller.ListWaffoPancakeSubscriptionProductOptions)
}
// Custom OAuth provider management (root only)
+1 -1
View File
@@ -17,7 +17,7 @@ func formatNotifyType(channelId int, status int) string {
// disable & notify
func DisableChannel(channelError types.ChannelError, reason string) {
common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, common.LocalLogPreview(reason)))
common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason))
// 检查是否启用自动禁用功能
if !channelError.AutoBan {
-32
View File
@@ -641,38 +641,6 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
return meta.SkipRetry
}
func ClearCurrentChannelAffinityCache(c *gin.Context) bool {
if c == nil {
return false
}
cacheKey, _, ok := getChannelAffinityContext(c)
if !ok || cacheKey == "" {
return false
}
cache := getChannelAffinityCache()
deleted, err := cache.DeleteMany([]string{cacheKey})
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache delete current failed: err=%v", err))
return false
}
c.Set(ginKeyChannelAffinitySkipRetry, false)
for _, ok := range deleted {
if ok {
return true
}
}
return false
}
func ShouldKeepChannelAffinityOnChannelDisabled() bool {
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil {
return false
}
return setting.KeepOnChannelDisabled
}
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
if c == nil || channelID <= 0 {
return
-27
View File
@@ -236,33 +236,6 @@ func TestGetPreferredChannelByAffinity_RequestHeaderKeySource(t *testing.T) {
require.Equal(t, buildChannelAffinityKeyHint(affinityValue), meta.KeyHint)
}
func TestClearCurrentChannelAffinityCache(t *testing.T) {
gin.SetMode(gin.TestMode)
cacheKeySuffix := fmt.Sprintf("codex cli trace:default:clear-current-%d", time.Now().UnixNano())
cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix
cache := getChannelAffinityCache()
require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9527, time.Minute))
t.Cleanup(func() {
_, _ = cache.DeleteMany([]string{cacheKeySuffix})
})
ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
CacheKey: cacheKeyFull,
TTLSeconds: 60,
RuleName: "codex cli trace",
SkipRetry: true,
})
require.True(t, ShouldSkipRetryAfterChannelAffinityFailure(ctx))
deleted := ClearCurrentChannelAffinityCache(ctx)
require.True(t, deleted)
_, found, err := cache.Get(cacheKeySuffix)
require.NoError(t, err)
require.False(t, found)
require.False(t, ShouldSkipRetryAfterChannelAffinityFailure(ctx))
}
func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
gin.SetMode(gin.TestMode)
+3 -5
View File
@@ -92,13 +92,11 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
}
CloseResponseBodyGracefully(resp)
var errResponse dto.GeneralErrorResponse
responseBodyText := string(responseBody)
responseBodyPreview := common.LocalLogPreview(responseBodyText)
buildErrWithBody := func(message string) error {
if message == "" {
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, responseBodyText)
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
}
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, responseBodyText)
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, string(responseBody))
}
err = common.Unmarshal(responseBody, &errResponse)
@@ -106,7 +104,7 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
if showBodyWhenFail {
newApiErr.Err = buildErrWithBody("")
} else {
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, responseBodyPreview))
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
}
return
-104
View File
@@ -1,17 +1,9 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -63,99 +55,3 @@ func TestResetStatusCode(t *testing.T) {
})
}
}
func TestRelayErrorHandlerTruncatesInvalidJSONBodyInLog(t *testing.T) {
withDebugEnabled(t, false)
body := strings.Repeat("b", common.LocalLogContentLimit+256)
var logBuffer bytes.Buffer
common.LogWriterMu.Lock()
oldWriter := gin.DefaultErrorWriter
gin.DefaultErrorWriter = &logBuffer
common.LogWriterMu.Unlock()
t.Cleanup(func() {
common.LogWriterMu.Lock()
gin.DefaultErrorWriter = oldWriter
common.LogWriterMu.Unlock()
})
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.Equal(t, "bad response status code 500", newAPIError.Error())
require.Contains(t, logBuffer.String(), "[truncated")
require.Contains(t, logBuffer.String(), fmt.Sprintf("original_length=%d", len(body)))
require.NotContains(t, logBuffer.String(), strings.Repeat("b", common.LocalLogContentLimit+1))
}
func TestRelayErrorHandlerKeepsStructuredErrorMessage(t *testing.T) {
message := strings.Repeat("c", common.LocalLogContentLimit+256)
body := `{"message":"` + message + `"}`
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.Equal(t, message, newAPIError.Error())
}
func TestRelayErrorHandlerKeepsOpenAIErrorMessage(t *testing.T) {
message := strings.Repeat("d", common.LocalLogContentLimit+256)
body := `{"error":{"message":"` + message + `","type":"server_error","code":"server_error"}}`
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.Equal(t, message, newAPIError.Error())
}
func TestRelayErrorHandlerKeepsInvalidJSONBodyInDebugLog(t *testing.T) {
withDebugEnabled(t, true)
body := strings.Repeat("e", common.LocalLogContentLimit+256)
var logBuffer bytes.Buffer
common.LogWriterMu.Lock()
oldWriter := gin.DefaultErrorWriter
gin.DefaultErrorWriter = &logBuffer
common.LogWriterMu.Unlock()
t.Cleanup(func() {
common.LogWriterMu.Lock()
gin.DefaultErrorWriter = oldWriter
common.LogWriterMu.Unlock()
})
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.NotContains(t, logBuffer.String(), "[truncated")
require.Contains(t, logBuffer.String(), body)
}
func withDebugEnabled(t *testing.T, enabled bool) {
t.Helper()
oldDebug := common.DebugEnabled
common.DebugEnabled = enabled
t.Cleanup(func() {
common.DebugEnabled = oldDebug
})
}
-3
View File
@@ -37,7 +37,6 @@ func InitHttpClient() {
transport := &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
ForceAttemptHTTP2: true,
Proxy: http.ProxyFromEnvironment, // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
}
@@ -109,7 +108,6 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
transport := &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
ForceAttemptHTTP2: true,
Proxy: http.ProxyURL(parsedURL),
}
@@ -149,7 +147,6 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
transport := &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
+321 -406
View File
@@ -1,483 +1,398 @@
package service
import (
"bytes"
"context"
"crypto"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
pancake "github.com/waffo-com/waffo-pancake-sdk-go"
)
// WaffoPancakePriceSnapshot is the per-session price override sent with checkout.
const (
waffoPancakeAuthBaseURL = "https://waffo-pancake-auth-service.vercel.app"
waffoPancakeCheckoutPath = "/v1/actions/checkout/create-session"
waffoPancakeDefaultTolerance = 5 * time.Minute
)
type WaffoPancakePriceSnapshot struct {
Amount string
TaxCategory string
Amount string `json:"amount"`
TaxIncluded bool `json:"taxIncluded"`
TaxCategory string `json:"taxCategory"`
}
// WaffoPancakeCreateSessionParams is the input to CreateWaffoPancakeCheckoutSession.
// BuyerIdentity must be stable per user (see WaffoPancakeBuyerIdentityFromUserID).
// OrderMerchantExternalID = our trade_no; Pancake echoes it back in webhooks.
type WaffoPancakeCreateSessionParams struct {
ProductID string
BuyerIdentity string
PriceSnapshot *WaffoPancakePriceSnapshot
BuyerEmail string
ExpiresInSeconds *int
OrderMerchantExternalID string
StoreID string `json:"storeId"`
ProductID string `json:"productId"`
ProductType string `json:"productType"`
Currency string `json:"currency"`
PriceSnapshot *WaffoPancakePriceSnapshot `json:"priceSnapshot,omitempty"`
BuyerEmail string `json:"buyerEmail,omitempty"`
SuccessURL string `json:"successUrl,omitempty"`
ExpiresInSeconds *int `json:"expiresInSeconds,omitempty"`
}
// WaffoPancakeCheckoutSession is the response of CreateWaffoPancakeCheckoutSession.
// CheckoutURL already carries the `#token=...` fragment; Token / TokenExpiresAt
// are exposed separately for self-service flows driven from new-api's own UI.
type WaffoPancakeCheckoutSession struct {
SessionID string
CheckoutURL string
ExpiresAt string
OrderID string
Token string
TokenExpiresAt string
SessionID string `json:"sessionId"`
CheckoutURL string `json:"checkoutUrl"`
ExpiresAt string `json:"expiresAt"`
OrderID string `json:"orderId"`
}
// WaffoPancakeWebhookEvent mirrors the SDK's WebhookEvent shape using plain
// strings so controllers don't have to import the SDK package.
type WaffoPancakeWebhookEvent struct {
ID string
Timestamp string
EventType string
EventID string
StoreID string
Mode string
Data WaffoPancakeWebhookData
type waffoPancakeAPIError struct {
Message string `json:"message"`
Layer string `json:"layer"`
}
type WaffoPancakeWebhookData struct {
// OrderID = Pancake ORD_* (logs); OrderMerchantExternalID = our trade_no (lookup).
OrderID string
OrderMerchantExternalID string
BuyerEmail string
Currency string
Amount string
TaxAmount string
ProductName string
MerchantProvidedBuyerIdentity string
type waffoPancakeCreateSessionResponse struct {
Data *WaffoPancakeCheckoutSession `json:"data"`
Errors []waffoPancakeAPIError `json:"errors"`
}
// NormalizedEventType returns the event type or empty string for a nil event.
func (e *WaffoPancakeWebhookEvent) NormalizedEventType() string {
type waffoPancakeWebhookData struct {
ID string `json:"id"`
OrderID string `json:"orderId"`
BuyerEmail string `json:"buyerEmail"`
Currency string `json:"currency"`
Amount dto.StringValue `json:"amount"`
TaxAmount dto.StringValue `json:"taxAmount"`
ProductName string `json:"productName"`
}
type waffoPancakeWebhookEvent struct {
ID string `json:"id"`
Timestamp string `json:"timestamp"`
EventType string `json:"eventType"`
EventID string `json:"eventId"`
StoreID string `json:"storeId"`
Mode string `json:"mode"`
Data waffoPancakeWebhookData `json:"data"`
}
func (e *waffoPancakeWebhookEvent) NormalizedEventType() string {
if e == nil {
return ""
}
return e.EventType
}
// newWaffoPancakeClient builds an SDK client from persisted settings. The
// runtime checkout / webhook paths use this; configuration endpoints use
// newWaffoPancakeClientFromCreds so the operator can verify typed-but-not-
// yet-saved credentials.
func newWaffoPancakeClient() (*pancake.Client, error) {
return pancake.New(pancake.Config{
MerchantID: setting.WaffoPancakeMerchantID,
PrivateKey: setting.WaffoPancakePrivateKey,
})
}
func newWaffoPancakeClientFromCreds(merchantID, privateKey string) (*pancake.Client, error) {
if strings.TrimSpace(merchantID) == "" || strings.TrimSpace(privateKey) == "" {
return nil, fmt.Errorf("merchant id and private key are required")
}
return pancake.New(pancake.Config{
MerchantID: merchantID,
PrivateKey: privateKey,
})
}
// CreateWaffoPancakeCheckoutSession creates an Authenticated-mode checkout
// session: the order is bound to BuyerIdentity (stable per user) so it stays
// attributable even if the buyer edits the email on Waffo's checkout form.
func CreateWaffoPancakeCheckoutSession(ctx context.Context, params *WaffoPancakeCreateSessionParams) (*WaffoPancakeCheckoutSession, error) {
if params == nil {
return nil, fmt.Errorf("missing checkout params")
}
if strings.TrimSpace(params.BuyerIdentity) == "" {
return nil, fmt.Errorf("missing buyer identity")
}
if strings.TrimSpace(params.OrderMerchantExternalID) == "" {
return nil, fmt.Errorf("missing order merchant external id")
}
client, err := newWaffoPancakeClient()
body, err := common.Marshal(params)
if err != nil {
return nil, fmt.Errorf("build Waffo Pancake client: %w", err)
return nil, fmt.Errorf("marshal Waffo Pancake checkout payload: %w", err)
}
sdkParams := pancake.AuthenticatedCheckoutParams{
CreateCheckoutSessionParams: pancake.CreateCheckoutSessionParams{
ProductID: params.ProductID,
Currency: "USD",
BuyerEmail: optionalString(params.BuyerEmail),
ExpiresInSeconds: params.ExpiresInSeconds,
OrderMerchantExternalID: optionalString(params.OrderMerchantExternalID),
},
BuyerIdentity: params.BuyerIdentity,
privateKey, err := normalizeRSAPrivateKey(setting.WaffoPancakePrivateKey)
if err != nil {
return nil, err
}
if params.PriceSnapshot != nil {
sdkParams.PriceSnapshot = &pancake.PriceInfo{
Amount: params.PriceSnapshot.Amount,
TaxCategory: pancake.TaxCategory(params.PriceSnapshot.TaxCategory),
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
signature, err := signWaffoPancakeRequest(http.MethodPost, waffoPancakeCheckoutPath, timestamp, string(body), privateKey)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, waffoPancakeAuthBaseURL+waffoPancakeCheckoutPath, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("build Waffo Pancake checkout request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Merchant-Id", setting.WaffoPancakeMerchantID)
req.Header.Set("X-Timestamp", timestamp)
req.Header.Set("X-Signature", signature)
if setting.WaffoPancakeSandbox {
req.Header.Set("X-Environment", "test")
} else {
req.Header.Set("X-Environment", "prod")
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request Waffo Pancake checkout session: %w", err)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read Waffo Pancake checkout response: %w", err)
}
var result waffoPancakeCreateSessionResponse
if err := common.Unmarshal(responseBody, &result); err != nil {
return nil, fmt.Errorf("decode Waffo Pancake checkout response: %w", err)
}
if resp.StatusCode >= http.StatusBadRequest {
if len(result.Errors) > 0 {
return nil, fmt.Errorf("Waffo Pancake error (%d): %s", resp.StatusCode, result.Errors[0].Message)
}
return nil, fmt.Errorf("Waffo Pancake checkout request failed with status %d", resp.StatusCode)
}
if len(result.Errors) > 0 {
return nil, fmt.Errorf("Waffo Pancake error: %s", result.Errors[0].Message)
}
if result.Data == nil || result.Data.CheckoutURL == "" || strings.TrimSpace(result.Data.SessionID) == "" {
return nil, fmt.Errorf("Waffo Pancake returned empty checkout session")
}
return result.Data, nil
}
func VerifyConfiguredWaffoPancakeWebhook(payload string, signatureHeader string) (*waffoPancakeWebhookEvent, error) {
environment := resolveWaffoPancakeWebhookEnvironment(payload)
return verifyWaffoPancakeWebhook(payload, signatureHeader, environment)
}
func ResolveWaffoPancakeTradeNo(event *waffoPancakeWebhookEvent) (string, error) {
if event == nil {
return "", fmt.Errorf("missing webhook event")
}
if tradeNo := strings.TrimSpace(event.Data.OrderID); tradeNo != "" {
topUp := model.GetTopUpByTradeNo(tradeNo)
if topUp != nil && topUp.PaymentMethod == model.PaymentMethodWaffoPancake {
return tradeNo, nil
}
return "", fmt.Errorf("waffo pancake order not found for webhook orderId=%s", tradeNo)
}
return "", fmt.Errorf("missing webhook orderId")
}
func normalizeRSAPrivateKey(raw string) (string, error) {
return normalizePEMKey(raw, "PRIVATE KEY", "RSA PRIVATE KEY")
}
func normalizeRSAPublicKey(raw string) (string, error) {
return normalizePEMKey(raw, "PUBLIC KEY", "RSA PUBLIC KEY")
}
func normalizePEMKey(raw string, pkcs8Type string, pkcs1Type string) (string, error) {
if strings.TrimSpace(raw) == "" {
return "", fmt.Errorf("%s is empty", strings.ToLower(pkcs8Type))
}
normalized := strings.TrimSpace(strings.ReplaceAll(raw, `\n`, "\n"))
if strings.Contains(normalized, "BEGIN ") {
block, _ := pem.Decode([]byte(normalized))
if block == nil {
return "", fmt.Errorf("invalid PEM encoded %s", strings.ToLower(pkcs8Type))
}
return string(pem.EncodeToMemory(block)), nil
}
der, err := base64.StdEncoding.DecodeString(strings.ReplaceAll(normalized, "\n", ""))
if err != nil {
return "", fmt.Errorf("invalid base64 encoded %s: %w", strings.ToLower(pkcs8Type), err)
}
pemType := pkcs8Type
if pkcs8Type == "PRIVATE KEY" {
if _, err := x509.ParsePKCS8PrivateKey(der); err != nil {
if _, err := x509.ParsePKCS1PrivateKey(der); err == nil {
pemType = pkcs1Type
} else {
return "", fmt.Errorf("invalid RSA private key")
}
}
} else {
if _, err := x509.ParsePKIXPublicKey(der); err != nil {
if _, err := x509.ParsePKCS1PublicKey(der); err == nil {
pemType = pkcs1Type
} else {
return "", fmt.Errorf("invalid RSA public key")
}
}
}
session, err := client.Checkout.Authenticated.Create(ctx, sdkParams)
return string(pem.EncodeToMemory(&pem.Block{Type: pemType, Bytes: der})), nil
}
func signWaffoPancakeRequest(method string, path string, timestamp string, body string, privateKeyPEM string) (string, error) {
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return "", fmt.Errorf("invalid RSA private key PEM")
}
var privateKey *rsa.PrivateKey
switch block.Type {
case "PRIVATE KEY":
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", fmt.Errorf("parse PKCS#8 private key: %w", err)
}
parsed, ok := key.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("private key is not RSA")
}
privateKey = parsed
case "RSA PRIVATE KEY":
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", fmt.Errorf("parse PKCS#1 private key: %w", err)
}
privateKey = key
default:
return "", fmt.Errorf("unsupported private key type: %s", block.Type)
}
canonicalRequest := buildWaffoPancakeCanonicalRequest(method, path, timestamp, body)
digest := sha256.Sum256([]byte(canonicalRequest))
signature, err := rsa.SignPKCS1v15(nil, privateKey, crypto.SHA256, digest[:])
if err != nil {
return nil, err
return "", fmt.Errorf("sign Waffo Pancake request: %w", err)
}
if session == nil || strings.TrimSpace(session.CheckoutURL) == "" || strings.TrimSpace(session.SessionID) == "" {
return nil, fmt.Errorf("Waffo Pancake returned empty checkout session")
}
return &WaffoPancakeCheckoutSession{
SessionID: session.SessionID,
CheckoutURL: session.CheckoutURL,
ExpiresAt: session.ExpiresAt,
Token: session.Token,
TokenExpiresAt: session.TokenExpiresAt,
}, nil
return base64.StdEncoding.EncodeToString(signature), nil
}
func optionalString(s string) *string {
if strings.TrimSpace(s) == "" {
return nil
func buildWaffoPancakeCanonicalRequest(method string, path string, timestamp string, body string) string {
bodyHash := sha256.Sum256([]byte(body))
return fmt.Sprintf(
"%s\n%s\n%s\n%s",
strings.ToUpper(method),
path,
timestamp,
base64.StdEncoding.EncodeToString(bodyHash[:]),
)
}
func verifyWaffoPancakeWebhook(payload string, signatureHeader string, environment string) (*waffoPancakeWebhookEvent, error) {
if signatureHeader == "" {
return nil, fmt.Errorf("missing X-Waffo-Signature header")
}
v := s
return &v
}
// WaffoPancakeBuyerIdentityFromUserID renders the canonical buyer identity
// for checkout. Webhook handlers compare against the value rendered here to
// reject identity mismatches, so both call sites must use this function.
func WaffoPancakeBuyerIdentityFromUserID(userID int) string {
return fmt.Sprintf("new-api-user-%d", userID)
}
timestampPart, signaturePart := parseWaffoPancakeSignatureHeader(signatureHeader)
if timestampPart == "" || signaturePart == "" {
return nil, fmt.Errorf("malformed X-Waffo-Signature header")
}
// VerifyConfiguredWaffoPancakeWebhook verifies the signature header. The SDK
// picks the matching test / prod public key from the payload's `mode` field.
func VerifyConfiguredWaffoPancakeWebhook(payload string, signatureHeader string) (*WaffoPancakeWebhookEvent, error) {
evt, err := pancake.VerifyWebhookTyped[pancake.WebhookEventData](payload, signatureHeader, nil)
timestampMs, err := strconv.ParseInt(timestampPart, 10, 64)
if err != nil {
return nil, err
return nil, fmt.Errorf("invalid timestamp in X-Waffo-Signature header")
}
identity := ""
if evt.Data.MerchantProvidedBuyerIdentity != nil {
identity = *evt.Data.MerchantProvidedBuyerIdentity
if math.Abs(float64(time.Now().UnixMilli()-timestampMs)) > float64(waffoPancakeDefaultTolerance.Milliseconds()) {
return nil, fmt.Errorf("webhook timestamp outside tolerance window")
}
externalID := ""
if evt.Data.OrderMerchantExternalID != nil {
externalID = *evt.Data.OrderMerchantExternalID
signatureInput := fmt.Sprintf("%s.%s", timestampPart, payload)
if err := verifyWaffoPancakeWebhookWithKey(signatureInput, signaturePart, resolveWaffoPancakeWebhookPublicKey(environment)); err != nil {
return nil, fmt.Errorf("invalid webhook signature")
}
return &WaffoPancakeWebhookEvent{
ID: evt.ID,
Timestamp: evt.Timestamp,
EventType: evt.EventType,
EventID: evt.EventID,
StoreID: evt.StoreID,
Mode: string(evt.Mode),
Data: WaffoPancakeWebhookData{
OrderID: evt.Data.OrderID,
OrderMerchantExternalID: externalID,
BuyerEmail: evt.Data.BuyerEmail,
Currency: evt.Data.Currency,
Amount: evt.Data.Amount,
TaxAmount: evt.Data.TaxAmount,
ProductName: evt.Data.ProductName,
MerchantProvidedBuyerIdentity: identity,
},
}, nil
var event waffoPancakeWebhookEvent
if err := common.Unmarshal([]byte(payload), &event); err != nil {
return nil, fmt.Errorf("parse Waffo Pancake webhook payload: %w", err)
}
return &event, nil
}
// ResolveWaffoPancakeTradeNo maps a verified webhook event to a local TopUp
// trade_no via OrderMerchantExternalID, and rejects buyer-identity mismatches.
func ResolveWaffoPancakeTradeNo(event *WaffoPancakeWebhookEvent) (string, error) {
if event == nil {
return "", fmt.Errorf("missing webhook event")
func parseWaffoPancakeSignatureHeader(header string) (string, string) {
var timestampPart string
var signaturePart string
for _, pair := range strings.Split(header, ",") {
key, value, found := strings.Cut(strings.TrimSpace(pair), "=")
if !found {
continue
}
switch key {
case "t":
timestampPart = value
case "v1":
signaturePart = value
}
}
tradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
if tradeNo == "" {
return "", fmt.Errorf("missing webhook orderMerchantExternalId")
}
topUp := model.GetTopUpByTradeNo(tradeNo)
if topUp == nil || topUp.PaymentProvider != model.PaymentProviderWaffoPancake {
return "", fmt.Errorf("waffo pancake order not found for tradeNo=%s", tradeNo)
}
expectedIdentity := WaffoPancakeBuyerIdentityFromUserID(topUp.UserId)
actualIdentity := strings.TrimSpace(event.Data.MerchantProvidedBuyerIdentity)
if actualIdentity != expectedIdentity {
return "", fmt.Errorf(
"waffo pancake buyer identity mismatch for tradeNo=%s: expected=%q actual=%q",
tradeNo,
expectedIdentity,
actualIdentity,
)
}
return tradeNo, nil
return timestampPart, signaturePart
}
// ResolveWaffoPancakeSubscriptionTradeNo is the SubscriptionOrder counterpart
// of ResolveWaffoPancakeTradeNo.
func ResolveWaffoPancakeSubscriptionTradeNo(event *WaffoPancakeWebhookEvent) (string, error) {
if event == nil {
return "", fmt.Errorf("missing webhook event")
func resolveWaffoPancakeWebhookEnvironment(payload string) string {
var envelope struct {
Mode string `json:"mode"`
}
tradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
if tradeNo == "" {
return "", fmt.Errorf("missing webhook orderMerchantExternalId")
if err := common.Unmarshal([]byte(payload), &envelope); err != nil {
if setting.WaffoPancakeSandbox {
return "test"
}
return "prod"
}
order := model.GetSubscriptionOrderByTradeNo(tradeNo)
if order == nil || order.PaymentProvider != model.PaymentProviderWaffoPancake {
return "", fmt.Errorf("waffo pancake subscription order not found for tradeNo=%s", tradeNo)
switch strings.ToLower(strings.TrimSpace(envelope.Mode)) {
case "test":
return "test"
case "prod":
return "prod"
default:
if setting.WaffoPancakeSandbox {
return "test"
}
return "prod"
}
expectedIdentity := WaffoPancakeBuyerIdentityFromUserID(order.UserId)
actualIdentity := strings.TrimSpace(event.Data.MerchantProvidedBuyerIdentity)
if actualIdentity != expectedIdentity {
return "", fmt.Errorf(
"waffo pancake buyer identity mismatch for subscription tradeNo=%s: expected=%q actual=%q",
tradeNo,
expectedIdentity,
actualIdentity,
)
}
return tradeNo, nil
}
// Deterministic default names for "+ Create": stable bodies mean stable
// X-Idempotency-Key, which lets Pancake dedupe retries server-side.
const (
defaultWaffoPancakeStoreName = "new-api-store"
defaultWaffoPancakeProductName = "new-api-charge-product"
)
// CreateWaffoPancakePrimaryStore creates a Pancake Store using in-flight
// (not-yet-persisted) credentials and returns the new store ID.
func CreateWaffoPancakePrimaryStore(ctx context.Context, merchantID, privateKey string) (string, error) {
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
if err != nil {
return "", err
func resolveWaffoPancakeWebhookPublicKey(environment string) string {
if environment == "prod" {
return strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
}
storeRes, err := client.Stores.Create(ctx, pancake.CreateStoreParams{
Name: defaultWaffoPancakeStoreName,
})
if err != nil {
return "", fmt.Errorf("create Waffo Pancake store: %w", err)
}
return storeRes.Store.ID, nil
return strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
}
// CreateWaffoPancakeProductForPlan mints (and publishes) a Pancake
// OnetimeProduct priced at `amount` USD, used as a subscription plan's
// SubscriptionPlan.WaffoPancakeProductId.
//
// OnetimeProduct (not SubscriptionProduct) because new-api has no renewal-
// event handling; Pancake auto-renewing without new-api extending user
// access would be a UX divergence. Revisit if renewal handling is added.
func CreateWaffoPancakeProductForPlan(ctx context.Context, merchantID, privateKey, storeID, name, amount, returnURL string) (string, error) {
storeID = strings.TrimSpace(storeID)
if storeID == "" {
return "", fmt.Errorf("store id is required to create a product")
}
name = strings.TrimSpace(name)
if name == "" {
return "", fmt.Errorf("plan name is required")
}
amount = strings.TrimSpace(amount)
if amount == "" {
return "", fmt.Errorf("plan price is required")
}
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
func verifyWaffoPancakeWebhookWithKey(signatureInput string, signaturePart string, rawPublicKey string) error {
publicKeyPEM, err := normalizeRSAPublicKey(rawPublicKey)
if err != nil {
return "", err
return err
}
prodRes, err := client.OnetimeProducts.Create(ctx, pancake.CreateOnetimeProductParams{
StoreID: storeID,
Name: name,
Prices: pancake.Prices{
"USD": {
Amount: amount,
TaxCategory: pancake.TaxCategory("saas"),
},
},
SuccessURL: optionalString(strings.TrimSpace(returnURL)),
})
if err != nil {
return "", fmt.Errorf("create Waffo Pancake plan product: %w", err)
}
productID := prodRes.Product.ID
if _, err := client.OnetimeProducts.Publish(ctx, pancake.PublishOnetimeProductParams{ID: productID}); err != nil {
return "", fmt.Errorf("publish Waffo Pancake plan product: %w", err)
}
return productID, nil
}
// CreateWaffoPancakePrimaryProduct mints (and publishes) the wallet-top-up
// OnetimeProduct under storeID. Per-checkout price overrides via PriceSnapshot
// are what make the "1.00" seed price irrelevant at runtime.
func CreateWaffoPancakePrimaryProduct(ctx context.Context, merchantID, privateKey, storeID, returnURL string) (string, error) {
storeID = strings.TrimSpace(storeID)
if storeID == "" {
return "", fmt.Errorf("store id is required to create a product")
block, _ := pem.Decode([]byte(publicKeyPEM))
if block == nil {
return fmt.Errorf("invalid RSA public key PEM")
}
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
if err != nil {
return "", err
}
prodRes, err := client.OnetimeProducts.Create(ctx, pancake.CreateOnetimeProductParams{
StoreID: storeID,
Name: defaultWaffoPancakeProductName,
Prices: pancake.Prices{
"USD": {
Amount: "1.00", // overridden at checkout via PriceSnapshot
TaxCategory: pancake.TaxCategory("saas"),
},
},
SuccessURL: optionalString(strings.TrimSpace(returnURL)),
})
if err != nil {
return "", fmt.Errorf("create Waffo Pancake product: %w", err)
}
productID := prodRes.Product.ID
if _, err := client.OnetimeProducts.Publish(ctx, pancake.PublishOnetimeProductParams{ID: productID}); err != nil {
return "", fmt.Errorf("publish Waffo Pancake product: %w", err)
}
return productID, nil
}
// WaffoPancakePairResult is the response of CreateWaffoPancakePrimaryPair.
// When OrphanStore is true the store was created but the product wasn't,
// so the caller can surface a partial-failure message with StoreID.
type WaffoPancakePairResult struct {
StoreID string
StoreName string
ProductID string
ProductName string
OrphanStore bool
}
var publicKey *rsa.PublicKey
switch block.Type {
case "PUBLIC KEY":
key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("parse PKIX public key: %w", err)
}
parsed, ok := key.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("public key is not RSA")
}
publicKey = parsed
case "RSA PUBLIC KEY":
key, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("parse PKCS#1 public key: %w", err)
}
publicKey = key
default:
return fmt.Errorf("unsupported public key type: %s", block.Type)
}
// CreateWaffoPancakePrimaryPair mints a Store + OnetimeProduct in one
// round-trip — the canonical "+ Create" entry point. Nothing is persisted
// to settings; the operator's final Save commits the chosen IDs.
func CreateWaffoPancakePrimaryPair(ctx context.Context, merchantID, privateKey, returnURL string) (*WaffoPancakePairResult, error) {
storeID, err := CreateWaffoPancakePrimaryStore(ctx, merchantID, privateKey)
signature, err := base64.StdEncoding.DecodeString(signaturePart)
if err != nil {
return nil, err
return fmt.Errorf("decode webhook signature: %w", err)
}
productID, err := CreateWaffoPancakePrimaryProduct(ctx, merchantID, privateKey, storeID, returnURL)
if err != nil {
return &WaffoPancakePairResult{
StoreID: storeID,
StoreName: defaultWaffoPancakeStoreName,
OrphanStore: true,
}, fmt.Errorf("store created at %s but product creation failed: %w", storeID, err)
}
return &WaffoPancakePairResult{
StoreID: storeID,
StoreName: defaultWaffoPancakeStoreName,
ProductID: productID,
ProductName: defaultWaffoPancakeProductName,
}, nil
}
// SaveWaffoPancakeConfig persists the operator-controlled fields atomically
// at the end of the configuration flow via model.UpdateOptionsBulk (single
// DB transaction). A blank privateKey is treated as "keep current"
// (Stripe-style API-secret UX) and is omitted from the bulk payload.
func SaveWaffoPancakeConfig(ctx context.Context, merchantID, privateKey, returnURL, storeID, productID string) error {
merchantID = strings.TrimSpace(merchantID)
storeID = strings.TrimSpace(storeID)
productID = strings.TrimSpace(productID)
if merchantID == "" || storeID == "" || productID == "" {
return fmt.Errorf("merchant id, store id, and product id are required to save")
}
values := map[string]string{
"WaffoPancakeMerchantID": merchantID,
"WaffoPancakeReturnURL": strings.TrimSpace(returnURL),
"WaffoPancakeStoreID": storeID,
"WaffoPancakeProductID": productID,
}
if pk := strings.TrimSpace(privateKey); pk != "" {
values["WaffoPancakePrivateKey"] = pk
}
if err := model.UpdateOptionsBulk(values); err != nil {
return fmt.Errorf("persist Waffo Pancake config: %w", err)
digest := sha256.Sum256([]byte(signatureInput))
if err := rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, digest[:], signature); err != nil {
return fmt.Errorf("verify webhook signature: %w", err)
}
return nil
}
type WaffoPancakeCatalogProduct struct {
ID string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
}
// WaffoPancakeCatalogStore nests its OnetimeProducts so the UI can render a
// dependent store→product select without a second round-trip.
type WaffoPancakeCatalogStore struct {
ID string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
ProdEnabled bool `json:"prodEnabled"`
OnetimeProducts []WaffoPancakeCatalogProduct `json:"onetimeProducts"`
}
type WaffoPancakeCatalog struct {
Stores []WaffoPancakeCatalogStore `json:"stores"`
}
// ListWaffoPancakeCatalog queries Pancake's GraphQL `stores` for the
// merchant's stores + onetime products. A successful call also proves
// the supplied credentials authenticate (doubles as a credential probe).
func ListWaffoPancakeCatalog(ctx context.Context, merchantID, privateKey string) (*WaffoPancakeCatalog, error) {
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
if err != nil {
return nil, err
}
type queryShape struct {
Stores []WaffoPancakeCatalogStore `json:"stores"`
}
// `limit: 100` because the API returns a single store when limit is
// omitted, even for multi-store merchants. Bump to paginated fetches
// (via `offset`) if real catalogs ever cross the cap.
resp, err := pancake.GraphQLQuery[queryShape](ctx, client, pancake.GraphQLParams{
Query: `query {
stores(limit: 100) {
id
name
status
prodEnabled
onetimeProducts {
id
name
status
}
}
}`,
})
if err != nil {
return nil, fmt.Errorf("query Waffo Pancake catalog: %w", err)
}
if len(resp.Errors) > 0 {
return nil, fmt.Errorf("waffo pancake catalog query returned %d errors: %s",
len(resp.Errors), resp.Errors[0].Message)
}
// Drop non-active products. Operators should only see items they can
// actually bind without later hitting "product unavailable" at checkout.
stores := resp.Data.Stores
for i := range stores {
active := stores[i].OnetimeProducts[:0]
for _, p := range stores[i].OnetimeProducts {
if strings.EqualFold(strings.TrimSpace(p.Status), "active") {
active = append(active, p)
}
}
stores[i].OnetimeProducts = active
}
return &WaffoPancakeCatalog{Stores: stores}, nil
}

Some files were not shown because too many files have changed in this diff Show More