Compare commits
83 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f5753a2b31 | |||
| adc390c5fb | |||
| 32805849d6 | |||
| 01c2128e23 | |||
| 189913b7a0 | |||
| d2f7f9ee3a | |||
| 83068d115e | |||
| 4a188deeaa | |||
| 933ea0cddc | |||
| b53319361f | |||
| 87cc22d7ec | |||
| 3aa113b5a3 | |||
| 00d23abf64 | |||
| 580ad97c02 | |||
| b0ac0429cf | |||
| d17b566bcc | |||
| 7aaa533265 | |||
| 7791b78429 | |||
| cb5c0453f5 | |||
| 4d20e053cb | |||
| 0ff9c35e62 | |||
| 0bbcaa8999 | |||
| 1e9ff8a0de | |||
| 9a2e60dff2 | |||
| b596de739d | |||
| 45d54c1613 | |||
| 086044650d | |||
| 0c7aceb831 | |||
| b2e25b7df2 | |||
| 230a3592f8 | |||
| afb470e405 | |||
| 1588027084 | |||
| 38bf2d8daa | |||
| e8c836d705 | |||
| e79cee1e9e | |||
| 63ead2bf7f | |||
| 5b86ce0d70 | |||
| 74985fa877 | |||
| 1d32037364 | |||
| dc245ae764 | |||
| f8add4ca49 | |||
| 65f8afe922 | |||
| 5bc4c74813 | |||
| 30025aeba3 | |||
| c91ba0c4eb | |||
| f223db9330 | |||
| 9e283ab10b | |||
| a8b7c92e5f | |||
| 6b6c9904ac | |||
| 1011934987 | |||
| bc8110ce36 | |||
| ad224ecf5b | |||
| a64f26d1d2 | |||
| 3360882642 | |||
| b37b6d80b3 | |||
| 3d850d38b6 | |||
| 349d5429ca | |||
| 465c5edab9 | |||
| ff06067a18 | |||
| 51ca897cf4 | |||
| 1288028181 | |||
| 2a528d46cb | |||
| 583da45296 | |||
| b302be30e3 | |||
| 88437a1869 | |||
| b08febaa3c | |||
| 92a0959448 | |||
| 49bc3a1175 | |||
| 0354c38bef | |||
| ebbe315533 | |||
| fddf54ccc5 | |||
| b9bc6f0e21 | |||
| f2c7647ecf | |||
| 19f1821fc8 | |||
| 8e5e89bb5b | |||
| e13d673454 | |||
| ae6a03364d | |||
| 006e801652 | |||
| 6f11d19877 | |||
| 58ba867dd6 | |||
| 20d3e73734 | |||
| 2d1ca15384 | |||
| 0d4b25795a |
@@ -56,6 +56,8 @@
|
||||
# 对话超时设置
|
||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||
# RELAY_TIMEOUT=0
|
||||
# Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制
|
||||
# RELAY_IDLE_CONN_TIMEOUT=90
|
||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||
# STREAMING_TIMEOUT=300
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- 文档:https://docs.newapi.ai/
|
||||
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
|
||||
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
|
||||
- 不接受 coding plan、逆向渠道等技术支持类 issue。
|
||||
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
|
||||
|
||||
**您当前的 newapi 版本**
|
||||
@@ -20,13 +22,18 @@ assignees: ''
|
||||
**提交确认**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已完整查看过文档 https://docs.newapi.ai/ 和项目 README,尤其是常见问题部分
|
||||
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
|
||||
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
|
||||
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue),确认目前没有类似 issue。
|
||||
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/、项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题。
|
||||
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写。
|
||||
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭。
|
||||
|
||||
**问题描述**
|
||||
|
||||
请尽可能说明问题现象、影响范围,以及你判断它是程序问题而不是上游行为或使用问题的依据。
|
||||
|
||||
- 转发问题请尽可能说明渠道类型、转换格式、上游原生支持依据和服务端日志。
|
||||
- 计费问题请尽可能附请求返回的 `usage` 示例。
|
||||
|
||||
**复现步骤**
|
||||
|
||||
**预期结果**
|
||||
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- Docs: https://docs.newapi.ai/
|
||||
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
|
||||
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
|
||||
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
|
||||
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
|
||||
|
||||
**Your current newapi version**
|
||||
@@ -20,13 +22,18 @@ Please fill this in, for example: `v1.0.0`
|
||||
**Submission Checks**
|
||||
|
||||
[//]: # (Remove the space in the box and fill with an x)
|
||||
+ [ ] I have confirmed there are no similar issues
|
||||
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, especially the FAQ section
|
||||
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
|
||||
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
|
||||
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
|
||||
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question.
|
||||
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
|
||||
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
|
||||
|
||||
**Issue Description**
|
||||
|
||||
Describe the symptom, impact scope, and why you believe this is an application issue rather than upstream behavior or a usage question with as much detail as possible.
|
||||
|
||||
- For forwarding issues, include the channel type, conversion format, upstream native-support evidence, and server logs when possible.
|
||||
- For billing issues, include an example of the returned `usage` when possible.
|
||||
|
||||
**Steps to Reproduce**
|
||||
|
||||
**Expected Result**
|
||||
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- 文档:https://docs.newapi.ai/
|
||||
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
|
||||
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
|
||||
- 不接受 coding plan、逆向渠道等技术支持类 issue。
|
||||
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
|
||||
|
||||
**您当前的 newapi 版本**
|
||||
@@ -20,10 +22,10 @@ assignees: ''
|
||||
**提交确认**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已完整查看过文档 https://docs.newapi.ai/ 和项目 README,已确定现有版本无法满足需求
|
||||
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
|
||||
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
|
||||
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue),确认目前没有类似 issue。
|
||||
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/、项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题,且现有版本无法满足需求。
|
||||
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写。
|
||||
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭。
|
||||
|
||||
**功能描述**
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- Docs: https://docs.newapi.ai/
|
||||
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
|
||||
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
|
||||
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
|
||||
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
|
||||
|
||||
**Your current newapi version**
|
||||
@@ -20,10 +22,10 @@ Please fill this in, for example: `v1.0.0`
|
||||
**Submission Checks**
|
||||
|
||||
[//]: # (Remove the space in the box and fill with an x)
|
||||
+ [ ] I have confirmed there are no similar issues
|
||||
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, and confirmed the current version cannot meet my needs
|
||||
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
|
||||
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
|
||||
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
|
||||
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question, and that the current version cannot meet my needs.
|
||||
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
|
||||
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
|
||||
|
||||
**Feature Description**
|
||||
|
||||
|
||||
@@ -33,16 +33,18 @@ jobs:
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/default
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd default
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd classic
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
@@ -91,16 +93,18 @@ jobs:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web/default
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd default
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd classic
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
@@ -146,16 +150,18 @@ jobs:
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/default
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd default
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd classic
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
|
||||
@@ -35,3 +35,4 @@ data/
|
||||
.test
|
||||
token_estimator_test.go
|
||||
skills-lock.json
|
||||
.playwright-mcp
|
||||
|
||||
+18
-16
@@ -1,22 +1,24 @@
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/default/package.json .
|
||||
COPY web/default/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web/default .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
WORKDIR /build/web
|
||||
COPY web/package.json web/bun.lock ./
|
||||
COPY web/default/package.json ./default/package.json
|
||||
COPY web/classic/package.json ./classic/package.json
|
||||
RUN bun install --frozen-lockfile
|
||||
COPY ./web/default ./default
|
||||
COPY ./VERSION /build/VERSION
|
||||
RUN cd default && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
|
||||
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder-classic
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/classic/package.json .
|
||||
COPY web/classic/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web/classic .
|
||||
COPY ./VERSION .
|
||||
RUN VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
WORKDIR /build/web
|
||||
COPY web/package.json web/bun.lock ./
|
||||
COPY web/default/package.json ./default/package.json
|
||||
COPY web/classic/package.json ./classic/package.json
|
||||
RUN bun install --frozen-lockfile
|
||||
COPY ./web/classic ./classic
|
||||
COPY ./VERSION /build/VERSION
|
||||
RUN cd classic && VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
|
||||
|
||||
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
@@ -32,8 +34,8 @@ ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
COPY --from=builder /build/dist ./web/default/dist
|
||||
COPY --from=builder-classic /build/dist ./web/classic/dist
|
||||
COPY --from=builder /build/web/default/dist ./web/default/dist
|
||||
COPY --from=builder-classic /build/web/classic/dist ./web/classic/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
|
||||
|
||||
@@ -316,6 +316,7 @@ docker run --name new-api -d --restart always \
|
||||
| `CRYPTO_SECRET` | Encryption secret (required for Redis) | - |
|
||||
| `SQL_DSN` | Database connection string | - |
|
||||
| `REDIS_CONN_STRING` | Redis connection string | - |
|
||||
| `RELAY_IDLE_CONN_TIMEOUT` | Idle keep-alive timeout for relay HTTP clients, seconds. Defaults to Go standard library behavior; set `0` to disable | `90` |
|
||||
| `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` |
|
||||
| `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` |
|
||||
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
|
||||
|
||||
@@ -170,6 +170,7 @@ var BatchUpdateInterval int
|
||||
|
||||
var RelayTimeout int // unit is second
|
||||
|
||||
var RelayIdleConnTimeout int // unit is second
|
||||
var RelayMaxIdleConns int
|
||||
var RelayMaxIdleConnsPerHost int
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func checkWriter(writer io.Writer) stringWriter {
|
||||
// W3C Working Draft 29 October 2009
|
||||
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
|
||||
|
||||
var contentType = []string{"text/event-stream"}
|
||||
var writeContentType = []string{"text/event-stream"}
|
||||
var noCache = []string{"no-cache"}
|
||||
|
||||
var fieldReplacer = strings.NewReplacer(
|
||||
@@ -79,7 +79,7 @@ func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
header["Content-Type"] = writeContentType
|
||||
|
||||
if _, exist := header["Cache-Control"]; !exist {
|
||||
header["Cache-Control"] = noCache
|
||||
|
||||
+19
-1
@@ -110,11 +110,29 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
|
||||
// disk-backed JSON: stream-decode directly from the file to avoid
|
||||
// materializing the entire payload back into a transient []byte
|
||||
// (diskStorage.Bytes() would ReadFull the whole file into the heap).
|
||||
if storage.IsDisk() && strings.HasPrefix(contentType, "application/json") {
|
||||
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||
return seekErr
|
||||
}
|
||||
if err := DecodeJson(storage, v); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||
return seekErr
|
||||
}
|
||||
c.Request.Body = io.NopCloser(storage)
|
||||
return nil
|
||||
}
|
||||
|
||||
requestBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = Unmarshal(requestBody, v)
|
||||
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
|
||||
|
||||
@@ -102,6 +102,7 @@ func InitEnv() {
|
||||
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
|
||||
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
|
||||
RelayIdleConnTimeout = GetEnvOrDefault("RELAY_IDLE_CONN_TIMEOUT", 90)
|
||||
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
|
||||
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
|
||||
|
||||
@@ -135,6 +136,7 @@ func initConstantEnv() {
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 128)
|
||||
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
|
||||
constant.AnonymousRequestBodyLimitKB = GetEnvOrDefault("ANONYMOUS_REQUEST_BODY_LIMIT_KB", 512)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package common
|
||||
|
||||
import "github.com/QuantumNous/new-api/constant"
|
||||
|
||||
const defaultAnonymousRequestBodyLimitKB = 512
|
||||
|
||||
func GetAnonymousRequestBodyLimitBytes() int64 {
|
||||
limitKB := constant.AnonymousRequestBodyLimitKB
|
||||
if limitKB < 0 {
|
||||
limitKB = defaultAnonymousRequestBodyLimitKB
|
||||
}
|
||||
return int64(limitKB) << 10
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -20,6 +21,16 @@ var (
|
||||
maskApiKeyPattern = regexp.MustCompile(`(['"]?)api_key:([^\s'"]+)(['"]?)`)
|
||||
)
|
||||
|
||||
const LocalLogContentLimit = 2048
|
||||
|
||||
// LocalLogPreview limits log-only content unless debug logging is enabled.
|
||||
func LocalLogPreview(content string) string {
|
||||
if DebugEnabled || len(content) <= LocalLogContentLimit {
|
||||
return content
|
||||
}
|
||||
return fmt.Sprintf("%s... [truncated, original_length=%d, limit=%d]", content[:LocalLogContentLimit], len(content), LocalLogContentLimit)
|
||||
}
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
if str == "" {
|
||||
return defaultValue
|
||||
|
||||
@@ -10,6 +10,7 @@ var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
var MaxRequestBodyMB int
|
||||
var AnonymousRequestBodyLimitKB int
|
||||
var AzureDefaultAPIVersion string
|
||||
var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
|
||||
@@ -57,7 +57,24 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
|
||||
return normalized
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
|
||||
func resolveChannelTestUserID(c *gin.Context) (int, error) {
|
||||
if c != nil {
|
||||
if userID := c.GetInt("id"); userID > 0 {
|
||||
return userID, nil
|
||||
}
|
||||
}
|
||||
|
||||
var rootUser model.User
|
||||
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
|
||||
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
|
||||
}
|
||||
if rootUser.Id == 0 {
|
||||
return 0, errors.New("failed to resolve channel test user")
|
||||
}
|
||||
return rootUser.Id, nil
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
|
||||
tik := time.Now()
|
||||
var unsupportedTestChannelTypes = []int{
|
||||
constant.ChannelTypeMidjourney,
|
||||
@@ -143,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
cache, err := model.GetUserCache(testUserID)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
localErr: err,
|
||||
@@ -151,13 +168,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
}
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
c.Set("id", 1)
|
||||
c.Set("id", testUserID)
|
||||
|
||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
group, _ := model.GetUserGroup(testUserID, false)
|
||||
c.Set("group", group)
|
||||
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
@@ -484,7 +501,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
@@ -797,7 +814,7 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel,
|
||||
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "o") {
|
||||
if dto.IsOpenAIReasoningOModel(model) {
|
||||
testRequest.MaxCompletionTokens = lo.ToPtr(uint(16))
|
||||
} else if strings.Contains(model, "thinking") {
|
||||
if !strings.Contains(model, "claude") {
|
||||
@@ -834,8 +851,13 @@ func TestChannel(c *gin.Context) {
|
||||
testModel := c.Query("model")
|
||||
endpointType := c.Query("endpoint_type")
|
||||
isStream, _ := strconv.ParseBool(c.Query("stream"))
|
||||
testUserID, err := resolveChannelTestUserID(c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, testModel, endpointType, isStream)
|
||||
result := testChannel(channel, testUserID, testModel, endpointType, isStream)
|
||||
if result.localErr != nil {
|
||||
resp := gin.H{
|
||||
"success": false,
|
||||
@@ -872,6 +894,10 @@ var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
testUserID, err := resolveChannelTestUserID(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
testAllChannelsLock.Lock()
|
||||
if testAllChannelsRunning {
|
||||
@@ -902,7 +928,7 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||
result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
|
||||
@@ -1218,7 +1218,7 @@ func CopyChannel(c *gin.Context) {
|
||||
}
|
||||
|
||||
// insert
|
||||
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
||||
if err := clone.Insert(); err != nil {
|
||||
common.SysError("failed to clone channel: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
|
||||
return
|
||||
|
||||
@@ -69,3 +69,14 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
|
||||
require.Equal(t, "base", other["matched_tier"])
|
||||
require.NotEmpty(t, other["expr_b64"])
|
||||
}
|
||||
|
||||
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Set("id", 2)
|
||||
|
||||
userID, err := resolveChannelTestUserID(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, userID)
|
||||
}
|
||||
|
||||
@@ -88,6 +88,7 @@ func GetStatus(c *gin.Context) {
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"register_enabled": common.RegisterEnabled,
|
||||
"password_login_enabled": common.PasswordLoginEnabled,
|
||||
"password_register_enabled": common.PasswordRegisterEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
|
||||
|
||||
+120
-43
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -109,9 +110,102 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
func channelOwnerName(channelType int) string {
|
||||
apiType, success := common.ChannelType2APIType(channelType)
|
||||
if !success {
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
adaptor.Init(&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: channelType,
|
||||
}})
|
||||
if name := strings.TrimSpace(adaptor.GetChannelName()); name != "" {
|
||||
return name
|
||||
}
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
|
||||
func getPreferredModelOwners(modelNames []string, groups []string) map[string]string {
|
||||
channelTypes, err := model.GetPreferredModelOwnerChannelTypes(modelNames, groups)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("GetPreferredModelOwnerChannelTypes error: %v", err))
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
ownerByChannelType := make(map[int]string)
|
||||
owners := make(map[string]string, len(channelTypes))
|
||||
for modelName, channelType := range channelTypes {
|
||||
owner, ok := ownerByChannelType[channelType]
|
||||
if !ok {
|
||||
owner = channelOwnerName(channelType)
|
||||
ownerByChannelType[channelType] = owner
|
||||
}
|
||||
if owner != "" {
|
||||
owners[modelName] = owner
|
||||
}
|
||||
}
|
||||
return owners
|
||||
}
|
||||
|
||||
func buildOpenAIModel(modelName string, ownerByModel map[string]string) dto.OpenAIModels {
|
||||
var oaiModel dto.OpenAIModels
|
||||
if staticModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel = staticModel
|
||||
} else {
|
||||
oaiModel = dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
}
|
||||
}
|
||||
if owner, ok := ownerByModel[modelName]; ok && owner != "" {
|
||||
oaiModel.OwnedBy = owner
|
||||
}
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
return oaiModel
|
||||
}
|
||||
|
||||
type modelListGroups struct {
|
||||
userGroup string
|
||||
tokenGroup string
|
||||
ownerGroups []string
|
||||
}
|
||||
|
||||
func getModelListGroups(c *gin.Context) (modelListGroups, error) {
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
if userGroup == "" && (tokenGroup == "" || tokenGroup == "auto") {
|
||||
var err error
|
||||
userGroup, err = model.GetUserGroup(c.GetInt("id"), false)
|
||||
if err != nil {
|
||||
return modelListGroups{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if tokenGroup == "auto" {
|
||||
return modelListGroups{
|
||||
userGroup: userGroup,
|
||||
tokenGroup: tokenGroup,
|
||||
ownerGroups: service.GetUserAutoGroup(userGroup),
|
||||
}, nil
|
||||
}
|
||||
|
||||
group := userGroup
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
return modelListGroups{
|
||||
userGroup: userGroup,
|
||||
tokenGroup: tokenGroup,
|
||||
ownerGroups: []string{group},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
|
||||
if !acceptUnsetRatioModel {
|
||||
userId := c.GetInt("id")
|
||||
@@ -123,6 +217,16 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
}
|
||||
|
||||
userModelNames := make([]string, 0)
|
||||
groups, err := getModelListGroups(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
ownerGroups := groups.ownerGroups
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
@@ -138,37 +242,12 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
userModelNames = append(userModelNames, allowModel)
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
|
||||
if groups.tokenGroup == "auto" {
|
||||
for _, autoGroup := range ownerGroups {
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
@@ -177,7 +256,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
models = model.GetGroupEnabledModels(ownerGroups[0])
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if !acceptUnsetRatioModel {
|
||||
@@ -185,21 +264,19 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
userModelNames = append(userModelNames, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
ownerByModel := map[string]string{}
|
||||
if len(ownerGroups) > 0 {
|
||||
ownerByModel = getPreferredModelOwners(userModelNames, ownerGroups)
|
||||
}
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0, len(userModelNames))
|
||||
for _, modelName := range userModelNames {
|
||||
userOpenAiModels = append(userOpenAiModels, buildOpenAIModel(modelName, ownerByModel))
|
||||
}
|
||||
|
||||
switch modelType {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChannelOwnerNameUsesAdaptorChannelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
channelType int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "openai",
|
||||
channelType: constant.ChannelTypeOpenAI,
|
||||
expected: "openai",
|
||||
},
|
||||
{
|
||||
name: "codex",
|
||||
channelType: constant.ChannelTypeCodex,
|
||||
expected: "codex",
|
||||
},
|
||||
{
|
||||
name: "openrouter",
|
||||
channelType: constant.ChannelTypeOpenRouter,
|
||||
expected: "openrouter",
|
||||
},
|
||||
{
|
||||
name: "azure fallback",
|
||||
channelType: constant.ChannelTypeAzure,
|
||||
expected: "azure",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, channelOwnerName(tt.channelType))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIModelOverridesOwnedBy(t *testing.T) {
|
||||
modelItem := buildOpenAIModel("gpt-5.4", map[string]string{"gpt-5.4": "openai"})
|
||||
require.Equal(t, "gpt-5.4", modelItem.Id)
|
||||
require.Equal(t, "openai", modelItem.OwnedBy)
|
||||
}
|
||||
|
||||
func TestBuildOpenAIModelFallsBackToCustomForUnknownModels(t *testing.T) {
|
||||
modelItem := buildOpenAIModel("custom-test-model", nil)
|
||||
require.Equal(t, "custom-test-model", modelItem.Id)
|
||||
require.Equal(t, "custom", modelItem.OwnedBy)
|
||||
}
|
||||
|
||||
func TestGetModelListGroupsUsesUserGroupWhenTokenGroupIsEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||
|
||||
groups, err := getModelListGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "default", groups.userGroup)
|
||||
require.Empty(t, groups.tokenGroup)
|
||||
require.Equal(t, []string{"default"}, groups.ownerGroups)
|
||||
}
|
||||
|
||||
func TestGetModelListGroupsUsesExplicitTokenGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenGroup, "vip")
|
||||
|
||||
groups, err := getModelListGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "default", groups.userGroup)
|
||||
require.Equal(t, "vip", groups.tokenGroup)
|
||||
require.Equal(t, []string{"vip"}, groups.ownerGroups)
|
||||
}
|
||||
+1
-10
@@ -42,15 +42,6 @@ func isPositiveOptionValue(value string) bool {
|
||||
return err == nil && floatValue > 0
|
||||
}
|
||||
|
||||
func isVisiblePublicKeyOption(key string) bool {
|
||||
switch key {
|
||||
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return
|
||||
@@ -95,7 +86,7 @@ func GetOptions(c *gin.Context) {
|
||||
strings.HasSuffix(k, "Key") ||
|
||||
strings.HasSuffix(k, "secret") ||
|
||||
strings.HasSuffix(k, "api_key")
|
||||
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
|
||||
if isSensitiveKey {
|
||||
continue
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
|
||||
@@ -77,24 +77,15 @@ func isWaffoPancakeTopUpEnabled() bool {
|
||||
if !isPaymentComplianceConfirmed() {
|
||||
return false
|
||||
}
|
||||
if !setting.WaffoPancakeEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
return isWaffoPancakeWebhookConfigured() &&
|
||||
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
|
||||
// Presence-of-credentials = enabled. Webhook public keys ship inside
|
||||
// the SDK; mode (test/prod) is read from each event.
|
||||
return strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
|
||||
}
|
||||
|
||||
func isWaffoPancakeWebhookConfigured() bool {
|
||||
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
|
||||
if setting.WaffoPancakeSandbox {
|
||||
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
|
||||
}
|
||||
|
||||
return currentWebhookKey != ""
|
||||
return isWaffoPancakeTopUpEnabled()
|
||||
}
|
||||
|
||||
func isWaffoPancakeWebhookEnabled() bool {
|
||||
|
||||
@@ -114,47 +114,32 @@ func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
|
||||
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
confirmPaymentComplianceForTest(t)
|
||||
originalEnabled := setting.WaffoPancakeEnabled
|
||||
originalSandbox := setting.WaffoPancakeSandbox
|
||||
originalMerchantID := setting.WaffoPancakeMerchantID
|
||||
originalPrivateKey := setting.WaffoPancakePrivateKey
|
||||
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
|
||||
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
|
||||
originalStoreID := setting.WaffoPancakeStoreID
|
||||
originalProductID := setting.WaffoPancakeProductID
|
||||
t.Cleanup(func() {
|
||||
setting.WaffoPancakeEnabled = originalEnabled
|
||||
setting.WaffoPancakeSandbox = originalSandbox
|
||||
setting.WaffoPancakeMerchantID = originalMerchantID
|
||||
setting.WaffoPancakePrivateKey = originalPrivateKey
|
||||
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
|
||||
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
|
||||
setting.WaffoPancakeStoreID = originalStoreID
|
||||
setting.WaffoPancakeProductID = originalProductID
|
||||
})
|
||||
|
||||
setting.WaffoPancakeEnabled = true
|
||||
setting.WaffoPancakeSandbox = false
|
||||
setting.WaffoPancakeMerchantID = "merchant"
|
||||
// Presence of all three credentials enables the gateway. Webhook public
|
||||
// keys are bundled in the SDK and there is no separate Enabled toggle —
|
||||
// clear any of the three fields to disable.
|
||||
setting.WaffoPancakeMerchantID = ""
|
||||
setting.WaffoPancakePrivateKey = "private"
|
||||
setting.WaffoPancakeStoreID = "store"
|
||||
setting.WaffoPancakeProductID = "product"
|
||||
setting.WaffoPancakeWebhookPublicKey = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeWebhookPublicKey = "public"
|
||||
setting.WaffoPancakeMerchantID = "merchant"
|
||||
require.True(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeEnabled = false
|
||||
setting.WaffoPancakeProductID = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeEnabled = true
|
||||
setting.WaffoPancakeSandbox = true
|
||||
setting.WaffoPancakeWebhookTestKey = ""
|
||||
setting.WaffoPancakeProductID = "product"
|
||||
setting.WaffoPancakePrivateKey = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeWebhookTestKey = "test_public"
|
||||
require.True(t, isWaffoPancakeWebhookEnabled())
|
||||
}
|
||||
|
||||
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func GetPerfMetricsSummary(c *gin.Context) {
|
||||
@@ -18,7 +19,8 @@ func GetPerfMetricsSummary(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
result, err := perfmetrics.QuerySummaryAll(hours)
|
||||
activeGroups := append(lo.Keys(ratio_setting.GetGroupRatioCopy()), "auto")
|
||||
result, err := perfmetrics.QuerySummaryAll(hours, activeGroups)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
@@ -72,12 +74,9 @@ func GetPerfMetrics(c *gin.Context) {
|
||||
}
|
||||
|
||||
func filterActiveGroups(groups []perfmetrics.GroupResult) []perfmetrics.GroupResult {
|
||||
activeGroups := ratio_setting.GetGroupRatioCopy()
|
||||
filtered := make([]perfmetrics.GroupResult, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
if _, ok := activeGroups[g.Group]; ok || g.Group == "auto" {
|
||||
filtered = append(filtered, g)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
activeRatios := ratio_setting.GetGroupRatioCopy()
|
||||
return lo.Filter(groups, func(g perfmetrics.GroupResult, _ int) bool {
|
||||
_, ok := activeRatios[g.Group]
|
||||
return ok || g.Group == "auto"
|
||||
})
|
||||
}
|
||||
|
||||
+2
-2
@@ -88,7 +88,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("relay error: %s", common.LocalLogPreview(newAPIError.Error())))
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
@@ -354,7 +354,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, common.LocalLogPreview(err.Error())))
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(err) && channelError.AutoBan {
|
||||
|
||||
@@ -22,6 +22,10 @@ type BillingPreferenceRequest struct {
|
||||
BillingPreference string `json:"billing_preference"`
|
||||
}
|
||||
|
||||
type SubscriptionBalancePayRequest struct {
|
||||
PlanId int `json:"plan_id"`
|
||||
}
|
||||
|
||||
// ---- User APIs ----
|
||||
|
||||
func GetSubscriptionPlans(c *gin.Context) {
|
||||
@@ -37,6 +41,7 @@ func GetSubscriptionPlans(c *gin.Context) {
|
||||
}
|
||||
result := make([]SubscriptionPlanDTO, 0, len(plans))
|
||||
for _, p := range plans {
|
||||
p.NormalizeDefaults()
|
||||
result = append(result, SubscriptionPlanDTO{
|
||||
Plan: p,
|
||||
})
|
||||
@@ -92,6 +97,25 @@ func UpdateSubscriptionPreference(c *gin.Context) {
|
||||
common.ApiSuccess(c, gin.H{"billing_preference": pref})
|
||||
}
|
||||
|
||||
func SubscriptionRequestBalancePay(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
var req SubscriptionBalancePayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.PurchaseSubscriptionWithBalance(userId, req.PlanId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
// ---- Admin APIs ----
|
||||
|
||||
func AdminListSubscriptionPlans(c *gin.Context) {
|
||||
@@ -102,6 +126,7 @@ func AdminListSubscriptionPlans(c *gin.Context) {
|
||||
}
|
||||
result := make([]SubscriptionPlanDTO, 0, len(plans))
|
||||
for _, p := range plans {
|
||||
p.NormalizeDefaults()
|
||||
result = append(result, SubscriptionPlanDTO{
|
||||
Plan: p,
|
||||
})
|
||||
@@ -140,6 +165,9 @@ func AdminCreateSubscriptionPlan(c *gin.Context) {
|
||||
req.Plan.Currency = "USD"
|
||||
}
|
||||
req.Plan.Currency = "USD"
|
||||
if req.Plan.AllowBalancePay == nil {
|
||||
req.Plan.AllowBalancePay = common.GetPointer(true)
|
||||
}
|
||||
if req.Plan.DurationUnit == "" {
|
||||
req.Plan.DurationUnit = model.SubscriptionDurationMonth
|
||||
}
|
||||
@@ -248,6 +276,7 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
|
||||
"sort_order": req.Plan.SortOrder,
|
||||
"stripe_price_id": req.Plan.StripePriceId,
|
||||
"creem_product_id": req.Plan.CreemProductId,
|
||||
"waffo_pancake_product_id": req.Plan.WaffoPancakeProductId,
|
||||
"max_purchase_per_user": req.Plan.MaxPurchasePerUser,
|
||||
"total_amount": req.Plan.TotalAmount,
|
||||
"upgrade_group": req.Plan.UpgradeGroup,
|
||||
@@ -255,6 +284,9 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
|
||||
"quota_reset_custom_seconds": req.Plan.QuotaResetCustomSeconds,
|
||||
"updated_at": common.GetTimestamp(),
|
||||
}
|
||||
if req.Plan.AllowBalancePay != nil {
|
||||
updateMap["allow_balance_pay"] = *req.Plan.AllowBalancePay
|
||||
}
|
||||
if err := tx.Model(&model.SubscriptionPlan{}).Where("id = ?", id).Updates(updateMap).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
type SubscriptionWaffoPancakePayRequest struct {
|
||||
PlanId int `json:"plan_id"`
|
||||
}
|
||||
|
||||
func SubscriptionRequestWaffoPancakePay(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var req SubscriptionWaffoPancakePayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
plan, err := model.GetSubscriptionPlanById(req.PlanId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if !plan.Enabled {
|
||||
common.ApiErrorMsg(c, "套餐未启用")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(plan.WaffoPancakeProductId) == "" {
|
||||
common.ApiErrorMsg(c, "该套餐未配置 WaffoPancakeProductId")
|
||||
return
|
||||
}
|
||||
// Plan targets its own Pancake product, so we only require credentials
|
||||
// here — not the gateway-level WaffoPancakeProductID.
|
||||
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" {
|
||||
common.ApiErrorMsg(c, "Waffo Pancake 未配置或密钥无效")
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if user == nil {
|
||||
common.ApiErrorMsg(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if plan.MaxPurchasePerUser > 0 {
|
||||
count, err := model.CountUserSubscriptionsByPlan(userId, plan.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if count >= int64(plan.MaxPurchasePerUser) {
|
||||
common.ApiErrorMsg(c, "已达到该套餐购买上限")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WAFFO_PANCAKE_SUB- prefix (vs. wallet's WAFFO_PANCAKE-) drives webhook
|
||||
// dispatch in WaffoPancakeWebhook.
|
||||
tradeNo := fmt.Sprintf("WAFFO_PANCAKE_SUB-%d-%d-%s", userId, time.Now().UnixMilli(), randstr.String(6))
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
expiresInSeconds := 45 * 60
|
||||
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
|
||||
ProductID: plan.WaffoPancakeProductId,
|
||||
BuyerIdentity: service.WaffoPancakeBuyerIdentityFromUserID(user.Id),
|
||||
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
|
||||
Amount: decimal.NewFromFloat(plan.PriceAmount).StringFixed(2),
|
||||
TaxCategory: "saas",
|
||||
},
|
||||
BuyerEmail: getWaffoPancakeBuyerEmail(user),
|
||||
ExpiresInSeconds: &expiresInSeconds,
|
||||
OrderMerchantExternalID: tradeNo,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅结账会话创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
|
||||
order.Status = common.TopUpStatusFailed
|
||||
_ = order.Update()
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建成功 user_id=%d plan_id=%d trade_no=%s session_id=%s money=%.2f", userId, plan.Id, tradeNo, session.SessionID, plan.PriceAmount))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"checkout_url": session.CheckoutURL,
|
||||
"session_id": session.SessionID,
|
||||
"expires_at": session.ExpiresAt,
|
||||
"order_id": tradeNo,
|
||||
"token": session.Token,
|
||||
"token_expires_at": session.TokenExpiresAt,
|
||||
},
|
||||
})
|
||||
}
|
||||
+21
-20
@@ -52,6 +52,27 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Waffo Pancake displayed above the legacy Waffo gateway.
|
||||
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
|
||||
if enableWaffoPancake {
|
||||
hasWaffoPancake := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == model.PaymentMethodWaffoPancake {
|
||||
hasWaffoPancake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasWaffoPancake {
|
||||
payMethods = append(payMethods, map[string]string{
|
||||
"name": "Waffo Pancake",
|
||||
"type": model.PaymentMethodWaffoPancake,
|
||||
"color": "rgba(var(--semi-orange-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 如果启用了 Waffo 支付,添加到支付方法列表
|
||||
enableWaffo := isWaffoTopUpEnabled()
|
||||
if enableWaffo {
|
||||
@@ -74,26 +95,6 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
|
||||
if enableWaffoPancake {
|
||||
hasWaffoPancake := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == model.PaymentMethodWaffoPancake {
|
||||
hasWaffoPancake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasWaffoPancake {
|
||||
payMethods = append(payMethods, map[string]string{
|
||||
"name": "Waffo Pancake",
|
||||
"type": model.PaymentMethodWaffoPancake,
|
||||
"color": "rgba(var(--semi-orange-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": isEpayTopUpEnabled(),
|
||||
"enable_stripe_topup": isStripeTopUpEnabled(),
|
||||
|
||||
@@ -96,33 +96,257 @@ func getWaffoPancakeBuyerEmail(user *model.User) string {
|
||||
if user != nil && strings.TrimSpace(user.Email) != "" {
|
||||
return user.Email
|
||||
}
|
||||
if user != nil {
|
||||
return fmt.Sprintf("%d@new-api.local", user.Id)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getWaffoPancakeReturnURL() string {
|
||||
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
|
||||
return setting.WaffoPancakeReturnURL
|
||||
// The admin config endpoints below accept typed-but-not-yet-saved creds in
|
||||
// the body and fall back to persisted creds when the body is blank (see
|
||||
// resolveWaffoPancakeAdminCreds). Only SaveWaffoPancake writes to OptionMap.
|
||||
|
||||
type waffoPancakeCredsRequest struct {
|
||||
MerchantID string `json:"merchant_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
}
|
||||
|
||||
type saveWaffoPancakeRequest struct {
|
||||
MerchantID string `json:"merchant_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
StoreID string `json:"store_id"`
|
||||
ProductID string `json:"product_id"`
|
||||
}
|
||||
|
||||
type createWaffoPancakePairRequest struct {
|
||||
MerchantID string `json:"merchant_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
}
|
||||
|
||||
// SaveWaffoPancake atomically persists all five operator-controlled fields.
|
||||
// Catalog / pair endpoints are transient — only this one writes the OptionMap.
|
||||
func SaveWaffoPancake(c *gin.Context) {
|
||||
var req saveWaffoPancakeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
return paymentReturnPath("/console/topup?show_history=true")
|
||||
if err := service.SaveWaffoPancakeConfig(
|
||||
c.Request.Context(),
|
||||
req.MerchantID,
|
||||
req.PrivateKey,
|
||||
req.ReturnURL,
|
||||
req.StoreID,
|
||||
req.ProductID,
|
||||
); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 保存配置失败 store_id=%q product_id=%q error=%q",
|
||||
req.StoreID, req.ProductID, err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "保存配置失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"product_id": setting.WaffoPancakeProductID,
|
||||
"store_id": setting.WaffoPancakeStoreID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// resolveWaffoPancakeAdminCreds prefers body creds (typed-but-not-yet-saved
|
||||
// values, for verification) and falls back to persisted creds when the body
|
||||
// is blank (so returning admins don't have to re-paste the private key,
|
||||
// which is stripped from GET /api/option/).
|
||||
func resolveWaffoPancakeAdminCreds(bodyMerchantID, bodyPrivateKey string) (string, string) {
|
||||
m := strings.TrimSpace(bodyMerchantID)
|
||||
k := strings.TrimSpace(bodyPrivateKey)
|
||||
if m == "" && k == "" {
|
||||
return setting.WaffoPancakeMerchantID, setting.WaffoPancakePrivateKey
|
||||
}
|
||||
return m, k
|
||||
}
|
||||
|
||||
// CreateWaffoPancakePair mints a Store + OnetimeProduct pair in one round-
|
||||
// trip. Surfaces an orphan-store flag when the product half fails so the
|
||||
// frontend can preselect / retry without losing context.
|
||||
func CreateWaffoPancakePair(c *gin.Context) {
|
||||
var req createWaffoPancakePairRequest
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
}
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
|
||||
if merchantID == "" || privateKey == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
|
||||
return
|
||||
}
|
||||
result, err := service.CreateWaffoPancakePrimaryPair(
|
||||
c.Request.Context(), merchantID, privateKey, req.ReturnURL,
|
||||
)
|
||||
if err != nil {
|
||||
orphan := result != nil && result.OrphanStore
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 创建店铺与产品失败 orphan_store=%t store_id=%q error=%q",
|
||||
orphan, func() string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
return result.StoreID
|
||||
}(), err.Error(),
|
||||
))
|
||||
data := gin.H{"error": err.Error()}
|
||||
if orphan {
|
||||
data["store_id"] = result.StoreID
|
||||
data["store_name"] = result.StoreName
|
||||
data["orphan_store"] = true
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": data})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"store_id": result.StoreID,
|
||||
"store_name": result.StoreName,
|
||||
"product_id": result.ProductID,
|
||||
"product_name": result.ProductName,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ListWaffoPancakeCatalog returns the merchant's Stores + OnetimeProducts.
|
||||
// Doubles as a credential probe (a successful 200 proves the resolved creds
|
||||
// authenticate). See resolveWaffoPancakeAdminCreds for credential resolution.
|
||||
func ListWaffoPancakeCatalog(c *gin.Context) {
|
||||
var req waffoPancakeCredsRequest
|
||||
// An empty body means "use persisted creds"; only fail on malformed JSON.
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
}
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
|
||||
if merchantID == "" || privateKey == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
|
||||
return
|
||||
}
|
||||
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 拉取店铺与产品目录失败 error=%q", err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取目录失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": catalog})
|
||||
}
|
||||
|
||||
type createWaffoPancakeSubscriptionProductRequest struct {
|
||||
Name string `json:"name"`
|
||||
Amount string `json:"amount"`
|
||||
}
|
||||
|
||||
// CreateWaffoPancakeSubscriptionProduct mints an OnetimeProduct (not
|
||||
// SubscriptionProduct — see service.CreateWaffoPancakeProductForPlan)
|
||||
// sized to a plan's `name` + `amount`, using persisted Pancake credentials
|
||||
// + StoreID. Reads from the form, not the plan row, so newly-typed unsaved
|
||||
// plans can mint a product too.
|
||||
func CreateWaffoPancakeSubscriptionProduct(c *gin.Context) {
|
||||
var req createWaffoPancakeSubscriptionProductRequest
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐名称不能为空"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Amount) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐价格不能为空"})
|
||||
return
|
||||
}
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
|
||||
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
|
||||
if merchantID == "" || privateKey == "" || storeID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
|
||||
return
|
||||
}
|
||||
productID, err := service.CreateWaffoPancakeProductForPlan(
|
||||
c.Request.Context(),
|
||||
merchantID,
|
||||
privateKey,
|
||||
storeID,
|
||||
req.Name,
|
||||
req.Amount,
|
||||
setting.WaffoPancakeReturnURL,
|
||||
)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 创建套餐产品失败 store_id=%q name=%q amount=%q error=%q",
|
||||
storeID, req.Name, req.Amount, err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建套餐产品失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"product_id": productID,
|
||||
"product_name": req.Name,
|
||||
"store_id": storeID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ListWaffoPancakeSubscriptionProductOptions returns the OnetimeProducts
|
||||
// in the saved Pancake store, for the subscription-plan dropdown. The name
|
||||
// reflects new-api's plan concept; under the hood it's still OnetimeProducts.
|
||||
func ListWaffoPancakeSubscriptionProductOptions(c *gin.Context) {
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
|
||||
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
|
||||
if merchantID == "" || privateKey == "" || storeID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
|
||||
return
|
||||
}
|
||||
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 拉取订阅产品列表失败 store_id=%q error=%q", storeID, err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取产品列表失败"})
|
||||
return
|
||||
}
|
||||
products := []service.WaffoPancakeCatalogProduct{}
|
||||
for _, store := range catalog.Stores {
|
||||
if store.ID == storeID {
|
||||
products = store.OnetimeProducts
|
||||
break
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"store_id": storeID,
|
||||
"products": products,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func getWaffoPancakeBuyerIdentity(user *model.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
return service.WaffoPancakeBuyerIdentityFromUserID(user.Id)
|
||||
}
|
||||
|
||||
func RequestWaffoPancakePay(c *gin.Context) {
|
||||
if !setting.WaffoPancakeEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
|
||||
return
|
||||
}
|
||||
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
|
||||
if setting.WaffoPancakeSandbox {
|
||||
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
|
||||
}
|
||||
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
|
||||
strings.TrimSpace(currentWebhookKey) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
|
||||
if !isWaffoPancakeTopUpEnabled() {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
|
||||
return
|
||||
}
|
||||
@@ -175,18 +399,15 @@ func RequestWaffoPancakePay(c *gin.Context) {
|
||||
|
||||
expiresInSeconds := 45 * 60
|
||||
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
|
||||
StoreID: setting.WaffoPancakeStoreID,
|
||||
ProductID: setting.WaffoPancakeProductID,
|
||||
ProductType: "onetime",
|
||||
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
|
||||
ProductID: setting.WaffoPancakeProductID,
|
||||
BuyerIdentity: getWaffoPancakeBuyerIdentity(user),
|
||||
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
|
||||
Amount: formatWaffoPancakeAmount(payMoney),
|
||||
TaxIncluded: false,
|
||||
TaxCategory: "saas",
|
||||
},
|
||||
BuyerEmail: getWaffoPancakeBuyerEmail(user),
|
||||
SuccessURL: getWaffoPancakeReturnURL(),
|
||||
ExpiresInSeconds: &expiresInSeconds,
|
||||
BuyerEmail: getWaffoPancakeBuyerEmail(user),
|
||||
ExpiresInSeconds: &expiresInSeconds,
|
||||
OrderMerchantExternalID: tradeNo,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
|
||||
@@ -200,10 +421,12 @@ func RequestWaffoPancakePay(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"checkout_url": session.CheckoutURL,
|
||||
"session_id": session.SessionID,
|
||||
"expires_at": session.ExpiresAt,
|
||||
"order_id": tradeNo,
|
||||
"checkout_url": session.CheckoutURL,
|
||||
"session_id": session.SessionID,
|
||||
"expires_at": session.ExpiresAt,
|
||||
"order_id": tradeNo,
|
||||
"token": session.Token,
|
||||
"token_expires_at": session.TokenExpiresAt,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -215,6 +438,19 @@ func WaffoPancakeWebhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// :env splits test vs prod traffic at the routing layer — operator
|
||||
// registers each URL in the matching webhook slot in Pancake's dashboard.
|
||||
// We then enforce event.mode == expectedEnv to catch mis-registrations.
|
||||
expectedEnv := strings.TrimSpace(c.Param("env"))
|
||||
if expectedEnv != "test" && expectedEnv != "prod" {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 路径环境段无效 env=%q path=%q client_ip=%s",
|
||||
expectedEnv, c.Request.RequestURI, c.ClientIP(),
|
||||
))
|
||||
c.String(http.StatusNotFound, "unknown env")
|
||||
return
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
@@ -232,15 +468,57 @@ func WaffoPancakeWebhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.EqualFold(strings.TrimSpace(event.Mode), expectedEnv) {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 环境不匹配 expected=%q actual_mode=%q event_id=%s order_id=%s client_ip=%s",
|
||||
expectedEnv, event.Mode, event.ID, event.Data.OrderID, c.ClientIP(),
|
||||
))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
|
||||
if event.NormalizedEventType() != "order.completed" {
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
// Dispatch by trade_no prefix. OrderMerchantExternalID = our trade_no;
|
||||
// OrderID is Pancake's internal ORD_* (logs only).
|
||||
rawTradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
|
||||
isSubscription := strings.HasPrefix(rawTradeNo, "WAFFO_PANCAKE_SUB-")
|
||||
|
||||
if isSubscription {
|
||||
tradeNo, err := service.ResolveWaffoPancakeSubscriptionTradeNo(event)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 订阅订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
|
||||
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
|
||||
))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
LockOrder(tradeNo)
|
||||
defer UnlockOrder(tradeNo)
|
||||
if err := model.CompleteSubscriptionOrder(tradeNo, string(bodyBytes), model.PaymentProviderWaffoPancake, ""); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
|
||||
c.String(http.StatusInternalServerError, "retry")
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
|
||||
// LogError (not LogWarn): covers order-not-found and buyer-identity
|
||||
// mismatch — both warrant human attention. 200 OK so Waffo doesn't
|
||||
// retry a permanently-unresolvable webhook.
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
|
||||
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
|
||||
))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
+13
-1
@@ -251,8 +251,20 @@ func GetAllUsers(c *gin.Context) {
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
var role *int
|
||||
if roleStr := c.Query("role"); roleStr != "" {
|
||||
if parsed, err := strconv.Atoi(roleStr); err == nil {
|
||||
role = &parsed
|
||||
}
|
||||
}
|
||||
var status *int
|
||||
if statusStr := c.Query("status"); statusStr != "" {
|
||||
if parsed, err := strconv.Atoi(statusStr); err == nil {
|
||||
status = &parsed
|
||||
}
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
users, total, err := model.SearchUsers(keyword, group, role, status, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
|
||||
@@ -34,6 +34,7 @@ services:
|
||||
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
|
||||
- NODE_NAME=new-api-node-1 # 节点名称,用于审计日志中标识节点身份;多节点/容器部署时建议设置 (Node name used in audit logs; recommended when running multiple instances or in containers)
|
||||
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions)
|
||||
# - RELAY_IDLE_CONN_TIMEOUT=90 # Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制 (Relay HTTP client idle keep-alive timeout in seconds, defaults to Go standard library; set 0 to disable)
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!!)
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
# - GOOGLE_ANALYTICS_ID=G-XXXXXXXXXX # Google Analytics 的测量 ID (Google Analytics Measurement ID)
|
||||
|
||||
+12
-2
@@ -213,12 +213,22 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
return result
|
||||
}
|
||||
|
||||
func IsOpenAIReasoningOModel(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "o1") ||
|
||||
strings.HasPrefix(modelName, "o3") ||
|
||||
strings.HasPrefix(modelName, "o4")
|
||||
}
|
||||
|
||||
func IsOpenAIGPT5Model(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "gpt-5")
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||
if strings.HasPrefix(r.Model, "o") {
|
||||
if IsOpenAIReasoningOModel(r.Model) {
|
||||
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||
return "developer"
|
||||
}
|
||||
} else if strings.HasPrefix(r.Model, "gpt-5") {
|
||||
} else if IsOpenAIGPT5Model(r.Model) {
|
||||
return "developer"
|
||||
}
|
||||
return "system"
|
||||
|
||||
@@ -71,3 +71,27 @@ func TestOpenAIResponsesRequestPreserveExplicitZeroValues(t *testing.T) {
|
||||
require.True(t, gjson.GetBytes(encoded, "stream").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "top_p").Exists())
|
||||
}
|
||||
|
||||
func TestGeneralOpenAIRequestGetSystemRoleName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want string
|
||||
}{
|
||||
{name: "o1 uses developer", model: "o1", want: "developer"},
|
||||
{name: "o3 family uses developer", model: "o3-mini-high", want: "developer"},
|
||||
{name: "o4 family uses developer", model: "o4-mini", want: "developer"},
|
||||
{name: "o1 mini stays system", model: "o1-mini", want: "system"},
|
||||
{name: "o1 preview stays system", model: "o1-preview", want: "system"},
|
||||
{name: "gpt 5 uses developer", model: "gpt-5", want: "developer"},
|
||||
{name: "omni is not o series", model: "omni-moderation-latest", want: "system"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := GeneralOpenAIRequest{Model: tt.model}
|
||||
|
||||
require.Equal(t, tt.want, req.GetSystemRoleName())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +60,8 @@ require (
|
||||
gorm.io/gorm v1.25.2
|
||||
)
|
||||
|
||||
require github.com/waffo-com/waffo-pancake-sdk-go v0.3.1
|
||||
|
||||
require (
|
||||
github.com/DmitriyVTitov/size v1.5.0 // indirect
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||
|
||||
@@ -308,6 +308,12 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/waffo-com/waffo-go v1.3.1 h1:NCYD3oQ59DTJj1bwS5T/659LI4h8PuAIW4Qj/w7fKPw=
|
||||
github.com/waffo-com/waffo-go v1.3.1/go.mod h1:IaXVYq6mmYtrLFFsLxPslNwuIZx0mIadWWjhe+eWb0g=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1 h1:YOI7+3zTBlTB7Ou6+ZXnJV2JvW/ag9d7CwE/TxH3Hls=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0 h1:cCSgccM66p7feTtgRqUUGT50tYQOhahsoPXavd+ib1U=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1 h1:ngQSN/oVB35xTwFPLfg++bxPC+SptcF145Mb6c62YCc=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1/go.mod h1:OB2MyFIQaefoPO0FV3J+yu9sDP8RVFQ+sbFsXqGuObc=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
FRONTEND_DIR = ./web/default
|
||||
FRONTEND_CLASSIC_DIR = ./web/classic
|
||||
BACKEND_DIR = .
|
||||
DEV_FRONTEND_DEFAULT_PORT ?= 5173
|
||||
DEV_FRONTEND_CLASSIC_PORT ?= 5174
|
||||
DEV_COMPOSE_FILE = docker-compose.dev.yml
|
||||
DEV_POSTGRES_SERVICE = postgres
|
||||
DEV_BACKEND_SERVICE = new-api
|
||||
@@ -14,11 +16,13 @@ all: build-all-frontends start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building default frontend..."
|
||||
@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
@cd ./web && bun install --frozen-lockfile
|
||||
@cd $(FRONTEND_DIR) && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
|
||||
build-frontend-classic:
|
||||
@echo "Building classic frontend..."
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun install && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
@cd ./web && bun install --frozen-lockfile
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
|
||||
build-all-frontends: build-frontend build-frontend-classic
|
||||
|
||||
@@ -35,12 +39,35 @@ dev-api-rebuild:
|
||||
@docker compose -f $(DEV_COMPOSE_FILE) up -d --build $(DEV_BACKEND_SERVICE)
|
||||
|
||||
dev-web:
|
||||
@echo "Starting frontend dev server..."
|
||||
@cd $(FRONTEND_DIR) && bun install && bun run dev
|
||||
@echo "Starting both frontend dev servers..."
|
||||
@echo "Default frontend: http://localhost:$(DEV_FRONTEND_DEFAULT_PORT)"
|
||||
@echo "Classic frontend: http://localhost:$(DEV_FRONTEND_CLASSIC_PORT)"
|
||||
@cd ./web && bun install
|
||||
@(cd $(FRONTEND_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_DEFAULT_PORT)) & \
|
||||
default_pid=$$!; \
|
||||
(cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)) & \
|
||||
classic_pid=$$!; \
|
||||
trap 'kill $$default_pid $$classic_pid 2>/dev/null; wait $$default_pid $$classic_pid 2>/dev/null; exit 130' INT TERM; \
|
||||
while kill -0 $$default_pid 2>/dev/null && kill -0 $$classic_pid 2>/dev/null; do \
|
||||
sleep 1; \
|
||||
done; \
|
||||
if ! kill -0 $$default_pid 2>/dev/null; then \
|
||||
wait $$default_pid; \
|
||||
status=$$?; \
|
||||
kill $$classic_pid 2>/dev/null; \
|
||||
wait $$classic_pid 2>/dev/null; \
|
||||
exit $$status; \
|
||||
fi; \
|
||||
wait $$classic_pid; \
|
||||
status=$$?; \
|
||||
kill $$default_pid 2>/dev/null; \
|
||||
wait $$default_pid 2>/dev/null; \
|
||||
exit $$status
|
||||
|
||||
dev-web-classic:
|
||||
@echo "Starting classic frontend dev server..."
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun install && bun run dev
|
||||
@cd ./web && bun install
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)
|
||||
|
||||
dev: dev-api dev-web
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
@@ -100,14 +102,10 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
|
||||
affinityUsable := false
|
||||
preferred, err := model.CacheGetChannel(preferredChannelID)
|
||||
if err == nil && preferred != nil {
|
||||
if preferred.Status != common.ChannelStatusEnabled {
|
||||
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorAffinityChannelDisabled))
|
||||
return
|
||||
}
|
||||
} else if usingGroup == "auto" {
|
||||
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
|
||||
if usingGroup == "auto" {
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
autoGroups := service.GetUserAutoGroup(userGroup)
|
||||
for _, g := range autoGroups {
|
||||
@@ -115,6 +113,7 @@ func Distribute() func(c *gin.Context) {
|
||||
selectGroup = g
|
||||
common.SetContextKey(c, constant.ContextKeyAutoGroup, g)
|
||||
channel = preferred
|
||||
affinityUsable = true
|
||||
service.MarkChannelAffinityUsed(c, g, preferred.Id)
|
||||
break
|
||||
}
|
||||
@@ -122,9 +121,13 @@ func Distribute() func(c *gin.Context) {
|
||||
} else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) {
|
||||
channel = preferred
|
||||
selectGroup = usingGroup
|
||||
affinityUsable = true
|
||||
service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id)
|
||||
}
|
||||
}
|
||||
if !affinityUsable && !service.ShouldKeepChannelAffinityOnChannelDisabled() {
|
||||
service.ClearCurrentChannelAffinityCache(c)
|
||||
}
|
||||
}
|
||||
|
||||
if channel == nil {
|
||||
@@ -170,6 +173,14 @@ func Distribute() func(c *gin.Context) {
|
||||
// - application/x-www-form-urlencoded
|
||||
// - multipart/form-data
|
||||
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
||||
if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") {
|
||||
modelRequest, err := getModelFromJSONBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()}))
|
||||
}
|
||||
return modelRequest, nil
|
||||
}
|
||||
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
@@ -178,6 +189,50 @@ func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
||||
return &modelRequest, nil
|
||||
}
|
||||
|
||||
func getModelFromJSONBody(c *gin.Context) (*ModelRequest, error) {
|
||||
storage, err := common.GetBodyStorage(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !gjson.ValidBytes(requestBody) {
|
||||
return nil, errors.New("invalid JSON request body")
|
||||
}
|
||||
|
||||
values := gjson.GetManyBytes(requestBody, "model", "group")
|
||||
model, err := getJSONStringValue(values[0], "model")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group, err := getJSONStringValue(values[1], "group")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||
return nil, seekErr
|
||||
}
|
||||
c.Request.Body = io.NopCloser(storage)
|
||||
|
||||
return &ModelRequest{
|
||||
Model: model,
|
||||
Group: group,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getJSONStringValue(result gjson.Result, field string) (string, error) {
|
||||
if !result.Exists() || result.Type == gjson.Null {
|
||||
return "", nil
|
||||
}
|
||||
if result.Type != gjson.String {
|
||||
return "", fmt.Errorf("field %s must be a string", field)
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
var modelRequest ModelRequest
|
||||
shouldSelectChannel := true
|
||||
@@ -244,6 +299,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
shouldSelectChannel = false
|
||||
modelRequest.Model = getTaskOriginModelName(c)
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
@@ -258,6 +314,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
shouldSelectChannel = false
|
||||
modelRequest.Model = getTaskOriginModelName(c)
|
||||
}
|
||||
if _, ok := c.Get("relay_mode"); !ok {
|
||||
c.Set("relay_mode", relayMode)
|
||||
@@ -342,6 +399,31 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
// 修复 #4834: GET /v1/video/generations/:task_id && /v1/video/:task_id 此前不解析 model,
|
||||
// 当 token 启用「可用模型限制」时,下游 modelLimitEnable 校验会因
|
||||
// modelRequest.Model 为空而误报 "This token has no access to model"。
|
||||
// 从已存储的任务记录中回填 OriginModelName 即可让校验走在正确的模型上。
|
||||
func getTaskOriginModelName(c *gin.Context) string {
|
||||
if !common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) {
|
||||
return ""
|
||||
}
|
||||
|
||||
taskId := c.Param("task_id")
|
||||
if taskId == "" {
|
||||
// jimeng adapter
|
||||
taskId = c.GetString("task_id")
|
||||
}
|
||||
if taskId == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
if task, exist, err := model.GetByTaskId(userId, taskId); err == nil && exist && task != nil {
|
||||
return task.Properties.OriginModelName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AnonymousRequestBodyLimit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
maxBytes := common.GetAnonymousRequestBodyLimitBytes()
|
||||
if maxBytes <= 0 || c.Request.Body == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
originalBody := c.Request.Body
|
||||
limitedBody, err := readAnonymousRequestBody(originalBody, maxBytes)
|
||||
_ = originalBody.Close()
|
||||
if err != nil {
|
||||
if common.IsRequestBodyTooLargeError(err) {
|
||||
c.AbortWithStatus(http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(limitedBody))
|
||||
c.Request.ContentLength = int64(len(limitedBody))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func readAnonymousRequestBody(body io.Reader, maxBytes int64) ([]byte, error) {
|
||||
data, err := io.ReadAll(io.LimitReader(body, maxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, common.ErrRequestBodyTooLarge
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
+33
-2
@@ -643,13 +643,25 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason
|
||||
if len(keys) == 0 {
|
||||
channel.Status = status
|
||||
} else {
|
||||
var keyIndex int
|
||||
keyIndex := -1
|
||||
for i, key := range keys {
|
||||
if key == usingKey {
|
||||
keyIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if keyIndex < 0 {
|
||||
if usingKey != "" {
|
||||
common.SysLog(fmt.Sprintf("failed to update multi-key status: channel_id=%d, using key not found", channel.Id))
|
||||
return
|
||||
}
|
||||
channel.Status = status
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
return
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
@@ -666,16 +678,31 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason
|
||||
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||||
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||||
}
|
||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||
if !hasEnabledMultiKey(keys, channel.ChannelInfo.MultiKeyStatusList) {
|
||||
channel.Status = common.ChannelStatusAutoDisabled
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = "All keys are disabled"
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
} else if status == common.ChannelStatusEnabled {
|
||||
channel.Status = common.ChannelStatusEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func hasEnabledMultiKey(keys []string, statusList map[int]int) bool {
|
||||
for i := range keys {
|
||||
if statusList == nil {
|
||||
return true
|
||||
}
|
||||
status, ok := statusList[i]
|
||||
if !ok || status == common.ChannelStatusEnabled {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||||
if common.MemoryCacheEnabled {
|
||||
channelStatusLock.Lock()
|
||||
@@ -687,11 +714,15 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
}
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
|
||||
beforeStatus := channelCache.Status
|
||||
pollingLock := GetChannelPollingLock(channelId)
|
||||
pollingLock.Lock()
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||
pollingLock.Unlock()
|
||||
if beforeStatus != channelCache.Status {
|
||||
CacheUpdateChannelStatus(channelId, channelCache.Status)
|
||||
}
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
|
||||
+65
-49
@@ -17,25 +17,39 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func applyExplicitLogTextFilter(tx *gorm.DB, column string, value string) (*gorm.DB, error) {
|
||||
if value == "" {
|
||||
return tx, nil
|
||||
}
|
||||
if strings.Contains(value, "%") {
|
||||
pattern, err := sanitizeLikePattern(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern), nil
|
||||
}
|
||||
return tx.Where(column+" = ?", value), nil
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:2;index:idx_user_id_id,priority:2"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:1;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
|
||||
UpstreamRequestId string `json:"upstream_request_id,omitempty" gorm:"type:varchar(128);index:idx_logs_upstream_request_id;default:''"`
|
||||
Other string `json:"other"`
|
||||
@@ -146,7 +160,7 @@ func RecordTopupLog(userId int, content string, callerIp string, paymentMethod s
|
||||
|
||||
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, common.LocalLogPreview(content)))
|
||||
username := c.GetString("username")
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
|
||||
@@ -309,9 +323,15 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
tx = LOG_DB.Where("logs.type = ?", logType)
|
||||
}
|
||||
|
||||
tx = applyLogContainsFilter(tx, "logs.model_name", modelName)
|
||||
tx = applyLogContainsFilter(tx, "logs.username", username)
|
||||
tx = applyLogContainsFilter(tx, "logs.token_name", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.username", username); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("logs.token_name = ?", tokenName)
|
||||
}
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
}
|
||||
@@ -334,7 +354,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
err = tx.Order("logs.created_at desc, logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -392,8 +412,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
|
||||
}
|
||||
|
||||
tx = applyLogContainsFilter(tx, "logs.model_name", modelName)
|
||||
tx = applyLogContainsFilter(tx, "logs.token_name", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("logs.token_name = ?", tokenName)
|
||||
}
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
}
|
||||
@@ -430,42 +454,34 @@ type Stat struct {
|
||||
Tpm int `json:"tpm"`
|
||||
}
|
||||
|
||||
func logContainsPattern(input string) (string, bool) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
replacer := strings.NewReplacer("!", "!!", "%", "!%", "_", "!_")
|
||||
return "%" + replacer.Replace(input) + "%", true
|
||||
}
|
||||
|
||||
func applyLogContainsFilter(tx *gorm.DB, column string, value string) *gorm.DB {
|
||||
pattern, ok := logContainsPattern(value)
|
||||
if !ok {
|
||||
return tx
|
||||
}
|
||||
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern)
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
|
||||
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
|
||||
|
||||
// 为rpm和tpm创建单独的查询
|
||||
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
|
||||
|
||||
tx = applyLogContainsFilter(tx, "username", username)
|
||||
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "username", username)
|
||||
tx = applyLogContainsFilter(tx, "token_name", tokenName)
|
||||
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "token_name", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "username", username); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "username", username); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("token_name = ?", tokenName)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||
}
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
tx = applyLogContainsFilter(tx, "model_name", modelName)
|
||||
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "model_name", modelName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "model_name", modelName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "model_name", modelName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
|
||||
@@ -397,8 +397,10 @@ func ensureSubscriptionPlanTableSQLite() error {
|
||||
` + "`custom_seconds`" + ` bigint NOT NULL DEFAULT 0,
|
||||
` + "`enabled`" + ` numeric DEFAULT 1,
|
||||
` + "`sort_order`" + ` integer DEFAULT 0,
|
||||
` + "`allow_balance_pay`" + ` numeric DEFAULT 1,
|
||||
` + "`stripe_price_id`" + ` varchar(128) DEFAULT '',
|
||||
` + "`creem_product_id`" + ` varchar(128) DEFAULT '',
|
||||
` + "`waffo_pancake_product_id`" + ` varchar(128) DEFAULT '',
|
||||
` + "`max_purchase_per_user`" + ` integer DEFAULT 0,
|
||||
` + "`upgrade_group`" + ` varchar(64) DEFAULT '',
|
||||
` + "`total_amount`" + ` bigint NOT NULL DEFAULT 0,
|
||||
@@ -430,8 +432,10 @@ PRIMARY KEY (` + "`id`" + `)
|
||||
{Name: "custom_seconds", DDL: "`custom_seconds` bigint NOT NULL DEFAULT 0"},
|
||||
{Name: "enabled", DDL: "`enabled` numeric DEFAULT 1"},
|
||||
{Name: "sort_order", DDL: "`sort_order` integer DEFAULT 0"},
|
||||
{Name: "allow_balance_pay", DDL: "`allow_balance_pay` numeric DEFAULT 1"},
|
||||
{Name: "stripe_price_id", DDL: "`stripe_price_id` varchar(128) DEFAULT ''"},
|
||||
{Name: "creem_product_id", DDL: "`creem_product_id` varchar(128) DEFAULT ''"},
|
||||
{Name: "waffo_pancake_product_id", DDL: "`waffo_pancake_product_id` varchar(128) DEFAULT ''"},
|
||||
{Name: "max_purchase_per_user", DDL: "`max_purchase_per_user` integer DEFAULT 0"},
|
||||
{Name: "upgrade_group", DDL: "`upgrade_group` varchar(64) DEFAULT ''"},
|
||||
{Name: "total_amount", DDL: "`total_amount` bigint NOT NULL DEFAULT 0"},
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
|
||||
@@ -135,6 +136,62 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeLookupValues(values []string) []string {
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
normalized = append(normalized, value)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
|
||||
result := make(map[string]int)
|
||||
modelNames = normalizeLookupValues(modelNames)
|
||||
if len(modelNames) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type row struct {
|
||||
Model string
|
||||
ChannelType int
|
||||
}
|
||||
var rows []row
|
||||
|
||||
query := DB.Table("abilities").
|
||||
Select("abilities.model as model, channels.type as channel_type").
|
||||
Joins("JOIN channels ON abilities.channel_id = channels.id").
|
||||
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
|
||||
Order("COALESCE(abilities.priority, 0) DESC").
|
||||
Order("abilities.weight DESC").
|
||||
Order("abilities.channel_id ASC")
|
||||
|
||||
groups = normalizeLookupValues(groups)
|
||||
if len(groups) > 0 {
|
||||
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
|
||||
}
|
||||
|
||||
if err := query.Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
if _, ok := result[r.Model]; ok {
|
||||
continue
|
||||
}
|
||||
result[r.Model] = r.ChannelType
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||
var models []*Model
|
||||
db := DB.Model(&Model{})
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func clearPreferredOwnerTables(t *testing.T) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
|
||||
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
|
||||
}
|
||||
|
||||
func insertPreferredOwnerCandidate(
|
||||
t *testing.T,
|
||||
channelID int,
|
||||
modelName string,
|
||||
group string,
|
||||
channelType int,
|
||||
priority int64,
|
||||
weight uint,
|
||||
channelStatus int,
|
||||
abilityEnabled bool,
|
||||
) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Create(&Channel{
|
||||
Id: channelID,
|
||||
Type: channelType,
|
||||
Key: fmt.Sprintf("key-%d", channelID),
|
||||
Status: channelStatus,
|
||||
Name: fmt.Sprintf("channel-%d", channelID),
|
||||
}).Error)
|
||||
require.NoError(t, DB.Create(&Ability{
|
||||
Group: group,
|
||||
Model: modelName,
|
||||
ChannelId: channelID,
|
||||
Enabled: abilityEnabled,
|
||||
Priority: &priority,
|
||||
Weight: weight,
|
||||
}).Error)
|
||||
}
|
||||
|
||||
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
|
||||
const modelName = "gpt-5.4"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T)
|
||||
groups []string
|
||||
expected int
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "openai only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "codex only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "priority wins",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "weight wins when priority is equal",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "channel id stabilizes exact ties",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "group filter excludes other groups",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "disabled candidates are ignored",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearPreferredOwnerTables(t)
|
||||
tt.setup(t)
|
||||
|
||||
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, ok := owners[modelName]
|
||||
require.Equal(t, tt.found, ok)
|
||||
if tt.found {
|
||||
require.Equal(t, tt.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+38
-19
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/setting/performance_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
@@ -106,18 +107,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp)
|
||||
common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString()
|
||||
common.OptionMap["WaffoPancakeEnabled"] = strconv.FormatBool(setting.WaffoPancakeEnabled)
|
||||
common.OptionMap["WaffoPancakeSandbox"] = strconv.FormatBool(setting.WaffoPancakeSandbox)
|
||||
common.OptionMap["WaffoPancakeMerchantID"] = setting.WaffoPancakeMerchantID
|
||||
common.OptionMap["WaffoPancakePrivateKey"] = setting.WaffoPancakePrivateKey
|
||||
common.OptionMap["WaffoPancakeWebhookPublicKey"] = setting.WaffoPancakeWebhookPublicKey
|
||||
common.OptionMap["WaffoPancakeWebhookTestKey"] = setting.WaffoPancakeWebhookTestKey
|
||||
common.OptionMap["WaffoPancakeStoreID"] = setting.WaffoPancakeStoreID
|
||||
common.OptionMap["WaffoPancakeProductID"] = setting.WaffoPancakeProductID
|
||||
common.OptionMap["WaffoPancakeReturnURL"] = setting.WaffoPancakeReturnURL
|
||||
common.OptionMap["WaffoPancakeCurrency"] = setting.WaffoPancakeCurrency
|
||||
common.OptionMap["WaffoPancakeUnitPrice"] = strconv.FormatFloat(setting.WaffoPancakeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["WaffoPancakeMinTopUp"] = strconv.Itoa(setting.WaffoPancakeMinTopUp)
|
||||
common.OptionMap["WaffoPancakeStoreID"] = setting.WaffoPancakeStoreID
|
||||
common.OptionMap["WaffoPancakeProductID"] = setting.WaffoPancakeProductID
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
@@ -222,6 +218,39 @@ func UpdateOption(key string, value string) error {
|
||||
return updateOptionMap(key, value)
|
||||
}
|
||||
|
||||
// UpdateOptionsBulk persists multiple key/value pairs in a single database
|
||||
// transaction, then dispatches them through updateOptionMap in one pass. If
|
||||
// any DB write fails the whole transaction rolls back and no in-memory state
|
||||
// is touched — safe for callers that must commit a set of related options
|
||||
// atomically (e.g. payment gateway binding).
|
||||
func UpdateOptionsBulk(values map[string]string) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
for k, v := range values {
|
||||
option := Option{Key: k}
|
||||
if err := tx.FirstOrCreate(&option, Option{Key: k}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
option.Value = v
|
||||
if err := tx.Save(&option).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range values {
|
||||
if err := updateOptionMap(k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateOptionMap(key string, value string) (err error) {
|
||||
common.OptionMapRWMutex.Lock()
|
||||
defer common.OptionMapRWMutex.Unlock()
|
||||
@@ -419,26 +448,16 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "WaffoMinTopUp":
|
||||
setting.WaffoMinTopUp, _ = strconv.Atoi(value)
|
||||
case "WaffoPancakeEnabled":
|
||||
setting.WaffoPancakeEnabled = value == "true"
|
||||
case "WaffoPancakeSandbox":
|
||||
setting.WaffoPancakeSandbox = value == "true"
|
||||
case "WaffoPancakeMerchantID":
|
||||
setting.WaffoPancakeMerchantID = value
|
||||
case "WaffoPancakePrivateKey":
|
||||
setting.WaffoPancakePrivateKey = value
|
||||
case "WaffoPancakeWebhookPublicKey":
|
||||
setting.WaffoPancakeWebhookPublicKey = value
|
||||
case "WaffoPancakeWebhookTestKey":
|
||||
setting.WaffoPancakeWebhookTestKey = value
|
||||
case "WaffoPancakeReturnURL":
|
||||
setting.WaffoPancakeReturnURL = value
|
||||
case "WaffoPancakeStoreID":
|
||||
setting.WaffoPancakeStoreID = value
|
||||
case "WaffoPancakeProductID":
|
||||
setting.WaffoPancakeProductID = value
|
||||
case "WaffoPancakeReturnURL":
|
||||
setting.WaffoPancakeReturnURL = value
|
||||
case "WaffoPancakeCurrency":
|
||||
setting.WaffoPancakeCurrency = value
|
||||
case "WaffoPancakeUnitPrice":
|
||||
setting.WaffoPancakeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "WaffoPancakeMinTopUp":
|
||||
|
||||
+10
-3
@@ -68,11 +68,18 @@ type PerfMetricSummary struct {
|
||||
GenerationMs int64 `json:"generation_ms"`
|
||||
}
|
||||
|
||||
func GetPerfMetricsSummaryAll(startTs int64, endTs int64) ([]PerfMetricSummary, error) {
|
||||
func GetPerfMetricsSummaryAll(startTs int64, endTs int64, groups []string) ([]PerfMetricSummary, error) {
|
||||
var summaries []PerfMetricSummary
|
||||
err := DB.Model(&PerfMetric{}).
|
||||
query := DB.Model(&PerfMetric{}).
|
||||
Select("model_name, SUM(request_count) as request_count, SUM(success_count) as success_count, SUM(total_latency_ms) as total_latency_ms, SUM(output_tokens) as output_tokens, SUM(generation_ms) as generation_ms").
|
||||
Where("bucket_ts >= ? AND bucket_ts <= ?", startTs, endTs).
|
||||
Where("bucket_ts >= ? AND bucket_ts <= ?", startTs, endTs)
|
||||
if groups != nil {
|
||||
if len(groups) == 0 {
|
||||
return summaries, nil
|
||||
}
|
||||
query = query.Where(commonGroupCol+" IN ?", groups)
|
||||
}
|
||||
err := query.
|
||||
Group("model_name").
|
||||
Having("SUM(request_count) > 0").
|
||||
Find(&summaries).Error
|
||||
|
||||
+117
-2
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/pkg/cachex"
|
||||
"github.com/samber/hot"
|
||||
"github.com/shopspring/decimal"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -159,8 +160,11 @@ type SubscriptionPlan struct {
|
||||
Enabled bool `json:"enabled" gorm:"default:true"`
|
||||
SortOrder int `json:"sort_order" gorm:"type:int;default:0"`
|
||||
|
||||
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
|
||||
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
|
||||
AllowBalancePay *bool `json:"allow_balance_pay" gorm:"default:true"`
|
||||
|
||||
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
|
||||
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
|
||||
WaffoPancakeProductId string `json:"waffo_pancake_product_id" gorm:"type:varchar(128);default:''"`
|
||||
|
||||
// Max purchases per user (0 = unlimited)
|
||||
MaxPurchasePerUser int `json:"max_purchase_per_user" gorm:"type:int;default:0"`
|
||||
@@ -191,6 +195,12 @@ func (p *SubscriptionPlan) BeforeUpdate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SubscriptionPlan) NormalizeDefaults() {
|
||||
if p.AllowBalancePay == nil {
|
||||
p.AllowBalancePay = common.GetPointer(true)
|
||||
}
|
||||
}
|
||||
|
||||
// Subscription order (payment -> webhook -> create UserSubscription)
|
||||
type SubscriptionOrder struct {
|
||||
Id int `json:"id"`
|
||||
@@ -358,6 +368,7 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
|
||||
key := subscriptionPlanCacheKey(id)
|
||||
if key != "" {
|
||||
if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
|
||||
cached.NormalizeDefaults()
|
||||
return &cached, nil
|
||||
}
|
||||
}
|
||||
@@ -369,6 +380,7 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
|
||||
if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.NormalizeDefaults()
|
||||
_ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
|
||||
return &plan, nil
|
||||
}
|
||||
@@ -664,6 +676,109 @@ func AdminBindSubscription(userId int, planId int, sourceNote string) (string, e
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func calcSubscriptionBalanceQuota(priceAmount float64) (int, error) {
|
||||
if priceAmount <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if common.QuotaPerUnit <= 0 {
|
||||
return 0, errors.New("额度单位配置错误")
|
||||
}
|
||||
quota := decimal.NewFromFloat(priceAmount).
|
||||
Mul(decimal.NewFromFloat(common.QuotaPerUnit)).
|
||||
Ceil().
|
||||
IntPart()
|
||||
return int(quota), nil
|
||||
}
|
||||
|
||||
// PurchaseSubscriptionWithBalance creates a subscription by deducting the user's wallet quota.
|
||||
func PurchaseSubscriptionWithBalance(userId int, planId int) error {
|
||||
if userId <= 0 || planId <= 0 {
|
||||
return errors.New("invalid userId or planId")
|
||||
}
|
||||
|
||||
var logPlanTitle string
|
||||
var logMoney float64
|
||||
var chargedQuota int
|
||||
var upgradeGroup string
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
plan, err := getSubscriptionPlanByIdTx(tx, planId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !plan.Enabled {
|
||||
return errors.New("套餐未启用")
|
||||
}
|
||||
if plan.PriceAmount < 0 {
|
||||
return errors.New("套餐价格不能为负数")
|
||||
}
|
||||
if plan.AllowBalancePay != nil && !*plan.AllowBalancePay {
|
||||
return errors.New("该套餐不允许使用余额兑换")
|
||||
}
|
||||
|
||||
requiredQuota, err := calcSubscriptionBalanceQuota(plan.PriceAmount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where("id = ?", userId).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if requiredQuota > 0 && user.Quota < requiredQuota {
|
||||
return errors.New("余额不足")
|
||||
}
|
||||
if requiredQuota > 0 {
|
||||
if err := tx.Model(&User{}).Where("id = ?", userId).
|
||||
Update("quota", gorm.Expr("quota - ?", requiredQuota)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := CreateUserSubscriptionFromPlanTx(tx, userId, plan, PaymentMethodBalance); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := common.GetTimestamp()
|
||||
tradeNo := fmt.Sprintf("SUBBALUSR%dNO%s%d", userId, common.GetRandomString(6), time.Now().UnixNano())
|
||||
order := &SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: PaymentMethodBalance,
|
||||
PaymentProvider: PaymentProviderBalance,
|
||||
Status: common.TopUpStatusSuccess,
|
||||
CreateTime: now,
|
||||
CompleteTime: now,
|
||||
ProviderPayload: fmt.Sprintf("charged_quota=%d", requiredQuota),
|
||||
}
|
||||
if err := tx.Create(order).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logPlanTitle = plan.Title
|
||||
logMoney = plan.PriceAmount
|
||||
chargedQuota = requiredQuota
|
||||
upgradeGroup = strings.TrimSpace(plan.UpgradeGroup)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if chargedQuota > 0 {
|
||||
if err := cacheDecrUserQuota(userId, int64(chargedQuota)); err != nil {
|
||||
common.SysLog("failed to decrease user quota cache after subscription balance purchase: " + err.Error())
|
||||
}
|
||||
}
|
||||
if upgradeGroup != "" {
|
||||
_ = UpdateUserGroupCache(userId, upgradeGroup)
|
||||
}
|
||||
msg := fmt.Sprintf("使用余额购买订阅成功,套餐: %s,支付金额: %.2f,扣除额度: %d", logPlanTitle, logMoney, chargedQuota)
|
||||
RecordLog(userId, LogTypeTopup, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllActiveUserSubscriptions returns all active subscriptions for a user.
|
||||
func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
|
||||
if userId <= 0 {
|
||||
|
||||
@@ -26,6 +26,7 @@ func TestMain(m *testing.M) {
|
||||
common.RedisEnabled = false
|
||||
common.BatchUpdateEnabled = false
|
||||
common.LogConsumeEnabled = true
|
||||
initCol()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
@@ -39,10 +40,12 @@ func TestMain(m *testing.M) {
|
||||
&Token{},
|
||||
&Log{},
|
||||
&Channel{},
|
||||
&Ability{},
|
||||
&TopUp{},
|
||||
&SubscriptionPlan{},
|
||||
&SubscriptionOrder{},
|
||||
&UserSubscription{},
|
||||
&PerfMetric{},
|
||||
); err != nil {
|
||||
panic("failed to migrate: " + err.Error())
|
||||
}
|
||||
@@ -58,10 +61,12 @@ func truncateTables(t *testing.T) {
|
||||
DB.Exec("DELETE FROM tokens")
|
||||
DB.Exec("DELETE FROM logs")
|
||||
DB.Exec("DELETE FROM channels")
|
||||
DB.Exec("DELETE FROM abilities")
|
||||
DB.Exec("DELETE FROM top_ups")
|
||||
DB.Exec("DELETE FROM subscription_orders")
|
||||
DB.Exec("DELETE FROM subscription_plans")
|
||||
DB.Exec("DELETE FROM user_subscriptions")
|
||||
DB.Exec("DELETE FROM perf_metrics")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ const (
|
||||
PaymentMethodCreem = "creem"
|
||||
PaymentMethodWaffo = "waffo"
|
||||
PaymentMethodWaffoPancake = "waffo_pancake"
|
||||
PaymentMethodBalance = "balance"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,6 +38,7 @@ const (
|
||||
PaymentProviderCreem = "creem"
|
||||
PaymentProviderWaffo = "waffo"
|
||||
PaymentProviderWaffoPancake = "waffo_pancake"
|
||||
PaymentProviderBalance = "balance"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
+31
-17
@@ -225,7 +225,7 @@ func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err err
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
|
||||
func SearchUsers(keyword string, group string, role *int, status *int, startIdx int, num int) ([]*User, int64, error) {
|
||||
var users []*User
|
||||
var total int64
|
||||
var err error
|
||||
@@ -246,28 +246,25 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
||||
|
||||
// 构建搜索条件
|
||||
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
|
||||
likeArgs := []interface{}{"%" + keyword + "%", "%" + keyword + "%", "%" + keyword + "%"}
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果是数字,同时搜索ID和其他字段
|
||||
likeCondition = "id = ? OR " + likeCondition
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
} else {
|
||||
// 非数字关键字,只搜索字符串字段
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
likeArgs = append([]interface{}{keywordInt}, likeArgs...)
|
||||
}
|
||||
|
||||
query = query.Where("("+likeCondition+")", likeArgs...)
|
||||
if group != "" {
|
||||
query = query.Where(commonGroupCol+" = ?", group)
|
||||
}
|
||||
if role != nil {
|
||||
query = query.Where("role = ?", *role)
|
||||
}
|
||||
if status != nil {
|
||||
query = query.Where("status = ?", *status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
@@ -987,6 +984,23 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||
//}
|
||||
}
|
||||
|
||||
func updateUserQuotaUsedQuotaAndRequestCount(id int, quota int, usedQuota int, requestCount int) {
|
||||
if quota == 0 && usedQuota == 0 && requestCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"quota": gorm.Expr("quota + ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota + ?", usedQuota),
|
||||
"request_count": gorm.Expr("request_count + ?", requestCount),
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysLog("failed to batch update user quota, used quota and request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserUsedQuota(id int, quota int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
|
||||
+26
-11
@@ -67,33 +67,48 @@ func batchUpdate() {
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
stores := make([]map[int]int, BatchUpdateTypeCount)
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
store := batchUpdateStores[i]
|
||||
stores[i] = batchUpdateStores[i]
|
||||
batchUpdateStores[i] = make(map[int]int)
|
||||
batchUpdateLocks[i].Unlock()
|
||||
// TODO: maybe we can combine updates with same key?
|
||||
}
|
||||
|
||||
for i, store := range stores {
|
||||
if i == BatchUpdateTypeUserQuota || i == BatchUpdateTypeUsedQuota || i == BatchUpdateTypeRequestCount {
|
||||
continue
|
||||
}
|
||||
for key, value := range store {
|
||||
switch i {
|
||||
case BatchUpdateTypeUserQuota:
|
||||
err := increaseUserQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysLog("failed to batch update user quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeTokenQuota:
|
||||
err := increaseTokenQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysLog("failed to batch update token quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeUsedQuota:
|
||||
updateUserUsedQuota(key, value)
|
||||
case BatchUpdateTypeRequestCount:
|
||||
updateUserRequestCount(key, value)
|
||||
case BatchUpdateTypeChannelUsedQuota:
|
||||
updateChannelUsedQuota(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userQuotaStore := stores[BatchUpdateTypeUserQuota]
|
||||
usedQuotaStore := stores[BatchUpdateTypeUsedQuota]
|
||||
requestCountStore := stores[BatchUpdateTypeRequestCount]
|
||||
|
||||
userIDs := make(map[int]struct{}, len(userQuotaStore)+len(usedQuotaStore)+len(requestCountStore))
|
||||
for key := range userQuotaStore {
|
||||
userIDs[key] = struct{}{}
|
||||
}
|
||||
for key := range usedQuotaStore {
|
||||
userIDs[key] = struct{}{}
|
||||
}
|
||||
for key := range requestCountStore {
|
||||
userIDs[key] = struct{}{}
|
||||
}
|
||||
for key := range userIDs {
|
||||
updateUserQuotaUsedQuotaAndRequestCount(key, userQuotaStore[key], usedQuotaStore[key], requestCountStore[key])
|
||||
}
|
||||
common.SysLog("batch update finished")
|
||||
}
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ func Query(params QueryParams) (QueryResult, error) {
|
||||
return buildQueryResult(params.Model, merged), nil
|
||||
}
|
||||
|
||||
func QuerySummaryAll(hours int) (SummaryAllResult, error) {
|
||||
func QuerySummaryAll(hours int, groups []string) (SummaryAllResult, error) {
|
||||
if hours <= 0 {
|
||||
hours = 24
|
||||
}
|
||||
@@ -131,8 +131,9 @@ func QuerySummaryAll(hours int) (SummaryAllResult, error) {
|
||||
}
|
||||
endTs := time.Now().Unix()
|
||||
startTs := endTs - int64(hours)*3600
|
||||
allowedGroups := allowedGroupSet(groups)
|
||||
|
||||
rows, err := model.GetPerfMetricsSummaryAll(startTs, endTs)
|
||||
rows, err := model.GetPerfMetricsSummaryAll(startTs, endTs, groups)
|
||||
if err != nil {
|
||||
return SummaryAllResult{}, err
|
||||
}
|
||||
@@ -153,6 +154,11 @@ func QuerySummaryAll(hours int) (SummaryAllResult, error) {
|
||||
if k.bucketTs < startTs || k.bucketTs > endTs {
|
||||
return true
|
||||
}
|
||||
if allowedGroups != nil {
|
||||
if _, ok := allowedGroups[k.group]; !ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
snap := value.(*atomicBucket).snapshot()
|
||||
if snap.requestCount == 0 {
|
||||
return true
|
||||
@@ -193,6 +199,17 @@ func QuerySummaryAll(hours int) (SummaryAllResult, error) {
|
||||
return SummaryAllResult{Models: models}, nil
|
||||
}
|
||||
|
||||
func allowedGroupSet(groups []string) map[string]struct{} {
|
||||
if groups == nil {
|
||||
return nil
|
||||
}
|
||||
allowed := make(map[string]struct{}, len(groups))
|
||||
for _, group := range groups {
|
||||
allowed[group] = struct{}{}
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
func bucketStart(ts int64) int64 {
|
||||
bucketSeconds := perf_metrics_setting.GetBucketSeconds()
|
||||
if bucketSeconds <= 0 {
|
||||
|
||||
@@ -25,6 +25,23 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// applyUpstreamContentLength populates req.ContentLength when the upstream
|
||||
// body is wrapped in a BodyStorage (see relay/common/outbound_body.go).
|
||||
//
|
||||
// net/http.NewRequest only auto-detects ContentLength for *bytes.Reader,
|
||||
// *bytes.Buffer and *strings.Reader. When the body is a type-erased io.Reader
|
||||
// (which is the case for ReaderOnly(BodyStorage)), the Content-Length header
|
||||
// would otherwise be omitted, forcing chunked transfer encoding and breaking
|
||||
// some upstreams that require an explicit Content-Length.
|
||||
func applyUpstreamContentLength(req *http.Request, info *common.RelayInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
if info.UpstreamRequestBodySize > 0 && req.ContentLength <= 0 {
|
||||
req.ContentLength = info.UpstreamRequestBodySize
|
||||
}
|
||||
}
|
||||
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
// multipart/form-data
|
||||
@@ -297,6 +314,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
applyUpstreamContentLength(req, info)
|
||||
headers := req.Header
|
||||
err = a.SetupRequestHeader(c, &headers, info)
|
||||
if err != nil {
|
||||
@@ -326,6 +344,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
applyUpstreamContentLength(req, info)
|
||||
// set form data
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
headers := req.Header
|
||||
@@ -522,6 +541,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
applyUpstreamContentLength(req, info)
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(requestBody), nil
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-7": "anthropic.claude-opus-4-7",
|
||||
"claude-opus-4-8": "anthropic.claude-opus-4-8",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
@@ -97,6 +98,11 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-opus-4-8": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
|
||||
@@ -33,6 +33,13 @@ var ModelList = []string{
|
||||
"claude-opus-4-7-medium",
|
||||
"claude-opus-4-7-low",
|
||||
"claude-opus-4-7-thinking",
|
||||
"claude-opus-4-8",
|
||||
"claude-opus-4-8-max",
|
||||
"claude-opus-4-8-xhigh",
|
||||
"claude-opus-4-8-high",
|
||||
"claude-opus-4-8-medium",
|
||||
"claude-opus-4-8-low",
|
||||
"claude-opus-4-8-thinking",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -154,14 +154,17 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
}
|
||||
|
||||
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
|
||||
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") || strings.HasPrefix(textRequest.Model, "claude-opus-4-7")) {
|
||||
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") ||
|
||||
strings.HasPrefix(textRequest.Model, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(textRequest.Model, "claude-opus-4-8")) {
|
||||
claudeRequest.Model = baseModel
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "adaptive",
|
||||
}
|
||||
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
|
||||
if strings.HasPrefix(baseModel, "claude-opus-4-7") {
|
||||
// Opus 4.7 rejects non-default temperature/top_p/top_k with 400
|
||||
if strings.HasPrefix(baseModel, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(baseModel, "claude-opus-4-8") {
|
||||
// Opus 4.7/4.8 reject non-default temperature/top_p/top_k with 400
|
||||
// and defaults display to "omitted"; restore the 4.6 visible summary.
|
||||
claudeRequest.Thinking.Display = "summarized"
|
||||
claudeRequest.Temperature = nil
|
||||
@@ -175,8 +178,9 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
|
||||
trimmedModel := strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") {
|
||||
// Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
|
||||
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(trimmedModel, "claude-opus-4-8") {
|
||||
// Opus 4.7/4.8 reject thinking.type="enabled"; use adaptive at high effort.
|
||||
claudeRequest.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
|
||||
claudeRequest.OutputConfig = json.RawMessage(`{"effort":"high"}`)
|
||||
claudeRequest.Temperature = nil
|
||||
@@ -442,10 +446,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
fcIdx := 0
|
||||
if claudeResponse.Index != nil {
|
||||
fcIdx = *claudeResponse.Index - 1
|
||||
if fcIdx < 0 {
|
||||
fcIdx = 0
|
||||
}
|
||||
fcIdx = *claudeResponse.Index
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if claudeResponse.Type == "message_start" {
|
||||
|
||||
@@ -9,6 +9,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func commonPointer[T any](value T) *T {
|
||||
return &value
|
||||
}
|
||||
|
||||
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
Usage: &dto.Usage{},
|
||||
@@ -310,6 +314,58 @@ func TestRequestOpenAI2ClaudeMessage_IgnoresUnsupportedFileContent(t *testing.T)
|
||||
require.Equal(t, "see attachment", *content[0].Text)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48HighUsesAdaptiveThinking(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-opus-4-8-high",
|
||||
Temperature: commonPointer(0.7),
|
||||
TopP: commonPointer(0.9),
|
||||
TopK: commonPointer(40),
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
|
||||
require.NotNil(t, claudeRequest.Thinking)
|
||||
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
|
||||
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
|
||||
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
|
||||
require.Nil(t, claudeRequest.Temperature)
|
||||
require.Nil(t, claudeRequest.TopP)
|
||||
require.Nil(t, claudeRequest.TopK)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48ThinkingUsesAdaptiveHighEffort(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-opus-4-8-thinking",
|
||||
Temperature: commonPointer(0.7),
|
||||
TopP: commonPointer(0.9),
|
||||
TopK: commonPointer(40),
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
|
||||
require.NotNil(t, claudeRequest.Thinking)
|
||||
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
|
||||
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
|
||||
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
|
||||
require.Nil(t, claudeRequest.Temperature)
|
||||
require.Nil(t, claudeRequest.TopP)
|
||||
require.Nil(t, claudeRequest.TopK)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_SupportsPDFFileContent(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
|
||||
@@ -30,7 +30,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -86,7 +85,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
@@ -106,6 +105,9 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
data := scanner.Text()
|
||||
dataChan <- data
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysLog("error reading stream: " + err.Error())
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
@@ -98,7 +98,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
}
|
||||
|
||||
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
id := helper.GetResponseID(c)
|
||||
|
||||
@@ -159,9 +159,14 @@ func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto
|
||||
media := mediaContent.GetImageMedia()
|
||||
var file *DifyFile
|
||||
if media.IsRemoteImage() {
|
||||
file.Type = media.MimeType
|
||||
file.TransferMode = "remote_url"
|
||||
file.URL = media.Url
|
||||
// 修复 #2083: 远程图片分支此前未初始化 file,
|
||||
// 导致 file.Type = ... 触发 nil pointer dereference
|
||||
// 而 panic(500: "invalid memory address or nil pointer dereference")。
|
||||
file = &DifyFile{
|
||||
Type: media.MimeType,
|
||||
TransferMode: "remote_url",
|
||||
URL: media.Url,
|
||||
}
|
||||
} else {
|
||||
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
|
||||
}
|
||||
|
||||
@@ -1079,17 +1079,47 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
FinishReason: constant.FinishReasonStop,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
var texts []string
|
||||
// 使用 strings.Builder 直接累积最终 content,避免:
|
||||
// 1) 每张 inline image 生成一次中间 "" 字符串
|
||||
// 2) 末尾 strings.Join 再分配一份等大缓冲
|
||||
// Gemini 图片返回时 InlineData.Data 可能是数 MB 的 base64,
|
||||
// 上述两份临时分配在高并发下会显著放大堆驻留。
|
||||
var content strings.Builder
|
||||
var inlineGrow int
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
|
||||
}
|
||||
}
|
||||
if inlineGrow > 0 {
|
||||
content.Grow(inlineGrow)
|
||||
}
|
||||
appended := 0
|
||||
writeSep := func() {
|
||||
if appended > 0 {
|
||||
content.WriteByte('\n')
|
||||
}
|
||||
appended++
|
||||
}
|
||||
var toolCalls []dto.ToolCallResponse
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
// 媒体内容
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
writeSep()
|
||||
content.WriteString("
|
||||
content.WriteString(part.InlineData.MimeType)
|
||||
content.WriteString(";base64,")
|
||||
content.WriteString(part.InlineData.Data)
|
||||
content.WriteByte(')')
|
||||
} else {
|
||||
// 其他媒体类型,直接显示链接
|
||||
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
|
||||
writeSep()
|
||||
content.WriteString("[media](data:")
|
||||
content.WriteString(part.InlineData.MimeType)
|
||||
content.WriteString(";base64,")
|
||||
content.WriteString(part.InlineData.Data)
|
||||
content.WriteByte(')')
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
@@ -1100,13 +1130,22 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
choice.Message.ReasoningContent = &part.Text
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
|
||||
writeSep()
|
||||
content.WriteString("```")
|
||||
content.WriteString(part.ExecutableCode.Language)
|
||||
content.WriteByte('\n')
|
||||
content.WriteString(part.ExecutableCode.Code)
|
||||
content.WriteString("\n```")
|
||||
} else if part.CodeExecutionResult != nil {
|
||||
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
|
||||
writeSep()
|
||||
content.WriteString("```output\n")
|
||||
content.WriteString(part.CodeExecutionResult.Output)
|
||||
content.WriteString("\n```")
|
||||
} else {
|
||||
// 过滤掉空行
|
||||
if part.Text != "\n" {
|
||||
texts = append(texts, part.Text)
|
||||
writeSep()
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1115,7 +1154,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
choice.Message.SetToolCalls(toolCalls)
|
||||
isToolCall = true
|
||||
}
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
choice.Message.SetStringContent(content.String())
|
||||
|
||||
}
|
||||
if candidate.FinishReason != nil {
|
||||
@@ -1169,7 +1208,25 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
|
||||
//Role: "assistant",
|
||||
},
|
||||
}
|
||||
var texts []string
|
||||
// 使用 strings.Builder 直接累积 delta content,避免每张 image / 每个
|
||||
// 文本片段都先 `+` 拼出一份临时 string,再 strings.Join 再拷贝一遍。
|
||||
var content strings.Builder
|
||||
var inlineGrow int
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
|
||||
}
|
||||
}
|
||||
if inlineGrow > 0 {
|
||||
content.Grow(inlineGrow)
|
||||
}
|
||||
appended := 0
|
||||
writeSep := func() {
|
||||
if appended > 0 {
|
||||
content.WriteByte('\n')
|
||||
}
|
||||
appended++
|
||||
}
|
||||
isTools := false
|
||||
isThought := false
|
||||
if candidate.FinishReason != nil {
|
||||
@@ -1207,8 +1264,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
writeSep()
|
||||
content.WriteString("
|
||||
content.WriteString(part.InlineData.MimeType)
|
||||
content.WriteString(";base64,")
|
||||
content.WriteString(part.InlineData.Data)
|
||||
content.WriteByte(')')
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
@@ -1219,23 +1280,33 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
|
||||
|
||||
} else if part.Thought {
|
||||
isThought = true
|
||||
texts = append(texts, part.Text)
|
||||
writeSep()
|
||||
content.WriteString(part.Text)
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
|
||||
writeSep()
|
||||
content.WriteString("```")
|
||||
content.WriteString(part.ExecutableCode.Language)
|
||||
content.WriteByte('\n')
|
||||
content.WriteString(part.ExecutableCode.Code)
|
||||
content.WriteString("\n```\n")
|
||||
} else if part.CodeExecutionResult != nil {
|
||||
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
|
||||
writeSep()
|
||||
content.WriteString("```output\n")
|
||||
content.WriteString(part.CodeExecutionResult.Output)
|
||||
content.WriteString("\n```\n")
|
||||
} else {
|
||||
if part.Text != "\n" {
|
||||
texts = append(texts, part.Text)
|
||||
writeSep()
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if isThought {
|
||||
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
|
||||
choice.Delta.SetReasoningContent(content.String())
|
||||
} else {
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
choice.Delta.SetContentString(content.String())
|
||||
}
|
||||
if isTools {
|
||||
choice.FinishReason = &constant.FinishReasonToolCalls
|
||||
@@ -1339,6 +1410,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
if response.IsToolCall() {
|
||||
finishReason = constant.FinishReasonToolCalls
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
for choiceIdx := range response.Choices {
|
||||
response.Choices[choiceIdx].FinishReason = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for choiceIdx := range response.Choices {
|
||||
choiceKey := response.Choices[choiceIdx].Index
|
||||
for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
|
||||
@@ -1399,7 +1478,9 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
if isStop {
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
|
||||
if info.RelayFormat != types.RelayFormatClaude {
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -1409,6 +1490,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil && !info.ClaudeConvertInfo.Done {
|
||||
response = helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)
|
||||
response.Usage = usage
|
||||
}
|
||||
handleErr := handleFinalStream(c, info, response)
|
||||
if handleErr != nil {
|
||||
common.SysLog("send final response failed: " + handleErr.Error())
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
@@ -397,7 +397,7 @@ func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback f
|
||||
}
|
||||
|
||||
// 读取流式响应
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
scanner := helper.NewStreamScanner(response.Body)
|
||||
successful := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -70,7 +69,7 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
usage := &dto.Usage{}
|
||||
var model = info.UpstreamModelName
|
||||
var responseId = common.GetUUID()
|
||||
|
||||
@@ -310,18 +310,20 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
isOModel := dto.IsOpenAIReasoningOModel(info.UpstreamModelName)
|
||||
isGPT5Model := dto.IsOpenAIGPT5Model(info.UpstreamModelName)
|
||||
if isOModel || isGPT5Model {
|
||||
if lo.FromPtrOr(request.MaxCompletionTokens, uint(0)) == 0 && lo.FromPtrOr(request.MaxTokens, uint(0)) != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") {
|
||||
if isOModel {
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
// gpt-5系列模型适配 归零不再支持的参数
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if isGPT5Model {
|
||||
request.Temperature = nil
|
||||
request.TopP = nil
|
||||
request.LogProbs = nil
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -92,78 +91,28 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
|
||||
func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
processCompletionsStreamResponse(streamResponse, responseTextBuilder)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||||
common.SysLog("error processing stream response: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||
|
||||
@@ -119,7 +119,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var responseTextBuilder strings.Builder
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var lastStreamData string
|
||||
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
|
||||
|
||||
@@ -140,7 +139,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil {
|
||||
logger.LogError(c, "error processing stream token data: "+err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -175,11 +177,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
}
|
||||
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
logger.LogError(c, "error processing tokens: "+err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
|
||||
@@ -92,7 +92,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
||||
|
||||
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
@@ -45,6 +45,7 @@ var claudeModelMap = map[string]string{
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
"claude-opus-4-7": "claude-opus-4-7",
|
||||
"claude-opus-4-8": "claude-opus-4-8",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
|
||||
@@ -157,7 +157,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
|
||||
|
||||
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage *dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
dataChan := make(chan string)
|
||||
metaChan := make(chan string)
|
||||
@@ -180,6 +180,9 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysLog("error reading stream: " + err.Error())
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -125,7 +124,14 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
var requestBody io.Reader = bytes.NewBuffer(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
var requestBody io.Reader = body
|
||||
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
+18
-7
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -54,14 +53,17 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
|
||||
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(request.Model); ok && effortLevel != "" &&
|
||||
(strings.HasPrefix(request.Model, "claude-opus-4-6") || strings.HasPrefix(request.Model, "claude-opus-4-7")) {
|
||||
(strings.HasPrefix(request.Model, "claude-opus-4-6") ||
|
||||
strings.HasPrefix(request.Model, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(request.Model, "claude-opus-4-8")) {
|
||||
request.Model = baseModel
|
||||
request.Thinking = &dto.Thinking{
|
||||
Type: "adaptive",
|
||||
}
|
||||
request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
|
||||
if strings.HasPrefix(request.Model, "claude-opus-4-7") {
|
||||
// Opus 4.7 rejects non-default temperature/top_p/top_k with 400
|
||||
if strings.HasPrefix(request.Model, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(request.Model, "claude-opus-4-8") {
|
||||
// Opus 4.7/4.8 reject non-default temperature/top_p/top_k with 400
|
||||
// and defaults display to "omitted"; restore the 4.6 visible summary.
|
||||
request.Thinking.Display = "summarized"
|
||||
request.Temperature = nil
|
||||
@@ -75,8 +77,9 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
strings.HasSuffix(request.Model, "-thinking") {
|
||||
if request.Thinking == nil {
|
||||
baseModel := strings.TrimSuffix(request.Model, "-thinking")
|
||||
if strings.HasPrefix(baseModel, "claude-opus-4-7") {
|
||||
// Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
|
||||
if strings.HasPrefix(baseModel, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(baseModel, "claude-opus-4-8") {
|
||||
// Opus 4.7/4.8 reject thinking.type="enabled"; use adaptive at high effort.
|
||||
request.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
|
||||
request.OutputConfig = json.RawMessage(`{"effort":"high"}`)
|
||||
request.Temperature = nil
|
||||
@@ -152,6 +155,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
info.UpstreamRequestBodySize = storage.Size()
|
||||
requestBody = common.ReaderOnly(storage)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
|
||||
@@ -179,7 +183,14 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
|
||||
logger.LogDebug(c, "requestBody: %s", jsonData)
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
requestBody = body
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
)
|
||||
|
||||
// NewOutboundJSONBody wraps the already-marshaled upstream request body into a
|
||||
// BodyStorage. When disk cache is enabled and the payload exceeds the configured
|
||||
// threshold, the data is written to a temp file and the original []byte can be
|
||||
// GC'd, significantly reducing the heap residency while waiting for the
|
||||
// upstream provider to respond (the dominant cost for large base64 payloads).
|
||||
//
|
||||
// In memory mode the underlying memoryStorage reuses the same backing array,
|
||||
// so this is equivalent to bytes.NewReader(data) in terms of memory usage.
|
||||
//
|
||||
// The caller MUST invoke closer.Close() once the upstream call has finished
|
||||
// (typically via defer) to release the disk file / memory accounting.
|
||||
//
|
||||
// The returned reader is wrapped with common.ReaderOnly to prevent the HTTP
|
||||
// transport from prematurely closing the underlying BodyStorage. The returned
|
||||
// size is meant to be propagated to http.Request.ContentLength because the
|
||||
// type-erased io.Reader prevents net/http from auto-detecting it.
|
||||
func NewOutboundJSONBody(data []byte) (body io.Reader, size int64, closer io.Closer, err error) {
|
||||
storage, err := common.CreateBodyStorage(data)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
return common.ReaderOnly(storage), storage.Size(), storage, nil
|
||||
}
|
||||
+190
-143
@@ -26,13 +26,20 @@ const (
|
||||
|
||||
var errSourceHeaderNotFound = errors.New("source header does not exist")
|
||||
|
||||
var paramOverrideKeyAuditPaths = map[string]struct{}{
|
||||
"model": {},
|
||||
"original_model": {},
|
||||
"upstream_model": {},
|
||||
"service_tier": {},
|
||||
"inference_geo": {},
|
||||
"speed": {},
|
||||
var paramOverrideSensitivePathPrefixes = []string{
|
||||
"model",
|
||||
"original_model",
|
||||
"upstream_model",
|
||||
"service_tier",
|
||||
"inference_geo",
|
||||
"speed",
|
||||
"messages",
|
||||
"input",
|
||||
"instructions",
|
||||
"system",
|
||||
"contents",
|
||||
"systemInstruction",
|
||||
"system_instruction",
|
||||
}
|
||||
|
||||
type paramOverrideAuditRecorder struct {
|
||||
@@ -146,9 +153,8 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
|
||||
}
|
||||
}
|
||||
|
||||
// 使用新方法
|
||||
result, err := applyOperations(string(workingJSON), operations, conditionContext)
|
||||
return []byte(result), err
|
||||
// 使用新方法(基于 []byte,避免整包 string 拷贝)
|
||||
return applyOperations(workingJSON, operations, conditionContext)
|
||||
}
|
||||
|
||||
// 直接使用旧方法
|
||||
@@ -206,6 +212,7 @@ func shouldEnableParamOverrideAudit(paramOverride map[string]interface{}) bool {
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
for _, operation := range operations {
|
||||
if shouldAuditParamPath(strings.TrimSpace(operation.Path)) ||
|
||||
shouldAuditParamPath(strings.TrimSpace(operation.From)) ||
|
||||
shouldAuditParamPath(strings.TrimSpace(operation.To)) {
|
||||
return true
|
||||
}
|
||||
@@ -255,15 +262,19 @@ func shouldAuditParamPath(path string) bool {
|
||||
if common.DebugEnabled {
|
||||
return true
|
||||
}
|
||||
_, ok := paramOverrideKeyAuditPaths[path]
|
||||
return ok
|
||||
for _, prefix := range paramOverrideSensitivePathPrefixes {
|
||||
if path == prefix || strings.HasPrefix(path, prefix+".") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func shouldAuditOperation(mode, path, from, to string) bool {
|
||||
if common.DebugEnabled {
|
||||
return true
|
||||
}
|
||||
for _, candidate := range []string{path, to} {
|
||||
for _, candidate := range []string{path, from, to} {
|
||||
if shouldAuditParamPath(candidate) {
|
||||
return true
|
||||
}
|
||||
@@ -498,13 +509,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
||||
return operations, true
|
||||
}
|
||||
|
||||
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
func checkConditions(data []byte, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
if len(conditions) == 0 {
|
||||
return true, nil // 没有条件,直接通过
|
||||
}
|
||||
results := make([]bool, len(conditions))
|
||||
for i, condition := range conditions {
|
||||
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
|
||||
result, err := checkSingleCondition(data, contextJSON, condition)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -517,10 +528,10 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio
|
||||
return lo.SomeBy(results, func(item bool) bool { return item }), nil
|
||||
}
|
||||
|
||||
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
|
||||
func checkSingleCondition(data []byte, contextJSON string, condition ConditionOperation) (bool, error) {
|
||||
// 处理负数索引
|
||||
path := processNegativeIndex(jsonStr, condition.Path)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
path := processNegativeIndex(data, condition.Path)
|
||||
value := gjson.GetBytes(data, path)
|
||||
if !value.Exists() && contextJSON != "" {
|
||||
value = gjson.Get(contextJSON, condition.Path)
|
||||
}
|
||||
@@ -549,7 +560,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func processNegativeIndex(jsonStr string, path string) string {
|
||||
func processNegativeIndex(data []byte, path string) string {
|
||||
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
@@ -566,7 +577,7 @@ func processNegativeIndex(jsonStr string, path string) string {
|
||||
arrayPath = arrayPath[:len(arrayPath)-1]
|
||||
}
|
||||
|
||||
array := gjson.Get(jsonStr, arrayPath)
|
||||
array := gjson.GetBytes(data, arrayPath)
|
||||
if array.IsArray() {
|
||||
length := len(array.Array())
|
||||
actualIndex := length + index
|
||||
@@ -655,36 +666,76 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
|
||||
}
|
||||
}
|
||||
|
||||
// applyOperationsLegacy 原参数覆盖方法
|
||||
// applyOperationsLegacy 原参数覆盖方法。
|
||||
//
|
||||
// 旧实现把整个 jsonData unmarshal 成 map[string]interface{} 再 marshal 回来,
|
||||
// 对包含大 base64 字段(如 Gemini inlineData.data)的请求会放大数倍内存
|
||||
// (interface 装箱、map bucket、再次 marshal)。
|
||||
// 这里改成在 []byte 上直接调用 sjson.SetBytes,按顶层 key 逐个写入,
|
||||
// 不再把 payload 解码到 map[string]interface{}。
|
||||
//
|
||||
// 语义保持:每个 paramOverride 顶层 key 视为字面 key(不解析点号路径),
|
||||
// 与旧的 reqMap[key] = value 一致。包含 `.` `*` `?` `\` 的 key 会被转义,
|
||||
// 防止被 sjson 当作嵌套路径或通配符。
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) {
|
||||
reqMap := make(map[string]interface{})
|
||||
err := common.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
result := jsonData
|
||||
for key, value := range paramOverride {
|
||||
reqMap[key] = value
|
||||
escaped := escapeSjsonLiteralKey(key)
|
||||
next, err := sjson.SetBytes(result, escaped, value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = next
|
||||
auditRecorder.recordOperation("set", key, "", "", value)
|
||||
}
|
||||
|
||||
return common.Marshal(reqMap)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
|
||||
// escapeSjsonLiteralKey 把可能被 sjson 误判为路径或通配符的字符转义,
|
||||
// 用于把字面 key 安全地传给 sjson.SetBytes / sjson.DeleteBytes。
|
||||
func escapeSjsonLiteralKey(key string) string {
|
||||
if !strings.ContainsAny(key, ".*?\\") {
|
||||
return key
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(key) + 4)
|
||||
for i := 0; i < len(key); i++ {
|
||||
c := key[i]
|
||||
switch c {
|
||||
case '.', '*', '?', '\\':
|
||||
sb.WriteByte('\\')
|
||||
}
|
||||
sb.WriteByte(c)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// applyOperations 在 []byte 上原地应用所有 param override 操作。
|
||||
//
|
||||
// 旧实现走 string-based gjson/sjson,在 ApplyParamOverride 入口会做
|
||||
// string(jsonData) 与最终 []byte(result) 各一次整包拷贝,对大 base64
|
||||
// payload 来说每次重试都额外多花 2 倍 body 体积的临时内存。
|
||||
// 这里改成全程在 []byte 上工作,sjson.SetBytes / gjson.GetBytes 都是
|
||||
// 直接读写 []byte,每个操作只会产生一份新 buffer。
|
||||
func applyOperations(jsonData []byte, operations []ParamOperation, conditionContext map[string]interface{}) ([]byte, error) {
|
||||
context := ensureContextMap(conditionContext)
|
||||
auditRecorder := getParamOverrideAuditRecorder(context)
|
||||
contextJSON, err := marshalContextJSON(context)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal condition context: %v", err)
|
||||
return nil, fmt.Errorf("failed to marshal condition context: %v", err)
|
||||
}
|
||||
|
||||
result := jsonStr
|
||||
result := jsonData
|
||||
for _, op := range operations {
|
||||
// 检查条件是否满足
|
||||
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
continue // 条件不满足,跳过当前操作
|
||||
@@ -695,7 +746,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if isPathBasedOperation(op.Mode) {
|
||||
opPaths, err = resolveOperationPaths(result, opPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
if len(opPaths) == 0 {
|
||||
continue
|
||||
@@ -713,10 +764,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
}
|
||||
case "set":
|
||||
for _, path := range opPaths {
|
||||
if op.KeepOrigin && gjson.Get(result, path).Exists() {
|
||||
if op.KeepOrigin && gjson.GetBytes(result, path).Exists() {
|
||||
continue
|
||||
}
|
||||
result, err = sjson.Set(result, path, op.Value)
|
||||
result, err = sjson.SetBytes(result, path, op.Value)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
@@ -731,7 +782,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
}
|
||||
case "copy":
|
||||
if op.From == "" || op.To == "" {
|
||||
return "", fmt.Errorf("copy from/to is required")
|
||||
return nil, fmt.Errorf("copy from/to is required")
|
||||
}
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
@@ -831,9 +882,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value)
|
||||
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
|
||||
if parseErr != nil {
|
||||
return "", parseErr
|
||||
return nil, parseErr
|
||||
}
|
||||
return "", returnErr
|
||||
return nil, returnErr
|
||||
case "prune_objects":
|
||||
for _, path := range opPaths {
|
||||
result, err = pruneObjects(result, path, contextJSON, op.Value)
|
||||
@@ -890,7 +941,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
case "pass_headers":
|
||||
headerNames, parseErr := parseHeaderPassThroughNames(op.Value)
|
||||
if parseErr != nil {
|
||||
return "", parseErr
|
||||
return nil, parseErr
|
||||
}
|
||||
for _, headerName := range headerNames {
|
||||
if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil {
|
||||
@@ -912,10 +963,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||
return nil, fmt.Errorf("unknown operation: %s", op.Mode)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
|
||||
return nil, fmt.Errorf("operation %s failed: %w", op.Mode, err)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
@@ -1349,11 +1400,11 @@ func parseSyncTarget(spec string) (syncTarget, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
|
||||
func readSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
|
||||
switch target.kind {
|
||||
case "json":
|
||||
path := processNegativeIndex(jsonStr, target.key)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
path := processNegativeIndex(data, target.key)
|
||||
value := gjson.GetBytes(data, path)
|
||||
if !value.Exists() || value.Type == gjson.Null {
|
||||
return nil, false, nil
|
||||
}
|
||||
@@ -1372,52 +1423,52 @@ func readSyncTargetValue(jsonStr string, context map[string]interface{}, target
|
||||
}
|
||||
}
|
||||
|
||||
func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) {
|
||||
func writeSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget, value interface{}) ([]byte, error) {
|
||||
switch target.kind {
|
||||
case "json":
|
||||
path := processNegativeIndex(jsonStr, target.key)
|
||||
nextJSON, err := sjson.Set(jsonStr, path, value)
|
||||
path := processNegativeIndex(data, target.key)
|
||||
nextJSON, err := sjson.SetBytes(data, path, value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return nextJSON, nil
|
||||
case "header":
|
||||
if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return jsonStr, nil
|
||||
return data, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
|
||||
return nil, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
|
||||
}
|
||||
}
|
||||
|
||||
func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) {
|
||||
func syncFieldsBetweenTargets(data []byte, context map[string]interface{}, fromSpec string, toSpec string) ([]byte, error) {
|
||||
fromTarget, err := parseSyncTarget(fromSpec)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
toTarget, err := parseSyncTarget(toSpec)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget)
|
||||
fromValue, fromExists, err := readSyncTargetValue(data, context, fromTarget)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget)
|
||||
toValue, toExists, err := readSyncTargetValue(data, context, toTarget)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If one side exists and the other side is missing, sync the missing side.
|
||||
if fromExists && !toExists {
|
||||
return writeSyncTargetValue(jsonStr, context, toTarget, fromValue)
|
||||
return writeSyncTargetValue(data, context, toTarget, fromValue)
|
||||
}
|
||||
if toExists && !fromExists {
|
||||
return writeSyncTargetValue(jsonStr, context, fromTarget, toValue)
|
||||
return writeSyncTargetValue(data, context, fromTarget, toValue)
|
||||
}
|
||||
return jsonStr, nil
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} {
|
||||
@@ -1491,24 +1542,24 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
|
||||
info.UseRuntimeHeadersOverride = true
|
||||
}
|
||||
|
||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||
func moveValue(data []byte, fromPath, toPath string) ([]byte, error) {
|
||||
sourceValue := gjson.GetBytes(data, fromPath)
|
||||
if !sourceValue.Exists() {
|
||||
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
|
||||
return data, fmt.Errorf("source path does not exist: %s", fromPath)
|
||||
}
|
||||
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
|
||||
result, err := sjson.SetBytes(data, toPath, sourceValue.Value())
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return sjson.Delete(result, fromPath)
|
||||
return sjson.DeleteBytes(result, fromPath)
|
||||
}
|
||||
|
||||
func copyValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||
func copyValue(data []byte, fromPath, toPath string) ([]byte, error) {
|
||||
sourceValue := gjson.GetBytes(data, fromPath)
|
||||
if !sourceValue.Exists() {
|
||||
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
|
||||
return data, fmt.Errorf("source path does not exist: %s", fromPath)
|
||||
}
|
||||
return sjson.Set(jsonStr, toPath, sourceValue.Value())
|
||||
return sjson.SetBytes(data, toPath, sourceValue.Value())
|
||||
}
|
||||
|
||||
func isPathBasedOperation(mode string) bool {
|
||||
@@ -1520,16 +1571,16 @@ func isPathBasedOperation(mode string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveOperationPaths(jsonStr, path string) ([]string, error) {
|
||||
func resolveOperationPaths(data []byte, path string) ([]string, error) {
|
||||
if !strings.Contains(path, "*") {
|
||||
return []string{path}, nil
|
||||
}
|
||||
return expandWildcardPaths(jsonStr, path)
|
||||
return expandWildcardPaths(data, path)
|
||||
}
|
||||
|
||||
func expandWildcardPaths(jsonStr, path string) ([]string, error) {
|
||||
func expandWildcardPaths(data []byte, path string) ([]string, error) {
|
||||
var root interface{}
|
||||
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
|
||||
if err := common.Unmarshal(data, &root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1590,28 +1641,28 @@ func collectWildcardPaths(node interface{}, segments []string, prefix []string)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteValue(jsonStr, path string) (string, error) {
|
||||
func deleteValue(data []byte, path string) ([]byte, error) {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return jsonStr, nil
|
||||
return data, nil
|
||||
}
|
||||
return sjson.Delete(jsonStr, path)
|
||||
return sjson.DeleteBytes(data, path)
|
||||
}
|
||||
|
||||
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func modifyValue(data []byte, path string, value interface{}, keepOrigin, isPrepend bool) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
switch {
|
||||
case current.IsArray():
|
||||
return modifyArray(jsonStr, path, value, isPrepend)
|
||||
return modifyArray(data, path, value, isPrepend)
|
||||
case current.Type == gjson.String:
|
||||
return modifyString(jsonStr, path, value, isPrepend)
|
||||
return modifyString(data, path, value, isPrepend)
|
||||
case current.Type == gjson.JSON:
|
||||
return mergeObjects(jsonStr, path, value, keepOrigin)
|
||||
return mergeObjects(data, path, value, keepOrigin)
|
||||
}
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
|
||||
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func modifyArray(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
var newArray []interface{}
|
||||
// 添加新值
|
||||
addValue := func() {
|
||||
@@ -1635,11 +1686,11 @@ func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (strin
|
||||
addOriginal()
|
||||
addValue()
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newArray)
|
||||
return sjson.SetBytes(data, path, newArray)
|
||||
}
|
||||
|
||||
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func modifyString(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
var newStr string
|
||||
if isPrepend {
|
||||
@@ -1647,17 +1698,17 @@ func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (stri
|
||||
} else {
|
||||
newStr = current.String() + valueStr
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newStr)
|
||||
return sjson.SetBytes(data, path, newStr)
|
||||
}
|
||||
|
||||
func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func trimStringValue(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
return jsonStr, fmt.Errorf("trim value is required")
|
||||
return data, fmt.Errorf("trim value is required")
|
||||
}
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
|
||||
@@ -1667,69 +1718,69 @@ func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (st
|
||||
} else {
|
||||
newStr = strings.TrimSuffix(current.String(), valueStr)
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newStr)
|
||||
return sjson.SetBytes(data, path, newStr)
|
||||
}
|
||||
|
||||
func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func ensureStringAffix(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
return jsonStr, fmt.Errorf("ensure value is required")
|
||||
return data, fmt.Errorf("ensure value is required")
|
||||
}
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
if valueStr == "" {
|
||||
return jsonStr, fmt.Errorf("ensure value is required")
|
||||
return data, fmt.Errorf("ensure value is required")
|
||||
}
|
||||
|
||||
currentStr := current.String()
|
||||
if isPrefix {
|
||||
if strings.HasPrefix(currentStr, valueStr) {
|
||||
return jsonStr, nil
|
||||
return data, nil
|
||||
}
|
||||
return sjson.Set(jsonStr, path, valueStr+currentStr)
|
||||
return sjson.SetBytes(data, path, valueStr+currentStr)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(currentStr, valueStr) {
|
||||
return jsonStr, nil
|
||||
return data, nil
|
||||
}
|
||||
return sjson.Set(jsonStr, path, currentStr+valueStr)
|
||||
return sjson.SetBytes(data, path, currentStr+valueStr)
|
||||
}
|
||||
|
||||
func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func transformStringValue(data []byte, path string, transform func(string) string) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
return sjson.Set(jsonStr, path, transform(current.String()))
|
||||
return sjson.SetBytes(data, path, transform(current.String()))
|
||||
}
|
||||
|
||||
func replaceStringValue(jsonStr, path, from, to string) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func replaceStringValue(data []byte, path, from, to string) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
if from == "" {
|
||||
return jsonStr, fmt.Errorf("replace from is required")
|
||||
return data, fmt.Errorf("replace from is required")
|
||||
}
|
||||
return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
|
||||
return sjson.SetBytes(data, path, strings.ReplaceAll(current.String(), from, to))
|
||||
}
|
||||
|
||||
func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func regexReplaceStringValue(data []byte, path, pattern, replacement string) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
return data, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
if pattern == "" {
|
||||
return jsonStr, fmt.Errorf("regex pattern is required")
|
||||
return data, fmt.Errorf("regex pattern is required")
|
||||
}
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return jsonStr, err
|
||||
return data, err
|
||||
}
|
||||
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
|
||||
return sjson.SetBytes(data, path, re.ReplaceAllString(current.String(), replacement))
|
||||
}
|
||||
|
||||
type pruneObjectsOptions struct {
|
||||
@@ -1738,37 +1789,33 @@ type pruneObjectsOptions struct {
|
||||
recursive bool
|
||||
}
|
||||
|
||||
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
|
||||
func pruneObjects(data []byte, path, contextJSON string, value interface{}) ([]byte, error) {
|
||||
options, err := parsePruneObjectsOptions(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
var root interface{}
|
||||
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
|
||||
return "", err
|
||||
if err := common.Unmarshal(data, &root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
cleanedBytes, err := common.Marshal(cleaned)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(cleanedBytes), nil
|
||||
return common.Marshal(cleaned)
|
||||
}
|
||||
|
||||
target := gjson.Get(jsonStr, path)
|
||||
target := gjson.GetBytes(data, path)
|
||||
if !target.Exists() {
|
||||
return jsonStr, nil
|
||||
return data, nil
|
||||
}
|
||||
|
||||
var targetNode interface{}
|
||||
if target.Type == gjson.JSON {
|
||||
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
|
||||
return "", err
|
||||
if err := common.UnmarshalJsonStr(target.Raw, &targetNode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
targetNode = target.Value()
|
||||
@@ -1776,13 +1823,13 @@ func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string,
|
||||
|
||||
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
cleanedBytes, err := common.Marshal(cleaned)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
|
||||
return sjson.SetRawBytes(data, path, cleanedBytes)
|
||||
}
|
||||
|
||||
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
|
||||
@@ -1958,16 +2005,16 @@ func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions,
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
|
||||
return checkConditions(nodeBytes, contextJSON, options.conditions, options.logic)
|
||||
}
|
||||
|
||||
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
func mergeObjects(data []byte, path string, value interface{}, keepOrigin bool) ([]byte, error) {
|
||||
current := gjson.GetBytes(data, path)
|
||||
var currentMap, newMap map[string]interface{}
|
||||
|
||||
// 解析当前值
|
||||
if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
return "", err
|
||||
// 解析当前值(current.Raw 是 data 的子串,避免再分配一份)
|
||||
if err := common.UnmarshalJsonStr(current.Raw, ¤tMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 解析新值
|
||||
switch v := value.(type) {
|
||||
@@ -1976,7 +2023,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
default:
|
||||
jsonBytes, _ := common.Marshal(v)
|
||||
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// 合并
|
||||
@@ -1989,7 +2036,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return sjson.Set(jsonStr, path, result)
|
||||
return sjson.SetBytes(data, path, result)
|
||||
}
|
||||
|
||||
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
||||
@@ -2053,6 +2054,17 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
|
||||
assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsNoControlledFieldsKeepsBody(t *testing.T) {
|
||||
input := `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
require.Equal(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
||||
input := `{
|
||||
"inference_geo":"eu",
|
||||
@@ -2184,6 +2196,95 @@ func TestApplyParamOverrideWithRelayInfoRecordsOnlyKeyOperationsWhenDebugDisable
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideWithRelayInfoRecordsConversationBodyOperationsWhenDebugDisabled(t *testing.T) {
|
||||
originalDebugEnabled := common2.DebugEnabled
|
||||
common2.DebugEnabled = false
|
||||
t.Cleanup(func() {
|
||||
common2.DebugEnabled = originalDebugEnabled
|
||||
})
|
||||
|
||||
info := &RelayInfo{
|
||||
ChannelMeta: &ChannelMeta{
|
||||
ParamOverride: map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "replace",
|
||||
"path": "messages.0.content",
|
||||
"from": "hello",
|
||||
"to": "hi",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "set",
|
||||
"path": "input.0.content.0.text",
|
||||
"value": "rewritten response input",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "set",
|
||||
"path": "instructions",
|
||||
"value": "new instruction",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "append",
|
||||
"path": "contents.0.parts",
|
||||
"value": map[string]interface{}{"text": "new gemini part"},
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "copy",
|
||||
"from": "system",
|
||||
"to": "metadata.system_copy",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "set",
|
||||
"path": "temperature",
|
||||
"value": 0.1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverrideWithRelayInfo([]byte(`{
|
||||
"messages":[{"role":"user","content":"hello world"}],
|
||||
"input":[{"role":"user","content":[{"type":"input_text","text":"original response input"}]}],
|
||||
"instructions":"old instruction",
|
||||
"system":"old system",
|
||||
"contents":[{"role":"user","parts":[{"text":"hello gemini"}]}],
|
||||
"temperature":0.7
|
||||
}`), info)
|
||||
require.NoError(t, err)
|
||||
assertJSONEqual(t, `{
|
||||
"messages":[{"role":"user","content":"hi world"}],
|
||||
"input":[{"role":"user","content":[{"type":"input_text","text":"rewritten response input"}]}],
|
||||
"instructions":"new instruction",
|
||||
"system":"old system",
|
||||
"contents":[{"role":"user","parts":[{"text":"hello gemini"},{"text":"new gemini part"}]}],
|
||||
"temperature":0.1,
|
||||
"metadata":{"system_copy":"old system"}
|
||||
}`, string(out))
|
||||
|
||||
require.Equal(t, []string{
|
||||
"replace messages.0.content from hello to hi",
|
||||
"set input.0.content.0.text = rewritten response input",
|
||||
"set instructions = new instruction",
|
||||
"append contents.0.parts with {\"text\":\"new gemini part\"}",
|
||||
"copy system -> metadata.system_copy",
|
||||
}, info.ParamOverrideAudit)
|
||||
}
|
||||
|
||||
func TestShouldAuditParamPathUsesFieldBoundaryPrefixMatching(t *testing.T) {
|
||||
originalDebugEnabled := common2.DebugEnabled
|
||||
common2.DebugEnabled = false
|
||||
t.Cleanup(func() {
|
||||
common2.DebugEnabled = originalDebugEnabled
|
||||
})
|
||||
|
||||
require.True(t, shouldAuditParamPath("messages"))
|
||||
require.True(t, shouldAuditParamPath("messages.0.content"))
|
||||
require.True(t, shouldAuditParamPath("systemInstruction.parts.0.text"))
|
||||
require.False(t, shouldAuditParamPath("model_name"))
|
||||
require.False(t, shouldAuditParamPath("message"))
|
||||
}
|
||||
|
||||
func assertJSONEqual(t *testing.T, want, got string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ThinkingContentInfo struct {
|
||||
@@ -153,6 +154,13 @@ type RelayInfo struct {
|
||||
UseRuntimeHeadersOverride bool
|
||||
ParamOverrideAudit []string
|
||||
|
||||
// UpstreamRequestBodySize is the byte size of the marshaled upstream request
|
||||
// body. It is set when the body is wrapped in a BodyStorage (see
|
||||
// relay/common/outbound_body.go), so that DoApiRequest can populate
|
||||
// http.Request.ContentLength manually (net/http only auto-detects it for
|
||||
// *bytes.Reader/Buffer/strings.Reader). 0 means "let net/http decide".
|
||||
UpstreamRequestBodySize int64
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
// TieredBillingSnapshot is a frozen snapshot of tiered billing rules
|
||||
@@ -785,6 +793,9 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
|
||||
return jsonData, nil
|
||||
}
|
||||
if !hasRemovableDisabledField(jsonData, channelOtherSettings) {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(jsonData, &data); err != nil {
|
||||
@@ -851,6 +862,25 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
return jsonDataAfter, nil
|
||||
}
|
||||
|
||||
func hasRemovableDisabledField(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) bool {
|
||||
values := gjson.GetManyBytes(
|
||||
jsonData,
|
||||
"service_tier",
|
||||
"inference_geo",
|
||||
"speed",
|
||||
"store",
|
||||
"safety_identifier",
|
||||
"stream_options.include_obfuscation",
|
||||
)
|
||||
|
||||
return (!channelOtherSettings.AllowServiceTier && values[0].Exists()) ||
|
||||
(!channelOtherSettings.AllowInferenceGeo && values[1].Exists()) ||
|
||||
(!channelOtherSettings.AllowSpeed && values[2].Exists()) ||
|
||||
(channelOtherSettings.DisableStore && values[3].Exists()) ||
|
||||
(!channelOtherSettings.AllowSafetyIdentifier && values[4].Exists()) ||
|
||||
(!channelOtherSettings.AllowIncludeObfuscation && values[5].Exists())
|
||||
}
|
||||
|
||||
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
|
||||
// Currently supports removing functionResponse.id field which Vertex AI does not support
|
||||
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -176,7 +175,14 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
logger.LogDebug(c, "text request body: %s", jsonData)
|
||||
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
requestBody = body
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -59,7 +58,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
|
||||
logger.LogDebug(c, "converted embedding request body: %s", jsonData)
|
||||
var requestBody io.Reader = bytes.NewBuffer(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
var requestBody io.Reader = body
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
|
||||
+16
-3
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -165,7 +164,14 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
logger.LogDebug(c, "Gemini request body: %s", jsonData)
|
||||
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
requestBody = body
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
@@ -263,7 +269,14 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
}
|
||||
}
|
||||
logger.LogDebug(c, "Gemini embedding request body: %s", jsonData)
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
requestBody = body
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -34,6 +34,12 @@ func getScannerBufferSize() int {
|
||||
return DefaultMaxScannerBufferSize
|
||||
}
|
||||
|
||||
func NewStreamScanner(reader io.Reader) *bufio.Scanner {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
|
||||
return scanner
|
||||
}
|
||||
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
|
||||
|
||||
if resp == nil || dataHandler == nil {
|
||||
@@ -54,7 +60,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
|
||||
var (
|
||||
stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
|
||||
scanner = bufio.NewScanner(resp.Body)
|
||||
scanner = NewStreamScanner(resp.Body)
|
||||
ticker = time.NewTicker(streamingTimeout)
|
||||
pingTicker *time.Ticker
|
||||
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
||||
@@ -104,7 +110,6 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
|
||||
scanner.Split(bufio.ScanLines)
|
||||
SetEventStreamHeaders(c)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -81,6 +82,22 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
||||
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
||||
}
|
||||
|
||||
func TestNewStreamScanner_AllowsLargeStreamLine(t *testing.T) {
|
||||
oldBufferMB := constant.StreamScannerMaxBufferMB
|
||||
constant.StreamScannerMaxBufferMB = 1
|
||||
t.Cleanup(func() {
|
||||
constant.StreamScannerMaxBufferMB = oldBufferMB
|
||||
})
|
||||
|
||||
payload := strings.Repeat("x", 128<<10)
|
||||
scanner := NewStreamScanner(strings.NewReader("data: " + payload + "\n"))
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
require.True(t, scanner.Scan())
|
||||
assert.Equal(t, "data: "+payload, scanner.Text())
|
||||
require.NoError(t, scanner.Err())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+11
-4
@@ -77,7 +77,14 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
|
||||
logger.LogDebug(c, "image request body: %s", jsonData)
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
requestBody = body
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,9 +140,9 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
usage.(*dto.Usage).PromptTokens = 1
|
||||
}
|
||||
|
||||
quality := "standard"
|
||||
if request.Quality == "hd" {
|
||||
quality = "hd"
|
||||
quality := request.Quality
|
||||
if quality == "" {
|
||||
quality = "standard"
|
||||
}
|
||||
|
||||
var logContent []string
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -69,7 +68,14 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
|
||||
logger.LogDebug(c, "Rerank request body: %s", jsonData)
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
requestBody = body
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -104,7 +103,14 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
|
||||
logger.LogDebug(c, "requestBody: %s", jsonData)
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
defer closer.Close()
|
||||
jsonData = nil
|
||||
info.UpstreamRequestBodySize = size
|
||||
requestBody = body
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
|
||||
+28
-18
@@ -17,9 +17,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
apiRouter.Use(middleware.BodyStorageCleanup()) // 清理请求体存储
|
||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||
anonymousRequestBodyLimit := middleware.AnonymousRequestBodyLimit()
|
||||
{
|
||||
apiRouter.GET("/setup", controller.GetSetup)
|
||||
apiRouter.POST("/setup", controller.PostSetup)
|
||||
apiRouter.POST("/setup", anonymousRequestBodyLimit, controller.PostSetup)
|
||||
apiRouter.GET("/status", controller.GetStatus)
|
||||
apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
|
||||
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
|
||||
@@ -40,37 +41,39 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/rankings", middleware.HeaderNavModuleAuth("rankings"), controller.GetRankings)
|
||||
apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.ResetPassword)
|
||||
// OAuth routes - specific routes must come before :provider wildcard
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.POST("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.POST("/oauth/email/bind", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.EmailBind)
|
||||
// Non-standard OAuth (WeChat, Telegram) - keep original routes
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.POST("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.POST("/oauth/wechat/bind", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
|
||||
apiRouter.GET("/oauth/:provider", middleware.CriticalRateLimit(), controller.HandleOAuth)
|
||||
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
apiRouter.POST("/creem/webhook", controller.CreemWebhook)
|
||||
apiRouter.POST("/waffo/webhook", controller.WaffoWebhook)
|
||||
//apiRouter.POST("/waffo-pancake/webhook", controller.WaffoPancakeWebhook)
|
||||
apiRouter.POST("/stripe/webhook", anonymousRequestBodyLimit, controller.StripeWebhook)
|
||||
apiRouter.POST("/creem/webhook", anonymousRequestBodyLimit, controller.CreemWebhook)
|
||||
apiRouter.POST("/waffo/webhook", anonymousRequestBodyLimit, controller.WaffoWebhook)
|
||||
// :env separates test vs prod URLs so the operator can register each
|
||||
// in Pancake's matching webhook slot; handler enforces env match.
|
||||
apiRouter.POST("/waffo-pancake/webhook/:env", anonymousRequestBodyLimit, controller.WaffoPancakeWebhook)
|
||||
|
||||
// Universal secure verification routes
|
||||
apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
|
||||
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
|
||||
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
|
||||
userRoute.POST("/passkey/login/begin", middleware.CriticalRateLimit(), controller.PasskeyLoginBegin)
|
||||
userRoute.POST("/passkey/login/finish", middleware.CriticalRateLimit(), controller.PasskeyLoginFinish)
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, middleware.TurnstileCheck(), controller.Register)
|
||||
userRoute.POST("/login", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, middleware.TurnstileCheck(), controller.Login)
|
||||
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.Verify2FALogin)
|
||||
userRoute.POST("/passkey/login/begin", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.PasskeyLoginBegin)
|
||||
userRoute.POST("/passkey/login/finish", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.PasskeyLoginFinish)
|
||||
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
|
||||
userRoute.GET("/logout", controller.Logout)
|
||||
userRoute.POST("/epay/notify", controller.EpayNotify)
|
||||
userRoute.POST("/epay/notify", anonymousRequestBodyLimit, controller.EpayNotify)
|
||||
userRoute.GET("/epay/notify", controller.EpayNotify)
|
||||
userRoute.GET("/groups", controller.GetUserGroups)
|
||||
|
||||
@@ -100,8 +103,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.RequestCreemPay)
|
||||
selfRoute.POST("/waffo/amount", controller.RequestWaffoAmount)
|
||||
selfRoute.POST("/waffo/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPay)
|
||||
//selfRoute.POST("/waffo-pancake/amount", controller.RequestWaffoPancakeAmount)
|
||||
//selfRoute.POST("/waffo-pancake/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPancakePay)
|
||||
selfRoute.POST("/waffo-pancake/amount", controller.RequestWaffoPancakeAmount)
|
||||
selfRoute.POST("/waffo-pancake/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPancakePay)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
selfRoute.PUT("/setting", controller.UpdateUserSetting)
|
||||
|
||||
@@ -151,9 +154,11 @@ func SetApiRouter(router *gin.Engine) {
|
||||
subscriptionRoute.GET("/plans", controller.GetSubscriptionPlans)
|
||||
subscriptionRoute.GET("/self", controller.GetSubscriptionSelf)
|
||||
subscriptionRoute.PUT("/self/preference", controller.UpdateSubscriptionPreference)
|
||||
subscriptionRoute.POST("/balance/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestBalancePay)
|
||||
subscriptionRoute.POST("/epay/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestEpay)
|
||||
subscriptionRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestStripePay)
|
||||
subscriptionRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestCreemPay)
|
||||
subscriptionRoute.POST("/waffo-pancake/pay", middleware.CriticalRateLimit(), controller.SubscriptionRequestWaffoPancakePay)
|
||||
}
|
||||
subscriptionAdminRoute := apiRouter.Group("/subscription/admin")
|
||||
subscriptionAdminRoute.Use(middleware.AdminAuth())
|
||||
@@ -172,10 +177,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
// Subscription payment callbacks (no auth)
|
||||
apiRouter.POST("/subscription/epay/notify", controller.SubscriptionEpayNotify)
|
||||
apiRouter.POST("/subscription/epay/notify", anonymousRequestBodyLimit, controller.SubscriptionEpayNotify)
|
||||
apiRouter.GET("/subscription/epay/notify", controller.SubscriptionEpayNotify)
|
||||
apiRouter.GET("/subscription/epay/return", controller.SubscriptionEpayReturn)
|
||||
apiRouter.POST("/subscription/epay/return", controller.SubscriptionEpayReturn)
|
||||
apiRouter.POST("/subscription/epay/return", anonymousRequestBodyLimit, controller.SubscriptionEpayReturn)
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
optionRoute.Use(middleware.RootAuth())
|
||||
{
|
||||
@@ -186,6 +191,11 @@ func SetApiRouter(router *gin.Engine) {
|
||||
optionRoute.DELETE("/channel_affinity_cache", controller.ClearChannelAffinityCache)
|
||||
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
|
||||
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
||||
optionRoute.POST("/waffo-pancake/catalog", controller.ListWaffoPancakeCatalog)
|
||||
optionRoute.POST("/waffo-pancake/pair", controller.CreateWaffoPancakePair)
|
||||
optionRoute.POST("/waffo-pancake/save", controller.SaveWaffoPancake)
|
||||
optionRoute.POST("/waffo-pancake/subscription-product", controller.CreateWaffoPancakeSubscriptionProduct)
|
||||
optionRoute.POST("/waffo-pancake/subscription-product-options", controller.ListWaffoPancakeSubscriptionProductOptions)
|
||||
}
|
||||
|
||||
// Custom OAuth provider management (root only)
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@ func formatNotifyType(channelId int, status int) string {
|
||||
|
||||
// disable & notify
|
||||
func DisableChannel(channelError types.ChannelError, reason string) {
|
||||
common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason))
|
||||
common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, common.LocalLogPreview(reason)))
|
||||
|
||||
// 检查是否启用自动禁用功能
|
||||
if !channelError.AutoBan {
|
||||
|
||||
@@ -641,6 +641,38 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
|
||||
return meta.SkipRetry
|
||||
}
|
||||
|
||||
func ClearCurrentChannelAffinityCache(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
cacheKey, _, ok := getChannelAffinityContext(c)
|
||||
if !ok || cacheKey == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
cache := getChannelAffinityCache()
|
||||
deleted, err := cache.DeleteMany([]string{cacheKey})
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("channel affinity cache delete current failed: err=%v", err))
|
||||
return false
|
||||
}
|
||||
c.Set(ginKeyChannelAffinitySkipRetry, false)
|
||||
for _, ok := range deleted {
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ShouldKeepChannelAffinityOnChannelDisabled() bool {
|
||||
setting := operation_setting.GetChannelAffinitySetting()
|
||||
if setting == nil {
|
||||
return false
|
||||
}
|
||||
return setting.KeepOnChannelDisabled
|
||||
}
|
||||
|
||||
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
|
||||
if c == nil || channelID <= 0 {
|
||||
return
|
||||
|
||||
@@ -236,6 +236,33 @@ func TestGetPreferredChannelByAffinity_RequestHeaderKeySource(t *testing.T) {
|
||||
require.Equal(t, buildChannelAffinityKeyHint(affinityValue), meta.KeyHint)
|
||||
}
|
||||
|
||||
func TestClearCurrentChannelAffinityCache(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cacheKeySuffix := fmt.Sprintf("codex cli trace:default:clear-current-%d", time.Now().UnixNano())
|
||||
cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix
|
||||
cache := getChannelAffinityCache()
|
||||
require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9527, time.Minute))
|
||||
t.Cleanup(func() {
|
||||
_, _ = cache.DeleteMany([]string{cacheKeySuffix})
|
||||
})
|
||||
|
||||
ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
CacheKey: cacheKeyFull,
|
||||
TTLSeconds: 60,
|
||||
RuleName: "codex cli trace",
|
||||
SkipRetry: true,
|
||||
})
|
||||
require.True(t, ShouldSkipRetryAfterChannelAffinityFailure(ctx))
|
||||
|
||||
deleted := ClearCurrentChannelAffinityCache(ctx)
|
||||
require.True(t, deleted)
|
||||
_, found, err := cache.Get(cacheKeySuffix)
|
||||
require.NoError(t, err)
|
||||
require.False(t, found)
|
||||
require.False(t, ShouldSkipRetryAfterChannelAffinityFailure(ctx))
|
||||
}
|
||||
|
||||
func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
+5
-3
@@ -92,11 +92,13 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
|
||||
}
|
||||
CloseResponseBodyGracefully(resp)
|
||||
var errResponse dto.GeneralErrorResponse
|
||||
responseBodyText := string(responseBody)
|
||||
responseBodyPreview := common.LocalLogPreview(responseBodyText)
|
||||
buildErrWithBody := func(message string) error {
|
||||
if message == "" {
|
||||
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, responseBodyText)
|
||||
}
|
||||
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, string(responseBody))
|
||||
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, responseBodyText)
|
||||
}
|
||||
|
||||
err = common.Unmarshal(responseBody, &errResponse)
|
||||
@@ -104,7 +106,7 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
|
||||
if showBodyWhenFail {
|
||||
newApiErr.Err = buildErrWithBody("")
|
||||
} else {
|
||||
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, responseBodyPreview))
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -55,3 +63,99 @@ func TestResetStatusCode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerTruncatesInvalidJSONBodyInLog(t *testing.T) {
|
||||
withDebugEnabled(t, false)
|
||||
|
||||
body := strings.Repeat("b", common.LocalLogContentLimit+256)
|
||||
var logBuffer bytes.Buffer
|
||||
|
||||
common.LogWriterMu.Lock()
|
||||
oldWriter := gin.DefaultErrorWriter
|
||||
gin.DefaultErrorWriter = &logBuffer
|
||||
common.LogWriterMu.Unlock()
|
||||
t.Cleanup(func() {
|
||||
common.LogWriterMu.Lock()
|
||||
gin.DefaultErrorWriter = oldWriter
|
||||
common.LogWriterMu.Unlock()
|
||||
})
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.Equal(t, "bad response status code 500", newAPIError.Error())
|
||||
require.Contains(t, logBuffer.String(), "[truncated")
|
||||
require.Contains(t, logBuffer.String(), fmt.Sprintf("original_length=%d", len(body)))
|
||||
require.NotContains(t, logBuffer.String(), strings.Repeat("b", common.LocalLogContentLimit+1))
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerKeepsStructuredErrorMessage(t *testing.T) {
|
||||
message := strings.Repeat("c", common.LocalLogContentLimit+256)
|
||||
body := `{"message":"` + message + `"}`
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.Equal(t, message, newAPIError.Error())
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerKeepsOpenAIErrorMessage(t *testing.T) {
|
||||
message := strings.Repeat("d", common.LocalLogContentLimit+256)
|
||||
body := `{"error":{"message":"` + message + `","type":"server_error","code":"server_error"}}`
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.Equal(t, message, newAPIError.Error())
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerKeepsInvalidJSONBodyInDebugLog(t *testing.T) {
|
||||
withDebugEnabled(t, true)
|
||||
|
||||
body := strings.Repeat("e", common.LocalLogContentLimit+256)
|
||||
var logBuffer bytes.Buffer
|
||||
|
||||
common.LogWriterMu.Lock()
|
||||
oldWriter := gin.DefaultErrorWriter
|
||||
gin.DefaultErrorWriter = &logBuffer
|
||||
common.LogWriterMu.Unlock()
|
||||
t.Cleanup(func() {
|
||||
common.LogWriterMu.Lock()
|
||||
gin.DefaultErrorWriter = oldWriter
|
||||
common.LogWriterMu.Unlock()
|
||||
})
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.NotContains(t, logBuffer.String(), "[truncated")
|
||||
require.Contains(t, logBuffer.String(), body)
|
||||
}
|
||||
|
||||
func withDebugEnabled(t *testing.T, enabled bool) {
|
||||
t.Helper()
|
||||
|
||||
oldDebug := common.DebugEnabled
|
||||
common.DebugEnabled = enabled
|
||||
t.Cleanup(func() {
|
||||
common.DebugEnabled = oldDebug
|
||||
})
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ func InitHttpClient() {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyFromEnvironment, // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
|
||||
}
|
||||
@@ -108,6 +109,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
}
|
||||
@@ -147,6 +149,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
|
||||
+406
-321
@@ -1,398 +1,483 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
pancake "github.com/waffo-com/waffo-pancake-sdk-go"
|
||||
)
|
||||
|
||||
const (
|
||||
waffoPancakeAuthBaseURL = "https://waffo-pancake-auth-service.vercel.app"
|
||||
waffoPancakeCheckoutPath = "/v1/actions/checkout/create-session"
|
||||
waffoPancakeDefaultTolerance = 5 * time.Minute
|
||||
)
|
||||
|
||||
// WaffoPancakePriceSnapshot is the per-session price override sent with checkout.
|
||||
type WaffoPancakePriceSnapshot struct {
|
||||
Amount string `json:"amount"`
|
||||
TaxIncluded bool `json:"taxIncluded"`
|
||||
TaxCategory string `json:"taxCategory"`
|
||||
Amount string
|
||||
TaxCategory string
|
||||
}
|
||||
|
||||
// WaffoPancakeCreateSessionParams is the input to CreateWaffoPancakeCheckoutSession.
|
||||
// BuyerIdentity must be stable per user (see WaffoPancakeBuyerIdentityFromUserID).
|
||||
// OrderMerchantExternalID = our trade_no; Pancake echoes it back in webhooks.
|
||||
type WaffoPancakeCreateSessionParams struct {
|
||||
StoreID string `json:"storeId"`
|
||||
ProductID string `json:"productId"`
|
||||
ProductType string `json:"productType"`
|
||||
Currency string `json:"currency"`
|
||||
PriceSnapshot *WaffoPancakePriceSnapshot `json:"priceSnapshot,omitempty"`
|
||||
BuyerEmail string `json:"buyerEmail,omitempty"`
|
||||
SuccessURL string `json:"successUrl,omitempty"`
|
||||
ExpiresInSeconds *int `json:"expiresInSeconds,omitempty"`
|
||||
ProductID string
|
||||
BuyerIdentity string
|
||||
PriceSnapshot *WaffoPancakePriceSnapshot
|
||||
BuyerEmail string
|
||||
ExpiresInSeconds *int
|
||||
OrderMerchantExternalID string
|
||||
}
|
||||
|
||||
// WaffoPancakeCheckoutSession is the response of CreateWaffoPancakeCheckoutSession.
|
||||
// CheckoutURL already carries the `#token=...` fragment; Token / TokenExpiresAt
|
||||
// are exposed separately for self-service flows driven from new-api's own UI.
|
||||
type WaffoPancakeCheckoutSession struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
CheckoutURL string `json:"checkoutUrl"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
OrderID string `json:"orderId"`
|
||||
SessionID string
|
||||
CheckoutURL string
|
||||
ExpiresAt string
|
||||
OrderID string
|
||||
Token string
|
||||
TokenExpiresAt string
|
||||
}
|
||||
|
||||
type waffoPancakeAPIError struct {
|
||||
Message string `json:"message"`
|
||||
Layer string `json:"layer"`
|
||||
// WaffoPancakeWebhookEvent mirrors the SDK's WebhookEvent shape using plain
|
||||
// strings so controllers don't have to import the SDK package.
|
||||
type WaffoPancakeWebhookEvent struct {
|
||||
ID string
|
||||
Timestamp string
|
||||
EventType string
|
||||
EventID string
|
||||
StoreID string
|
||||
Mode string
|
||||
Data WaffoPancakeWebhookData
|
||||
}
|
||||
|
||||
type waffoPancakeCreateSessionResponse struct {
|
||||
Data *WaffoPancakeCheckoutSession `json:"data"`
|
||||
Errors []waffoPancakeAPIError `json:"errors"`
|
||||
type WaffoPancakeWebhookData struct {
|
||||
// OrderID = Pancake ORD_* (logs); OrderMerchantExternalID = our trade_no (lookup).
|
||||
OrderID string
|
||||
OrderMerchantExternalID string
|
||||
BuyerEmail string
|
||||
Currency string
|
||||
Amount string
|
||||
TaxAmount string
|
||||
ProductName string
|
||||
MerchantProvidedBuyerIdentity string
|
||||
}
|
||||
|
||||
type waffoPancakeWebhookData struct {
|
||||
ID string `json:"id"`
|
||||
OrderID string `json:"orderId"`
|
||||
BuyerEmail string `json:"buyerEmail"`
|
||||
Currency string `json:"currency"`
|
||||
Amount dto.StringValue `json:"amount"`
|
||||
TaxAmount dto.StringValue `json:"taxAmount"`
|
||||
ProductName string `json:"productName"`
|
||||
}
|
||||
|
||||
type waffoPancakeWebhookEvent struct {
|
||||
ID string `json:"id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
EventType string `json:"eventType"`
|
||||
EventID string `json:"eventId"`
|
||||
StoreID string `json:"storeId"`
|
||||
Mode string `json:"mode"`
|
||||
Data waffoPancakeWebhookData `json:"data"`
|
||||
}
|
||||
|
||||
func (e *waffoPancakeWebhookEvent) NormalizedEventType() string {
|
||||
// NormalizedEventType returns the event type or empty string for a nil event.
|
||||
func (e *WaffoPancakeWebhookEvent) NormalizedEventType() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.EventType
|
||||
}
|
||||
|
||||
// newWaffoPancakeClient builds an SDK client from persisted settings. The
|
||||
// runtime checkout / webhook paths use this; configuration endpoints use
|
||||
// newWaffoPancakeClientFromCreds so the operator can verify typed-but-not-
|
||||
// yet-saved credentials.
|
||||
func newWaffoPancakeClient() (*pancake.Client, error) {
|
||||
return pancake.New(pancake.Config{
|
||||
MerchantID: setting.WaffoPancakeMerchantID,
|
||||
PrivateKey: setting.WaffoPancakePrivateKey,
|
||||
})
|
||||
}
|
||||
|
||||
func newWaffoPancakeClientFromCreds(merchantID, privateKey string) (*pancake.Client, error) {
|
||||
if strings.TrimSpace(merchantID) == "" || strings.TrimSpace(privateKey) == "" {
|
||||
return nil, fmt.Errorf("merchant id and private key are required")
|
||||
}
|
||||
return pancake.New(pancake.Config{
|
||||
MerchantID: merchantID,
|
||||
PrivateKey: privateKey,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateWaffoPancakeCheckoutSession creates an Authenticated-mode checkout
|
||||
// session: the order is bound to BuyerIdentity (stable per user) so it stays
|
||||
// attributable even if the buyer edits the email on Waffo's checkout form.
|
||||
func CreateWaffoPancakeCheckoutSession(ctx context.Context, params *WaffoPancakeCreateSessionParams) (*WaffoPancakeCheckoutSession, error) {
|
||||
if params == nil {
|
||||
return nil, fmt.Errorf("missing checkout params")
|
||||
}
|
||||
|
||||
body, err := common.Marshal(params)
|
||||
if strings.TrimSpace(params.BuyerIdentity) == "" {
|
||||
return nil, fmt.Errorf("missing buyer identity")
|
||||
}
|
||||
if strings.TrimSpace(params.OrderMerchantExternalID) == "" {
|
||||
return nil, fmt.Errorf("missing order merchant external id")
|
||||
}
|
||||
client, err := newWaffoPancakeClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal Waffo Pancake checkout payload: %w", err)
|
||||
return nil, fmt.Errorf("build Waffo Pancake client: %w", err)
|
||||
}
|
||||
|
||||
privateKey, err := normalizeRSAPrivateKey(setting.WaffoPancakePrivateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
sdkParams := pancake.AuthenticatedCheckoutParams{
|
||||
CreateCheckoutSessionParams: pancake.CreateCheckoutSessionParams{
|
||||
ProductID: params.ProductID,
|
||||
Currency: "USD",
|
||||
BuyerEmail: optionalString(params.BuyerEmail),
|
||||
ExpiresInSeconds: params.ExpiresInSeconds,
|
||||
OrderMerchantExternalID: optionalString(params.OrderMerchantExternalID),
|
||||
},
|
||||
BuyerIdentity: params.BuyerIdentity,
|
||||
}
|
||||
|
||||
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
signature, err := signWaffoPancakeRequest(http.MethodPost, waffoPancakeCheckoutPath, timestamp, string(body), privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, waffoPancakeAuthBaseURL+waffoPancakeCheckoutPath, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build Waffo Pancake checkout request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Merchant-Id", setting.WaffoPancakeMerchantID)
|
||||
req.Header.Set("X-Timestamp", timestamp)
|
||||
req.Header.Set("X-Signature", signature)
|
||||
if setting.WaffoPancakeSandbox {
|
||||
req.Header.Set("X-Environment", "test")
|
||||
} else {
|
||||
req.Header.Set("X-Environment", "prod")
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request Waffo Pancake checkout session: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read Waffo Pancake checkout response: %w", err)
|
||||
}
|
||||
|
||||
var result waffoPancakeCreateSessionResponse
|
||||
if err := common.Unmarshal(responseBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("decode Waffo Pancake checkout response: %w", err)
|
||||
}
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
if len(result.Errors) > 0 {
|
||||
return nil, fmt.Errorf("Waffo Pancake error (%d): %s", resp.StatusCode, result.Errors[0].Message)
|
||||
if params.PriceSnapshot != nil {
|
||||
sdkParams.PriceSnapshot = &pancake.PriceInfo{
|
||||
Amount: params.PriceSnapshot.Amount,
|
||||
TaxCategory: pancake.TaxCategory(params.PriceSnapshot.TaxCategory),
|
||||
}
|
||||
return nil, fmt.Errorf("Waffo Pancake checkout request failed with status %d", resp.StatusCode)
|
||||
}
|
||||
if len(result.Errors) > 0 {
|
||||
return nil, fmt.Errorf("Waffo Pancake error: %s", result.Errors[0].Message)
|
||||
|
||||
session, err := client.Checkout.Authenticated.Create(ctx, sdkParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Data == nil || result.Data.CheckoutURL == "" || strings.TrimSpace(result.Data.SessionID) == "" {
|
||||
if session == nil || strings.TrimSpace(session.CheckoutURL) == "" || strings.TrimSpace(session.SessionID) == "" {
|
||||
return nil, fmt.Errorf("Waffo Pancake returned empty checkout session")
|
||||
}
|
||||
return result.Data, nil
|
||||
return &WaffoPancakeCheckoutSession{
|
||||
SessionID: session.SessionID,
|
||||
CheckoutURL: session.CheckoutURL,
|
||||
ExpiresAt: session.ExpiresAt,
|
||||
Token: session.Token,
|
||||
TokenExpiresAt: session.TokenExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func VerifyConfiguredWaffoPancakeWebhook(payload string, signatureHeader string) (*waffoPancakeWebhookEvent, error) {
|
||||
environment := resolveWaffoPancakeWebhookEnvironment(payload)
|
||||
return verifyWaffoPancakeWebhook(payload, signatureHeader, environment)
|
||||
func optionalString(s string) *string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return nil
|
||||
}
|
||||
v := s
|
||||
return &v
|
||||
}
|
||||
|
||||
func ResolveWaffoPancakeTradeNo(event *waffoPancakeWebhookEvent) (string, error) {
|
||||
// WaffoPancakeBuyerIdentityFromUserID renders the canonical buyer identity
|
||||
// for checkout. Webhook handlers compare against the value rendered here to
|
||||
// reject identity mismatches, so both call sites must use this function.
|
||||
func WaffoPancakeBuyerIdentityFromUserID(userID int) string {
|
||||
return fmt.Sprintf("new-api-user-%d", userID)
|
||||
}
|
||||
|
||||
// VerifyConfiguredWaffoPancakeWebhook verifies the signature header. The SDK
|
||||
// picks the matching test / prod public key from the payload's `mode` field.
|
||||
func VerifyConfiguredWaffoPancakeWebhook(payload string, signatureHeader string) (*WaffoPancakeWebhookEvent, error) {
|
||||
evt, err := pancake.VerifyWebhookTyped[pancake.WebhookEventData](payload, signatureHeader, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identity := ""
|
||||
if evt.Data.MerchantProvidedBuyerIdentity != nil {
|
||||
identity = *evt.Data.MerchantProvidedBuyerIdentity
|
||||
}
|
||||
externalID := ""
|
||||
if evt.Data.OrderMerchantExternalID != nil {
|
||||
externalID = *evt.Data.OrderMerchantExternalID
|
||||
}
|
||||
return &WaffoPancakeWebhookEvent{
|
||||
ID: evt.ID,
|
||||
Timestamp: evt.Timestamp,
|
||||
EventType: evt.EventType,
|
||||
EventID: evt.EventID,
|
||||
StoreID: evt.StoreID,
|
||||
Mode: string(evt.Mode),
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderID: evt.Data.OrderID,
|
||||
OrderMerchantExternalID: externalID,
|
||||
BuyerEmail: evt.Data.BuyerEmail,
|
||||
Currency: evt.Data.Currency,
|
||||
Amount: evt.Data.Amount,
|
||||
TaxAmount: evt.Data.TaxAmount,
|
||||
ProductName: evt.Data.ProductName,
|
||||
MerchantProvidedBuyerIdentity: identity,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ResolveWaffoPancakeTradeNo maps a verified webhook event to a local TopUp
|
||||
// trade_no via OrderMerchantExternalID, and rejects buyer-identity mismatches.
|
||||
func ResolveWaffoPancakeTradeNo(event *WaffoPancakeWebhookEvent) (string, error) {
|
||||
if event == nil {
|
||||
return "", fmt.Errorf("missing webhook event")
|
||||
}
|
||||
|
||||
if tradeNo := strings.TrimSpace(event.Data.OrderID); tradeNo != "" {
|
||||
topUp := model.GetTopUpByTradeNo(tradeNo)
|
||||
if topUp != nil && topUp.PaymentMethod == model.PaymentMethodWaffoPancake {
|
||||
return tradeNo, nil
|
||||
}
|
||||
return "", fmt.Errorf("waffo pancake order not found for webhook orderId=%s", tradeNo)
|
||||
tradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
|
||||
if tradeNo == "" {
|
||||
return "", fmt.Errorf("missing webhook orderMerchantExternalId")
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("missing webhook orderId")
|
||||
topUp := model.GetTopUpByTradeNo(tradeNo)
|
||||
if topUp == nil || topUp.PaymentProvider != model.PaymentProviderWaffoPancake {
|
||||
return "", fmt.Errorf("waffo pancake order not found for tradeNo=%s", tradeNo)
|
||||
}
|
||||
expectedIdentity := WaffoPancakeBuyerIdentityFromUserID(topUp.UserId)
|
||||
actualIdentity := strings.TrimSpace(event.Data.MerchantProvidedBuyerIdentity)
|
||||
if actualIdentity != expectedIdentity {
|
||||
return "", fmt.Errorf(
|
||||
"waffo pancake buyer identity mismatch for tradeNo=%s: expected=%q actual=%q",
|
||||
tradeNo,
|
||||
expectedIdentity,
|
||||
actualIdentity,
|
||||
)
|
||||
}
|
||||
return tradeNo, nil
|
||||
}
|
||||
|
||||
func normalizeRSAPrivateKey(raw string) (string, error) {
|
||||
return normalizePEMKey(raw, "PRIVATE KEY", "RSA PRIVATE KEY")
|
||||
// ResolveWaffoPancakeSubscriptionTradeNo is the SubscriptionOrder counterpart
|
||||
// of ResolveWaffoPancakeTradeNo.
|
||||
func ResolveWaffoPancakeSubscriptionTradeNo(event *WaffoPancakeWebhookEvent) (string, error) {
|
||||
if event == nil {
|
||||
return "", fmt.Errorf("missing webhook event")
|
||||
}
|
||||
tradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
|
||||
if tradeNo == "" {
|
||||
return "", fmt.Errorf("missing webhook orderMerchantExternalId")
|
||||
}
|
||||
order := model.GetSubscriptionOrderByTradeNo(tradeNo)
|
||||
if order == nil || order.PaymentProvider != model.PaymentProviderWaffoPancake {
|
||||
return "", fmt.Errorf("waffo pancake subscription order not found for tradeNo=%s", tradeNo)
|
||||
}
|
||||
expectedIdentity := WaffoPancakeBuyerIdentityFromUserID(order.UserId)
|
||||
actualIdentity := strings.TrimSpace(event.Data.MerchantProvidedBuyerIdentity)
|
||||
if actualIdentity != expectedIdentity {
|
||||
return "", fmt.Errorf(
|
||||
"waffo pancake buyer identity mismatch for subscription tradeNo=%s: expected=%q actual=%q",
|
||||
tradeNo,
|
||||
expectedIdentity,
|
||||
actualIdentity,
|
||||
)
|
||||
}
|
||||
return tradeNo, nil
|
||||
}
|
||||
|
||||
func normalizeRSAPublicKey(raw string) (string, error) {
|
||||
return normalizePEMKey(raw, "PUBLIC KEY", "RSA PUBLIC KEY")
|
||||
}
|
||||
// Deterministic default names for "+ Create": stable bodies mean stable
|
||||
// X-Idempotency-Key, which lets Pancake dedupe retries server-side.
|
||||
const (
|
||||
defaultWaffoPancakeStoreName = "new-api-store"
|
||||
defaultWaffoPancakeProductName = "new-api-charge-product"
|
||||
)
|
||||
|
||||
func normalizePEMKey(raw string, pkcs8Type string, pkcs1Type string) (string, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return "", fmt.Errorf("%s is empty", strings.ToLower(pkcs8Type))
|
||||
}
|
||||
|
||||
normalized := strings.TrimSpace(strings.ReplaceAll(raw, `\n`, "\n"))
|
||||
if strings.Contains(normalized, "BEGIN ") {
|
||||
block, _ := pem.Decode([]byte(normalized))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("invalid PEM encoded %s", strings.ToLower(pkcs8Type))
|
||||
}
|
||||
return string(pem.EncodeToMemory(block)), nil
|
||||
}
|
||||
|
||||
der, err := base64.StdEncoding.DecodeString(strings.ReplaceAll(normalized, "\n", ""))
|
||||
// CreateWaffoPancakePrimaryStore creates a Pancake Store using in-flight
|
||||
// (not-yet-persisted) credentials and returns the new store ID.
|
||||
func CreateWaffoPancakePrimaryStore(ctx context.Context, merchantID, privateKey string) (string, error) {
|
||||
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base64 encoded %s: %w", strings.ToLower(pkcs8Type), err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
pemType := pkcs8Type
|
||||
if pkcs8Type == "PRIVATE KEY" {
|
||||
if _, err := x509.ParsePKCS8PrivateKey(der); err != nil {
|
||||
if _, err := x509.ParsePKCS1PrivateKey(der); err == nil {
|
||||
pemType = pkcs1Type
|
||||
} else {
|
||||
return "", fmt.Errorf("invalid RSA private key")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err := x509.ParsePKIXPublicKey(der); err != nil {
|
||||
if _, err := x509.ParsePKCS1PublicKey(der); err == nil {
|
||||
pemType = pkcs1Type
|
||||
} else {
|
||||
return "", fmt.Errorf("invalid RSA public key")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return string(pem.EncodeToMemory(&pem.Block{Type: pemType, Bytes: der})), nil
|
||||
}
|
||||
|
||||
func signWaffoPancakeRequest(method string, path string, timestamp string, body string, privateKeyPEM string) (string, error) {
|
||||
block, _ := pem.Decode([]byte(privateKeyPEM))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("invalid RSA private key PEM")
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey
|
||||
switch block.Type {
|
||||
case "PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse PKCS#8 private key: %w", err)
|
||||
}
|
||||
parsed, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("private key is not RSA")
|
||||
}
|
||||
privateKey = parsed
|
||||
case "RSA PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse PKCS#1 private key: %w", err)
|
||||
}
|
||||
privateKey = key
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported private key type: %s", block.Type)
|
||||
}
|
||||
|
||||
canonicalRequest := buildWaffoPancakeCanonicalRequest(method, path, timestamp, body)
|
||||
digest := sha256.Sum256([]byte(canonicalRequest))
|
||||
signature, err := rsa.SignPKCS1v15(nil, privateKey, crypto.SHA256, digest[:])
|
||||
storeRes, err := client.Stores.Create(ctx, pancake.CreateStoreParams{
|
||||
Name: defaultWaffoPancakeStoreName,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign Waffo Pancake request: %w", err)
|
||||
return "", fmt.Errorf("create Waffo Pancake store: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(signature), nil
|
||||
return storeRes.Store.ID, nil
|
||||
}
|
||||
|
||||
func buildWaffoPancakeCanonicalRequest(method string, path string, timestamp string, body string) string {
|
||||
bodyHash := sha256.Sum256([]byte(body))
|
||||
return fmt.Sprintf(
|
||||
"%s\n%s\n%s\n%s",
|
||||
strings.ToUpper(method),
|
||||
path,
|
||||
timestamp,
|
||||
base64.StdEncoding.EncodeToString(bodyHash[:]),
|
||||
)
|
||||
}
|
||||
|
||||
func verifyWaffoPancakeWebhook(payload string, signatureHeader string, environment string) (*waffoPancakeWebhookEvent, error) {
|
||||
if signatureHeader == "" {
|
||||
return nil, fmt.Errorf("missing X-Waffo-Signature header")
|
||||
// CreateWaffoPancakeProductForPlan mints (and publishes) a Pancake
|
||||
// OnetimeProduct priced at `amount` USD, used as a subscription plan's
|
||||
// SubscriptionPlan.WaffoPancakeProductId.
|
||||
//
|
||||
// OnetimeProduct (not SubscriptionProduct) because new-api has no renewal-
|
||||
// event handling; Pancake auto-renewing without new-api extending user
|
||||
// access would be a UX divergence. Revisit if renewal handling is added.
|
||||
func CreateWaffoPancakeProductForPlan(ctx context.Context, merchantID, privateKey, storeID, name, amount, returnURL string) (string, error) {
|
||||
storeID = strings.TrimSpace(storeID)
|
||||
if storeID == "" {
|
||||
return "", fmt.Errorf("store id is required to create a product")
|
||||
}
|
||||
|
||||
timestampPart, signaturePart := parseWaffoPancakeSignatureHeader(signatureHeader)
|
||||
if timestampPart == "" || signaturePart == "" {
|
||||
return nil, fmt.Errorf("malformed X-Waffo-Signature header")
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("plan name is required")
|
||||
}
|
||||
|
||||
timestampMs, err := strconv.ParseInt(timestampPart, 10, 64)
|
||||
amount = strings.TrimSpace(amount)
|
||||
if amount == "" {
|
||||
return "", fmt.Errorf("plan price is required")
|
||||
}
|
||||
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid timestamp in X-Waffo-Signature header")
|
||||
return "", err
|
||||
}
|
||||
if math.Abs(float64(time.Now().UnixMilli()-timestampMs)) > float64(waffoPancakeDefaultTolerance.Milliseconds()) {
|
||||
return nil, fmt.Errorf("webhook timestamp outside tolerance window")
|
||||
}
|
||||
|
||||
signatureInput := fmt.Sprintf("%s.%s", timestampPart, payload)
|
||||
if err := verifyWaffoPancakeWebhookWithKey(signatureInput, signaturePart, resolveWaffoPancakeWebhookPublicKey(environment)); err != nil {
|
||||
return nil, fmt.Errorf("invalid webhook signature")
|
||||
}
|
||||
|
||||
var event waffoPancakeWebhookEvent
|
||||
if err := common.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return nil, fmt.Errorf("parse Waffo Pancake webhook payload: %w", err)
|
||||
}
|
||||
return &event, nil
|
||||
}
|
||||
|
||||
func parseWaffoPancakeSignatureHeader(header string) (string, string) {
|
||||
var timestampPart string
|
||||
var signaturePart string
|
||||
for _, pair := range strings.Split(header, ",") {
|
||||
key, value, found := strings.Cut(strings.TrimSpace(pair), "=")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "t":
|
||||
timestampPart = value
|
||||
case "v1":
|
||||
signaturePart = value
|
||||
}
|
||||
}
|
||||
return timestampPart, signaturePart
|
||||
}
|
||||
|
||||
func resolveWaffoPancakeWebhookEnvironment(payload string) string {
|
||||
var envelope struct {
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
if err := common.Unmarshal([]byte(payload), &envelope); err != nil {
|
||||
if setting.WaffoPancakeSandbox {
|
||||
return "test"
|
||||
}
|
||||
return "prod"
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(envelope.Mode)) {
|
||||
case "test":
|
||||
return "test"
|
||||
case "prod":
|
||||
return "prod"
|
||||
default:
|
||||
if setting.WaffoPancakeSandbox {
|
||||
return "test"
|
||||
}
|
||||
return "prod"
|
||||
}
|
||||
}
|
||||
|
||||
func resolveWaffoPancakeWebhookPublicKey(environment string) string {
|
||||
if environment == "prod" {
|
||||
return strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
|
||||
}
|
||||
return strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
|
||||
}
|
||||
|
||||
func verifyWaffoPancakeWebhookWithKey(signatureInput string, signaturePart string, rawPublicKey string) error {
|
||||
publicKeyPEM, err := normalizeRSAPublicKey(rawPublicKey)
|
||||
prodRes, err := client.OnetimeProducts.Create(ctx, pancake.CreateOnetimeProductParams{
|
||||
StoreID: storeID,
|
||||
Name: name,
|
||||
Prices: pancake.Prices{
|
||||
"USD": {
|
||||
Amount: amount,
|
||||
TaxCategory: pancake.TaxCategory("saas"),
|
||||
},
|
||||
},
|
||||
SuccessURL: optionalString(strings.TrimSpace(returnURL)),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return "", fmt.Errorf("create Waffo Pancake plan product: %w", err)
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(publicKeyPEM))
|
||||
if block == nil {
|
||||
return fmt.Errorf("invalid RSA public key PEM")
|
||||
productID := prodRes.Product.ID
|
||||
if _, err := client.OnetimeProducts.Publish(ctx, pancake.PublishOnetimeProductParams{ID: productID}); err != nil {
|
||||
return "", fmt.Errorf("publish Waffo Pancake plan product: %w", err)
|
||||
}
|
||||
return productID, nil
|
||||
}
|
||||
|
||||
var publicKey *rsa.PublicKey
|
||||
switch block.Type {
|
||||
case "PUBLIC KEY":
|
||||
key, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse PKIX public key: %w", err)
|
||||
}
|
||||
parsed, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("public key is not RSA")
|
||||
}
|
||||
publicKey = parsed
|
||||
case "RSA PUBLIC KEY":
|
||||
key, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse PKCS#1 public key: %w", err)
|
||||
}
|
||||
publicKey = key
|
||||
default:
|
||||
return fmt.Errorf("unsupported public key type: %s", block.Type)
|
||||
// CreateWaffoPancakePrimaryProduct mints (and publishes) the wallet-top-up
|
||||
// OnetimeProduct under storeID. Per-checkout price overrides via PriceSnapshot
|
||||
// are what make the "1.00" seed price irrelevant at runtime.
|
||||
func CreateWaffoPancakePrimaryProduct(ctx context.Context, merchantID, privateKey, storeID, returnURL string) (string, error) {
|
||||
storeID = strings.TrimSpace(storeID)
|
||||
if storeID == "" {
|
||||
return "", fmt.Errorf("store id is required to create a product")
|
||||
}
|
||||
|
||||
signature, err := base64.StdEncoding.DecodeString(signaturePart)
|
||||
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode webhook signature: %w", err)
|
||||
return "", err
|
||||
}
|
||||
prodRes, err := client.OnetimeProducts.Create(ctx, pancake.CreateOnetimeProductParams{
|
||||
StoreID: storeID,
|
||||
Name: defaultWaffoPancakeProductName,
|
||||
Prices: pancake.Prices{
|
||||
"USD": {
|
||||
Amount: "1.00", // overridden at checkout via PriceSnapshot
|
||||
TaxCategory: pancake.TaxCategory("saas"),
|
||||
},
|
||||
},
|
||||
SuccessURL: optionalString(strings.TrimSpace(returnURL)),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create Waffo Pancake product: %w", err)
|
||||
}
|
||||
productID := prodRes.Product.ID
|
||||
if _, err := client.OnetimeProducts.Publish(ctx, pancake.PublishOnetimeProductParams{ID: productID}); err != nil {
|
||||
return "", fmt.Errorf("publish Waffo Pancake product: %w", err)
|
||||
}
|
||||
return productID, nil
|
||||
}
|
||||
|
||||
digest := sha256.Sum256([]byte(signatureInput))
|
||||
if err := rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, digest[:], signature); err != nil {
|
||||
return fmt.Errorf("verify webhook signature: %w", err)
|
||||
// WaffoPancakePairResult is the response of CreateWaffoPancakePrimaryPair.
|
||||
// When OrphanStore is true the store was created but the product wasn't,
|
||||
// so the caller can surface a partial-failure message with StoreID.
|
||||
type WaffoPancakePairResult struct {
|
||||
StoreID string
|
||||
StoreName string
|
||||
ProductID string
|
||||
ProductName string
|
||||
OrphanStore bool
|
||||
}
|
||||
|
||||
// CreateWaffoPancakePrimaryPair mints a Store + OnetimeProduct in one
|
||||
// round-trip — the canonical "+ Create" entry point. Nothing is persisted
|
||||
// to settings; the operator's final Save commits the chosen IDs.
|
||||
func CreateWaffoPancakePrimaryPair(ctx context.Context, merchantID, privateKey, returnURL string) (*WaffoPancakePairResult, error) {
|
||||
storeID, err := CreateWaffoPancakePrimaryStore(ctx, merchantID, privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
productID, err := CreateWaffoPancakePrimaryProduct(ctx, merchantID, privateKey, storeID, returnURL)
|
||||
if err != nil {
|
||||
return &WaffoPancakePairResult{
|
||||
StoreID: storeID,
|
||||
StoreName: defaultWaffoPancakeStoreName,
|
||||
OrphanStore: true,
|
||||
}, fmt.Errorf("store created at %s but product creation failed: %w", storeID, err)
|
||||
}
|
||||
return &WaffoPancakePairResult{
|
||||
StoreID: storeID,
|
||||
StoreName: defaultWaffoPancakeStoreName,
|
||||
ProductID: productID,
|
||||
ProductName: defaultWaffoPancakeProductName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SaveWaffoPancakeConfig persists the operator-controlled fields atomically
|
||||
// at the end of the configuration flow via model.UpdateOptionsBulk (single
|
||||
// DB transaction). A blank privateKey is treated as "keep current"
|
||||
// (Stripe-style API-secret UX) and is omitted from the bulk payload.
|
||||
func SaveWaffoPancakeConfig(ctx context.Context, merchantID, privateKey, returnURL, storeID, productID string) error {
|
||||
merchantID = strings.TrimSpace(merchantID)
|
||||
storeID = strings.TrimSpace(storeID)
|
||||
productID = strings.TrimSpace(productID)
|
||||
if merchantID == "" || storeID == "" || productID == "" {
|
||||
return fmt.Errorf("merchant id, store id, and product id are required to save")
|
||||
}
|
||||
values := map[string]string{
|
||||
"WaffoPancakeMerchantID": merchantID,
|
||||
"WaffoPancakeReturnURL": strings.TrimSpace(returnURL),
|
||||
"WaffoPancakeStoreID": storeID,
|
||||
"WaffoPancakeProductID": productID,
|
||||
}
|
||||
if pk := strings.TrimSpace(privateKey); pk != "" {
|
||||
values["WaffoPancakePrivateKey"] = pk
|
||||
}
|
||||
if err := model.UpdateOptionsBulk(values); err != nil {
|
||||
return fmt.Errorf("persist Waffo Pancake config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type WaffoPancakeCatalogProduct struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// WaffoPancakeCatalogStore nests its OnetimeProducts so the UI can render a
|
||||
// dependent store→product select without a second round-trip.
|
||||
type WaffoPancakeCatalogStore struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
ProdEnabled bool `json:"prodEnabled"`
|
||||
OnetimeProducts []WaffoPancakeCatalogProduct `json:"onetimeProducts"`
|
||||
}
|
||||
|
||||
type WaffoPancakeCatalog struct {
|
||||
Stores []WaffoPancakeCatalogStore `json:"stores"`
|
||||
}
|
||||
|
||||
// ListWaffoPancakeCatalog queries Pancake's GraphQL `stores` for the
|
||||
// merchant's stores + onetime products. A successful call also proves
|
||||
// the supplied credentials authenticate (doubles as a credential probe).
|
||||
func ListWaffoPancakeCatalog(ctx context.Context, merchantID, privateKey string) (*WaffoPancakeCatalog, error) {
|
||||
client, err := newWaffoPancakeClientFromCreds(merchantID, privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type queryShape struct {
|
||||
Stores []WaffoPancakeCatalogStore `json:"stores"`
|
||||
}
|
||||
// `limit: 100` because the API returns a single store when limit is
|
||||
// omitted, even for multi-store merchants. Bump to paginated fetches
|
||||
// (via `offset`) if real catalogs ever cross the cap.
|
||||
resp, err := pancake.GraphQLQuery[queryShape](ctx, client, pancake.GraphQLParams{
|
||||
Query: `query {
|
||||
stores(limit: 100) {
|
||||
id
|
||||
name
|
||||
status
|
||||
prodEnabled
|
||||
onetimeProducts {
|
||||
id
|
||||
name
|
||||
status
|
||||
}
|
||||
}
|
||||
}`,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query Waffo Pancake catalog: %w", err)
|
||||
}
|
||||
if len(resp.Errors) > 0 {
|
||||
return nil, fmt.Errorf("waffo pancake catalog query returned %d errors: %s",
|
||||
len(resp.Errors), resp.Errors[0].Message)
|
||||
}
|
||||
// Drop non-active products. Operators should only see items they can
|
||||
// actually bind without later hitting "product unavailable" at checkout.
|
||||
stores := resp.Data.Stores
|
||||
for i := range stores {
|
||||
active := stores[i].OnetimeProducts[:0]
|
||||
for _, p := range stores[i].OnetimeProducts {
|
||||
if strings.EqualFold(strings.TrimSpace(p.Status), "active") {
|
||||
active = append(active, p)
|
||||
}
|
||||
}
|
||||
stores[i].OnetimeProducts = active
|
||||
}
|
||||
return &WaffoPancakeCatalog{Stores: stores}, nil
|
||||
}
|
||||
|
||||
+200
-78
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -8,7 +9,6 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
@@ -29,7 +29,7 @@ func setupWaffoPancakeTestDB(t *testing.T) *gorm.DB {
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
require.NoError(t, db.AutoMigrate(&model.User{}, &model.TopUp{}))
|
||||
require.NoError(t, db.AutoMigrate(&model.User{}, &model.TopUp{}, &model.SubscriptionOrder{}))
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlDB, err := db.DB()
|
||||
@@ -41,44 +41,101 @@ func setupWaffoPancakeTestDB(t *testing.T) *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
func TestWaffoPancakeCreateSessionResponseParsesDocumentedPayload(t *testing.T) {
|
||||
var result waffoPancakeCreateSessionResponse
|
||||
err := common.Unmarshal([]byte(`{
|
||||
"data": {
|
||||
"sessionId": "cs_550e8400-e29b-41d4-a716-446655440000",
|
||||
"checkoutUrl": "https://checkout.waffo.ai/my-store-abc123/checkout/cs_550e8400-e29b-41d4-a716-446655440000",
|
||||
"expiresAt": "2026-01-22T10:30:00.000Z"
|
||||
}
|
||||
}`), &result)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result.Data)
|
||||
require.Equal(t, "cs_550e8400-e29b-41d4-a716-446655440000", result.Data.SessionID)
|
||||
require.Empty(t, result.Data.OrderID)
|
||||
func TestCreateWaffoPancakeCheckoutSession_RequiresOrderMerchantExternalID(t *testing.T) {
|
||||
session, err := CreateWaffoPancakeCheckoutSession(context.Background(), &WaffoPancakeCreateSessionParams{
|
||||
ProductID: "PROD_checkout_guard",
|
||||
BuyerIdentity: WaffoPancakeBuyerIdentityFromUserID(1),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, session)
|
||||
require.Contains(t, err.Error(), "missing order merchant external id")
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeTradeNo_UsesWebhookOrderIDWhenLocalOrderExists(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: 1,
|
||||
Amount: 10,
|
||||
Money: 29,
|
||||
TradeNo: "ORD_5dXBtmF2HLlHfbPNm0Wcnz",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: 1,
|
||||
Amount: 10,
|
||||
Money: 29,
|
||||
TradeNo: "ORD_5dXBtmF2HLlHfbPNm0Wcnz",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(topUp).Error)
|
||||
|
||||
tradeNo, err := ResolveWaffoPancakeTradeNo(&waffoPancakeWebhookEvent{
|
||||
Data: waffoPancakeWebhookData{
|
||||
OrderID: "ORD_5dXBtmF2HLlHfbPNm0Wcnz",
|
||||
tradeNo, err := ResolveWaffoPancakeTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderID: "ORD_internal_pancake_id",
|
||||
OrderMerchantExternalID: "ORD_5dXBtmF2HLlHfbPNm0Wcnz",
|
||||
MerchantProvidedBuyerIdentity: WaffoPancakeBuyerIdentityFromUserID(topUp.UserId),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ORD_5dXBtmF2HLlHfbPNm0Wcnz", tradeNo)
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeTradeNo_RejectsBuyerIdentityMismatch(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: 42,
|
||||
Amount: 10,
|
||||
Money: 29,
|
||||
TradeNo: "ORD_identity_mismatch_case",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(topUp).Error)
|
||||
|
||||
// Webhook reports the right order but a different buyer — could be a
|
||||
// crossed-wires bug or a tampered payload. Either way: reject.
|
||||
tradeNo, err := ResolveWaffoPancakeTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderID: "ORD_internal_pancake_id",
|
||||
OrderMerchantExternalID: "ORD_identity_mismatch_case",
|
||||
MerchantProvidedBuyerIdentity: WaffoPancakeBuyerIdentityFromUserID(99), // wrong user
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, tradeNo)
|
||||
require.Contains(t, err.Error(), "buyer identity mismatch")
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeTradeNo_RejectsMissingBuyerIdentity(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: 7,
|
||||
Amount: 10,
|
||||
Money: 29,
|
||||
TradeNo: "ORD_missing_identity",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(topUp).Error)
|
||||
|
||||
// An empty MerchantProvidedBuyerIdentity means the order was either created
|
||||
// via the (now-deprecated) anonymous flow or the field was stripped — also
|
||||
// reject so that we never credit anonymous orders to a specific user.
|
||||
tradeNo, err := ResolveWaffoPancakeTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderID: "ORD_internal_pancake_id",
|
||||
OrderMerchantExternalID: "ORD_missing_identity",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, tradeNo)
|
||||
require.Contains(t, err.Error(), "buyer identity mismatch")
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeTradeNo_FailsWhenWebhookOrderIDIsUnknown(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
@@ -91,67 +148,132 @@ func TestResolveWaffoPancakeTradeNo_FailsWhenWebhookOrderIDIsUnknown(t *testing.
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: user.Id,
|
||||
Amount: 10,
|
||||
Money: 29,
|
||||
TradeNo: "WAFFO_PANCAKE-42-123456-abc123",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: user.Id,
|
||||
Amount: 10,
|
||||
Money: 29,
|
||||
TradeNo: "WAFFO_PANCAKE-42-123456-abc123",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(topUp).Error)
|
||||
|
||||
tradeNo, err := ResolveWaffoPancakeTradeNo(&waffoPancakeWebhookEvent{
|
||||
Data: waffoPancakeWebhookData{
|
||||
OrderID: "ORD_unknown",
|
||||
BuyerEmail: user.Email,
|
||||
Amount: "29.00",
|
||||
tradeNo, err := ResolveWaffoPancakeTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderID: "ORD_internal_pancake_id",
|
||||
OrderMerchantExternalID: "WAFFO_PANCAKE-unknown",
|
||||
BuyerEmail: user.Email,
|
||||
Amount: "29.00",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, tradeNo)
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeWebhookEnvironment(t *testing.T) {
|
||||
originalSandbox := setting.WaffoPancakeSandbox
|
||||
t.Cleanup(func() {
|
||||
setting.WaffoPancakeSandbox = originalSandbox
|
||||
// Parity tests for ResolveWaffoPancakeSubscriptionTradeNo — same four cases
|
||||
// as the TopUp resolver above, exercised against SubscriptionOrder records.
|
||||
// Drift between the two webhook flows is a real risk because they share
|
||||
// the same buyer-identity defence-in-depth pattern.
|
||||
|
||||
func TestResolveWaffoPancakeSubscriptionTradeNo_UsesWebhookOrderIDWhenLocalOrderExists(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: 1,
|
||||
PlanId: 5,
|
||||
Money: 29,
|
||||
TradeNo: "WAFFO_PANCAKE_SUB-1-1700000000-abc123",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(order).Error)
|
||||
|
||||
tradeNo, err := ResolveWaffoPancakeSubscriptionTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderID: "ORD_internal_pancake_id",
|
||||
OrderMerchantExternalID: "WAFFO_PANCAKE_SUB-1-1700000000-abc123",
|
||||
MerchantProvidedBuyerIdentity: WaffoPancakeBuyerIdentityFromUserID(order.UserId),
|
||||
},
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
payload string
|
||||
expected string
|
||||
sandbox bool
|
||||
}{
|
||||
{
|
||||
name: "test mode",
|
||||
payload: `{"mode":"test"}`,
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "prod mode",
|
||||
payload: `{"mode":"prod"}`,
|
||||
expected: "prod",
|
||||
},
|
||||
{
|
||||
name: "missing mode falls back to sandbox",
|
||||
payload: `{}`,
|
||||
expected: "test",
|
||||
sandbox: true,
|
||||
},
|
||||
{
|
||||
name: "invalid mode falls back to prod",
|
||||
payload: `{"mode":"staging"}`,
|
||||
expected: "prod",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setting.WaffoPancakeSandbox = tc.sandbox
|
||||
environment := resolveWaffoPancakeWebhookEnvironment(tc.payload)
|
||||
require.Equal(t, tc.expected, environment)
|
||||
})
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "WAFFO_PANCAKE_SUB-1-1700000000-abc123", tradeNo)
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeSubscriptionTradeNo_RejectsBuyerIdentityMismatch(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: 42,
|
||||
PlanId: 5,
|
||||
Money: 29,
|
||||
TradeNo: "WAFFO_PANCAKE_SUB-42-mismatch",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(order).Error)
|
||||
|
||||
tradeNo, err := ResolveWaffoPancakeSubscriptionTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderID: "ORD_internal_pancake_id",
|
||||
OrderMerchantExternalID: "WAFFO_PANCAKE_SUB-42-mismatch",
|
||||
MerchantProvidedBuyerIdentity: WaffoPancakeBuyerIdentityFromUserID(99), // wrong user
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, tradeNo)
|
||||
require.Contains(t, err.Error(), "buyer identity mismatch")
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeSubscriptionTradeNo_RejectsMissingBuyerIdentity(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: 7,
|
||||
PlanId: 5,
|
||||
Money: 29,
|
||||
TradeNo: "WAFFO_PANCAKE_SUB-7-missing-identity",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(order).Error)
|
||||
|
||||
tradeNo, err := ResolveWaffoPancakeSubscriptionTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderMerchantExternalID: "WAFFO_PANCAKE_SUB-7-missing-identity",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, tradeNo)
|
||||
require.Contains(t, err.Error(), "buyer identity mismatch")
|
||||
}
|
||||
|
||||
func TestResolveWaffoPancakeSubscriptionTradeNo_FailsWhenWebhookOrderIDIsUnknown(t *testing.T) {
|
||||
db := setupWaffoPancakeTestDB(t)
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: 42,
|
||||
PlanId: 5,
|
||||
Money: 29,
|
||||
TradeNo: "WAFFO_PANCAKE_SUB-42-real-order",
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(order).Error)
|
||||
|
||||
tradeNo, err := ResolveWaffoPancakeSubscriptionTradeNo(&WaffoPancakeWebhookEvent{
|
||||
Data: WaffoPancakeWebhookData{
|
||||
OrderMerchantExternalID: "WAFFO_PANCAKE_SUB-unknown",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, tradeNo)
|
||||
}
|
||||
|
||||
@@ -28,11 +28,12 @@ type ChannelAffinityRule struct {
|
||||
}
|
||||
|
||||
type ChannelAffinitySetting struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
SwitchOnSuccess bool `json:"switch_on_success"`
|
||||
MaxEntries int `json:"max_entries"`
|
||||
DefaultTTLSeconds int `json:"default_ttl_seconds"`
|
||||
Rules []ChannelAffinityRule `json:"rules"`
|
||||
Enabled bool `json:"enabled"`
|
||||
SwitchOnSuccess bool `json:"switch_on_success"`
|
||||
KeepOnChannelDisabled bool `json:"keep_on_channel_disabled"`
|
||||
MaxEntries int `json:"max_entries"`
|
||||
DefaultTTLSeconds int `json:"default_ttl_seconds"`
|
||||
Rules []ChannelAffinityRule `json:"rules"`
|
||||
}
|
||||
|
||||
var codexCliPassThroughHeaders = []string{
|
||||
@@ -74,10 +75,11 @@ func buildPassHeaderTemplate(headers []string) map[string]interface{} {
|
||||
}
|
||||
|
||||
var channelAffinitySetting = ChannelAffinitySetting{
|
||||
Enabled: true,
|
||||
SwitchOnSuccess: true,
|
||||
MaxEntries: 100_000,
|
||||
DefaultTTLSeconds: 3600,
|
||||
Enabled: true,
|
||||
SwitchOnSuccess: true,
|
||||
KeepOnChannelDisabled: false,
|
||||
MaxEntries: 100_000,
|
||||
DefaultTTLSeconds: 3600,
|
||||
Rules: []ChannelAffinityRule{
|
||||
{
|
||||
Name: "codex cli trace",
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
package setting
|
||||
|
||||
// Waffo Pancake hosted checkout configuration. Gateway is enabled once
|
||||
// MerchantID + PrivateKey + ProductID are populated (no separate Enabled
|
||||
// flag, matching Stripe / Creem). StoreID + ProductID are operator-bound
|
||||
// via SaveWaffoPancakeConfig.
|
||||
var (
|
||||
WaffoPancakeEnabled bool
|
||||
WaffoPancakeSandbox bool
|
||||
WaffoPancakeMerchantID string
|
||||
WaffoPancakePrivateKey string
|
||||
WaffoPancakeWebhookPublicKey string
|
||||
WaffoPancakeWebhookTestKey string
|
||||
WaffoPancakeStoreID string
|
||||
WaffoPancakeProductID string
|
||||
WaffoPancakeReturnURL string
|
||||
WaffoPancakeCurrency string = "USD"
|
||||
WaffoPancakeUnitPrice float64 = 1.0
|
||||
WaffoPancakeMinTopUp int = 1
|
||||
WaffoPancakeMerchantID string
|
||||
WaffoPancakePrivateKey string
|
||||
WaffoPancakeReturnURL string
|
||||
WaffoPancakeUnitPrice float64 = 1.0
|
||||
WaffoPancakeMinTopUp int = 1
|
||||
WaffoPancakeStoreID string
|
||||
WaffoPancakeProductID string
|
||||
)
|
||||
|
||||
@@ -71,6 +71,13 @@ var defaultCacheRatio = map[string]float64{
|
||||
"claude-opus-4-7-high": 0.1,
|
||||
"claude-opus-4-7-medium": 0.1,
|
||||
"claude-opus-4-7-low": 0.1,
|
||||
"claude-opus-4-8": 0.1,
|
||||
"claude-opus-4-8-thinking": 0.1,
|
||||
"claude-opus-4-8-max": 0.1,
|
||||
"claude-opus-4-8-xhigh": 0.1,
|
||||
"claude-opus-4-8-high": 0.1,
|
||||
"claude-opus-4-8-medium": 0.1,
|
||||
"claude-opus-4-8-low": 0.1,
|
||||
}
|
||||
|
||||
var defaultCreateCacheRatio = map[string]float64{
|
||||
@@ -106,6 +113,13 @@ var defaultCreateCacheRatio = map[string]float64{
|
||||
"claude-opus-4-7-high": 1.25,
|
||||
"claude-opus-4-7-medium": 1.25,
|
||||
"claude-opus-4-7-low": 1.25,
|
||||
"claude-opus-4-8": 1.25,
|
||||
"claude-opus-4-8-thinking": 1.25,
|
||||
"claude-opus-4-8-max": 1.25,
|
||||
"claude-opus-4-8-xhigh": 1.25,
|
||||
"claude-opus-4-8-high": 1.25,
|
||||
"claude-opus-4-8-medium": 1.25,
|
||||
"claude-opus-4-8-low": 1.25,
|
||||
}
|
||||
|
||||
//var defaultCreateCacheRatio = map[string]float64{}
|
||||
|
||||
@@ -152,6 +152,12 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-opus-4-7-high": 2.5,
|
||||
"claude-opus-4-7-medium": 2.5,
|
||||
"claude-opus-4-7-low": 2.5,
|
||||
"claude-opus-4-8": 2.5,
|
||||
"claude-opus-4-8-max": 2.5,
|
||||
"claude-opus-4-8-xhigh": 2.5,
|
||||
"claude-opus-4-8-high": 2.5,
|
||||
"claude-opus-4-8-medium": 2.5,
|
||||
"claude-opus-4-8-low": 2.5,
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"claude-opus-4-20250514": 7.5,
|
||||
"claude-opus-4-1-20250805": 7.5,
|
||||
|
||||
Vendored
+1245
-526
File diff suppressed because it is too large
Load Diff
Vendored
-2379
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user