Compare commits

..

2 Commits

Author SHA1 Message Date
CaIon 35f6880ad2 fix: update hero section text and localization strings for clarity 2026-05-12 16:43:56 +08:00
CaIon d07b159c18 fix: theme-aware payment paths, auto-group validation, route guards, perf group filtering
- Add common.ThemeAwarePath to generate correct redirect URLs based on
  active theme (default vs classic), replacing hardcoded /console/* paths
  in 7 controllers and service/quota.go (#4765)
- Validate auto-group availability against getUserGroups before defaulting
  form values; playground falls back to 'default' group when 'auto' is
  unavailable (#4796, #4799)
- Enforce HeaderNavModules settings in rankings route (frontend + backend
  API) and SidebarModulesAdmin in playground route to block direct URL
  access when features are disabled (#4704, #4512)
- Filter perf_metrics API response to only include currently configured
  groups, hiding stale data from deleted groups (#4790)
- Preserve query params (pay=success/fail) in /console/topup → /wallet
  frontend redirect
2026-05-12 16:43:32 +08:00
492 changed files with 13904 additions and 22151 deletions
-2
View File
@@ -56,8 +56,6 @@
# 对话超时设置
# 所有请求超时时间,单位秒,默认为0,表示不限制
# RELAY_TIMEOUT=0
# Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制
# RELAY_IDLE_CONN_TIMEOUT=90
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
# STREAMING_TIMEOUT=300
+4 -11
View File
@@ -11,8 +11,6 @@ assignees: ''
- 文档:https://docs.newapi.ai/
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
- 不接受 coding plan、逆向渠道等技术支持类 issue。
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
**您当前的 newapi 版本**
@@ -22,18 +20,13 @@ assignees: ''
**提交确认**
[//]: # (方框内删除已有的空格,填 x 号)
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue)确认目前没有类似 issue
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题。
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已完整查看文档 https://docs.newapi.ai/项目 README,尤其是常见问题部分
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
**问题描述**
请尽可能说明问题现象、影响范围,以及你判断它是程序问题而不是上游行为或使用问题的依据。
- 转发问题请尽可能说明渠道类型、转换格式、上游原生支持依据和服务端日志。
- 计费问题请尽可能附请求返回的 `usage` 示例。
**复现步骤**
**预期结果**
+4 -11
View File
@@ -11,8 +11,6 @@ assignees: ''
- Docs: https://docs.newapi.ai/
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
**Your current newapi version**
@@ -22,18 +20,13 @@ Please fill this in, for example: `v1.0.0`
**Submission Checks**
[//]: # (Remove the space in the box and fill with an x)
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question.
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
+ [ ] I have confirmed there are no similar issues
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, especially the FAQ section
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
**Issue Description**
Describe the symptom, impact scope, and why you believe this is an application issue rather than upstream behavior or a usage question with as much detail as possible.
- For forwarding issues, include the channel type, conversion format, upstream native-support evidence, and server logs when possible.
- For billing issues, include an example of the returned `usage` when possible.
**Steps to Reproduce**
**Expected Result**
+4 -6
View File
@@ -11,8 +11,6 @@ assignees: ''
- 文档:https://docs.newapi.ai/
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
- 不接受 coding plan、逆向渠道等技术支持类 issue。
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
**您当前的 newapi 版本**
@@ -22,10 +20,10 @@ assignees: ''
**提交确认**
[//]: # (方框内删除已有的空格,填 x 号)
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue)确认目前没有类似 issue
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题,且现有版本无法满足需求
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已完整查看文档 https://docs.newapi.ai/项目 README,已确定现有版本无法满足需求
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
**功能描述**
+4 -6
View File
@@ -11,8 +11,6 @@ assignees: ''
- Docs: https://docs.newapi.ai/
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
**Your current newapi version**
@@ -22,10 +20,10 @@ Please fill this in, for example: `v1.0.0`
**Submission Checks**
[//]: # (Remove the space in the box and fill with an x)
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question, and that the current version cannot meet my needs.
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
+ [ ] I have confirmed there are no similar issues
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, and confirmed the current version cannot meet my needs
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
**Feature Description**
+12 -18
View File
@@ -33,18 +33,16 @@ jobs:
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd default
cd web/default
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Build Frontend (classic)
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd classic
cd web/classic
bun install
VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Set up Go
@@ -93,18 +91,16 @@ jobs:
CI: ""
NODE_OPTIONS: "--max-old-space-size=4096"
run: |
cd web
bun install --frozen-lockfile
cd default
cd web/default
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Build Frontend (classic)
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd classic
cd web/classic
bun install
VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Set up Go
@@ -150,18 +146,16 @@ jobs:
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd default
cd web/default
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Build Frontend (classic)
env:
CI: ""
run: |
cd web
bun install --frozen-lockfile
cd classic
cd web/classic
bun install
VITE_REACT_APP_VERSION=$VERSION bun run build
cd ../..
- name: Set up Go
-1
View File
@@ -35,4 +35,3 @@ data/
.test
token_estimator_test.go
skills-lock.json
.playwright-mcp
+16 -18
View File
@@ -1,24 +1,22 @@
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
WORKDIR /build/web
COPY web/package.json web/bun.lock ./
COPY web/default/package.json ./default/package.json
COPY web/classic/package.json ./classic/package.json
RUN bun install --frozen-lockfile
COPY ./web/default ./default
COPY ./VERSION /build/VERSION
RUN cd default && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
WORKDIR /build
COPY web/default/package.json .
COPY web/default/bun.lock .
RUN bun install
COPY ./web/default .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder-classic
WORKDIR /build/web
COPY web/package.json web/bun.lock ./
COPY web/default/package.json ./default/package.json
COPY web/classic/package.json ./classic/package.json
RUN bun install --frozen-lockfile
COPY ./web/classic ./classic
COPY ./VERSION /build/VERSION
RUN cd classic && VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
WORKDIR /build
COPY web/classic/package.json .
COPY web/classic/bun.lock .
RUN bun install
COPY ./web/classic .
COPY ./VERSION .
RUN VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
ENV GO111MODULE=on CGO_ENABLED=0
@@ -34,8 +32,8 @@ ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/web/default/dist ./web/default/dist
COPY --from=builder-classic /build/web/classic/dist ./web/classic/dist
COPY --from=builder /build/dist ./web/default/dist
COPY --from=builder-classic /build/dist ./web/classic/dist
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
+1 -1
View File
@@ -413,7 +413,7 @@ docker run --name new-api -d --restart always \
| Project | Description |
|------|------|
| [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool) | Key quota query tool |
| [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) | Key quota query tool |
| [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) | New API high-performance optimized version |
---
+2 -2
View File
@@ -206,7 +206,7 @@ docker run --name new-api -d --restart always \
- 🤖 Connexion par autorisation LinuxDO
- 📱 Connexion par autorisation Telegram
- 🔑 Authentification unifiée OIDC
- 🔍 Requête de quota d'utilisation de clé (avec [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool))
- 🔍 Requête de quota d'utilisation de clé (avec [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
### 🚀 Fonctionnalités avancées
@@ -420,7 +420,7 @@ docker run --name new-api -d --restart always \
| Projet | Description |
|------|------|
| [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool) | Outil de recherche de quota d'utilisation avec une clé |
| [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) | Outil de recherche de quota d'utilisation avec une clé |
| [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) | Version optimisée haute performance de New API |
---
+2 -2
View File
@@ -206,7 +206,7 @@ docker run --name new-api -d --restart always \
- 🤖 LinuxDO認証ログイン
- 📱 Telegram認証ログイン
- 🔑 OIDC統一認証
- 🔍 Key使用量クォータ照会([new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool)と併用)
- 🔍 Key使用量クォータ照会([neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)と併用)
@@ -420,7 +420,7 @@ docker run --name new-api -d --restart always \
| プロジェクト | 説明 |
|------|------|
| [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool) | キー使用量クォータ照会ツール |
| [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) | キー使用量クォータ照会ツール |
| [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) | New API高性能最適化版 |
---
+2 -3
View File
@@ -206,7 +206,7 @@ docker run --name new-api -d --restart always \
- 🤖 LinuxDO authorization login
- 📱 Telegram authorization login
- 🔑 OIDC unified authentication
- 🔍 Key quota query usage (with [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool))
- 🔍 Key quota query usage (with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
### 🚀 Advanced Features
@@ -316,7 +316,6 @@ docker run --name new-api -d --restart always \
| `CRYPTO_SECRET` | Encryption secret (required for Redis) | - |
| `SQL_DSN` | Database connection string | - |
| `REDIS_CONN_STRING` | Redis connection string | - |
| `RELAY_IDLE_CONN_TIMEOUT` | Idle keep-alive timeout for relay HTTP clients, seconds. Defaults to Go standard library behavior; set `0` to disable | `90` |
| `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` |
| `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` |
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
@@ -421,7 +420,7 @@ docker run --name new-api -d --restart always \
| Project | Description |
|------|------|
| [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool) | Key quota query tool |
| [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) | Key quota query tool |
| [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) | New API high-performance optimized version |
---
+2 -2
View File
@@ -206,7 +206,7 @@ docker run --name new-api -d --restart always \
- 🤖 LinuxDO 授权登录
- 📱 Telegram 授权登录
- 🔑 OIDC 统一认证
- 🔍 Key 查询使用额度(配合 [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool)
- 🔍 Key 查询使用额度(配合 [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)
### 🚀 高级功能
@@ -420,7 +420,7 @@ docker run --name new-api -d --restart always \
| 项目 | 说明 |
|------|------|
| [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool) | Key 额度查询工具 |
| [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) | Key 额度查询工具 |
| [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) | New API 高性能优化版 |
---
+2 -2
View File
@@ -206,7 +206,7 @@ docker run --name new-api -d --restart always \
- 🤖 LinuxDO 授權登錄
- 📱 Telegram 授權登錄
- 🔑 OIDC 統一認證
- 🔍 Key 查詢使用額度(配合 [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool)
- 🔍 Key 查詢使用額度(配合 [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)
### 🚀 高級功能
@@ -420,7 +420,7 @@ docker run --name new-api -d --restart always \
| 項目 | 說明 |
|------|------|
| [new-api-key-tool](https://github.com/Calcium-Ion/new-api-key-tool) | Key 額度查詢工具 |
| [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) | Key 額度查詢工具 |
| [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) | New API 高性能優化版 |
---
+1 -3
View File
@@ -170,7 +170,6 @@ var BatchUpdateInterval int
var RelayTimeout int // unit is second
var RelayIdleConnTimeout int // unit is second
var RelayMaxIdleConns int
var RelayMaxIdleConnsPerHost int
@@ -180,8 +179,7 @@ var GeminiSafetySetting string
var CohereSafetySetting string
const (
RequestIdKey = "X-Oneapi-Request-Id"
UpstreamRequestIdKey = "X-Upstream-Request-Id"
RequestIdKey = "X-Oneapi-Request-Id"
)
const (
+2 -2
View File
@@ -37,7 +37,7 @@ func checkWriter(writer io.Writer) stringWriter {
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
var writeContentType = []string{"text/event-stream"}
var contentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
@@ -79,7 +79,7 @@ func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
header := w.Header()
header["Content-Type"] = writeContentType
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
+1 -19
View File
@@ -110,29 +110,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
// disk-backed JSON: stream-decode directly from the file to avoid
// materializing the entire payload back into a transient []byte
// (diskStorage.Bytes() would ReadFull the whole file into the heap).
if storage.IsDisk() && strings.HasPrefix(contentType, "application/json") {
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
if err := DecodeJson(storage, v); err != nil {
return err
}
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
c.Request.Body = io.NopCloser(storage)
return nil
}
requestBody, err := storage.Bytes()
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
-2
View File
@@ -102,7 +102,6 @@ func InitEnv() {
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
RelayIdleConnTimeout = GetEnvOrDefault("RELAY_IDLE_CONN_TIMEOUT", 90)
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
@@ -136,7 +135,6 @@ func initConstantEnv() {
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 128)
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
constant.AnonymousRequestBodyLimitKB = GetEnvOrDefault("ANONYMOUS_REQUEST_BODY_LIMIT_KB", 512)
// ForceStreamOption 覆盖请求参数,强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
-13
View File
@@ -1,13 +0,0 @@
package common
import "github.com/QuantumNous/new-api/constant"
const defaultAnonymousRequestBodyLimitKB = 512
func GetAnonymousRequestBodyLimitBytes() int64 {
limitKB := constant.AnonymousRequestBodyLimitKB
if limitKB < 0 {
limitKB = defaultAnonymousRequestBodyLimitKB
}
return int64(limitKB) << 10
}
-11
View File
@@ -3,7 +3,6 @@ package common
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
@@ -21,16 +20,6 @@ var (
maskApiKeyPattern = regexp.MustCompile(`(['"]?)api_key:([^\s'"]+)(['"]?)`)
)
const LocalLogContentLimit = 2048
// LocalLogPreview limits log-only content unless debug logging is enabled.
func LocalLogPreview(content string) string {
if DebugEnabled || len(content) <= LocalLogContentLimit {
return content
}
return fmt.Sprintf("%s... [truncated, original_length=%d, limit=%d]", content[:LocalLogContentLimit], len(content), LocalLogContentLimit)
}
func GetStringIfEmpty(str string, defaultValue string) string {
if str == "" {
return defaultValue
-1
View File
@@ -10,7 +10,6 @@ var GetMediaToken bool
var GetMediaTokenNotStream bool
var UpdateTask bool
var MaxRequestBodyMB int
var AnonymousRequestBodyLimitKB int
var AzureDefaultAPIVersion string
var NotifyLimitCount int
var NotificationLimitDurationMinute int
+8 -34
View File
@@ -57,24 +57,7 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
return normalized
}
func resolveChannelTestUserID(c *gin.Context) (int, error) {
if c != nil {
if userID := c.GetInt("id"); userID > 0 {
return userID, nil
}
}
var rootUser model.User
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
}
if rootUser.Id == 0 {
return 0, errors.New("failed to resolve channel test user")
}
return rootUser.Id, nil
}
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney,
@@ -160,7 +143,7 @@ func testChannel(channel *model.Channel, testUserID int, testModel string, endpo
Header: make(http.Header),
}
cache, err := model.GetUserCache(testUserID)
cache, err := model.GetUserCache(1)
if err != nil {
return testResult{
localErr: err,
@@ -168,13 +151,13 @@ func testChannel(channel *model.Channel, testUserID int, testModel string, endpo
}
}
cache.WriteContext(c)
c.Set("id", testUserID)
c.Set("id", 1)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(testUserID, false)
group, _ := model.GetUserGroup(1, false)
c.Set("group", group)
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
@@ -501,7 +484,7 @@ func testChannel(channel *model.Channel, testUserID int, testModel string, endpo
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
@@ -814,7 +797,7 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel,
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
if dto.IsOpenAIReasoningOModel(model) {
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = lo.ToPtr(uint(16))
} else if strings.Contains(model, "thinking") {
if !strings.Contains(model, "claude") {
@@ -851,13 +834,8 @@ func TestChannel(c *gin.Context) {
testModel := c.Query("model")
endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream"))
testUserID, err := resolveChannelTestUserID(c)
if err != nil {
common.ApiError(c, err)
return
}
tik := time.Now()
result := testChannel(channel, testUserID, testModel, endpointType, isStream)
result := testChannel(channel, testModel, endpointType, isStream)
if result.localErr != nil {
resp := gin.H{
"success": false,
@@ -894,10 +872,6 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
testUserID, err := resolveChannelTestUserID(nil)
if err != nil {
return err
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
@@ -928,7 +902,7 @@ func testAllChannels(notify bool) error {
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
+40 -64
View File
@@ -19,7 +19,6 @@ import (
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OpenAIModel struct {
@@ -69,33 +68,12 @@ func clearChannelInfo(channel *model.Channel) {
}
}
func applyChannelStatusFilter(query *gorm.DB, statusFilter int) *gorm.DB {
if statusFilter == common.ChannelStatusEnabled {
return query.Where("status = ?", common.ChannelStatusEnabled)
}
if statusFilter == 0 {
return query.Where("status != ?", common.ChannelStatusEnabled)
}
return query
}
func buildChannelListQuery(group string, statusFilter int, typeFilter int) *gorm.DB {
query := model.DB.Model(&model.Channel{})
query = model.ApplyChannelGroupFilter(query, group)
query = applyChannelStatusFilter(query, statusFilter)
if typeFilter >= 0 {
query = query.Where("type = ?", typeFilter)
}
return query
}
func GetAllChannels(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
groupFilter := model.NormalizeChannelGroupFilter(c.Query("group"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
@@ -111,45 +89,50 @@ func GetAllChannels(c *gin.Context) {
var total int64
if enableTagMode {
tags, err := model.GetPaginatedChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter), pageInfo.GetStartIdx(), pageInfo.GetPageSize())
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.SysError("failed to get paginated tags: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
return
}
total, err = model.CountChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter))
if err != nil {
common.SysError("failed to count tags: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签数量失败,请稍后重试"})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
var tagChannels []*model.Channel
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter).Where("tag = ?", *tag)).
Omit("key").
Find(&tagChannels).Error
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
if err != nil {
common.SysError("failed to get channels by tag: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签渠道失败,请稍后重试"})
return
continue
}
channelData = append(channelData, tagChannels...)
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
total, _ = model.CountAllTags()
} else {
if err := buildChannelListQuery(groupFilter, statusFilter, typeFilter).Count(&total).Error; err != nil {
common.SysError("failed to count channels: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道数量失败,请稍后重试"})
return
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter)).
Limit(pageInfo.GetPageSize()).
Offset(pageInfo.GetStartIdx()).
Omit("key").
Find(&channelData).Error
baseQuery.Count(&total)
err := sortOptions.Apply(baseQuery).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
if err != nil {
common.SysError("failed to get channels: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
@@ -161,16 +144,17 @@ func GetAllChannels(c *gin.Context) {
clearChannelInfo(datum)
}
countQuery := buildChannelListQuery(groupFilter, statusFilter, -1)
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
if err := countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error; err != nil {
common.SysError("failed to count channel types: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道类型统计失败,请稍后重试"})
return
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
@@ -278,18 +262,10 @@ func SearchChannels(c *gin.Context) {
}
for _, tag := range tags {
if tag != nil && *tag != "" {
var tagChannels []*model.Channel
err := sortOptions.Apply(buildChannelListQuery(group, -1, -1).Where("tag = ?", *tag)).
Omit("key").
Find(&tagChannels).Error
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
if err == nil {
channelData = append(channelData, tagChannel...)
}
channelData = append(channelData, tagChannels...)
}
}
} else {
@@ -1218,7 +1194,7 @@ func CopyChannel(c *gin.Context) {
}
// insert
if err := clone.Insert(); err != nil {
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
common.SysError("failed to clone channel: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
return
-11
View File
@@ -69,14 +69,3 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
require.Equal(t, "base", other["matched_tier"])
require.NotEmpty(t, other["expr_b64"])
}
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Set("id", 2)
userID, err := resolveChannelTestUserID(ctx)
require.NoError(t, err)
require.Equal(t, 2, userID)
}
+2 -2
View File
@@ -501,7 +501,7 @@ func GetUserOAuthBindingsByAdmin(c *gin.Context) {
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, targetUser.Role) {
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
common.ApiErrorMsg(c, "no permission")
return
}
@@ -560,7 +560,7 @@ func UnbindCustomOAuthByAdmin(c *gin.Context) {
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, targetUser.Role) {
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
common.ApiErrorMsg(c, "no permission")
return
}
+2 -4
View File
@@ -21,8 +21,7 @@ func GetAllLogs(c *gin.Context) {
channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group")
requestId := c.Query("request_id")
upstreamRequestId := c.Query("upstream_request_id")
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group, requestId, upstreamRequestId)
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group, requestId)
if err != nil {
common.ApiError(c, err)
return
@@ -43,8 +42,7 @@ func GetUserLogs(c *gin.Context) {
modelName := c.Query("model_name")
group := c.Query("group")
requestId := c.Query("request_id")
upstreamRequestId := c.Query("upstream_request_id")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group, requestId, upstreamRequestId)
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group, requestId)
if err != nil {
common.ApiError(c, err)
return
-3
View File
@@ -87,9 +87,6 @@ func GetStatus(c *gin.Context) {
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"register_enabled": common.RegisterEnabled,
"password_login_enabled": common.PasswordLoginEnabled,
"password_register_enabled": common.PasswordRegisterEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup,
"usd_exchange_rate": operation_setting.USDExchangeRate,
+43 -120
View File
@@ -3,7 +3,6 @@ package controller
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
@@ -110,102 +109,9 @@ func init() {
})
}
func channelOwnerName(channelType int) string {
apiType, success := common.ChannelType2APIType(channelType)
if !success {
return strings.ToLower(constant.GetChannelTypeName(channelType))
}
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return strings.ToLower(constant.GetChannelTypeName(channelType))
}
adaptor.Init(&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
ChannelType: channelType,
}})
if name := strings.TrimSpace(adaptor.GetChannelName()); name != "" {
return name
}
return strings.ToLower(constant.GetChannelTypeName(channelType))
}
func getPreferredModelOwners(modelNames []string, groups []string) map[string]string {
channelTypes, err := model.GetPreferredModelOwnerChannelTypes(modelNames, groups)
if err != nil {
common.SysLog(fmt.Sprintf("GetPreferredModelOwnerChannelTypes error: %v", err))
return map[string]string{}
}
ownerByChannelType := make(map[int]string)
owners := make(map[string]string, len(channelTypes))
for modelName, channelType := range channelTypes {
owner, ok := ownerByChannelType[channelType]
if !ok {
owner = channelOwnerName(channelType)
ownerByChannelType[channelType] = owner
}
if owner != "" {
owners[modelName] = owner
}
}
return owners
}
func buildOpenAIModel(modelName string, ownerByModel map[string]string) dto.OpenAIModels {
var oaiModel dto.OpenAIModels
if staticModel, ok := openAIModelsMap[modelName]; ok {
oaiModel = staticModel
} else {
oaiModel = dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
}
}
if owner, ok := ownerByModel[modelName]; ok && owner != "" {
oaiModel.OwnedBy = owner
}
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
return oaiModel
}
type modelListGroups struct {
userGroup string
tokenGroup string
ownerGroups []string
}
func getModelListGroups(c *gin.Context) (modelListGroups, error) {
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
if userGroup == "" && (tokenGroup == "" || tokenGroup == "auto") {
var err error
userGroup, err = model.GetUserGroup(c.GetInt("id"), false)
if err != nil {
return modelListGroups{}, err
}
}
if tokenGroup == "auto" {
return modelListGroups{
userGroup: userGroup,
tokenGroup: tokenGroup,
ownerGroups: service.GetUserAutoGroup(userGroup),
}, nil
}
group := userGroup
if tokenGroup != "" {
group = tokenGroup
}
return modelListGroups{
userGroup: userGroup,
tokenGroup: tokenGroup,
ownerGroups: []string{group},
}, nil
}
func ListModels(c *gin.Context, modelType int) {
userOpenAiModels := make([]dto.OpenAIModels, 0)
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
if !acceptUnsetRatioModel {
userId := c.GetInt("id")
@@ -217,16 +123,6 @@ func ListModels(c *gin.Context, modelType int) {
}
}
userModelNames := make([]string, 0)
groups, err := getModelListGroups(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
})
return
}
ownerGroups := groups.ownerGroups
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
@@ -242,12 +138,37 @@ func ListModels(c *gin.Context, modelType int) {
continue
}
}
userModelNames = append(userModelNames, allowModel)
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
})
return
}
group := userGroup
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
group = tokenGroup
}
var models []string
if groups.tokenGroup == "auto" {
for _, autoGroup := range ownerGroups {
if tokenGroup == "auto" {
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
groupModels := model.GetGroupEnabledModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {
@@ -256,7 +177,7 @@ func ListModels(c *gin.Context, modelType int) {
}
}
} else {
models = model.GetGroupEnabledModels(ownerGroups[0])
models = model.GetGroupEnabledModels(group)
}
for _, modelName := range models {
if !acceptUnsetRatioModel {
@@ -264,19 +185,21 @@ func ListModels(c *gin.Context, modelType int) {
continue
}
}
userModelNames = append(userModelNames, modelName)
if oaiModel, ok := openAIModelsMap[modelName]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
})
}
}
}
ownerByModel := map[string]string{}
if len(ownerGroups) > 0 {
ownerByModel = getPreferredModelOwners(userModelNames, ownerGroups)
}
userOpenAiModels := make([]dto.OpenAIModels, 0, len(userModelNames))
for _, modelName := range userModelNames {
userOpenAiModels = append(userOpenAiModels, buildOpenAIModel(modelName, ownerByModel))
}
switch modelType {
case constant.ChannelTypeAnthropic:
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
-85
View File
@@ -1,85 +0,0 @@
package controller
import (
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestChannelOwnerNameUsesAdaptorChannelName(t *testing.T) {
tests := []struct {
name string
channelType int
expected string
}{
{
name: "openai",
channelType: constant.ChannelTypeOpenAI,
expected: "openai",
},
{
name: "codex",
channelType: constant.ChannelTypeCodex,
expected: "codex",
},
{
name: "openrouter",
channelType: constant.ChannelTypeOpenRouter,
expected: "openrouter",
},
{
name: "azure fallback",
channelType: constant.ChannelTypeAzure,
expected: "azure",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.expected, channelOwnerName(tt.channelType))
})
}
}
func TestBuildOpenAIModelOverridesOwnedBy(t *testing.T) {
modelItem := buildOpenAIModel("gpt-5.4", map[string]string{"gpt-5.4": "openai"})
require.Equal(t, "gpt-5.4", modelItem.Id)
require.Equal(t, "openai", modelItem.OwnedBy)
}
func TestBuildOpenAIModelFallsBackToCustomForUnknownModels(t *testing.T) {
modelItem := buildOpenAIModel("custom-test-model", nil)
require.Equal(t, "custom-test-model", modelItem.Id)
require.Equal(t, "custom", modelItem.OwnedBy)
}
func TestGetModelListGroupsUsesUserGroupWhenTokenGroupIsEmpty(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
groups, err := getModelListGroups(ctx)
require.NoError(t, err)
require.Equal(t, "default", groups.userGroup)
require.Empty(t, groups.tokenGroup)
require.Equal(t, []string{"default"}, groups.ownerGroups)
}
func TestGetModelListGroupsUsesExplicitTokenGroup(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
common.SetContextKey(ctx, constant.ContextKeyTokenGroup, "vip")
groups, err := getModelListGroups(ctx)
require.NoError(t, err)
require.Equal(t, "default", groups.userGroup)
require.Equal(t, "vip", groups.tokenGroup)
require.Equal(t, []string{"vip"}, groups.ownerGroups)
}
+9 -25
View File
@@ -3,11 +3,9 @@ package controller
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/console_setting"
@@ -29,17 +27,13 @@ var completionRatioMetaOptionKeys = []string{
"AudioCompletionRatio",
}
func isPaymentComplianceOptionKey(key string) bool {
return strings.HasPrefix(key, "payment_setting.compliance_")
}
func isPositiveOptionValue(value string) bool {
intValue, err := strconv.Atoi(strings.TrimSpace(value))
if err == nil {
return intValue > 0
func isVisiblePublicKeyOption(key string) bool {
switch key {
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
return true
default:
return false
}
floatValue, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
return err == nil && floatValue > 0
}
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
@@ -86,7 +80,7 @@ func GetOptions(c *gin.Context) {
strings.HasSuffix(k, "Key") ||
strings.HasSuffix(k, "secret") ||
strings.HasSuffix(k, "api_key")
if isSensitiveKey {
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
continue
}
options = append(options, &model.Option{
@@ -110,6 +104,7 @@ func GetOptions(c *gin.Context) {
"message": "",
"data": options,
})
return
}
type OptionUpdateRequest struct {
@@ -138,18 +133,6 @@ func UpdateOption(c *gin.Context) {
option.Value = fmt.Sprintf("%v", option.Value)
}
switch option.Key {
case "QuotaForInviter", "QuotaForInvitee":
if isPositiveOptionValue(option.Value.(string)) && !operation_setting.IsPaymentComplianceConfirmed() {
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
return
}
default:
if isPaymentComplianceOptionKey(option.Key) {
common.ApiErrorMsg(c, "合规确认字段不允许通过通用设置接口修改")
return
}
}
switch option.Key {
case "GitHubOAuthEnabled":
if option.Value == "true" && common.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
@@ -341,4 +324,5 @@ func UpdateOption(c *gin.Context) {
"success": true,
"message": "",
})
return
}
-5
View File
@@ -350,11 +350,6 @@ func AdminResetPasskey(c *gin.Context) {
common.ApiError(c, err)
return
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, user.Role) {
common.ApiErrorMsg(c, "no permission")
return
}
if _, err := model.GetPasskeyByUserID(user.Id); err != nil {
if errors.Is(err, model.ErrPasskeyNotFound) {
-82
View File
@@ -1,82 +0,0 @@
package controller
import (
"fmt"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
)
type PaymentComplianceRequest struct {
Confirmed bool `json:"confirmed"`
}
func requirePaymentCompliance(c *gin.Context) bool {
if !operation_setting.IsPaymentComplianceConfirmed() {
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
return false
}
return true
}
func ConfirmPaymentCompliance(c *gin.Context) {
if c.GetBool("use_access_token") {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "This operation requires dashboard session authentication. API access token is not allowed.",
})
return
}
var req PaymentComplianceRequest
if err := common.DecodeJson(c.Request.Body, &req); err != nil {
common.ApiErrorMsg(c, "参数错误")
return
}
if !req.Confirmed {
common.ApiErrorMsg(c, "请确认合规声明")
return
}
now := time.Now().Unix()
userId := c.GetInt("id")
clientIP := c.ClientIP()
updates := map[string]string{
"payment_setting.compliance_confirmed": "true",
"payment_setting.compliance_terms_version": operation_setting.CurrentComplianceTermsVersion,
"payment_setting.compliance_confirmed_at": strconv.FormatInt(now, 10),
"payment_setting.compliance_confirmed_by": strconv.Itoa(userId),
"payment_setting.compliance_confirmed_ip": clientIP,
}
for key, value := range updates {
if err := model.UpdateOption(key, value); err != nil {
common.ApiError(c, err)
return
}
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf(
"payment compliance confirmed user_id=%d ip=%s terms_version=%s confirmed_at=%d",
userId,
clientIP,
operation_setting.CurrentComplianceTermsVersion,
now,
))
common.ApiSuccess(c, gin.H{
"confirmed": true,
"terms_version": operation_setting.CurrentComplianceTermsVersion,
"confirmed_at": now,
"confirmed_by": userId,
})
}
+11 -21
View File
@@ -7,14 +7,7 @@ import (
"github.com/QuantumNous/new-api/setting/operation_setting"
)
func isPaymentComplianceConfirmed() bool {
return operation_setting.IsPaymentComplianceConfirmed()
}
func isStripeTopUpEnabled() bool {
if !isPaymentComplianceConfirmed() {
return false
}
return strings.TrimSpace(setting.StripeApiSecret) != "" &&
strings.TrimSpace(setting.StripeWebhookSecret) != "" &&
strings.TrimSpace(setting.StripePriceId) != ""
@@ -29,9 +22,6 @@ func isStripeWebhookEnabled() bool {
}
func isCreemTopUpEnabled() bool {
if !isPaymentComplianceConfirmed() {
return false
}
products := strings.TrimSpace(setting.CreemProducts)
return strings.TrimSpace(setting.CreemApiKey) != "" &&
products != "" &&
@@ -47,9 +37,6 @@ func isCreemWebhookEnabled() bool {
}
func isWaffoTopUpEnabled() bool {
if !isPaymentComplianceConfirmed() {
return false
}
if !setting.WaffoEnabled {
return false
}
@@ -74,18 +61,24 @@ func isWaffoWebhookEnabled() bool {
}
func isWaffoPancakeTopUpEnabled() bool {
if !isPaymentComplianceConfirmed() {
if !setting.WaffoPancakeEnabled {
return false
}
// Presence-of-credentials = enabled. Webhook public keys ship inside
// the SDK; mode (test/prod) is read from each event.
return strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
return isWaffoPancakeWebhookConfigured() &&
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
}
func isWaffoPancakeWebhookConfigured() bool {
return isWaffoPancakeTopUpEnabled()
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
if setting.WaffoPancakeSandbox {
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
}
return currentWebhookKey != ""
}
func isWaffoPancakeWebhookEnabled() bool {
@@ -93,9 +86,6 @@ func isWaffoPancakeWebhookEnabled() bool {
}
func isEpayTopUpEnabled() bool {
if !isPaymentComplianceConfirmed() {
return false
}
return isEpayWebhookConfigured() && len(operation_setting.PayMethods) > 0
}
+23 -26
View File
@@ -8,21 +8,7 @@ import (
"github.com/stretchr/testify/require"
)
func confirmPaymentComplianceForTest(t *testing.T) {
t.Helper()
paymentSetting := operation_setting.GetPaymentSetting()
originalConfirmed := paymentSetting.ComplianceConfirmed
originalTermsVersion := paymentSetting.ComplianceTermsVersion
t.Cleanup(func() {
paymentSetting.ComplianceConfirmed = originalConfirmed
paymentSetting.ComplianceTermsVersion = originalTermsVersion
})
paymentSetting.ComplianceConfirmed = true
paymentSetting.ComplianceTermsVersion = operation_setting.CurrentComplianceTermsVersion
}
func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
confirmPaymentComplianceForTest(t)
originalAPISecret := setting.StripeApiSecret
originalWebhookSecret := setting.StripeWebhookSecret
originalPriceID := setting.StripePriceId
@@ -45,7 +31,6 @@ func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
}
func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
confirmPaymentComplianceForTest(t)
originalAPIKey := setting.CreemApiKey
originalProducts := setting.CreemProducts
originalWebhookSecret := setting.CreemWebhookSecret
@@ -68,7 +53,6 @@ func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
}
func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
confirmPaymentComplianceForTest(t)
originalEnabled := setting.WaffoEnabled
originalSandbox := setting.WaffoSandbox
originalAPIKey := setting.WaffoApiKey
@@ -113,37 +97,50 @@ func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
}
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
confirmPaymentComplianceForTest(t)
originalEnabled := setting.WaffoPancakeEnabled
originalSandbox := setting.WaffoPancakeSandbox
originalMerchantID := setting.WaffoPancakeMerchantID
originalPrivateKey := setting.WaffoPancakePrivateKey
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
originalStoreID := setting.WaffoPancakeStoreID
originalProductID := setting.WaffoPancakeProductID
t.Cleanup(func() {
setting.WaffoPancakeEnabled = originalEnabled
setting.WaffoPancakeSandbox = originalSandbox
setting.WaffoPancakeMerchantID = originalMerchantID
setting.WaffoPancakePrivateKey = originalPrivateKey
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
setting.WaffoPancakeStoreID = originalStoreID
setting.WaffoPancakeProductID = originalProductID
})
// Presence of all three credentials enables the gateway. Webhook public
// keys are bundled in the SDK and there is no separate Enabled toggle —
// clear any of the three fields to disable.
setting.WaffoPancakeMerchantID = ""
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = false
setting.WaffoPancakeMerchantID = "merchant"
setting.WaffoPancakePrivateKey = "private"
setting.WaffoPancakeStoreID = "store"
setting.WaffoPancakeProductID = "product"
setting.WaffoPancakeWebhookPublicKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeMerchantID = "merchant"
setting.WaffoPancakeWebhookPublicKey = "public"
require.True(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeProductID = ""
setting.WaffoPancakeEnabled = false
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeProductID = "product"
setting.WaffoPancakePrivateKey = ""
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = true
setting.WaffoPancakeWebhookTestKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookTestKey = "test_public"
require.True(t, isWaffoPancakeWebhookEnabled())
}
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
confirmPaymentComplianceForTest(t)
originalPayAddress := operation_setting.PayAddress
originalEpayID := operation_setting.EpayId
originalEpayKey := operation_setting.EpayKey
+9 -8
View File
@@ -8,7 +8,6 @@ import (
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
)
func GetPerfMetricsSummary(c *gin.Context) {
@@ -19,8 +18,7 @@ func GetPerfMetricsSummary(c *gin.Context) {
}
}
activeGroups := append(lo.Keys(ratio_setting.GetGroupRatioCopy()), "auto")
result, err := perfmetrics.QuerySummaryAll(hours, activeGroups)
result, err := perfmetrics.QuerySummaryAll(hours)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
@@ -74,9 +72,12 @@ func GetPerfMetrics(c *gin.Context) {
}
func filterActiveGroups(groups []perfmetrics.GroupResult) []perfmetrics.GroupResult {
activeRatios := ratio_setting.GetGroupRatioCopy()
return lo.Filter(groups, func(g perfmetrics.GroupResult, _ int) bool {
_, ok := activeRatios[g.Group]
return ok || g.Group == "auto"
})
activeGroups := ratio_setting.GetGroupRatioCopy()
filtered := make([]perfmetrics.GroupResult, 0, len(groups))
for _, g := range groups {
if _, ok := activeGroups[g.Group]; ok || g.Group == "auto" {
filtered = append(filtered, g)
}
}
return filtered
}
+40
View File
@@ -3,11 +3,51 @@ package controller
import (
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
func isRankingsEnabled() bool {
common.OptionMapRWMutex.RLock()
raw := common.OptionMap["HeaderNavModules"]
common.OptionMapRWMutex.RUnlock()
if raw == "" {
return true
}
var parsed map[string]interface{}
if err := common.Unmarshal([]byte(raw), &parsed); err != nil {
return true
}
rankings, ok := parsed["rankings"]
if !ok {
return true
}
switch v := rankings.(type) {
case bool:
return v
case map[string]interface{}:
if enabled, ok := v["enabled"]; ok {
if b, ok := enabled.(bool); ok {
return b
}
}
return true
}
return true
}
func GetRankings(c *gin.Context) {
if !isRankingsEnabled() {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "rankings is disabled",
})
return
}
result, err := service.GetRankingsSnapshot(c.DefaultQuery("period", "week"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
-6
View File
@@ -8,7 +8,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
)
@@ -60,11 +59,6 @@ func GetRedemption(c *gin.Context) {
}
func AddRedemption(c *gin.Context) {
if !operation_setting.IsPaymentComplianceConfirmed() {
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
return
}
redemption := model.Redemption{}
err := c.ShouldBindJSON(&redemption)
if err != nil {
+2 -2
View File
@@ -88,7 +88,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
defer func() {
if newAPIError != nil {
logger.LogError(c, fmt.Sprintf("relay error: %s", common.LocalLogPreview(newAPIError.Error())))
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
@@ -354,7 +354,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
}
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, common.LocalLogPreview(err.Error())))
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(err) && channelError.AutoBan {
-58
View File
@@ -6,7 +6,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -22,18 +21,9 @@ type BillingPreferenceRequest struct {
BillingPreference string `json:"billing_preference"`
}
type SubscriptionBalancePayRequest struct {
PlanId int `json:"plan_id"`
}
// ---- User APIs ----
func GetSubscriptionPlans(c *gin.Context) {
if !operation_setting.IsPaymentComplianceConfirmed() {
common.ApiSuccess(c, []SubscriptionPlanDTO{})
return
}
var plans []model.SubscriptionPlan
if err := model.DB.Where("enabled = ?", true).Order("sort_order desc, id desc").Find(&plans).Error; err != nil {
common.ApiError(c, err)
@@ -41,7 +31,6 @@ func GetSubscriptionPlans(c *gin.Context) {
}
result := make([]SubscriptionPlanDTO, 0, len(plans))
for _, p := range plans {
p.NormalizeDefaults()
result = append(result, SubscriptionPlanDTO{
Plan: p,
})
@@ -97,25 +86,6 @@ func UpdateSubscriptionPreference(c *gin.Context) {
common.ApiSuccess(c, gin.H{"billing_preference": pref})
}
func SubscriptionRequestBalancePay(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
userId := c.GetInt("id")
var req SubscriptionBalancePayRequest
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
common.ApiErrorMsg(c, "参数错误")
return
}
if err := model.PurchaseSubscriptionWithBalance(userId, req.PlanId); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, nil)
}
// ---- Admin APIs ----
func AdminListSubscriptionPlans(c *gin.Context) {
@@ -126,7 +96,6 @@ func AdminListSubscriptionPlans(c *gin.Context) {
}
result := make([]SubscriptionPlanDTO, 0, len(plans))
for _, p := range plans {
p.NormalizeDefaults()
result = append(result, SubscriptionPlanDTO{
Plan: p,
})
@@ -139,10 +108,6 @@ type AdminUpsertSubscriptionPlanRequest struct {
}
func AdminCreateSubscriptionPlan(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
var req AdminUpsertSubscriptionPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiErrorMsg(c, "参数错误")
@@ -165,9 +130,6 @@ func AdminCreateSubscriptionPlan(c *gin.Context) {
req.Plan.Currency = "USD"
}
req.Plan.Currency = "USD"
if req.Plan.AllowBalancePay == nil {
req.Plan.AllowBalancePay = common.GetPointer(true)
}
if req.Plan.DurationUnit == "" {
req.Plan.DurationUnit = model.SubscriptionDurationMonth
}
@@ -204,10 +166,6 @@ func AdminCreateSubscriptionPlan(c *gin.Context) {
}
func AdminUpdateSubscriptionPlan(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
id, _ := strconv.Atoi(c.Param("id"))
if id <= 0 {
common.ApiErrorMsg(c, "无效的ID")
@@ -276,7 +234,6 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
"sort_order": req.Plan.SortOrder,
"stripe_price_id": req.Plan.StripePriceId,
"creem_product_id": req.Plan.CreemProductId,
"waffo_pancake_product_id": req.Plan.WaffoPancakeProductId,
"max_purchase_per_user": req.Plan.MaxPurchasePerUser,
"total_amount": req.Plan.TotalAmount,
"upgrade_group": req.Plan.UpgradeGroup,
@@ -284,9 +241,6 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
"quota_reset_custom_seconds": req.Plan.QuotaResetCustomSeconds,
"updated_at": common.GetTimestamp(),
}
if req.Plan.AllowBalancePay != nil {
updateMap["allow_balance_pay"] = *req.Plan.AllowBalancePay
}
if err := tx.Model(&model.SubscriptionPlan{}).Where("id = ?", id).Updates(updateMap).Error; err != nil {
return err
}
@@ -305,10 +259,6 @@ type AdminUpdateSubscriptionPlanStatusRequest struct {
}
func AdminUpdateSubscriptionPlanStatus(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
id, _ := strconv.Atoi(c.Param("id"))
if id <= 0 {
common.ApiErrorMsg(c, "无效的ID")
@@ -333,10 +283,6 @@ type AdminBindSubscriptionRequest struct {
}
func AdminBindSubscription(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
var req AdminBindSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil || req.UserId <= 0 || req.PlanId <= 0 {
common.ApiErrorMsg(c, "参数错误")
@@ -376,10 +322,6 @@ type AdminCreateUserSubscriptionRequest struct {
// AdminCreateUserSubscription creates a new user subscription from a plan (no payment).
func AdminCreateUserSubscription(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
userId, _ := strconv.Atoi(c.Param("id"))
if userId <= 0 {
common.ApiErrorMsg(c, "无效的用户ID")
-4
View File
@@ -21,10 +21,6 @@ type SubscriptionCreemPayRequest struct {
}
func SubscriptionRequestCreemPay(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
var req SubscriptionCreemPayRequest
// Keep body for debugging consistency (like RequestCreemPay)
-4
View File
@@ -22,10 +22,6 @@ type SubscriptionEpayPayRequest struct {
}
func SubscriptionRequestEpay(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
var req SubscriptionEpayPayRequest
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
common.ApiErrorMsg(c, "参数错误")
@@ -21,10 +21,6 @@ type SubscriptionStripePayRequest struct {
}
func SubscriptionRequestStripePay(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
var req SubscriptionStripePayRequest
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
common.ApiErrorMsg(c, "参数错误")
@@ -1,130 +0,0 @@
package controller
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"github.com/thanhpk/randstr"
)
type SubscriptionWaffoPancakePayRequest struct {
PlanId int `json:"plan_id"`
}
func SubscriptionRequestWaffoPancakePay(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
var req SubscriptionWaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
common.ApiErrorMsg(c, "参数错误")
return
}
plan, err := model.GetSubscriptionPlanById(req.PlanId)
if err != nil {
common.ApiError(c, err)
return
}
if !plan.Enabled {
common.ApiErrorMsg(c, "套餐未启用")
return
}
if strings.TrimSpace(plan.WaffoPancakeProductId) == "" {
common.ApiErrorMsg(c, "该套餐未配置 WaffoPancakeProductId")
return
}
// Plan targets its own Pancake product, so we only require credentials
// here — not the gateway-level WaffoPancakeProductID.
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" {
common.ApiErrorMsg(c, "Waffo Pancake 未配置或密钥无效")
return
}
userId := c.GetInt("id")
user, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
if user == nil {
common.ApiErrorMsg(c, "用户不存在")
return
}
if plan.MaxPurchasePerUser > 0 {
count, err := model.CountUserSubscriptionsByPlan(userId, plan.Id)
if err != nil {
common.ApiError(c, err)
return
}
if count >= int64(plan.MaxPurchasePerUser) {
common.ApiErrorMsg(c, "已达到该套餐购买上限")
return
}
}
// WAFFO_PANCAKE_SUB- prefix (vs. wallet's WAFFO_PANCAKE-) drives webhook
// dispatch in WaffoPancakeWebhook.
tradeNo := fmt.Sprintf("WAFFO_PANCAKE_SUB-%d-%d-%s", userId, time.Now().UnixMilli(), randstr.String(6))
order := &model.SubscriptionOrder{
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: tradeNo,
PaymentMethod: model.PaymentMethodWaffoPancake,
PaymentProvider: model.PaymentProviderWaffoPancake,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := order.Insert(); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
expiresInSeconds := 45 * 60
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
ProductID: plan.WaffoPancakeProductId,
BuyerIdentity: service.WaffoPancakeBuyerIdentityFromUserID(user.Id),
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
Amount: decimal.NewFromFloat(plan.PriceAmount).StringFixed(2),
TaxCategory: "saas",
},
BuyerEmail: getWaffoPancakeBuyerEmail(user),
ExpiresInSeconds: &expiresInSeconds,
OrderMerchantExternalID: tradeNo,
})
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅结账会话创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
order.Status = common.TopUpStatusFailed
_ = order.Update()
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建成功 user_id=%d plan_id=%d trade_no=%s session_id=%s money=%.2f", userId, plan.Id, tradeNo, session.SessionID, plan.PriceAmount))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
"token": session.Token,
"token_expires_at": session.TokenExpiresAt,
},
})
}
+3 -3
View File
@@ -96,13 +96,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
logger.LogDebug(ctx, "UpdateVideoSingleTask response: %s", responseBody)
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
logger.LogDebug(ctx, "UpdateVideoSingleTask parsed as new api response format: %+v", responseItems)
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)
@@ -116,7 +116,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.Data = redactVideoResponseBody(responseBody)
}
logger.LogDebug(ctx, "UpdateVideoSingleTask taskResult: %+v", taskResult)
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
now := time.Now().Unix()
if taskResult.Status == "" {
+25 -34
View File
@@ -22,13 +22,8 @@ import (
)
func GetTopUpInfo(c *gin.Context) {
complianceConfirmed := operation_setting.IsPaymentComplianceConfirmed()
// 获取支付方式
payMethods := operation_setting.PayMethods
if !complianceConfirmed {
payMethods = []map[string]string{}
}
// 如果启用了 Stripe 支付,添加到支付方法列表
if isStripeTopUpEnabled() {
@@ -52,27 +47,6 @@ func GetTopUpInfo(c *gin.Context) {
}
}
// Waffo Pancake displayed above the legacy Waffo gateway.
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
if enableWaffoPancake {
hasWaffoPancake := false
for _, method := range payMethods {
if method["type"] == model.PaymentMethodWaffoPancake {
hasWaffoPancake = true
break
}
}
if !hasWaffoPancake {
payMethods = append(payMethods, map[string]string{
"name": "Waffo Pancake",
"type": model.PaymentMethodWaffoPancake,
"color": "rgba(var(--semi-orange-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
})
}
}
// 如果启用了 Waffo 支付,添加到支付方法列表
enableWaffo := isWaffoTopUpEnabled()
if enableWaffo {
@@ -95,15 +69,32 @@ func GetTopUpInfo(c *gin.Context) {
}
}
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
if enableWaffoPancake {
hasWaffoPancake := false
for _, method := range payMethods {
if method["type"] == model.PaymentMethodWaffoPancake {
hasWaffoPancake = true
break
}
}
if !hasWaffoPancake {
payMethods = append(payMethods, map[string]string{
"name": "Waffo Pancake",
"type": model.PaymentMethodWaffoPancake,
"color": "rgba(var(--semi-orange-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
})
}
}
data := gin.H{
"enable_online_topup": isEpayTopUpEnabled(),
"enable_stripe_topup": isStripeTopUpEnabled(),
"enable_creem_topup": isCreemTopUpEnabled(),
"enable_waffo_topup": enableWaffo,
"enable_waffo_pancake_topup": enableWaffoPancake,
"enable_redemption": complianceConfirmed,
"payment_compliance_confirmed": complianceConfirmed,
"payment_compliance_terms_version": operation_setting.CurrentComplianceTermsVersion,
"enable_online_topup": isEpayTopUpEnabled(),
"enable_stripe_topup": isStripeTopUpEnabled(),
"enable_creem_topup": isCreemTopUpEnabled(),
"enable_waffo_topup": enableWaffo,
"enable_waffo_pancake_topup": enableWaffoPancake,
"waffo_pay_methods": func() interface{} {
if enableWaffo {
return setting.GetWaffoPayMethods()
+33 -311
View File
@@ -96,257 +96,33 @@ func getWaffoPancakeBuyerEmail(user *model.User) string {
if user != nil && strings.TrimSpace(user.Email) != "" {
return user.Email
}
if user != nil {
return fmt.Sprintf("%d@new-api.local", user.Id)
}
return ""
}
// The admin config endpoints below accept typed-but-not-yet-saved creds in
// the body and fall back to persisted creds when the body is blank (see
// resolveWaffoPancakeAdminCreds). Only SaveWaffoPancake writes to OptionMap.
type waffoPancakeCredsRequest struct {
MerchantID string `json:"merchant_id"`
PrivateKey string `json:"private_key"`
}
type saveWaffoPancakeRequest struct {
MerchantID string `json:"merchant_id"`
PrivateKey string `json:"private_key"`
ReturnURL string `json:"return_url"`
StoreID string `json:"store_id"`
ProductID string `json:"product_id"`
}
type createWaffoPancakePairRequest struct {
MerchantID string `json:"merchant_id"`
PrivateKey string `json:"private_key"`
ReturnURL string `json:"return_url"`
}
// SaveWaffoPancake atomically persists all five operator-controlled fields.
// Catalog / pair endpoints are transient — only this one writes the OptionMap.
func SaveWaffoPancake(c *gin.Context) {
var req saveWaffoPancakeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
func getWaffoPancakeReturnURL() string {
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
return setting.WaffoPancakeReturnURL
}
if err := service.SaveWaffoPancakeConfig(
c.Request.Context(),
req.MerchantID,
req.PrivateKey,
req.ReturnURL,
req.StoreID,
req.ProductID,
); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 保存配置失败 store_id=%q product_id=%q error=%q",
req.StoreID, req.ProductID, err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "保存配置失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"product_id": setting.WaffoPancakeProductID,
"store_id": setting.WaffoPancakeStoreID,
},
})
}
// resolveWaffoPancakeAdminCreds prefers body creds (typed-but-not-yet-saved
// values, for verification) and falls back to persisted creds when the body
// is blank (so returning admins don't have to re-paste the private key,
// which is stripped from GET /api/option/).
func resolveWaffoPancakeAdminCreds(bodyMerchantID, bodyPrivateKey string) (string, string) {
m := strings.TrimSpace(bodyMerchantID)
k := strings.TrimSpace(bodyPrivateKey)
if m == "" && k == "" {
return setting.WaffoPancakeMerchantID, setting.WaffoPancakePrivateKey
}
return m, k
}
// CreateWaffoPancakePair mints a Store + OnetimeProduct pair in one round-
// trip. Surfaces an orphan-store flag when the product half fails so the
// frontend can preselect / retry without losing context.
func CreateWaffoPancakePair(c *gin.Context) {
var req createWaffoPancakePairRequest
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
}
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
if merchantID == "" || privateKey == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
return
}
result, err := service.CreateWaffoPancakePrimaryPair(
c.Request.Context(), merchantID, privateKey, req.ReturnURL,
)
if err != nil {
orphan := result != nil && result.OrphanStore
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 创建店铺与产品失败 orphan_store=%t store_id=%q error=%q",
orphan, func() string {
if result == nil {
return ""
}
return result.StoreID
}(), err.Error(),
))
data := gin.H{"error": err.Error()}
if orphan {
data["store_id"] = result.StoreID
data["store_name"] = result.StoreName
data["orphan_store"] = true
}
c.JSON(http.StatusOK, gin.H{"message": "error", "data": data})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"store_id": result.StoreID,
"store_name": result.StoreName,
"product_id": result.ProductID,
"product_name": result.ProductName,
},
})
}
// ListWaffoPancakeCatalog returns the merchant's Stores + OnetimeProducts.
// Doubles as a credential probe (a successful 200 proves the resolved creds
// authenticate). See resolveWaffoPancakeAdminCreds for credential resolution.
func ListWaffoPancakeCatalog(c *gin.Context) {
var req waffoPancakeCredsRequest
// An empty body means "use persisted creds"; only fail on malformed JSON.
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
}
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
if merchantID == "" || privateKey == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
return
}
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 拉取店铺与产品目录失败 error=%q", err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取目录失败"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": catalog})
}
type createWaffoPancakeSubscriptionProductRequest struct {
Name string `json:"name"`
Amount string `json:"amount"`
}
// CreateWaffoPancakeSubscriptionProduct mints an OnetimeProduct (not
// SubscriptionProduct — see service.CreateWaffoPancakeProductForPlan)
// sized to a plan's `name` + `amount`, using persisted Pancake credentials
// + StoreID. Reads from the form, not the plan row, so newly-typed unsaved
// plans can mint a product too.
func CreateWaffoPancakeSubscriptionProduct(c *gin.Context) {
var req createWaffoPancakeSubscriptionProductRequest
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
}
if strings.TrimSpace(req.Name) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐名称不能为空"})
return
}
if strings.TrimSpace(req.Amount) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐价格不能为空"})
return
}
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
if merchantID == "" || privateKey == "" || storeID == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
return
}
productID, err := service.CreateWaffoPancakeProductForPlan(
c.Request.Context(),
merchantID,
privateKey,
storeID,
req.Name,
req.Amount,
setting.WaffoPancakeReturnURL,
)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 创建套餐产品失败 store_id=%q name=%q amount=%q error=%q",
storeID, req.Name, req.Amount, err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建套餐产品失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"product_id": productID,
"product_name": req.Name,
"store_id": storeID,
},
})
}
// ListWaffoPancakeSubscriptionProductOptions returns the OnetimeProducts
// in the saved Pancake store, for the subscription-plan dropdown. The name
// reflects new-api's plan concept; under the hood it's still OnetimeProducts.
func ListWaffoPancakeSubscriptionProductOptions(c *gin.Context) {
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
if merchantID == "" || privateKey == "" || storeID == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
return
}
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake 拉取订阅产品列表失败 store_id=%q error=%q", storeID, err.Error(),
))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取产品列表失败"})
return
}
products := []service.WaffoPancakeCatalogProduct{}
for _, store := range catalog.Stores {
if store.ID == storeID {
products = store.OnetimeProducts
break
}
}
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"store_id": storeID,
"products": products,
},
})
}
func getWaffoPancakeBuyerIdentity(user *model.User) string {
if user == nil {
return ""
}
return service.WaffoPancakeBuyerIdentityFromUserID(user.Id)
return paymentReturnPath("/console/topup?show_history=true")
}
func RequestWaffoPancakePay(c *gin.Context) {
if !isWaffoPancakeTopUpEnabled() {
if !setting.WaffoPancakeEnabled {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
return
}
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
if setting.WaffoPancakeSandbox {
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
}
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
strings.TrimSpace(currentWebhookKey) == "" ||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
return
}
@@ -399,15 +175,18 @@ func RequestWaffoPancakePay(c *gin.Context) {
expiresInSeconds := 45 * 60
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
ProductID: setting.WaffoPancakeProductID,
BuyerIdentity: getWaffoPancakeBuyerIdentity(user),
StoreID: setting.WaffoPancakeStoreID,
ProductID: setting.WaffoPancakeProductID,
ProductType: "onetime",
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
Amount: formatWaffoPancakeAmount(payMoney),
TaxIncluded: false,
TaxCategory: "saas",
},
BuyerEmail: getWaffoPancakeBuyerEmail(user),
ExpiresInSeconds: &expiresInSeconds,
OrderMerchantExternalID: tradeNo,
BuyerEmail: getWaffoPancakeBuyerEmail(user),
SuccessURL: getWaffoPancakeReturnURL(),
ExpiresInSeconds: &expiresInSeconds,
})
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
@@ -421,12 +200,10 @@ func RequestWaffoPancakePay(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
"token": session.Token,
"token_expires_at": session.TokenExpiresAt,
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
},
})
}
@@ -438,19 +215,6 @@ func WaffoPancakeWebhook(c *gin.Context) {
return
}
// :env splits test vs prod traffic at the routing layer — operator
// registers each URL in the matching webhook slot in Pancake's dashboard.
// We then enforce event.mode == expectedEnv to catch mis-registrations.
expectedEnv := strings.TrimSpace(c.Param("env"))
if expectedEnv != "test" && expectedEnv != "prod" {
logger.LogWarn(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 路径环境段无效 env=%q path=%q client_ip=%s",
expectedEnv, c.Request.RequestURI, c.ClientIP(),
))
c.String(http.StatusNotFound, "unknown env")
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
@@ -468,57 +232,15 @@ func WaffoPancakeWebhook(c *gin.Context) {
return
}
if !strings.EqualFold(strings.TrimSpace(event.Mode), expectedEnv) {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 环境不匹配 expected=%q actual_mode=%q event_id=%s order_id=%s client_ip=%s",
expectedEnv, event.Mode, event.ID, event.Data.OrderID, c.ClientIP(),
))
c.String(http.StatusOK, "OK")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
if event.NormalizedEventType() != "order.completed" {
c.String(http.StatusOK, "OK")
return
}
// Dispatch by trade_no prefix. OrderMerchantExternalID = our trade_no;
// OrderID is Pancake's internal ORD_* (logs only).
rawTradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
isSubscription := strings.HasPrefix(rawTradeNo, "WAFFO_PANCAKE_SUB-")
if isSubscription {
tradeNo, err := service.ResolveWaffoPancakeSubscriptionTradeNo(event)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 订阅订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
))
c.String(http.StatusOK, "OK")
return
}
LockOrder(tradeNo)
defer UnlockOrder(tradeNo)
if err := model.CompleteSubscriptionOrder(tradeNo, string(bodyBytes), model.PaymentProviderWaffoPancake, ""); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
c.String(http.StatusInternalServerError, "retry")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
c.String(http.StatusOK, "OK")
return
}
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
if err != nil {
// LogError (not LogWarn): covers order-not-found and buyer-identity
// mismatch — both warrant human attention. 200 OK so Waffo doesn't
// retry a permanently-unresolvable webhook.
logger.LogError(c.Request.Context(), fmt.Sprintf(
"Waffo Pancake webhook 订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
))
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
c.String(http.StatusOK, "OK")
return
}
+1 -1
View File
@@ -520,7 +520,7 @@ func AdminDisable2FA(c *gin.Context) {
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, targetUser.Role) {
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权操作同级或更高级用户的2FA设置",
+10 -38
View File
@@ -17,7 +17,6 @@ import (
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/constant"
@@ -251,20 +250,8 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
var role *int
if roleStr := c.Query("role"); roleStr != "" {
if parsed, err := strconv.Atoi(roleStr); err == nil {
role = &parsed
}
}
var status *int
if statusStr := c.Query("status"); statusStr != "" {
if parsed, err := strconv.Atoi(statusStr); err == nil {
status = &parsed
}
}
pageInfo := common.GetPageQuery(c)
users, total, err := model.SearchUsers(keyword, group, role, status, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
@@ -276,10 +263,6 @@ func SearchUsers(c *gin.Context) {
return
}
func canManageTargetRole(myRole int, targetRole int) bool {
return myRole == common.RoleRootUser || myRole > targetRole
}
func GetUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -292,7 +275,7 @@ func GetUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, user.Role) {
if myRole <= user.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel)
return
}
@@ -344,10 +327,6 @@ type TransferAffQuotaRequest struct {
}
func TransferAffQuota(c *gin.Context) {
if !requirePaymentCompliance(c) {
return
}
id := c.GetInt("id")
user, err := model.GetUserById(id, true)
if err != nil {
@@ -583,11 +562,11 @@ func UpdateUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, originUser.Role) {
if myRole <= originUser.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionHigherLevel)
return
}
if !canManageTargetRole(myRole, updatedUser.Role) {
if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserCannotCreateHigherLevel)
return
}
@@ -626,7 +605,7 @@ func AdminClearUserBinding(c *gin.Context) {
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, user.Role) {
if myRole <= user.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel)
return
}
@@ -794,14 +773,12 @@ func DeleteUser(c *gin.Context) {
}
err = model.HardDeleteUserById(id)
if err != nil {
common.ApiError(c, err)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteSelf(c *gin.Context) {
@@ -890,7 +867,7 @@ func ManageUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
if !canManageTargetRole(myRole, user.Role) {
if myRole <= user.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionHigherLevel)
return
}
@@ -1104,11 +1081,6 @@ func getTopUpLock(userID int) *topUpTryLock {
}
func TopUp(c *gin.Context) {
if !operation_setting.IsPaymentComplianceConfirmed() {
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
return
}
id := c.GetInt("id")
lock := getTopUpLock(id)
if !lock.TryLock() {
-1
View File
@@ -34,7 +34,6 @@ services:
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
- NODE_NAME=new-api-node-1 # 节点名称,用于审计日志中标识节点身份;多节点/容器部署时建议设置 (Node name used in audit logs; recommended when running multiple instances or in containers)
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions
# - RELAY_IDLE_CONN_TIMEOUT=90 # Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制 (Relay HTTP client idle keep-alive timeout in seconds, defaults to Go standard library; set 0 to disable)
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!!
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
# - GOOGLE_ANALYTICS_ID=G-XXXXXXXXXX # Google Analytics 的测量 ID (Google Analytics Measurement ID)
+2 -12
View File
@@ -213,22 +213,12 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
return result
}
func IsOpenAIReasoningOModel(modelName string) bool {
return strings.HasPrefix(modelName, "o1") ||
strings.HasPrefix(modelName, "o3") ||
strings.HasPrefix(modelName, "o4")
}
func IsOpenAIGPT5Model(modelName string) bool {
return strings.HasPrefix(modelName, "gpt-5")
}
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
if IsOpenAIReasoningOModel(r.Model) {
if strings.HasPrefix(r.Model, "o") {
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
return "developer"
}
} else if IsOpenAIGPT5Model(r.Model) {
} else if strings.HasPrefix(r.Model, "gpt-5") {
return "developer"
}
return "system"
-24
View File
@@ -71,27 +71,3 @@ func TestOpenAIResponsesRequestPreserveExplicitZeroValues(t *testing.T) {
require.True(t, gjson.GetBytes(encoded, "stream").Exists())
require.True(t, gjson.GetBytes(encoded, "top_p").Exists())
}
func TestGeneralOpenAIRequestGetSystemRoleName(t *testing.T) {
tests := []struct {
name string
model string
want string
}{
{name: "o1 uses developer", model: "o1", want: "developer"},
{name: "o3 family uses developer", model: "o3-mini-high", want: "developer"},
{name: "o4 family uses developer", model: "o4-mini", want: "developer"},
{name: "o1 mini stays system", model: "o1-mini", want: "system"},
{name: "o1 preview stays system", model: "o1-preview", want: "system"},
{name: "gpt 5 uses developer", model: "gpt-5", want: "developer"},
{name: "omni is not o series", model: "omni-moderation-latest", want: "system"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := GeneralOpenAIRequest{Model: tt.model}
require.Equal(t, tt.want, req.GetSystemRoleName())
})
}
}
Generated Vendored
+3 -3
View File
@@ -3097,9 +3097,9 @@
"license": "ISC"
},
"node_modules/ip-address": {
"version": "10.2.0",
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.2.0.tgz",
"integrity": "sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==",
"version": "10.1.0",
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
"dev": true,
"license": "MIT",
"engines": {
-2
View File
@@ -60,8 +60,6 @@ require (
gorm.io/gorm v1.25.2
)
require github.com/waffo-com/waffo-pancake-sdk-go v0.3.1
require (
github.com/DmitriyVTitov/size v1.5.0 // indirect
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
-6
View File
@@ -308,12 +308,6 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/waffo-com/waffo-go v1.3.1 h1:NCYD3oQ59DTJj1bwS5T/659LI4h8PuAIW4Qj/w7fKPw=
github.com/waffo-com/waffo-go v1.3.1/go.mod h1:IaXVYq6mmYtrLFFsLxPslNwuIZx0mIadWWjhe+eWb0g=
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1 h1:YOI7+3zTBlTB7Ou6+ZXnJV2JvW/ag9d7CwE/TxH3Hls=
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0 h1:cCSgccM66p7feTtgRqUUGT50tYQOhahsoPXavd+ib1U=
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1 h1:ngQSN/oVB35xTwFPLfg++bxPC+SptcF145Mb6c62YCc=
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1/go.mod h1:OB2MyFIQaefoPO0FV3J+yu9sDP8RVFQ+sbFsXqGuObc=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
+10 -11
View File
@@ -142,17 +142,16 @@ const (
// Payment related messages
const (
MsgPaymentNotConfigured = "payment.not_configured"
MsgPaymentMethodNotExists = "payment.method_not_exists"
MsgPaymentCallbackError = "payment.callback_error"
MsgPaymentCreateFailed = "payment.create_failed"
MsgPaymentStartFailed = "payment.start_failed"
MsgPaymentAmountTooLow = "payment.amount_too_low"
MsgPaymentStripeNotConfig = "payment.stripe_not_configured"
MsgPaymentWebhookNotConfig = "payment.webhook_not_configured"
MsgPaymentPriceIdNotConfig = "payment.price_id_not_configured"
MsgPaymentCreemNotConfig = "payment.creem_not_configured"
MsgPaymentComplianceRequired = "payment.compliance_required"
MsgPaymentNotConfigured = "payment.not_configured"
MsgPaymentMethodNotExists = "payment.method_not_exists"
MsgPaymentCallbackError = "payment.callback_error"
MsgPaymentCreateFailed = "payment.create_failed"
MsgPaymentStartFailed = "payment.start_failed"
MsgPaymentAmountTooLow = "payment.amount_too_low"
MsgPaymentStripeNotConfig = "payment.stripe_not_configured"
MsgPaymentWebhookNotConfig = "payment.webhook_not_configured"
MsgPaymentPriceIdNotConfig = "payment.price_id_not_configured"
MsgPaymentCreemNotConfig = "payment.creem_not_configured"
)
// Topup related messages
-1
View File
@@ -134,7 +134,6 @@ payment.stripe_not_configured: "Stripe is not configured or key is invalid"
payment.webhook_not_configured: "Webhook is not configured"
payment.price_id_not_configured: "StripePriceId is not configured for this plan"
payment.creem_not_configured: "CreemProductId is not configured for this plan"
payment.compliance_required: "Payment, redemption, subscription, and invitation reward features are disabled. The administrator must confirm compliance terms before enabling them."
# Topup messages
topup.not_provided: "Payment order number not provided"
-1
View File
@@ -135,7 +135,6 @@ payment.stripe_not_configured: "Stripe 未配置或密钥无效"
payment.webhook_not_configured: "Webhook 未配置"
payment.price_id_not_configured: "该套餐未配置 StripePriceId"
payment.creem_not_configured: "该套餐未配置 CreemProductId"
payment.compliance_required: "支付、兑换码、订阅计划和邀请返利功能已禁用。管理员需先确认合规声明后方可启用。"
# Topup messages
topup.not_provided: "未提供支付单号"
-1
View File
@@ -135,7 +135,6 @@ payment.stripe_not_configured: "Stripe 未設定或密鑰無效"
payment.webhook_not_configured: "Webhook 未設定"
payment.price_id_not_configured: "該訂閱方案未設定 StripePriceId"
payment.creem_not_configured: "該訂閱方案未設定 CreemProductId"
payment.compliance_required: "支付、兌換碼、訂閱方案和邀請返利功能已停用。管理員需先確認合規聲明後方可啟用。"
# Topup messages
topup.not_provided: "未提供支付單號"
+4 -9
View File
@@ -95,11 +95,9 @@ func LogDebug(ctx context.Context, msg string, args ...any) {
}
func logHelper(ctx context.Context, level string, msg string) {
var id any = "SYSTEM"
if ctx != nil {
if requestID := ctx.Value(common.RequestIdKey); requestID != nil {
id = requestID
}
id := ctx.Value(common.RequestIdKey)
if id == nil {
id = "SYSTEM"
}
now := time.Now()
common.LogWriterMu.RLock()
@@ -174,13 +172,10 @@ func FormatQuota(quota int) string {
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
if !common.DebugEnabled {
return
}
jsonStr, err := common.Marshal(obj)
if err != nil {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
return
}
LogDebug(ctx, "%s | %s", msg, jsonStr)
LogDebug(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
}
+7 -68
View File
@@ -1,28 +1,18 @@
FRONTEND_DIR = ./web/default
FRONTEND_CLASSIC_DIR = ./web/classic
BACKEND_DIR = .
DEV_FRONTEND_DEFAULT_PORT ?= 5173
DEV_FRONTEND_CLASSIC_PORT ?= 5174
DEV_COMPOSE_FILE = docker-compose.dev.yml
DEV_POSTGRES_SERVICE = postgres
DEV_BACKEND_SERVICE = new-api
DEV_POSTGRES_DB = new-api
DEV_POSTGRES_USER = root
DEV_SQLITE_PATH ?= one-api.db
.PHONY: all build-frontend build-frontend-classic build-all-frontends start-backend dev dev-api dev-api-rebuild dev-web dev-web-classic reset-setup
.PHONY: all build-frontend build-frontend-classic build-all-frontends start-backend dev dev-api dev-web dev-web-classic
all: build-all-frontends start-backend
build-frontend:
@echo "Building default frontend..."
@cd ./web && bun install --frozen-lockfile
@cd $(FRONTEND_DIR) && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
build-frontend-classic:
@echo "Building classic frontend..."
@cd ./web && bun install --frozen-lockfile
@cd $(FRONTEND_CLASSIC_DIR) && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
@cd $(FRONTEND_CLASSIC_DIR) && bun install && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
build-all-frontends: build-frontend build-frontend-classic
@@ -32,65 +22,14 @@ start-backend:
dev-api:
@echo "Starting backend services (docker)..."
@docker compose -f $(DEV_COMPOSE_FILE) up -d
dev-api-rebuild:
@echo "Rebuilding and starting backend service (docker)..."
@docker compose -f $(DEV_COMPOSE_FILE) up -d --build $(DEV_BACKEND_SERVICE)
@docker compose -f docker-compose.dev.yml up -d
dev-web:
@echo "Starting both frontend dev servers..."
@echo "Default frontend: http://localhost:$(DEV_FRONTEND_DEFAULT_PORT)"
@echo "Classic frontend: http://localhost:$(DEV_FRONTEND_CLASSIC_PORT)"
@cd ./web && bun install
@(cd $(FRONTEND_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_DEFAULT_PORT)) & \
default_pid=$$!; \
(cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)) & \
classic_pid=$$!; \
trap 'kill $$default_pid $$classic_pid 2>/dev/null; wait $$default_pid $$classic_pid 2>/dev/null; exit 130' INT TERM; \
while kill -0 $$default_pid 2>/dev/null && kill -0 $$classic_pid 2>/dev/null; do \
sleep 1; \
done; \
if ! kill -0 $$default_pid 2>/dev/null; then \
wait $$default_pid; \
status=$$?; \
kill $$classic_pid 2>/dev/null; \
wait $$classic_pid 2>/dev/null; \
exit $$status; \
fi; \
wait $$classic_pid; \
status=$$?; \
kill $$default_pid 2>/dev/null; \
wait $$default_pid 2>/dev/null; \
exit $$status
@echo "Starting frontend dev server..."
@cd $(FRONTEND_DIR) && bun install && bun run dev
dev-web-classic:
@echo "Starting classic frontend dev server..."
@cd ./web && bun install
@cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)
@cd $(FRONTEND_CLASSIC_DIR) && bun install && bun run dev
dev: dev-api dev-web
reset-setup:
@echo "Resetting local setup wizard state..."
@if docker compose -f $(DEV_COMPOSE_FILE) ps --services --status running | grep -qx "$(DEV_POSTGRES_SERVICE)"; then \
echo "Detected running docker dev PostgreSQL. Removing setup record and root users..."; \
docker compose -f $(DEV_COMPOSE_FILE) exec -T $(DEV_POSTGRES_SERVICE) \
psql -U $(DEV_POSTGRES_USER) -d $(DEV_POSTGRES_DB) \
-c 'DELETE FROM setups;' \
-c 'DELETE FROM users WHERE role = 100;' \
-c "DELETE FROM options WHERE key IN ('SelfUseModeEnabled', 'DemoSiteEnabled');"; \
echo "Restarting docker dev backend so setup status is recalculated..."; \
docker compose -f $(DEV_COMPOSE_FILE) restart $(DEV_BACKEND_SERVICE); \
elif db_path="$${SQLITE_PATH:-$(DEV_SQLITE_PATH)}"; db_path="$${db_path%%\?*}"; [ -f "$$db_path" ]; then \
db_path="$${SQLITE_PATH:-$(DEV_SQLITE_PATH)}"; \
db_path="$${db_path%%\?*}"; \
echo "Detected local SQLite database: $$db_path"; \
sqlite3 "$$db_path" \
"DELETE FROM setups; DELETE FROM users WHERE role = 100; DELETE FROM options WHERE key IN ('SelfUseModeEnabled', 'DemoSiteEnabled');"; \
echo "SQLite setup state reset. Restart the local backend process before testing the setup wizard."; \
else \
echo "No running docker dev PostgreSQL or local SQLite database found."; \
echo "Start the dev stack with 'make dev-api', or set SQLITE_PATH/DEV_SQLITE_PATH to your local SQLite database."; \
exit 1; \
fi
+7 -89
View File
@@ -3,7 +3,6 @@ package middleware
import (
"errors"
"fmt"
"io"
"net/http"
"slices"
"strconv"
@@ -21,7 +20,6 @@ import (
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type ModelRequest struct {
@@ -102,10 +100,14 @@ func Distribute() func(c *gin.Context) {
}
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
affinityUsable := false
preferred, err := model.CacheGetChannel(preferredChannelID)
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
if usingGroup == "auto" {
if err == nil && preferred != nil {
if preferred.Status != common.ChannelStatusEnabled {
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorAffinityChannelDisabled))
return
}
} else if usingGroup == "auto" {
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
autoGroups := service.GetUserAutoGroup(userGroup)
for _, g := range autoGroups {
@@ -113,7 +115,6 @@ func Distribute() func(c *gin.Context) {
selectGroup = g
common.SetContextKey(c, constant.ContextKeyAutoGroup, g)
channel = preferred
affinityUsable = true
service.MarkChannelAffinityUsed(c, g, preferred.Id)
break
}
@@ -121,13 +122,9 @@ func Distribute() func(c *gin.Context) {
} else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) {
channel = preferred
selectGroup = usingGroup
affinityUsable = true
service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id)
}
}
if !affinityUsable && !service.ShouldKeepChannelAffinityOnChannelDisabled() {
service.ClearCurrentChannelAffinityCache(c)
}
}
if channel == nil {
@@ -173,14 +170,6 @@ func Distribute() func(c *gin.Context) {
// - application/x-www-form-urlencoded
// - multipart/form-data
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") {
modelRequest, err := getModelFromJSONBody(c)
if err != nil {
return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()}))
}
return modelRequest, nil
}
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
@@ -189,50 +178,6 @@ func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
return &modelRequest, nil
}
func getModelFromJSONBody(c *gin.Context) (*ModelRequest, error) {
storage, err := common.GetBodyStorage(c)
if err != nil {
return nil, err
}
requestBody, err := storage.Bytes()
if err != nil {
return nil, err
}
if !gjson.ValidBytes(requestBody) {
return nil, errors.New("invalid JSON request body")
}
values := gjson.GetManyBytes(requestBody, "model", "group")
model, err := getJSONStringValue(values[0], "model")
if err != nil {
return nil, err
}
group, err := getJSONStringValue(values[1], "group")
if err != nil {
return nil, err
}
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return nil, seekErr
}
c.Request.Body = io.NopCloser(storage)
return &ModelRequest{
Model: model,
Group: group,
}, nil
}
func getJSONStringValue(result gjson.Result, field string) (string, error) {
if !result.Exists() || result.Type == gjson.Null {
return "", nil
}
if result.Type != gjson.String {
return "", fmt.Errorf("field %s must be a string", field)
}
return result.String(), nil
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
@@ -299,7 +244,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
shouldSelectChannel = false
modelRequest.Model = getTaskOriginModelName(c)
}
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
@@ -314,7 +258,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID
shouldSelectChannel = false
modelRequest.Model = getTaskOriginModelName(c)
}
if _, ok := c.Get("relay_mode"); !ok {
c.Set("relay_mode", relayMode)
@@ -399,31 +342,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
return &modelRequest, shouldSelectChannel, nil
}
// 修复 #4834: GET /v1/video/generations/:task_id && /v1/video/:task_id 此前不解析 model
// 当 token 启用「可用模型限制」时,下游 modelLimitEnable 校验会因
// modelRequest.Model 为空而误报 "This token has no access to model"。
// 从已存储的任务记录中回填 OriginModelName 即可让校验走在正确的模型上。
func getTaskOriginModelName(c *gin.Context) string {
if !common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) {
return ""
}
taskId := c.Param("task_id")
if taskId == "" {
// jimeng adapter
taskId = c.GetString("task_id")
}
if taskId == "" {
return ""
}
userId := c.GetInt("id")
if task, exist, err := model.GetByTaskId(userId, taskId); err == nil && exist && task != nil {
return task.Properties.OriginModelName
}
return ""
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
c.Set("original_model", modelName) // for retry
if channel == nil {
-135
View File
@@ -1,135 +0,0 @@
package middleware
import (
"fmt"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/gin-gonic/gin"
)
type headerNavAccess struct {
Enabled bool
RequireAuth bool
}
func getHeaderNavAccess(module string) headerNavAccess {
fallback := headerNavAccess{
Enabled: true,
RequireAuth: false,
}
common.OptionMapRWMutex.RLock()
raw := common.OptionMap["HeaderNavModules"]
common.OptionMapRWMutex.RUnlock()
if strings.TrimSpace(raw) == "" {
return fallback
}
var parsed map[string]any
if err := common.Unmarshal([]byte(raw), &parsed); err != nil {
return fallback
}
return parseHeaderNavAccess(parsed[module], fallback)
}
func parseHeaderNavAccess(raw any, fallback headerNavAccess) headerNavAccess {
switch value := raw.(type) {
case bool:
return headerNavAccess{
Enabled: value,
RequireAuth: fallback.RequireAuth,
}
case string:
return headerNavAccess{
Enabled: parseHeaderNavBool(value, fallback.Enabled),
RequireAuth: fallback.RequireAuth,
}
case float64:
return headerNavAccess{
Enabled: parseHeaderNavBool(value, fallback.Enabled),
RequireAuth: fallback.RequireAuth,
}
case map[string]any:
access := fallback
if enabled, ok := value["enabled"]; ok {
access.Enabled = parseHeaderNavBool(enabled, fallback.Enabled)
}
if requireAuth, ok := value["requireAuth"]; ok {
access.RequireAuth = parseHeaderNavBool(requireAuth, fallback.RequireAuth)
}
return access
default:
return fallback
}
}
func parseHeaderNavBool(value any, fallback bool) bool {
switch v := value.(type) {
case bool:
return v
case string:
switch strings.ToLower(strings.TrimSpace(v)) {
case "true", "1":
return true
case "false", "0":
return false
default:
return fallback
}
case float64:
if v == 1 {
return true
}
if v == 0 {
return false
}
return fallback
case int:
if v == 1 {
return true
}
if v == 0 {
return false
}
return fallback
default:
return fallback
}
}
func HeaderNavModuleAuth(module string) gin.HandlerFunc {
return func(c *gin.Context) {
access := getHeaderNavAccess(module)
if !access.Enabled {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": fmt.Sprintf("%s is disabled", module),
})
c.Abort()
return
}
if access.RequireAuth {
UserAuth()(c)
return
}
TryUserAuth()(c)
}
}
func HeaderNavModulePublicOrUserAuth(module string) gin.HandlerFunc {
return func(c *gin.Context) {
access := getHeaderNavAccess(module)
if !access.Enabled || access.RequireAuth {
UserAuth()(c)
return
}
TryUserAuth()(c)
}
}
-167
View File
@@ -1,167 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func withHeaderNavModules(t *testing.T, raw string) {
t.Helper()
common.OptionMapRWMutex.Lock()
if common.OptionMap == nil {
common.OptionMap = map[string]string{}
}
previous, hadPrevious := common.OptionMap["HeaderNavModules"]
common.OptionMap["HeaderNavModules"] = raw
common.OptionMapRWMutex.Unlock()
t.Cleanup(func() {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
if hadPrevious {
common.OptionMap["HeaderNavModules"] = previous
return
}
delete(common.OptionMap, "HeaderNavModules")
})
}
func performHeaderNavRequest(t *testing.T, handler gin.HandlerFunc, authenticated bool) *httptest.ResponseRecorder {
t.Helper()
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(sessions.Sessions("session", cookie.NewStore([]byte("header-nav-test"))))
router.GET("/login", func(c *gin.Context) {
session := sessions.Default(c)
session.Set("username", "tester")
session.Set("role", common.RoleCommonUser)
session.Set("id", 1)
session.Set("status", common.UserStatusEnabled)
session.Set("group", "default")
if err := session.Save(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"success": false})
return
}
c.Status(http.StatusNoContent)
})
router.GET("/api/test", handler, func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"success": true})
})
var cookies []*http.Cookie
if authenticated {
loginRecorder := httptest.NewRecorder()
loginRequest := httptest.NewRequest(http.MethodGet, "/login", nil)
router.ServeHTTP(loginRecorder, loginRequest)
require.Equal(t, http.StatusNoContent, loginRecorder.Code)
cookies = loginRecorder.Result().Cookies()
}
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/test", nil)
if authenticated {
request.Header.Set("New-Api-User", "1")
for _, cookie := range cookies {
request.AddCookie(cookie)
}
}
router.ServeHTTP(recorder, request)
return recorder
}
func TestHeaderNavModuleAuthAllowsDefaultPublicAccess(t *testing.T) {
withHeaderNavModules(t, "")
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false)
require.Equal(t, http.StatusOK, recorder.Code)
}
func TestHeaderNavModuleAuthRejectsDisabledPricing(t *testing.T) {
raw := `{"pricing":{"enabled":false,"requireAuth":false}}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false)
require.Equal(t, http.StatusForbidden, recorder.Code)
}
func TestHeaderNavModuleAuthRequiresLoginForPricing(t *testing.T) {
raw := `{"pricing":{"enabled":true,"requireAuth":true}}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestHeaderNavModuleAuthRequiresLoginForRankings(t *testing.T) {
raw := `{"rankings":{"enabled":true,"requireAuth":true}}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("rankings"), false)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestHeaderNavModuleAuthRejectsLegacyDisabledModule(t *testing.T) {
raw := `{"rankings":false}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("rankings"), false)
require.Equal(t, http.StatusForbidden, recorder.Code)
}
func TestHeaderNavModulePublicOrUserAuthAllowsDefaultPublicAccess(t *testing.T) {
withHeaderNavModules(t, "")
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
require.Equal(t, http.StatusOK, recorder.Code)
}
func TestHeaderNavModulePublicOrUserAuthRequiresLoginWhenDisabled(t *testing.T) {
raw := `{"pricing":{"enabled":false,"requireAuth":false}}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestHeaderNavModulePublicOrUserAuthAllowsLoggedInWhenDisabled(t *testing.T) {
raw := `{"pricing":{"enabled":false,"requireAuth":false}}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), true)
require.Equal(t, http.StatusOK, recorder.Code)
}
func TestHeaderNavModulePublicOrUserAuthRequiresLoginWhenRequireAuth(t *testing.T) {
raw := `{"pricing":{"enabled":true,"requireAuth":true}}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestHeaderNavModulePublicOrUserAuthRequiresLoginForLegacyDisabledModule(t *testing.T) {
raw := `{"pricing":false}`
withHeaderNavModules(t, raw)
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
-47
View File
@@ -1,47 +0,0 @@
package middleware
import (
"bytes"
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/gin-gonic/gin"
)
func AnonymousRequestBodyLimit() gin.HandlerFunc {
return func(c *gin.Context) {
maxBytes := common.GetAnonymousRequestBodyLimitBytes()
if maxBytes <= 0 || c.Request.Body == nil {
c.Next()
return
}
originalBody := c.Request.Body
limitedBody, err := readAnonymousRequestBody(originalBody, maxBytes)
_ = originalBody.Close()
if err != nil {
if common.IsRequestBodyTooLargeError(err) {
c.AbortWithStatus(http.StatusRequestEntityTooLarge)
return
}
c.AbortWithStatus(http.StatusBadRequest)
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(limitedBody))
c.Request.ContentLength = int64(len(limitedBody))
c.Next()
}
}
func readAnonymousRequestBody(body io.Reader, maxBytes int64) ([]byte, error) {
data, err := io.ReadAll(io.LimitReader(body, maxBytes+1))
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, common.ErrRequestBodyTooLarge
}
return data, nil
}
+40 -91
View File
@@ -12,7 +12,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/types"
"github.com/samber/lo"
@@ -129,38 +128,6 @@ func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) Ch
return options
}
func NormalizeChannelGroupFilter(group string) string {
group = strings.TrimSpace(group)
if group == "" || strings.EqualFold(group, "all") || strings.EqualFold(group, "null") {
return ""
}
return group
}
func channelGroupFilterCondition() string {
if common.UsingMySQL {
return `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ? ESCAPE '!'`
}
return `(',' || ` + commonGroupCol + ` || ',') LIKE ? ESCAPE '!'`
}
func channelGroupFilterPattern(group string) string {
group = strings.NewReplacer(
"!", "!!",
"%", "!%",
"_", "!_",
).Replace(group)
return "%," + group + ",%"
}
func ApplyChannelGroupFilter(query *gorm.DB, group string) *gorm.DB {
group = NormalizeChannelGroupFilter(group)
if group == "" {
return query
}
return query.Where(channelGroupFilterCondition(), channelGroupFilterPattern(group))
}
// Value implements driver.Valuer interface
func (c ChannelInfo) Value() (driver.Value, error) {
return common.Marshal(&c)
@@ -251,9 +218,10 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
if err != nil {
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
defer func() {
if common.DebugEnabled {
logger.LogDebug(nil, "channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex)
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
}
if !common.MemoryCacheEnabled {
_ = channel.SaveChannelInfo()
@@ -397,12 +365,25 @@ func SearchChannels(keyword string, group string, model string, idSort bool, sor
baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
var whereClause string
var args []interface{}
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
// 执行查询
err := order.Apply(baseQuery).Find(&channels).Error
err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error
if err != nil {
return nil, err
}
@@ -643,25 +624,13 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason
if len(keys) == 0 {
channel.Status = status
} else {
keyIndex := -1
var keyIndex int
for i, key := range keys {
if key == usingKey {
keyIndex = i
break
}
}
if keyIndex < 0 {
if usingKey != "" {
common.SysLog(fmt.Sprintf("failed to update multi-key status: channel_id=%d, using key not found", channel.Id))
return
}
channel.Status = status
info := channel.GetOtherInfo()
info["status_reason"] = reason
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
return
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
@@ -678,31 +647,16 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
}
if !hasEnabledMultiKey(keys, channel.ChannelInfo.MultiKeyStatusList) {
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
channel.Status = common.ChannelStatusAutoDisabled
info := channel.GetOtherInfo()
info["status_reason"] = "All keys are disabled"
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
} else if status == common.ChannelStatusEnabled {
channel.Status = common.ChannelStatusEnabled
}
}
}
func hasEnabledMultiKey(keys []string, statusList map[int]int) bool {
for i := range keys {
if statusList == nil {
return true
}
status, ok := statusList[i]
if !ok || status == common.ChannelStatusEnabled {
return true
}
}
return false
}
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
if common.MemoryCacheEnabled {
channelStatusLock.Lock()
@@ -714,15 +668,11 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
}
if channelCache.ChannelInfo.IsMultiKey {
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
beforeStatus := channelCache.Status
pollingLock := GetChannelPollingLock(channelId)
pollingLock.Lock()
// 如果是多Key模式,更新缓存中的状态
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
pollingLock.Unlock()
if beforeStatus != channelCache.Status {
CacheUpdateChannelStatus(channelId, channelCache.Status)
}
//CacheUpdateChannel(channelCache)
//return true
} else {
@@ -878,18 +828,8 @@ func DeleteDisabledChannel() (int64, error) {
}
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
return GetPaginatedChannelTags(DB.Model(&Channel{}), offset, limit)
}
func GetPaginatedChannelTags(query *gorm.DB, offset int, limit int) ([]*string, error) {
var tags []*string
err := query.
Select("DISTINCT tag").
Where("tag is not null AND tag != ''").
Order(clause.OrderByColumn{Column: clause.Column{Name: "tag"}}).
Offset(offset).
Limit(limit).
Find(&tags).Error
err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
return tags, err
}
@@ -917,11 +857,24 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
var whereClause string
var args []interface{}
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
subQuery := baseQuery.
subQuery := baseQuery.Where(whereClause, args...).
Select("tag").
Where("tag != ''").
Order(order)
@@ -1062,12 +1015,8 @@ func CountAllChannels() (int64, error) {
// CountAllTags returns number of non-empty distinct tags
func CountAllTags() (int64, error) {
return CountChannelTags(DB.Model(&Channel{}))
}
func CountChannelTags(query *gorm.DB) (int64, error) {
var total int64
err := query.Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
return total, err
}
+4 -8
View File
@@ -11,7 +11,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/setting/ratio_setting"
)
@@ -258,12 +257,9 @@ func CacheUpdateChannel(channel *Channel) {
return
}
if channelsIDM == nil {
channelsIDM = make(map[int]*Channel)
}
if oldChannel, ok := channelsIDM[channel.Id]; ok {
logger.LogDebug(nil, "CacheUpdateChannel before: id=%d, name=%s, status=%d, polling_index=%d", channel.Id, channel.Name, channel.Status, oldChannel.ChannelInfo.MultiKeyPollingIndex)
}
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
channelsIDM[channel.Id] = channel
logger.LogDebug(nil, "CacheUpdateChannel after: id=%d, name=%s, status=%d, polling_index=%d", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
}
+48 -70
View File
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
@@ -17,42 +16,27 @@ import (
"gorm.io/gorm"
)
func applyExplicitLogTextFilter(tx *gorm.DB, column string, value string) (*gorm.DB, error) {
if value == "" {
return tx, nil
}
if strings.Contains(value, "%") {
pattern, err := sanitizeLikePattern(value)
if err != nil {
return nil, err
}
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern), nil
}
return tx.Where(column+" = ?", value), nil
}
type Log struct {
Id int `json:"id" gorm:"index:idx_created_at_id,priority:2;index:idx_user_id_id,priority:2"`
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:1;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
IsStream bool `json:"is_stream"`
ChannelId int `json:"channel" gorm:"index"`
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Ip string `json:"ip" gorm:"index;default:''"`
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
UpstreamRequestId string `json:"upstream_request_id,omitempty" gorm:"type:varchar(128);index:idx_logs_upstream_request_id;default:''"`
Other string `json:"other"`
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
IsStream bool `json:"is_stream"`
ChannelId int `json:"channel" gorm:"index"`
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Ip string `json:"ip" gorm:"index;default:''"`
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
Other string `json:"other"`
}
// don't use iota, avoid change log type value
@@ -160,10 +144,9 @@ func RecordTopupLog(userId int, content string, callerIp string, paymentMethod s
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, common.LocalLogPreview(content)))
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
requestId := c.GetString(common.RequestIdKey)
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
otherStr := common.MapToJsonStr(other)
// 判断是否需要记录 IP
needRecordIp := false
@@ -194,9 +177,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
return ""
}(),
RequestId: requestId,
UpstreamRequestId: upstreamRequestId,
Other: otherStr,
RequestId: requestId,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -226,7 +208,6 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
username := c.GetString("username")
requestId := c.GetString(common.RequestIdKey)
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
otherStr := common.MapToJsonStr(params.Other)
// 判断是否需要记录 IP
needRecordIp := false
@@ -257,9 +238,8 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
return ""
}(),
RequestId: requestId,
UpstreamRequestId: upstreamRequestId,
Other: otherStr,
RequestId: requestId,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -315,7 +295,7 @@ func RecordTaskBillingLog(params RecordTaskBillingLogParams) {
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string, upstreamRequestId string) (logs []*Log, total int64, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB
@@ -323,11 +303,11 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = LOG_DB.Where("logs.type = ?", logType)
}
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
return nil, 0, err
if modelName != "" {
tx = tx.Where("logs.model_name like ?", modelName)
}
if tx, err = applyExplicitLogTextFilter(tx, "logs.username", username); err != nil {
return nil, 0, err
if username != "" {
tx = tx.Where("logs.username = ?", username)
}
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
@@ -335,9 +315,6 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
if upstreamRequestId != "" {
tx = tx.Where("logs.upstream_request_id = ?", upstreamRequestId)
}
if startTimestamp != 0 {
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
@@ -354,7 +331,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if err != nil {
return nil, 0, err
}
err = tx.Order("logs.created_at desc, logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
@@ -404,7 +381,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
const logSearchCountLimit = 10000
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string, upstreamRequestId string) (logs []*Log, total int64, err error) {
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB.Where("logs.user_id = ?", userId)
@@ -412,8 +389,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
}
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
return nil, 0, err
if modelName != "" {
modelNamePattern, err := sanitizeLikePattern(modelName)
if err != nil {
return nil, 0, err
}
tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
}
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
@@ -421,9 +402,6 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
if upstreamRequestId != "" {
tx = tx.Where("logs.upstream_request_id = ?", upstreamRequestId)
}
if startTimestamp != 0 {
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
@@ -460,11 +438,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
// 为rpm和tpm创建单独的查询
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if tx, err = applyExplicitLogTextFilter(tx, "username", username); err != nil {
return stat, err
}
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "username", username); err != nil {
return stat, err
if username != "" {
tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
@@ -476,11 +452,13 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if tx, err = applyExplicitLogTextFilter(tx, "model_name", modelName); err != nil {
return stat, err
}
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "model_name", modelName); err != nil {
return stat, err
if modelName != "" {
modelNamePattern, err := sanitizeLikePattern(modelName)
if err != nil {
return stat, err
}
tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
-4
View File
@@ -397,10 +397,8 @@ func ensureSubscriptionPlanTableSQLite() error {
` + "`custom_seconds`" + ` bigint NOT NULL DEFAULT 0,
` + "`enabled`" + ` numeric DEFAULT 1,
` + "`sort_order`" + ` integer DEFAULT 0,
` + "`allow_balance_pay`" + ` numeric DEFAULT 1,
` + "`stripe_price_id`" + ` varchar(128) DEFAULT '',
` + "`creem_product_id`" + ` varchar(128) DEFAULT '',
` + "`waffo_pancake_product_id`" + ` varchar(128) DEFAULT '',
` + "`max_purchase_per_user`" + ` integer DEFAULT 0,
` + "`upgrade_group`" + ` varchar(64) DEFAULT '',
` + "`total_amount`" + ` bigint NOT NULL DEFAULT 0,
@@ -432,10 +430,8 @@ PRIMARY KEY (` + "`id`" + `)
{Name: "custom_seconds", DDL: "`custom_seconds` bigint NOT NULL DEFAULT 0"},
{Name: "enabled", DDL: "`enabled` numeric DEFAULT 1"},
{Name: "sort_order", DDL: "`sort_order` integer DEFAULT 0"},
{Name: "allow_balance_pay", DDL: "`allow_balance_pay` numeric DEFAULT 1"},
{Name: "stripe_price_id", DDL: "`stripe_price_id` varchar(128) DEFAULT ''"},
{Name: "creem_product_id", DDL: "`creem_product_id` varchar(128) DEFAULT ''"},
{Name: "waffo_pancake_product_id", DDL: "`waffo_pancake_product_id` varchar(128) DEFAULT ''"},
{Name: "max_purchase_per_user", DDL: "`max_purchase_per_user` integer DEFAULT 0"},
{Name: "upgrade_group", DDL: "`upgrade_group` varchar(64) DEFAULT ''"},
{Name: "total_amount", DDL: "`total_amount` bigint NOT NULL DEFAULT 0"},
-57
View File
@@ -2,7 +2,6 @@ package model
import (
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -136,62 +135,6 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
return result, nil
}
func normalizeLookupValues(values []string) []string {
seen := make(map[string]struct{}, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
normalized = append(normalized, value)
}
return normalized
}
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
result := make(map[string]int)
modelNames = normalizeLookupValues(modelNames)
if len(modelNames) == 0 {
return result, nil
}
type row struct {
Model string
ChannelType int
}
var rows []row
query := DB.Table("abilities").
Select("abilities.model as model, channels.type as channel_type").
Joins("JOIN channels ON abilities.channel_id = channels.id").
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
Order("COALESCE(abilities.priority, 0) DESC").
Order("abilities.weight DESC").
Order("abilities.channel_id ASC")
groups = normalizeLookupValues(groups)
if len(groups) > 0 {
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
}
if err := query.Scan(&rows).Error; err != nil {
return nil, err
}
for _, r := range rows {
if _, ok := result[r.Model]; ok {
continue
}
result[r.Model] = r.ChannelType
}
return result, nil
}
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var models []*Model
db := DB.Model(&Model{})
-141
View File
@@ -1,141 +0,0 @@
package model
import (
"fmt"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/stretchr/testify/require"
)
func clearPreferredOwnerTables(t *testing.T) {
t.Helper()
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
}
func insertPreferredOwnerCandidate(
t *testing.T,
channelID int,
modelName string,
group string,
channelType int,
priority int64,
weight uint,
channelStatus int,
abilityEnabled bool,
) {
t.Helper()
require.NoError(t, DB.Create(&Channel{
Id: channelID,
Type: channelType,
Key: fmt.Sprintf("key-%d", channelID),
Status: channelStatus,
Name: fmt.Sprintf("channel-%d", channelID),
}).Error)
require.NoError(t, DB.Create(&Ability{
Group: group,
Model: modelName,
ChannelId: channelID,
Enabled: abilityEnabled,
Priority: &priority,
Weight: weight,
}).Error)
}
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
const modelName = "gpt-5.4"
tests := []struct {
name string
setup func(t *testing.T)
groups []string
expected int
found bool
}{
{
name: "openai only",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "codex only",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "priority wins",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "weight wins when priority is equal",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "channel id stabilizes exact ties",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "group filter excludes other groups",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "disabled candidates are ignored",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
},
groups: []string{"default"},
found: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clearPreferredOwnerTables(t)
tt.setup(t)
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
require.NoError(t, err)
got, ok := owners[modelName]
require.Equal(t, tt.found, ok)
if tt.found {
require.Equal(t, tt.expected, got)
}
})
}
}
+20 -39
View File
@@ -12,7 +12,6 @@ import (
"github.com/QuantumNous/new-api/setting/performance_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"gorm.io/gorm"
)
type Option struct {
@@ -107,13 +106,18 @@ func InitOptionMap() {
common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp)
common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString()
common.OptionMap["WaffoPancakeEnabled"] = strconv.FormatBool(setting.WaffoPancakeEnabled)
common.OptionMap["WaffoPancakeSandbox"] = strconv.FormatBool(setting.WaffoPancakeSandbox)
common.OptionMap["WaffoPancakeMerchantID"] = setting.WaffoPancakeMerchantID
common.OptionMap["WaffoPancakePrivateKey"] = setting.WaffoPancakePrivateKey
common.OptionMap["WaffoPancakeReturnURL"] = setting.WaffoPancakeReturnURL
common.OptionMap["WaffoPancakeUnitPrice"] = strconv.FormatFloat(setting.WaffoPancakeUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoPancakeMinTopUp"] = strconv.Itoa(setting.WaffoPancakeMinTopUp)
common.OptionMap["WaffoPancakeWebhookPublicKey"] = setting.WaffoPancakeWebhookPublicKey
common.OptionMap["WaffoPancakeWebhookTestKey"] = setting.WaffoPancakeWebhookTestKey
common.OptionMap["WaffoPancakeStoreID"] = setting.WaffoPancakeStoreID
common.OptionMap["WaffoPancakeProductID"] = setting.WaffoPancakeProductID
common.OptionMap["WaffoPancakeReturnURL"] = setting.WaffoPancakeReturnURL
common.OptionMap["WaffoPancakeCurrency"] = setting.WaffoPancakeCurrency
common.OptionMap["WaffoPancakeUnitPrice"] = strconv.FormatFloat(setting.WaffoPancakeUnitPrice, 'f', -1, 64)
common.OptionMap["WaffoPancakeMinTopUp"] = strconv.Itoa(setting.WaffoPancakeMinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
@@ -218,39 +222,6 @@ func UpdateOption(key string, value string) error {
return updateOptionMap(key, value)
}
// UpdateOptionsBulk persists multiple key/value pairs in a single database
// transaction, then dispatches them through updateOptionMap in one pass. If
// any DB write fails the whole transaction rolls back and no in-memory state
// is touched — safe for callers that must commit a set of related options
// atomically (e.g. payment gateway binding).
func UpdateOptionsBulk(values map[string]string) error {
if len(values) == 0 {
return nil
}
err := DB.Transaction(func(tx *gorm.DB) error {
for k, v := range values {
option := Option{Key: k}
if err := tx.FirstOrCreate(&option, Option{Key: k}).Error; err != nil {
return err
}
option.Value = v
if err := tx.Save(&option).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
for k, v := range values {
if err := updateOptionMap(k, v); err != nil {
return err
}
}
return nil
}
func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
@@ -448,16 +419,26 @@ func updateOptionMap(key string, value string) (err error) {
setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64)
case "WaffoMinTopUp":
setting.WaffoMinTopUp, _ = strconv.Atoi(value)
case "WaffoPancakeEnabled":
setting.WaffoPancakeEnabled = value == "true"
case "WaffoPancakeSandbox":
setting.WaffoPancakeSandbox = value == "true"
case "WaffoPancakeMerchantID":
setting.WaffoPancakeMerchantID = value
case "WaffoPancakePrivateKey":
setting.WaffoPancakePrivateKey = value
case "WaffoPancakeReturnURL":
setting.WaffoPancakeReturnURL = value
case "WaffoPancakeWebhookPublicKey":
setting.WaffoPancakeWebhookPublicKey = value
case "WaffoPancakeWebhookTestKey":
setting.WaffoPancakeWebhookTestKey = value
case "WaffoPancakeStoreID":
setting.WaffoPancakeStoreID = value
case "WaffoPancakeProductID":
setting.WaffoPancakeProductID = value
case "WaffoPancakeReturnURL":
setting.WaffoPancakeReturnURL = value
case "WaffoPancakeCurrency":
setting.WaffoPancakeCurrency = value
case "WaffoPancakeUnitPrice":
setting.WaffoPancakeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "WaffoPancakeMinTopUp":
+10 -17
View File
@@ -37,13 +37,13 @@ func UpsertPerfMetric(metric *PerfMetric) error {
{Name: "bucket_ts"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("perf_metrics.request_count + ?", metric.RequestCount),
"success_count": gorm.Expr("perf_metrics.success_count + ?", metric.SuccessCount),
"total_latency_ms": gorm.Expr("perf_metrics.total_latency_ms + ?", metric.TotalLatencyMs),
"ttft_sum_ms": gorm.Expr("perf_metrics.ttft_sum_ms + ?", metric.TtftSumMs),
"ttft_count": gorm.Expr("perf_metrics.ttft_count + ?", metric.TtftCount),
"output_tokens": gorm.Expr("perf_metrics.output_tokens + ?", metric.OutputTokens),
"generation_ms": gorm.Expr("perf_metrics.generation_ms + ?", metric.GenerationMs),
"request_count": gorm.Expr("request_count + ?", metric.RequestCount),
"success_count": gorm.Expr("success_count + ?", metric.SuccessCount),
"total_latency_ms": gorm.Expr("total_latency_ms + ?", metric.TotalLatencyMs),
"ttft_sum_ms": gorm.Expr("ttft_sum_ms + ?", metric.TtftSumMs),
"ttft_count": gorm.Expr("ttft_count + ?", metric.TtftCount),
"output_tokens": gorm.Expr("output_tokens + ?", metric.OutputTokens),
"generation_ms": gorm.Expr("generation_ms + ?", metric.GenerationMs),
}),
}).Create(metric).Error
}
@@ -68,18 +68,11 @@ type PerfMetricSummary struct {
GenerationMs int64 `json:"generation_ms"`
}
func GetPerfMetricsSummaryAll(startTs int64, endTs int64, groups []string) ([]PerfMetricSummary, error) {
func GetPerfMetricsSummaryAll(startTs int64, endTs int64) ([]PerfMetricSummary, error) {
var summaries []PerfMetricSummary
query := DB.Model(&PerfMetric{}).
err := DB.Model(&PerfMetric{}).
Select("model_name, SUM(request_count) as request_count, SUM(success_count) as success_count, SUM(total_latency_ms) as total_latency_ms, SUM(output_tokens) as output_tokens, SUM(generation_ms) as generation_ms").
Where("bucket_ts >= ? AND bucket_ts <= ?", startTs, endTs)
if groups != nil {
if len(groups) == 0 {
return summaries, nil
}
query = query.Where(commonGroupCol+" IN ?", groups)
}
err := query.
Where("bucket_ts >= ? AND bucket_ts <= ?", startTs, endTs).
Group("model_name").
Having("SUM(request_count) > 0").
Find(&summaries).Error
+2 -117
View File
@@ -11,7 +11,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/samber/hot"
"github.com/shopspring/decimal"
"gorm.io/gorm"
)
@@ -160,11 +159,8 @@ type SubscriptionPlan struct {
Enabled bool `json:"enabled" gorm:"default:true"`
SortOrder int `json:"sort_order" gorm:"type:int;default:0"`
AllowBalancePay *bool `json:"allow_balance_pay" gorm:"default:true"`
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
WaffoPancakeProductId string `json:"waffo_pancake_product_id" gorm:"type:varchar(128);default:''"`
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
// Max purchases per user (0 = unlimited)
MaxPurchasePerUser int `json:"max_purchase_per_user" gorm:"type:int;default:0"`
@@ -195,12 +191,6 @@ func (p *SubscriptionPlan) BeforeUpdate(tx *gorm.DB) error {
return nil
}
func (p *SubscriptionPlan) NormalizeDefaults() {
if p.AllowBalancePay == nil {
p.AllowBalancePay = common.GetPointer(true)
}
}
// Subscription order (payment -> webhook -> create UserSubscription)
type SubscriptionOrder struct {
Id int `json:"id"`
@@ -368,7 +358,6 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
key := subscriptionPlanCacheKey(id)
if key != "" {
if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
cached.NormalizeDefaults()
return &cached, nil
}
}
@@ -380,7 +369,6 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
return nil, err
}
plan.NormalizeDefaults()
_ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
return &plan, nil
}
@@ -676,109 +664,6 @@ func AdminBindSubscription(userId int, planId int, sourceNote string) (string, e
return "", nil
}
func calcSubscriptionBalanceQuota(priceAmount float64) (int, error) {
if priceAmount <= 0 {
return 0, nil
}
if common.QuotaPerUnit <= 0 {
return 0, errors.New("额度单位配置错误")
}
quota := decimal.NewFromFloat(priceAmount).
Mul(decimal.NewFromFloat(common.QuotaPerUnit)).
Ceil().
IntPart()
return int(quota), nil
}
// PurchaseSubscriptionWithBalance creates a subscription by deducting the user's wallet quota.
func PurchaseSubscriptionWithBalance(userId int, planId int) error {
if userId <= 0 || planId <= 0 {
return errors.New("invalid userId or planId")
}
var logPlanTitle string
var logMoney float64
var chargedQuota int
var upgradeGroup string
err := DB.Transaction(func(tx *gorm.DB) error {
plan, err := getSubscriptionPlanByIdTx(tx, planId)
if err != nil {
return err
}
if !plan.Enabled {
return errors.New("套餐未启用")
}
if plan.PriceAmount < 0 {
return errors.New("套餐价格不能为负数")
}
if plan.AllowBalancePay != nil && !*plan.AllowBalancePay {
return errors.New("该套餐不允许使用余额兑换")
}
requiredQuota, err := calcSubscriptionBalanceQuota(plan.PriceAmount)
if err != nil {
return err
}
var user User
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where("id = ?", userId).First(&user).Error; err != nil {
return err
}
if requiredQuota > 0 && user.Quota < requiredQuota {
return errors.New("余额不足")
}
if requiredQuota > 0 {
if err := tx.Model(&User{}).Where("id = ?", userId).
Update("quota", gorm.Expr("quota - ?", requiredQuota)).Error; err != nil {
return err
}
}
if _, err := CreateUserSubscriptionFromPlanTx(tx, userId, plan, PaymentMethodBalance); err != nil {
return err
}
now := common.GetTimestamp()
tradeNo := fmt.Sprintf("SUBBALUSR%dNO%s%d", userId, common.GetRandomString(6), time.Now().UnixNano())
order := &SubscriptionOrder{
UserId: userId,
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: tradeNo,
PaymentMethod: PaymentMethodBalance,
PaymentProvider: PaymentProviderBalance,
Status: common.TopUpStatusSuccess,
CreateTime: now,
CompleteTime: now,
ProviderPayload: fmt.Sprintf("charged_quota=%d", requiredQuota),
}
if err := tx.Create(order).Error; err != nil {
return err
}
logPlanTitle = plan.Title
logMoney = plan.PriceAmount
chargedQuota = requiredQuota
upgradeGroup = strings.TrimSpace(plan.UpgradeGroup)
return nil
})
if err != nil {
return err
}
if chargedQuota > 0 {
if err := cacheDecrUserQuota(userId, int64(chargedQuota)); err != nil {
common.SysLog("failed to decrease user quota cache after subscription balance purchase: " + err.Error())
}
}
if upgradeGroup != "" {
_ = UpdateUserGroupCache(userId, upgradeGroup)
}
msg := fmt.Sprintf("使用余额购买订阅成功,套餐: %s,支付金额: %.2f,扣除额度: %d", logPlanTitle, logMoney, chargedQuota)
RecordLog(userId, LogTypeTopup, msg)
return nil
}
// GetAllActiveUserSubscriptions returns all active subscriptions for a user.
func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
if userId <= 0 {
-5
View File
@@ -26,7 +26,6 @@ func TestMain(m *testing.M) {
common.RedisEnabled = false
common.BatchUpdateEnabled = false
common.LogConsumeEnabled = true
initCol()
sqlDB, err := db.DB()
if err != nil {
@@ -40,12 +39,10 @@ func TestMain(m *testing.M) {
&Token{},
&Log{},
&Channel{},
&Ability{},
&TopUp{},
&SubscriptionPlan{},
&SubscriptionOrder{},
&UserSubscription{},
&PerfMetric{},
); err != nil {
panic("failed to migrate: " + err.Error())
}
@@ -61,12 +58,10 @@ func truncateTables(t *testing.T) {
DB.Exec("DELETE FROM tokens")
DB.Exec("DELETE FROM logs")
DB.Exec("DELETE FROM channels")
DB.Exec("DELETE FROM abilities")
DB.Exec("DELETE FROM top_ups")
DB.Exec("DELETE FROM subscription_orders")
DB.Exec("DELETE FROM subscription_plans")
DB.Exec("DELETE FROM user_subscriptions")
DB.Exec("DELETE FROM perf_metrics")
})
}
-2
View File
@@ -29,7 +29,6 @@ const (
PaymentMethodCreem = "creem"
PaymentMethodWaffo = "waffo"
PaymentMethodWaffoPancake = "waffo_pancake"
PaymentMethodBalance = "balance"
)
const (
@@ -38,7 +37,6 @@ const (
PaymentProviderCreem = "creem"
PaymentProviderWaffo = "waffo"
PaymentProviderWaffoPancake = "waffo_pancake"
PaymentProviderBalance = "balance"
)
var (
+21 -36
View File
@@ -11,7 +11,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
@@ -35,8 +34,8 @@ type User struct {
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken *string `json:"-" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
@@ -225,7 +224,7 @@ func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err err
return users, total, nil
}
func SearchUsers(keyword string, group string, role *int, status *int, startIdx int, num int) ([]*User, int64, error) {
func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
var users []*User
var total int64
var err error
@@ -246,25 +245,28 @@ func SearchUsers(keyword string, group string, role *int, status *int, startIdx
// 构建搜索条件
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
likeArgs := []interface{}{"%" + keyword + "%", "%" + keyword + "%", "%" + keyword + "%"}
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果是数字,同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
likeArgs = append([]interface{}{keywordInt}, likeArgs...)
}
query = query.Where("("+likeCondition+")", likeArgs...)
if group != "" {
query = query.Where(commonGroupCol+" = ?", group)
}
if role != nil {
query = query.Where("role = ?", *role)
}
if status != nil {
query = query.Where("status = ?", *status)
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
} else {
// 非数字关键字,只搜索字符串字段
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
}
// 获取总数
@@ -418,7 +420,7 @@ func (user *User) Insert(inviterId int) error {
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 && operation_setting.IsPaymentComplianceConfirmed() {
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
@@ -479,7 +481,7 @@ func (user *User) FinalizeOAuthUserCreation(inviterId int) {
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 && operation_setting.IsPaymentComplianceConfirmed() {
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
@@ -984,23 +986,6 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
//}
}
func updateUserQuotaUsedQuotaAndRequestCount(id int, quota int, usedQuota int, requestCount int) {
if quota == 0 && usedQuota == 0 && requestCount == 0 {
return
}
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"quota": gorm.Expr("quota + ?", quota),
"used_quota": gorm.Expr("used_quota + ?", usedQuota),
"request_count": gorm.Expr("request_count + ?", requestCount),
},
).Error
if err != nil {
common.SysLog("failed to batch update user quota, used quota and request count: " + err.Error())
}
}
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
+11 -26
View File
@@ -67,48 +67,33 @@ func batchUpdate() {
}
common.SysLog("batch update started")
stores := make([]map[int]int, BatchUpdateTypeCount)
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
stores[i] = batchUpdateStores[i]
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
}
for i, store := range stores {
if i == BatchUpdateTypeUserQuota || i == BatchUpdateTypeUsedQuota || i == BatchUpdateTypeRequestCount {
continue
}
// TODO: maybe we can combine updates with same key?
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysLog("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysLog("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
userQuotaStore := stores[BatchUpdateTypeUserQuota]
usedQuotaStore := stores[BatchUpdateTypeUsedQuota]
requestCountStore := stores[BatchUpdateTypeRequestCount]
userIDs := make(map[int]struct{}, len(userQuotaStore)+len(usedQuotaStore)+len(requestCountStore))
for key := range userQuotaStore {
userIDs[key] = struct{}{}
}
for key := range usedQuotaStore {
userIDs[key] = struct{}{}
}
for key := range requestCountStore {
userIDs[key] = struct{}{}
}
for key := range userIDs {
updateUserQuotaUsedQuotaAndRequestCount(key, userQuotaStore[key], usedQuotaStore[key], requestCountStore[key])
}
common.SysLog("batch update finished")
}
+2 -19
View File
@@ -122,7 +122,7 @@ func Query(params QueryParams) (QueryResult, error) {
return buildQueryResult(params.Model, merged), nil
}
func QuerySummaryAll(hours int, groups []string) (SummaryAllResult, error) {
func QuerySummaryAll(hours int) (SummaryAllResult, error) {
if hours <= 0 {
hours = 24
}
@@ -131,9 +131,8 @@ func QuerySummaryAll(hours int, groups []string) (SummaryAllResult, error) {
}
endTs := time.Now().Unix()
startTs := endTs - int64(hours)*3600
allowedGroups := allowedGroupSet(groups)
rows, err := model.GetPerfMetricsSummaryAll(startTs, endTs, groups)
rows, err := model.GetPerfMetricsSummaryAll(startTs, endTs)
if err != nil {
return SummaryAllResult{}, err
}
@@ -154,11 +153,6 @@ func QuerySummaryAll(hours int, groups []string) (SummaryAllResult, error) {
if k.bucketTs < startTs || k.bucketTs > endTs {
return true
}
if allowedGroups != nil {
if _, ok := allowedGroups[k.group]; !ok {
return true
}
}
snap := value.(*atomicBucket).snapshot()
if snap.requestCount == 0 {
return true
@@ -199,17 +193,6 @@ func QuerySummaryAll(hours int, groups []string) (SummaryAllResult, error) {
return SummaryAllResult{Models: models}, nil
}
func allowedGroupSet(groups []string) map[string]struct{} {
if groups == nil {
return nil
}
allowed := make(map[string]struct{}, len(groups))
for _, group := range groups {
allowed[group] = struct{}{}
}
return allowed
}
func bucketStart(ts int64) int64 {
bucketSeconds := perf_metrics_setting.GetBucketSeconds()
if bucketSeconds <= 0 {
+4 -3
View File
@@ -229,7 +229,7 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
time.Sleep(time.Duration(5) * time.Second)
for {
logger.LogDebug(c, "asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds)
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
step++
rsp, err, body := updateTask(info, taskID)
responseBody = body
@@ -320,10 +320,11 @@ func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *rela
}
}
//logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
if a.IsSyncImageModel {
logger.LogDebug(c, "ali_sync_image_result: %s", originRespBody)
logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
} else {
logger.LogDebug(c, "ali_async_image_result: %s", originRespBody)
logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
}
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
+30 -34
View File
@@ -25,23 +25,6 @@ import (
"github.com/gorilla/websocket"
)
// applyUpstreamContentLength populates req.ContentLength when the upstream
// body is wrapped in a BodyStorage (see relay/common/outbound_body.go).
//
// net/http.NewRequest only auto-detects ContentLength for *bytes.Reader,
// *bytes.Buffer and *strings.Reader. When the body is a type-erased io.Reader
// (which is the case for ReaderOnly(BodyStorage)), the Content-Length header
// would otherwise be omitted, forcing chunked transfer encoding and breaking
// some upstreams that require an explicit Content-Length.
func applyUpstreamContentLength(req *http.Request, info *common.RelayInfo) {
if info == nil {
return
}
if info.UpstreamRequestBodySize > 0 && req.ContentLength <= 0 {
req.ContentLength = info.UpstreamRequestBodySize
}
}
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data
@@ -309,12 +292,13 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
logger.LogDebug(c, "fullRequestURL: %s", fullRequestURL)
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
headers := req.Header
err = a.SetupRequestHeader(c, &headers, info)
if err != nil {
@@ -339,12 +323,13 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
logger.LogDebug(c, "fullRequestURL: %s", fullRequestURL)
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
headers := req.Header
@@ -403,9 +388,13 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
defer func() {
// 增加panic恢复处理
if r := recover(); r != nil {
logger.LogDebug(c, "SSE ping goroutine panic recovered: %v", r)
if common2.DebugEnabled {
println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
}
}
if common2.DebugEnabled {
println("SSE ping goroutine stopped.")
}
logger.LogDebug(c, "SSE ping goroutine stopped")
}()
if pingInterval <= 0 {
@@ -416,11 +405,15 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
// 确保在任何情况下都清理ticker
defer func() {
ticker.Stop()
logger.LogDebug(c, "SSE ping ticker stopped")
if common2.DebugEnabled {
println("SSE ping ticker stopped")
}
}()
var pingMutex sync.Mutex
logger.LogDebug(c, "SSE ping goroutine started")
if common2.DebugEnabled {
println("SSE ping goroutine started")
}
// 增加超时控制,防止goroutine长时间运行
maxPingDuration := 120 * time.Minute // 最大ping持续时间
@@ -432,7 +425,9 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
// 发送 ping 数据
case <-ticker.C:
if err := sendPingData(c, &pingMutex); err != nil {
logger.LogDebug(c, "SSE ping error, stopping goroutine: %s", err.Error())
if common2.DebugEnabled {
println("SSE ping error, stopping goroutine:", err.Error())
}
return
}
// 收到退出信号
@@ -443,7 +438,9 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
return
// 超时保护,防止goroutine无限运行
case <-pingTimeout.C:
logger.LogDebug(c, "SSE ping goroutine timeout, stopping")
if common2.DebugEnabled {
println("SSE ping goroutine timeout, stopping")
}
return
}
}
@@ -466,7 +463,9 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
return
}
logger.LogDebug(c, "SSE ping data sent")
if common2.DebugEnabled {
println("SSE ping data sent.")
}
done <- nil
}()
@@ -508,7 +507,9 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
defer func() {
if stopPinger != nil {
stopPinger()
logger.LogDebug(c, "SSE ping goroutine stopped by defer")
if common2.DebugEnabled {
println("SSE ping goroutine stopped by defer")
}
}
}()
}
@@ -523,10 +524,6 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
return nil, errors.New("resp is nil")
}
if upID := resp.Header.Get(common2.RequestIdKey); upID != "" {
c.Set(common2.UpstreamRequestIdKey, upID)
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
@@ -541,7 +538,6 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
applyUpstreamContentLength(req, info)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
-6
View File
@@ -19,7 +19,6 @@ var awsModelIDMap = map[string]string{
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
"claude-opus-4-7": "anthropic.claude-opus-4-7",
"claude-opus-4-8": "anthropic.claude-opus-4-8",
// Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
@@ -98,11 +97,6 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"ap": true,
"eu": true,
},
"anthropic.claude-opus-4-8": {
"us": true,
"ap": true,
"eu": true,
},
"anthropic.claude-haiku-4-5-20251001-v1:0": {
"us": true,
"ap": true,
-7
View File
@@ -33,13 +33,6 @@ var ModelList = []string{
"claude-opus-4-7-medium",
"claude-opus-4-7-low",
"claude-opus-4-7-thinking",
"claude-opus-4-8",
"claude-opus-4-8-max",
"claude-opus-4-8-xhigh",
"claude-opus-4-8-high",
"claude-opus-4-8-medium",
"claude-opus-4-8-low",
"claude-opus-4-8-thinking",
}
var ChannelName = "claude"
+12 -11
View File
@@ -154,17 +154,14 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
}
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") ||
strings.HasPrefix(textRequest.Model, "claude-opus-4-7") ||
strings.HasPrefix(textRequest.Model, "claude-opus-4-8")) {
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") || strings.HasPrefix(textRequest.Model, "claude-opus-4-7")) {
claudeRequest.Model = baseModel
claudeRequest.Thinking = &dto.Thinking{
Type: "adaptive",
}
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
if strings.HasPrefix(baseModel, "claude-opus-4-7") ||
strings.HasPrefix(baseModel, "claude-opus-4-8") {
// Opus 4.7/4.8 reject non-default temperature/top_p/top_k with 400
if strings.HasPrefix(baseModel, "claude-opus-4-7") {
// Opus 4.7 rejects non-default temperature/top_p/top_k with 400
// and defaults display to "omitted"; restore the 4.6 visible summary.
claudeRequest.Thinking.Display = "summarized"
claudeRequest.Temperature = nil
@@ -178,9 +175,8 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
strings.HasSuffix(textRequest.Model, "-thinking") {
trimmedModel := strings.TrimSuffix(textRequest.Model, "-thinking")
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") ||
strings.HasPrefix(trimmedModel, "claude-opus-4-8") {
// Opus 4.7/4.8 reject thinking.type="enabled"; use adaptive at high effort.
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") {
// Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
claudeRequest.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
claudeRequest.OutputConfig = json.RawMessage(`{"effort":"high"}`)
claudeRequest.Temperature = nil
@@ -446,7 +442,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
tools := make([]dto.ToolCallResponse, 0)
fcIdx := 0
if claudeResponse.Index != nil {
fcIdx = *claudeResponse.Index
fcIdx = *claudeResponse.Index - 1
if fcIdx < 0 {
fcIdx = 0
}
}
var choice dto.ChatCompletionsStreamResponseChoice
if claudeResponse.Type == "message_start" {
@@ -950,7 +949,9 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
logger.LogDebug(c, "responseBody: %s", responseBody)
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody)
if handleErr != nil {
return nil, handleErr
-56
View File
@@ -9,10 +9,6 @@ import (
"github.com/stretchr/testify/require"
)
func commonPointer[T any](value T) *T {
return &value
}
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
claudeInfo := &ClaudeResponseInfo{
Usage: &dto.Usage{},
@@ -314,58 +310,6 @@ func TestRequestOpenAI2ClaudeMessage_IgnoresUnsupportedFileContent(t *testing.T)
require.Equal(t, "see attachment", *content[0].Text)
}
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48HighUsesAdaptiveThinking(t *testing.T) {
request := dto.GeneralOpenAIRequest{
Model: "claude-opus-4-8-high",
Temperature: commonPointer(0.7),
TopP: commonPointer(0.9),
TopK: commonPointer(40),
Messages: []dto.Message{
{
Role: "user",
Content: "hello",
},
},
}
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
require.NoError(t, err)
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
require.NotNil(t, claudeRequest.Thinking)
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
require.Nil(t, claudeRequest.Temperature)
require.Nil(t, claudeRequest.TopP)
require.Nil(t, claudeRequest.TopK)
}
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48ThinkingUsesAdaptiveHighEffort(t *testing.T) {
request := dto.GeneralOpenAIRequest{
Model: "claude-opus-4-8-thinking",
Temperature: commonPointer(0.7),
TopP: commonPointer(0.9),
TopK: commonPointer(40),
Messages: []dto.Message{
{
Role: "user",
Content: "hello",
},
},
}
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
require.NoError(t, err)
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
require.NotNil(t, claudeRequest.Thinking)
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
require.Nil(t, claudeRequest.Temperature)
require.Nil(t, claudeRequest.TopP)
require.Nil(t, claudeRequest.TopK)
}
func TestRequestOpenAI2ClaudeMessage_SupportsPDFFileContent(t *testing.T) {
request := dto.GeneralOpenAIRequest{
Model: "claude-3-5-sonnet",
+1 -1
View File
@@ -30,7 +30,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
}
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
+2 -4
View File
@@ -1,6 +1,7 @@
package cohere
import (
"bufio"
"encoding/json"
"io"
"net/http"
@@ -85,7 +86,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
@@ -105,9 +106,6 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
data := scanner.Text()
dataChan <- data
}
if err := scanner.Err(); err != nil {
common.SysLog("error reading stream: " + err.Error())
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
+1 -1
View File
@@ -98,7 +98,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
}
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
+3 -8
View File
@@ -159,14 +159,9 @@ func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto
media := mediaContent.GetImageMedia()
var file *DifyFile
if media.IsRemoteImage() {
// 修复 #2083: 远程图片分支此前未初始化 file,
// 导致 file.Type = ... 触发 nil pointer dereference
// 而 panic500: "invalid memory address or nil pointer dereference")。
file = &DifyFile{
Type: media.MimeType,
TransferMode: "remote_url",
URL: media.Url,
}
file.Type = media.MimeType
file.TransferMode = "remote_url"
file.URL = media.Url
} else {
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
}
+6 -2
View File
@@ -26,7 +26,9 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
logger.LogDebug(c, "Gemini native response body: %s", responseBody)
if common.DebugEnabled {
println(string(responseBody))
}
// 解析为 Gemini 原生响应格式
var geminiResponse dto.GeminiChatResponse
@@ -55,7 +57,9 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
logger.LogDebug(c, "Gemini native embedding response body: %s", responseBody)
if common.DebugEnabled {
println(string(responseBody))
}
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
+22 -105
View File
@@ -1079,47 +1079,17 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
FinishReason: constant.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
// 使用 strings.Builder 直接累积最终 content,避免:
// 1) 每张 inline image 生成一次中间 "![image](...)" 字符串
// 2) 末尾 strings.Join 再分配一份等大缓冲
// Gemini 图片返回时 InlineData.Data 可能是数 MB 的 base64
// 上述两份临时分配在高并发下会显著放大堆驻留。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
var texts []string
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
// 媒体内容
if strings.HasPrefix(part.InlineData.MimeType, "image") {
writeSep()
content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
} else {
// 其他媒体类型,直接显示链接
writeSep()
content.WriteString("[media](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
}
} else if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
@@ -1130,22 +1100,13 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.ReasoningContent = &part.Text
} else {
if part.ExecutableCode != nil {
writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```")
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
} else if part.CodeExecutionResult != nil {
writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```")
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
writeSep()
content.WriteString(part.Text)
texts = append(texts, part.Text)
}
}
}
@@ -1154,7 +1115,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(content.String())
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
if candidate.FinishReason != nil {
@@ -1208,25 +1169,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
//Role: "assistant",
},
}
// 使用 strings.Builder 直接累积 delta content,避免每张 image / 每个
// 文本片段都先 `+` 拼出一份临时 string,再 strings.Join 再拷贝一遍。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
var texts []string
isTools := false
isThought := false
if candidate.FinishReason != nil {
@@ -1264,12 +1207,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
writeSep()
content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
}
} else if part.FunctionCall != nil {
isTools = true
@@ -1280,33 +1219,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
} else if part.Thought {
isThought = true
writeSep()
content.WriteString(part.Text)
texts = append(texts, part.Text)
} else {
if part.ExecutableCode != nil {
writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```\n")
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
} else if part.CodeExecutionResult != nil {
writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```\n")
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
} else {
if part.Text != "\n" {
writeSep()
content.WriteString(part.Text)
texts = append(texts, part.Text)
}
}
}
}
if isThought {
choice.Delta.SetReasoningContent(content.String())
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
} else {
choice.Delta.SetContentString(content.String())
choice.Delta.SetContentString(strings.Join(texts, "\n"))
}
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
@@ -1410,14 +1339,6 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
if response.IsToolCall() {
finishReason = constant.FinishReasonToolCalls
if info.RelayFormat == types.RelayFormatClaude {
for choiceIdx := range response.Choices {
response.Choices[choiceIdx].FinishReason = nil
}
}
}
for choiceIdx := range response.Choices {
choiceKey := response.Choices[choiceIdx].Index
for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
@@ -1441,7 +1362,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
}
}
logger.LogDebug(c, "info.SendResponseCount = %d", info.SendResponseCount)
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
if info.SendResponseCount == 0 {
// send first response
emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
@@ -1478,9 +1399,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
logger.LogError(c, err.Error())
}
if isStop {
if info.RelayFormat != types.RelayFormatClaude {
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
}
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
}
return true
})
@@ -1490,10 +1409,6 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
}
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil && !info.ClaudeConvertInfo.Done {
response = helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)
response.Usage = usage
}
handleErr := handleFinalStream(c, info, response)
if handleErr != nil {
common.SysLog("send final response failed: " + handleErr.Error())
@@ -1507,7 +1422,9 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
service.CloseResponseBodyGracefully(resp)
logger.LogDebug(c, "Gemini response body: %s", responseBody)
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse dto.GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
-4
View File
@@ -11,7 +11,6 @@ import (
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
@@ -185,9 +184,6 @@ func handleChatCompletionResponse(c *gin.Context, resp *http.Response, info *rel
// Set response headers
for key, values := range resp.Header {
if !service.ShouldCopyUpstreamHeader(c, key, values) {
continue
}
for _, value := range values {
c.Header(key, value)
}
+2 -2
View File
@@ -1,6 +1,7 @@
package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -11,7 +12,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
@@ -397,7 +397,7 @@ func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback f
}
// 读取流式响应
scanner := helper.NewStreamScanner(response.Body)
scanner := bufio.NewScanner(response.Body)
successful := false
for scanner.Scan() {
line := scanner.Text()
+2 -1
View File
@@ -1,6 +1,7 @@
package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -69,7 +70,7 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
defer service.CloseResponseBodyGracefully(resp)
helper.SetEventStreamHeaders(c)
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
usage := &dto.Usage{}
var model = info.UpstreamModelName
var responseId = common.GetUUID()
+8 -10
View File
@@ -310,20 +310,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
}
}
isOModel := dto.IsOpenAIReasoningOModel(info.UpstreamModelName)
isGPT5Model := dto.IsOpenAIGPT5Model(info.UpstreamModelName)
if isOModel || isGPT5Model {
if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
if lo.FromPtrOr(request.MaxCompletionTokens, uint(0)) == 0 && lo.FromPtrOr(request.MaxTokens, uint(0)) != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = nil
}
if isOModel {
if strings.HasPrefix(info.UpstreamModelName, "o") {
request.Temperature = nil
}
// gpt-5系列模型适配 归零不再支持的参数
if isGPT5Model {
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
request.Temperature = nil
request.TopP = nil
request.LogProbs = nil
@@ -379,7 +377,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
// 打印类似 curl 命令格式的信息
logger.LogDebug(c.Request.Context(), "--form 'model=\"%s\"'", request.Model)
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form 'model=\"%s\"'", request.Model))
// 遍历表单字段并打印输出
for key, values := range formData.Value {
@@ -388,7 +386,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
for _, value := range values {
writer.WriteField(key, value)
logger.LogDebug(c.Request.Context(), "--form '%s=\"%s\"'", key, value)
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form '%s=\"%s\"'", key, value))
}
}
@@ -400,8 +398,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
// 使用 formData 中的第一个文件
fileHeader := fileHeaders[0]
logger.LogDebug(c.Request.Context(), "--form 'file=@\"%s\"' (size: %d bytes, content-type: %s)",
fileHeader.Filename, fileHeader.Size, fileHeader.Header.Get("Content-Type"))
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form 'file=@\"%s\"' (size: %d bytes, content-type: %s)",
fileHeader.Filename, fileHeader.Size, fileHeader.Header.Get("Content-Type")))
file, err := fileHeader.Open()
if err != nil {
@@ -420,7 +418,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
logger.LogDebug(c.Request.Context(), "--header 'Content-Type: %s'", writer.FormDataContentType())
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--header 'Content-Type: %s'", writer.FormDataContentType()))
return &requestBody, nil
}
}
-3
View File
@@ -30,9 +30,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage.PromptTokens = info.GetEstimatePromptTokens()
usage.TotalTokens = info.GetEstimatePromptTokens()
for k, v := range resp.Header {
if !service.ShouldCopyUpstreamHeader(c, k, v) {
continue
}
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
+65 -14
View File
@@ -1,6 +1,7 @@
package openai
import (
"encoding/json"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -91,28 +92,78 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res
return nil
}
func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error {
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
return err
}
return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount)
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
var streamResponse dto.CompletionsStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
return err
}
processCompletionsStreamResponse(streamResponse, responseTextBuilder)
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysLog("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
+12 -7
View File
@@ -119,6 +119,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
var lastStreamData string
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
@@ -139,10 +140,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
lastStreamData = data
if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil {
logger.LogError(c, "error processing stream token data: "+err.Error())
sr.Error(err)
}
streamItems = append(streamItems, data)
}
})
@@ -157,9 +155,9 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
containStreamUsage = true
if common.DebugEnabled {
logger.LogDebug(c, "Audio model usage extracted from second last SSE: PromptTokens=%d, CompletionTokens=%d, TotalTokens=%d, InputTokens=%d, OutputTokens=%d",
logger.LogDebug(c, fmt.Sprintf("Audio model usage extracted from second last SSE: PromptTokens=%d, CompletionTokens=%d, TotalTokens=%d, InputTokens=%d, OutputTokens=%d",
usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens,
usage.InputTokens, usage.OutputTokens)
usage.InputTokens, usage.OutputTokens))
}
}
}
@@ -177,6 +175,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
logger.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
usage.CompletionTokens += toolCount * 7
@@ -197,7 +200,9 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
logger.LogDebug(c, "upstream response body: %s", responseBody)
if common.DebugEnabled {
println("upstream response body:", string(responseBody))
}
// Unmarshal to simpleResponse
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
// 尝试解析为 openrouter enterprise
+1 -1
View File
@@ -92,7 +92,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
-1
View File
@@ -45,7 +45,6 @@ var claudeModelMap = map[string]string{
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
"claude-opus-4-6": "claude-opus-4-6",
"claude-opus-4-7": "claude-opus-4-7",
"claude-opus-4-8": "claude-opus-4-8",
}
const anthropicVersion = "vertex-2023-10-16"
+1 -4
View File
@@ -157,7 +157,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage *dto.Usage
scanner := helper.NewStreamScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
dataChan := make(chan string)
metaChan := make(chan string)
@@ -180,9 +180,6 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
}
}
}
if err := scanner.Err(); err != nil {
common.SysLog("error reading stream: " + err.Error())
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)

Some files were not shown because too many files have changed in this diff Show More