Compare commits
180 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| de471d0fa0 | |||
| ea5b152313 | |||
| 75e608bdc8 | |||
| c52ff411a7 | |||
| ff407c4607 | |||
| ae380a5a81 | |||
| 71519301e4 | |||
| 497ed7c39a | |||
| 8be5131c29 | |||
| 6a9fd7611a | |||
| c42757f9af | |||
| c450214d78 | |||
| 73e701e49a | |||
| 537c292801 | |||
| b330253fe2 | |||
| d703ccbb79 | |||
| 697ef3fcca | |||
| 1249eebf46 | |||
| a34c284017 | |||
| 3b2c0a0471 | |||
| 207f98252d | |||
| 113b9c8ecb | |||
| a4c069f88d | |||
| 7caf77db63 | |||
| 6faf989549 | |||
| 719d06ecd6 | |||
| 5464d14973 | |||
| 446e1ce10b | |||
| cd8cdebdcb | |||
| 790d832756 | |||
| 63c83d5abd | |||
| 64326a86d5 | |||
| 5c4ed6206e | |||
| 43f9869246 | |||
| 8eb31a7c82 | |||
| ca47fd18f5 | |||
| 296120641a | |||
| 5d18485e9a | |||
| 60fc1eee31 | |||
| da8cf3eef0 | |||
| 1ddffd236c | |||
| 5eeb3c9f18 | |||
| 2d8bdc1b7c | |||
| 2723bd66ad | |||
| 2c4c2002c6 | |||
| 8cf49db0a8 | |||
| 53a884a6fe | |||
| 60d63f87ff | |||
| 603886d422 | |||
| 3c056b00b1 | |||
| e9d408e718 | |||
| c175a41a2b | |||
| 4e335a9997 | |||
| 6f1f8b60f1 | |||
| a318279d5f | |||
| d1542a65ac | |||
| 26ac8b5dc1 | |||
| a8a96a7e60 | |||
| 45f676514a | |||
| 9e4ab6e7bb | |||
| 45aea22451 | |||
| 5b9aa9e77a | |||
| 80bf73b41c | |||
| e83ec743c8 | |||
| 824c6d9133 | |||
| 76f8112ee7 | |||
| 8afd1d1247 | |||
| b4b58ee887 | |||
| 97ddfeab59 | |||
| 6db071543c | |||
| d83db15801 | |||
| 2780f71b8f | |||
| 052da457eb | |||
| 6f415428d3 | |||
| 59a93cf5c7 | |||
| 867d8acfc3 | |||
| 30d3a3a5f7 | |||
| d2576ddcd3 | |||
| 4ca47ee236 | |||
| 16dd7237c0 | |||
| 1915344838 | |||
| 15ff8e0268 | |||
| a1c82841b5 | |||
| 1e6f31b235 | |||
| 2eaa943d9f | |||
| 7a5348caa3 | |||
| f5753a2b31 | |||
| 4dd68bad52 | |||
| 0f043ae404 | |||
| 75c05bb4b8 | |||
| 81d3dc08e5 | |||
| adc390c5fb | |||
| e8c36762fd | |||
| e2dbd02cbb | |||
| c8d3768087 | |||
| 32805849d6 | |||
| 01c2128e23 | |||
| 189913b7a0 | |||
| d2f7f9ee3a | |||
| 83068d115e | |||
| 4a188deeaa | |||
| 933ea0cddc | |||
| b53319361f | |||
| 5681c92b3f | |||
| 6e5a359110 | |||
| 87cc22d7ec | |||
| 3aa113b5a3 | |||
| 77d3157592 | |||
| 00d23abf64 | |||
| 580ad97c02 | |||
| 39e05118ff | |||
| 9e59ffc3d8 | |||
| abad0d3cc0 | |||
| b0ac0429cf | |||
| d17b566bcc | |||
| 7aaa533265 | |||
| 7791b78429 | |||
| cb5c0453f5 | |||
| 4d20e053cb | |||
| 0ff9c35e62 | |||
| 0bbcaa8999 | |||
| 1e9ff8a0de | |||
| 9a2e60dff2 | |||
| b596de739d | |||
| 45d54c1613 | |||
| 086044650d | |||
| 0c7aceb831 | |||
| b2e25b7df2 | |||
| 230a3592f8 | |||
| afb470e405 | |||
| 1588027084 | |||
| 38bf2d8daa | |||
| e8c836d705 | |||
| 979aeceb5c | |||
| e79cee1e9e | |||
| 63ead2bf7f | |||
| 5b86ce0d70 | |||
| 74985fa877 | |||
| 1d32037364 | |||
| dc245ae764 | |||
| f8add4ca49 | |||
| 65f8afe922 | |||
| 5bc4c74813 | |||
| 30025aeba3 | |||
| c91ba0c4eb | |||
| f223db9330 | |||
| 9e283ab10b | |||
| a8b7c92e5f | |||
| 6b6c9904ac | |||
| 1011934987 | |||
| bc8110ce36 | |||
| ad224ecf5b | |||
| a64f26d1d2 | |||
| 3360882642 | |||
| b37b6d80b3 | |||
| 3d850d38b6 | |||
| 349d5429ca | |||
| 465c5edab9 | |||
| ff06067a18 | |||
| 51ca897cf4 | |||
| 1288028181 | |||
| 2a528d46cb | |||
| 583da45296 | |||
| b302be30e3 | |||
| 88437a1869 | |||
| b08febaa3c | |||
| 92a0959448 | |||
| 49bc3a1175 | |||
| 0354c38bef | |||
| ebbe315533 | |||
| fddf54ccc5 | |||
| b9bc6f0e21 | |||
| f2c7647ecf | |||
| 19f1821fc8 | |||
| 8e5e89bb5b | |||
| e13d673454 | |||
| ae6a03364d | |||
| 006e801652 | |||
| 6f11d19877 | |||
| 58ba867dd6 |
@@ -56,6 +56,8 @@
|
||||
# 对话超时设置
|
||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||
# RELAY_TIMEOUT=0
|
||||
# Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制
|
||||
# RELAY_IDLE_CONN_TIMEOUT=90
|
||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||
# STREAMING_TIMEOUT=300
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
name: Docker Build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: act-runner-4c6g
|
||||
env:
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
|
||||
steps:
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
if ! command -v docker &> /dev/null; then
|
||||
if command -v apk &> /dev/null; then
|
||||
apk add --no-cache docker-cli
|
||||
elif command -v apt-get &> /dev/null; then
|
||||
apt-get update && apt-get install -y docker.io
|
||||
else
|
||||
curl -fsSL https://download.docker.com/linux/static/stable/x86_64/docker-24.0.7.tgz | tar xz -C /tmp
|
||||
mv /tmp/docker/docker /usr/local/bin/
|
||||
chmod +x /usr/local/bin/docker
|
||||
fi
|
||||
fi
|
||||
docker --version
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Resolve tag & write VERSION
|
||||
id: version
|
||||
run: |
|
||||
if echo "${{ github.ref }}" | grep -q "^refs/tags/"; then
|
||||
TAG=${GITHUB_REF#refs/tags/}
|
||||
else
|
||||
SHORT_SHA=$(git rev-parse --short HEAD)
|
||||
TAG="dev-${SHORT_SHA}"
|
||||
fi
|
||||
echo "TAG=${TAG}" >> $GITHUB_ENV
|
||||
echo "${TAG}" > VERSION
|
||||
echo "Building tag: ${TAG}"
|
||||
cat VERSION
|
||||
|
||||
- name: Login to Gitea Container Registry
|
||||
run: |
|
||||
echo "${{ secrets.PACKAGES_TOKEN }}" | docker login git.viaeon.com -u "${{ github.actor }}" --password-stdin
|
||||
|
||||
- name: Build Docker image
|
||||
run: |
|
||||
echo "Building image with tag: ${{ env.TAG }}"
|
||||
docker build \
|
||||
--label "org.opencontainers.image.source=https://git.viaeon.com/admin/new-api" \
|
||||
--label "org.opencontainers.image.revision=${{ github.sha }}" \
|
||||
-t git.viaeon.com/admin/new-api:${{ env.TAG }} \
|
||||
-t git.viaeon.com/admin/new-api:latest .
|
||||
|
||||
- name: Push Docker image
|
||||
run: |
|
||||
echo "Pushing ${{ env.TAG }}..."
|
||||
docker push git.viaeon.com/admin/new-api:${{ env.TAG }}
|
||||
echo "Pushing latest..."
|
||||
docker push git.viaeon.com/admin/new-api:latest
|
||||
|
||||
- name: Cleanup Docker
|
||||
if: always()
|
||||
run: |
|
||||
echo "Removing local images..."
|
||||
docker rmi git.viaeon.com/admin/new-api:${{ env.TAG }} git.viaeon.com/admin/new-api:latest 2>/dev/null || true
|
||||
echo "Pruning unused Docker resources..."
|
||||
docker system prune -af --volumes 2>/dev/null || true
|
||||
echo "Docker disk usage:"
|
||||
docker system df
|
||||
|
||||
- name: Deploy via SSH
|
||||
if: success()
|
||||
run: |
|
||||
if [ -z "${{ secrets.DEPLOY_SSH_HOST }}" ]; then
|
||||
echo "DEPLOY_SSH_HOST not set, skip deploy"
|
||||
exit 0
|
||||
fi
|
||||
apk add --no-cache sshpass 2>/dev/null || apt-get update && apt-get install -y sshpass 2>/dev/null || true
|
||||
sshpass -p "${{ secrets.DEPLOY_SSH_PASS }}" ssh -o StrictHostKeyChecking=no -p ${{ secrets.DEPLOY_SSH_PORT || 22 }} ${{ secrets.DEPLOY_SSH_USER }}@${{ secrets.DEPLOY_SSH_HOST }} "cd ${{ secrets.DEPLOY_DIR || '/opt/new-api' }} && docker compose pull && docker compose up -d"
|
||||
@@ -0,0 +1,73 @@
|
||||
name: Docker Build (alpha)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- alpha
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
name: Build and Push Alpha Docker Image
|
||||
runs-on: act-runner-4c6g
|
||||
env:
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
|
||||
steps:
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
if ! command -v docker &> /dev/null; then
|
||||
if command -v apk &> /dev/null; then
|
||||
apk add --no-cache docker-cli
|
||||
elif command -v apt-get &> /dev/null; then
|
||||
apt-get update && apt-get install -y docker.io
|
||||
else
|
||||
curl -fsSL https://download.docker.com/linux/static/stable/x86_64/docker-24.0.7.tgz | tar xz -C /tmp
|
||||
mv /tmp/docker/docker /usr/local/bin/
|
||||
chmod +x /usr/local/bin/docker
|
||||
fi
|
||||
fi
|
||||
docker --version
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Determine alpha version
|
||||
id: version
|
||||
run: |
|
||||
VERSION="alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)"
|
||||
echo "$VERSION" > VERSION
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "Publishing version: $VERSION"
|
||||
|
||||
- name: Login to Gitea Container Registry
|
||||
run: |
|
||||
echo "${{ secrets.PACKAGES_TOKEN }}" | docker login git.viaeon.com -u "${{ github.actor }}" --password-stdin 2>&1
|
||||
|
||||
- name: Build Docker image
|
||||
run: |
|
||||
echo "Building alpha image..."
|
||||
docker build \
|
||||
--label "org.opencontainers.image.source=https://git.viaeon.com/admin/new-api" \
|
||||
--label "org.opencontainers.image.revision=${{ github.sha }}" \
|
||||
-t git.viaeon.com/admin/new-api:${{ env.VERSION }} \
|
||||
-t git.viaeon.com/admin/new-api:alpha . 2>&1
|
||||
|
||||
- name: Push Docker image
|
||||
run: |
|
||||
echo "Pushing ${{ env.VERSION }}..."
|
||||
docker push git.viaeon.com/admin/new-api:${{ env.VERSION }}
|
||||
echo "Pushing alpha..."
|
||||
docker push git.viaeon.com/admin/new-api:alpha
|
||||
|
||||
- name: Cleanup Docker
|
||||
if: always()
|
||||
run: |
|
||||
echo "Removing local images..."
|
||||
docker rmi git.viaeon.com/admin/new-api:${{ env.VERSION }} git.viaeon.com/admin/new-api:alpha 2>/dev/null || true
|
||||
echo "Pruning unused Docker resources..."
|
||||
docker system prune -af --volumes 2>/dev/null || true
|
||||
echo "Docker disk usage:"
|
||||
docker system df
|
||||
@@ -0,0 +1,73 @@
|
||||
name: Docker Build (nightly)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- nightly
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
name: Build and Push Nightly Docker Image
|
||||
runs-on: act-runner-4c6g
|
||||
env:
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
|
||||
steps:
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
if ! command -v docker &> /dev/null; then
|
||||
if command -v apk &> /dev/null; then
|
||||
apk add --no-cache docker-cli
|
||||
elif command -v apt-get &> /dev/null; then
|
||||
apt-get update && apt-get install -y docker.io
|
||||
else
|
||||
curl -fsSL https://download.docker.com/linux/static/stable/x86_64/docker-24.0.7.tgz | tar xz -C /tmp
|
||||
mv /tmp/docker/docker /usr/local/bin/
|
||||
chmod +x /usr/local/bin/docker
|
||||
fi
|
||||
fi
|
||||
docker --version
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Determine nightly version
|
||||
id: version
|
||||
run: |
|
||||
VERSION="nightly-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)"
|
||||
echo "$VERSION" > VERSION
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "Publishing version: $VERSION"
|
||||
|
||||
- name: Login to Gitea Container Registry
|
||||
run: |
|
||||
echo "${{ secrets.PACKAGES_TOKEN }}" | docker login git.viaeon.com -u "${{ github.actor }}" --password-stdin 2>&1
|
||||
|
||||
- name: Build Docker image
|
||||
run: |
|
||||
echo "Building nightly image..."
|
||||
docker build \
|
||||
--label "org.opencontainers.image.source=https://git.viaeon.com/admin/new-api" \
|
||||
--label "org.opencontainers.image.revision=${{ github.sha }}" \
|
||||
-t git.viaeon.com/admin/new-api:${{ env.VERSION }} \
|
||||
-t git.viaeon.com/admin/new-api:nightly . 2>&1
|
||||
|
||||
- name: Push Docker image
|
||||
run: |
|
||||
echo "Pushing ${{ env.VERSION }}..."
|
||||
docker push git.viaeon.com/admin/new-api:${{ env.VERSION }}
|
||||
echo "Pushing nightly..."
|
||||
docker push git.viaeon.com/admin/new-api:nightly
|
||||
|
||||
- name: Cleanup Docker
|
||||
if: always()
|
||||
run: |
|
||||
echo "Removing local images..."
|
||||
docker rmi git.viaeon.com/admin/new-api:${{ env.VERSION }} git.viaeon.com/admin/new-api:nightly 2>/dev/null || true
|
||||
echo "Pruning unused Docker resources..."
|
||||
docker system prune -af --volumes 2>/dev/null || true
|
||||
echo "Docker disk usage:"
|
||||
docker system df
|
||||
@@ -0,0 +1,82 @@
|
||||
name: PR Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, reopened]
|
||||
|
||||
jobs:
|
||||
pr-quality:
|
||||
name: PR Quality Check
|
||||
runs-on: act-runner-4c6g
|
||||
env:
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check PR description
|
||||
env:
|
||||
GITEA_TOKEN: ${{ secrets.PACKAGES_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REPO: admin/new-api
|
||||
GITEA_URL: https://git.viaeon.com
|
||||
run: |
|
||||
# 获取 PR 信息
|
||||
PR_INFO=$(curl -s -H "Authorization: token ${GITEA_TOKEN}" \
|
||||
"${GITEA_URL}/api/v1/repos/${REPO}/pulls/${PR_NUMBER}")
|
||||
|
||||
PR_BODY=$(echo "$PR_INFO" | jq -r '.body // empty')
|
||||
PR_TITLE=$(echo "$PR_INFO" | jq -r '.title // empty')
|
||||
PR_USER=$(echo "$PR_INFO" | jq -r '.user.login // empty')
|
||||
|
||||
FAILED=0
|
||||
REASONS=""
|
||||
|
||||
# 检查 PR 描述是否为空
|
||||
if [ -z "$PR_BODY" ] || [ "$PR_BODY" = "null" ]; then
|
||||
FAILED=1
|
||||
REASONS="${REASONS}- PR description is empty\n"
|
||||
fi
|
||||
|
||||
# 检查 PR 标题是否为空
|
||||
if [ -z "$PR_TITLE" ] || [ "$PR_TITLE" = "null" ]; then
|
||||
FAILED=1
|
||||
REASONS="${REASONS}- PR title is empty\n"
|
||||
fi
|
||||
|
||||
# 检查是否包含纯 AI 生成标记
|
||||
if echo "$PR_BODY" | grep -qi "Generated with Claude Code"; then
|
||||
FAILED=1
|
||||
REASONS="${REASONS}- PR appears to be purely AI-generated without meaningful human involvement\n"
|
||||
fi
|
||||
|
||||
if [ "$FAILED" -eq 1 ]; then
|
||||
echo "PR check failed:"
|
||||
echo -e "$REASONS"
|
||||
|
||||
# 添加标签
|
||||
curl -s -X POST \
|
||||
"${GITEA_URL}/api/v1/repos/${REPO}/issues/${PR_NUMBER}/labels" \
|
||||
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"labels": ["pr-check-failed"]}'
|
||||
|
||||
# 添加评论
|
||||
curl -s -X POST \
|
||||
"${GITEA_URL}/api/v1/repos/${REPO}/issues/${PR_NUMBER}/comments" \
|
||||
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"body": "感谢您的提交。由于该 PR 未遵循我们的贡献模板,且被识别为缺乏人工参与的纯 AI 生成内容,我们将先予以关闭。我们更欢迎经过人工审核、验证并带有个人思考的贡献。如果您认为这其中存在误解,请回复告知。"}'
|
||||
|
||||
# 关闭 PR
|
||||
curl -s -X PATCH \
|
||||
"${GITEA_URL}/api/v1/repos/${REPO}/pulls/${PR_NUMBER}" \
|
||||
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"state": "closed"}'
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "PR check passed!"
|
||||
@@ -0,0 +1,161 @@
|
||||
name: Release (Linux)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag name to build (e.g., v0.10.8)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
build-linux:
|
||||
name: Linux Release
|
||||
runs-on: act-runner-4c6g
|
||||
env:
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
export PATH="/toolcache/bin:$PATH"
|
||||
# Install Go
|
||||
if ! command -v go &> /dev/null; then
|
||||
curl -fsSL https://go.dev/dl/go1.25.1.linux-amd64.tar.gz | tar -C /usr/local -xzf -
|
||||
echo "export PATH=\$PATH:/usr/local/go/bin" >> ~/.bashrc
|
||||
export PATH=$PATH:/usr/local/go/bin
|
||||
fi
|
||||
go version
|
||||
# Install Bun
|
||||
if ! command -v bun &> /dev/null; then
|
||||
curl -fsSL https://bun.sh/install | bash
|
||||
export PATH="$HOME/.bun/bin:$PATH"
|
||||
fi
|
||||
bun --version
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Determine Version
|
||||
run: |
|
||||
if [ -n "${{ github.event.inputs.tag }}" ]; then
|
||||
TAG="${{ github.event.inputs.tag }}"
|
||||
else
|
||||
TAG=${GITHUB_REF#refs/tags/}
|
||||
fi
|
||||
VERSION=$(git describe --tags 2>/dev/null || echo "$TAG")
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "Building version: $VERSION"
|
||||
|
||||
- name: Build Frontend (default)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
export PATH="$HOME/.bun/bin:/usr/local/go/bin:$PATH"
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd default
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
export PATH="$HOME/.bun/bin:/usr/local/go/bin:$PATH"
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd classic
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
|
||||
- name: Build Backend (amd64)
|
||||
run: |
|
||||
export PATH="/usr/local/go/bin:$PATH"
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-$VERSION
|
||||
|
||||
- name: Build Backend (arm64)
|
||||
run: |
|
||||
export PATH="/usr/local/go/bin:$PATH"
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
|
||||
|
||||
- name: Generate checksums
|
||||
run: sha256sum new-api-* > checksums-linux.txt
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-build
|
||||
path: |
|
||||
new-api-*
|
||||
checksums-linux.txt
|
||||
|
||||
release:
|
||||
name: Create Gitea Release
|
||||
needs: [build-linux]
|
||||
runs-on: act-runner-4c6g
|
||||
env:
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Determine Version
|
||||
run: |
|
||||
if [ -n "${{ github.event.inputs.tag }}" ]; then
|
||||
TAG="${{ github.event.inputs.tag }}"
|
||||
else
|
||||
TAG=${GITHUB_REF#refs/tags/}
|
||||
fi
|
||||
echo "TAG=$TAG" >> $GITHUB_ENV
|
||||
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
- name: Create Gitea Release
|
||||
env:
|
||||
GITEA_TOKEN: ${{ secrets.PACKAGES_TOKEN }}
|
||||
run: |
|
||||
# 使用 Gitea API 创建 Release
|
||||
TAG="${{ env.TAG }}"
|
||||
REPO="admin/new-api"
|
||||
GITEA_URL="https://git.viaeon.com"
|
||||
|
||||
# 创建 Release
|
||||
RELEASE_ID=$(curl -s -X POST \
|
||||
"${GITEA_URL}/api/v1/repos/${REPO}/releases" \
|
||||
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"tag_name\": \"${TAG}\",
|
||||
\"name\": \"${TAG}\",
|
||||
\"body\": \"Release ${TAG}\",
|
||||
\"draft\": false,
|
||||
\"prerelease\": false
|
||||
}" | jq -r '.id')
|
||||
|
||||
echo "Created release ID: ${RELEASE_ID}"
|
||||
|
||||
# 上传附件
|
||||
find artifacts -type f | while read file; do
|
||||
echo "Uploading: ${file}"
|
||||
curl -s -X POST \
|
||||
"${GITEA_URL}/api/v1/repos/${REPO}/releases/${RELEASE_ID}/assets" \
|
||||
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||
-F "attachment=@${file}" \
|
||||
-F "name=$(basename ${file})"
|
||||
done
|
||||
|
||||
echo "Release ${TAG} created successfully!"
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- 文档:https://docs.newapi.ai/
|
||||
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
|
||||
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
|
||||
- 不接受 coding plan、逆向渠道等技术支持类 issue。
|
||||
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
|
||||
|
||||
**您当前的 newapi 版本**
|
||||
@@ -20,13 +22,18 @@ assignees: ''
|
||||
**提交确认**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已完整查看过文档 https://docs.newapi.ai/ 和项目 README,尤其是常见问题部分
|
||||
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
|
||||
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
|
||||
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue),确认目前没有类似 issue。
|
||||
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/、项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题。
|
||||
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写。
|
||||
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭。
|
||||
|
||||
**问题描述**
|
||||
|
||||
请尽可能说明问题现象、影响范围,以及你判断它是程序问题而不是上游行为或使用问题的依据。
|
||||
|
||||
- 转发问题请尽可能说明渠道类型、转换格式、上游原生支持依据和服务端日志。
|
||||
- 计费问题请尽可能附请求返回的 `usage` 示例。
|
||||
|
||||
**复现步骤**
|
||||
|
||||
**预期结果**
|
||||
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- Docs: https://docs.newapi.ai/
|
||||
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
|
||||
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
|
||||
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
|
||||
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
|
||||
|
||||
**Your current newapi version**
|
||||
@@ -20,13 +22,18 @@ Please fill this in, for example: `v1.0.0`
|
||||
**Submission Checks**
|
||||
|
||||
[//]: # (Remove the space in the box and fill with an x)
|
||||
+ [ ] I have confirmed there are no similar issues
|
||||
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, especially the FAQ section
|
||||
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
|
||||
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
|
||||
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
|
||||
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question.
|
||||
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
|
||||
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
|
||||
|
||||
**Issue Description**
|
||||
|
||||
Describe the symptom, impact scope, and why you believe this is an application issue rather than upstream behavior or a usage question with as much detail as possible.
|
||||
|
||||
- For forwarding issues, include the channel type, conversion format, upstream native-support evidence, and server logs when possible.
|
||||
- For billing issues, include an example of the returned `usage` when possible.
|
||||
|
||||
**Steps to Reproduce**
|
||||
|
||||
**Expected Result**
|
||||
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- 文档:https://docs.newapi.ai/
|
||||
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
|
||||
- 开启透传后的转发相关反馈不接受 issue;透传模式会直接转发请求,请自行确认上游行为。
|
||||
- 不接受 coding plan、逆向渠道等技术支持类 issue。
|
||||
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
|
||||
|
||||
**您当前的 newapi 版本**
|
||||
@@ -20,10 +22,10 @@ assignees: ''
|
||||
**提交确认**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已完整查看过文档 https://docs.newapi.ai/ 和项目 README,已确定现有版本无法满足需求
|
||||
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
|
||||
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
|
||||
- [ ] **非重复 issue:** 我已搜索现有 [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue),确认目前没有类似 issue。
|
||||
- [ ] **提交前必读:** 我已完整阅读上方“提交前必读”,并已查看文档 https://docs.newapi.ai/、项目 README 且向 AI 提问,确认这不是使用、配置或接入类问题,且现有版本无法满足需求。
|
||||
- [ ] **模板完整:** 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写。
|
||||
- [ ] **维护成本:** 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭。
|
||||
|
||||
**功能描述**
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ assignees: ''
|
||||
|
||||
- Docs: https://docs.newapi.ai/
|
||||
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
|
||||
- Issues about forwarding behavior after enabling pass-through mode are not accepted; pass-through mode forwards requests directly, so please verify upstream behavior yourself.
|
||||
- Technical support requests such as coding plans or reverse-engineering channels are not accepted as issues.
|
||||
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
|
||||
|
||||
**Your current newapi version**
|
||||
@@ -20,10 +22,10 @@ Please fill this in, for example: `v1.0.0`
|
||||
**Submission Checks**
|
||||
|
||||
[//]: # (Remove the space in the box and fill with an x)
|
||||
+ [ ] I have confirmed there are no similar issues
|
||||
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, and confirmed the current version cannot meet my needs
|
||||
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
|
||||
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
|
||||
- [ ] **Non-duplicate issue:** I have searched existing [Issues](https://github.com/QuantumNous/new-api/issues?q=is%3Aissue) and confirmed there are no similar issues.
|
||||
- [ ] **Read this first:** I have fully read the section above, reviewed the docs at https://docs.newapi.ai/ and the project README, and asked AI first, confirming this is not a usage, configuration, or integration question, and that the current version cannot meet my needs.
|
||||
- [ ] **Template intact:** I have not removed any guidance or section headings from this template and will complete it as requested.
|
||||
- [ ] **Maintainer time:** I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly.
|
||||
|
||||
**Feature Description**
|
||||
|
||||
|
||||
@@ -33,16 +33,18 @@ jobs:
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/default
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd default
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd classic
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
@@ -91,16 +93,18 @@ jobs:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web/default
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd default
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd classic
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
@@ -146,16 +150,18 @@ jobs:
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/default
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd default
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
cd web
|
||||
bun install --frozen-lockfile
|
||||
cd classic
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
|
||||
+3
-1
@@ -7,9 +7,10 @@ upload
|
||||
*.db
|
||||
build
|
||||
*.db-journal
|
||||
logs
|
||||
/logs
|
||||
web/default/dist
|
||||
web/classic/dist
|
||||
web/daisy/dist
|
||||
web/node_modules
|
||||
web/dist
|
||||
.env
|
||||
@@ -35,3 +36,4 @@ data/
|
||||
.test
|
||||
token_estimator_test.go
|
||||
skills-lock.json
|
||||
.playwright-mcp
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
# ModelsToken 管理平台 - 产品需求文档 (PRD)
|
||||
|
||||
## 1. 产品概述
|
||||
|
||||
ModelsToken 是一个 AI API 管理与分发平台,为开发者和企业提供统一的 AI 模型接入、密钥管理、用量计费、渠道代理等一站式服务。新前端将采用 React + DaisyUI 5 + TypeScript 构建,替换现有的 Default/Classic 双前端,同时新增本地文档管理功能。
|
||||
|
||||
- 目标用户:AI 应用开发者、企业运维人员、API 服务管理者
|
||||
- 核心价值:简化 AI API 的管理复杂度,提供直观的操作界面和完整的文档支持
|
||||
|
||||
## 2. 核心功能
|
||||
|
||||
### 2.1 用户角色
|
||||
|
||||
| 角色 | 注册方式 | 核心权限 |
|
||||
|------|----------|----------|
|
||||
| 普通用户 | 用户名/邮箱/OAuth | 密钥管理、充值、订阅、日志查看、文档访问 |
|
||||
| 管理员 | 由超级管理员指定 | 渠道管理、用户管理、兑换码、模型管理、订阅管理 |
|
||||
| 超级管理员 | 系统初始化 | 全部权限 + 系统设置 |
|
||||
|
||||
### 2.2 功能模块
|
||||
|
||||
#### 公共页面(无需登录)
|
||||
1. **首页**:Hero 区域、特性展示、快速入门指引
|
||||
2. **登录页**:用户名/密码、OAuth 登录(GitHub/Discord/OIDC/LinuxDO/微信/Telegram/自定义)
|
||||
3. **注册页**:注册表单 + Turnstile 人机验证
|
||||
4. **忘记密码**:邮箱重置链接
|
||||
5. **模型定价**:模型价格列表、搜索筛选
|
||||
6. **关于页面**:项目信息、版本、许可证
|
||||
7. **用户协议/隐私政策**
|
||||
8. **初始化向导**:首次部署配置
|
||||
|
||||
#### 用户功能(需登录)
|
||||
1. **仪表盘**:额度概览、使用趋势图、API 信息面板、公告、FAQ
|
||||
2. **API 密钥管理**:创建/编辑/删除/批量操作、额度限制、模型限制、IP 限制
|
||||
3. **钱包/充值**:余额查看、兑换码充值、在线支付(易支付/Stripe/Creem/Waffo)、签到
|
||||
4. **订阅管理**:查看计划、购买订阅、当前订阅状态
|
||||
5. **使用日志**:请求日志搜索/筛选、MJ 日志、任务日志、统计图表
|
||||
6. **个人设置**:资料编辑、2FA 设置、Passkey 管理、OAuth 绑定、语言切换
|
||||
7. **Playground**:API 在线调试、Chat Completions 测试
|
||||
8. **文档中心**(新增):本地文档管理、分类浏览、搜索、Markdown 渲染
|
||||
|
||||
#### 管理员功能
|
||||
1. **渠道管理**:CRUD、测试、余额更新、标签管理、批量操作、多密钥、Codex OAuth、Ollama 管理
|
||||
2. **用户管理**:列表/搜索/创建/编辑/升降级/启禁/额度调整
|
||||
3. **兑换码管理**:CRUD、批量删除无效码
|
||||
4. **模型管理**:模型元数据 CRUD、上游同步、缺失模型检测
|
||||
5. **供应商管理**:CRUD
|
||||
6. **订阅管理**:计划 CRUD、用户订阅管理
|
||||
7. **部署管理**:io.net 部署 CRUD、容器管理、日志
|
||||
|
||||
#### 超级管理员 - 系统设置
|
||||
1. **站点设置**:名称/Logo/页脚/公告/首页内容/服务器地址
|
||||
2. **认证设置**:注册/登录开关、OAuth 配置、Turnstile、Passkey、自定义 OAuth
|
||||
3. **计费设置**:额度/倍率/支付配置/签到/分组倍率
|
||||
4. **内容设置**:公告/FAQ/Uptime Kuma/聊天/绘图/Midjourney
|
||||
5. **模型设置**:透传/思维模型/Gemini/Claude 配置
|
||||
6. **运维设置**:重试/自动禁用/SMTP/性能监控/日志
|
||||
7. **安全设置**:速率限制/敏感词/SSRF 防护/IP 过滤
|
||||
|
||||
### 2.3 新增功能 - 本地文档管理
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| 文档分类 | 支持多级分类树,管理员可创建/编辑/删除分类 |
|
||||
| 文档 CRUD | 管理员创建/编辑/删除文档,支持 Markdown 编辑器 |
|
||||
| 文档浏览 | 用户按分类浏览文档,支持搜索 |
|
||||
| 文档搜索 | 全文搜索文档标题和内容 |
|
||||
| 文档版本 | 文档更新历史记录 |
|
||||
| 权限控制 | 可设置文档为公开/登录可见/管理员可见 |
|
||||
|
||||
## 3. 核心流程
|
||||
|
||||
### 3.1 用户认证流程
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
"访问平台" --> "已登录?"
|
||||
"已登录?" -->|"是"| "仪表盘"
|
||||
"已登录?" -->|"否"| "登录页"
|
||||
"登录页" --> "输入凭证"
|
||||
"输入凭证" --> "需要2FA?"
|
||||
"需要2FA?" -->|"是"| "输入2FA码"
|
||||
"需要2FA?" -->|"否"| "验证成功"
|
||||
"输入2FA码" --> "验证成功"
|
||||
"验证成功" --> "仪表盘"
|
||||
"登录页" --> "OAuth登录"
|
||||
"OAuth登录" --> "OAuth回调"
|
||||
"OAuth回调" --> "已绑定账号?"
|
||||
"已绑定账号?" -->|"是"| "仪表盘"
|
||||
"已绑定账号?" -->|"否"| "绑定/注册"
|
||||
```
|
||||
|
||||
### 3.2 API 调用流程
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
"创建API密钥" --> "配置密钥参数"
|
||||
"配置密钥参数" --> "使用密钥调用API"
|
||||
"使用密钥调用API" --> "平台路由到渠道"
|
||||
"平台路由到渠道" --> "返回结果"
|
||||
"返回结果" --> "记录日志"
|
||||
"记录日志" --> "扣除额度"
|
||||
```
|
||||
|
||||
### 3.3 文档管理流程(新增)
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
"管理员创建分类" --> "创建文档"
|
||||
"创建文档" --> "Markdown编辑"
|
||||
"Markdown编辑" --> "设置可见性"
|
||||
"设置可见性" --> "发布文档"
|
||||
"发布文档" --> "用户浏览/搜索"
|
||||
```
|
||||
|
||||
## 4. 用户界面设计
|
||||
|
||||
### 4.1 设计风格
|
||||
|
||||
- **主色调**:深蓝 (#1e293b) + 亮蓝 (#3b82f6) 渐变,搭配 DaisyUI 的 `business` 主题
|
||||
- **辅助色**:翡翠绿 (#10b981) 用于成功/在线状态,琥珀色 (#f59e0b) 用于警告
|
||||
- **按钮风格**:DaisyUI 默认圆角按钮,主要操作用 `btn-primary`,危险操作用 `btn-error`
|
||||
- **字体**:JetBrains Mono(代码/密钥)+ Noto Sans SC(中文正文)
|
||||
- **布局风格**:左侧固定导航栏 + 顶部状态栏 + 主内容区,响应式折叠
|
||||
- **图标**:Lucide React 图标库
|
||||
- **动效**:DaisyUI 内置动画 + 页面切换淡入
|
||||
|
||||
### 4.2 页面设计概览
|
||||
|
||||
| 页面 | 模块 | UI 元素 |
|
||||
|------|------|---------|
|
||||
| 首页 | Hero | 渐变背景、特性卡片、快速开始按钮 |
|
||||
| 登录 | 表单 | 居中卡片、OAuth 按钮组、Turnstile |
|
||||
| 仪表盘 | 统计卡片 | 4 列额度卡片、折线图、公告栏、API 信息 |
|
||||
| 密钥管理 | 数据表 | 搜索栏、筛选器、表格、批量操作栏 |
|
||||
| 渠道管理 | 数据表+表单 | 标签筛选、测试按钮、多密钥管理抽屉 |
|
||||
| 系统设置 | 标签页 | 7 大分类侧边导航、表单分组、开关/输入框 |
|
||||
| 文档中心 | 侧边树+内容 | 分类树导航、Markdown 渲染、搜索框、面包屑 |
|
||||
| Playground | 分栏 | 左侧参数面板、右侧响应面板、模型选择器 |
|
||||
|
||||
### 4.3 响应式设计
|
||||
|
||||
- 桌面优先(1280px+)
|
||||
- 平板适配(768px-1279px):侧边栏折叠为抽屉
|
||||
- 移动端适配(<768px):单列布局,表格改为卡片列表
|
||||
|
||||
### 4.4 布局结构
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────┐
|
||||
│ 顶部导航栏 (Navbar) │
|
||||
│ Logo | 搜索 | 通知 | 用户菜单 | 主题切换 │
|
||||
├──────┬───────────────────────────────────────┤
|
||||
│ │ │
|
||||
│ 侧边 │ 主内容区 │
|
||||
│ 导航 │ │
|
||||
│ 栏 │ ┌─────────────────────────────────┐ │
|
||||
│ │ │ 面包屑 + 页面标题 + 操作按钮 │ │
|
||||
│ 仪表盘│ ├─────────────────────────────────┤ │
|
||||
│ 密钥 │ │ │ │
|
||||
│ 渠道 │ │ 页面内容 │ │
|
||||
│ 用户 │ │ │ │
|
||||
│ 日志 │ │ │ │
|
||||
│ 钱包 │ └─────────────────────────────────┘ │
|
||||
│ 订阅 │ │
|
||||
│ 文档 │ │
|
||||
│ 设置 │ │
|
||||
│ │ │
|
||||
└──────┴───────────────────────────────────────┘
|
||||
```
|
||||
@@ -0,0 +1,459 @@
|
||||
# ModelsToken 管理平台 - 技术架构文档
|
||||
|
||||
## 1. 架构设计
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph "前端 (React + DaisyUI 5)"
|
||||
A["React 18"] --> B["React Router v6"]
|
||||
B --> C["页面组件"]
|
||||
C --> D["DaisyUI 5 组件"]
|
||||
D --> E["Tailwind CSS 4"]
|
||||
A --> F["Zustand 状态管理"]
|
||||
A --> G["React Query 数据请求"]
|
||||
A --> H["i18next 国际化"]
|
||||
A --> I["React Markdown 渲染"]
|
||||
end
|
||||
subgraph "后端 (Go + Gin)"
|
||||
J["Gin HTTP Server"]
|
||||
J --> K["API 路由"]
|
||||
J --> L["Relay 代理"]
|
||||
K --> M["控制器"]
|
||||
M --> N["模型层"]
|
||||
N --> O["数据库 (SQLite/MySQL/PostgreSQL)"]
|
||||
end
|
||||
C -->|"Axios HTTP"| K
|
||||
```
|
||||
|
||||
## 2. 技术说明
|
||||
|
||||
- **前端框架**:React 18 + TypeScript
|
||||
- **UI 库**:DaisyUI 5 + Tailwind CSS 4
|
||||
- **构建工具**:Vite 6
|
||||
- **路由**:React Router v6(懒加载)
|
||||
- **状态管理**:Zustand(轻量级,替代 Redux)
|
||||
- **数据请求**:TanStack React Query v5 + Axios
|
||||
- **国际化**:i18next + react-i18next
|
||||
- **图表**:Recharts
|
||||
- **Markdown**:react-markdown + remark-gfm + rehype-highlight
|
||||
- **图标**:Lucide React
|
||||
- **代码高亮**:highlight.js
|
||||
- **表单验证**:React Hook Form + Zod
|
||||
- **通知**:react-hot-toast
|
||||
- **项目目录**:`web/daisy/`
|
||||
|
||||
## 3. 路由定义
|
||||
|
||||
### 3.1 公共路由
|
||||
|
||||
| 路由 | 用途 |
|
||||
|------|------|
|
||||
| `/` | 首页 |
|
||||
| `/login` | 登录 |
|
||||
| `/register` | 注册 |
|
||||
| `/forgot-password` | 忘记密码 |
|
||||
| `/reset-password` | 密码重置确认 |
|
||||
| `/setup` | 初始化向导 |
|
||||
| `/pricing` | 模型定价 |
|
||||
| `/about` | 关于 |
|
||||
| `/user-agreement` | 用户协议 |
|
||||
| `/privacy-policy` | 隐私政策 |
|
||||
| `/oauth/callback/:provider` | OAuth 回调 |
|
||||
|
||||
### 3.2 认证后路由
|
||||
|
||||
| 路由 | 用途 |
|
||||
|------|------|
|
||||
| `/dashboard` | 仪表盘 |
|
||||
| `/tokens` | API 密钥管理 |
|
||||
| `/wallet` | 钱包/充值 |
|
||||
| `/subscriptions` | 订阅管理 |
|
||||
| `/logs` | 使用日志 |
|
||||
| `/logs/midjourney` | MJ 日志 |
|
||||
| `/logs/tasks` | 任务日志 |
|
||||
| `/profile` | 个人设置 |
|
||||
| `/playground` | Playground |
|
||||
| `/docs` | 文档中心(新增) |
|
||||
| `/docs/:slug` | 文档详情(新增) |
|
||||
|
||||
### 3.3 管理员路由
|
||||
|
||||
| 路由 | 用途 |
|
||||
|------|------|
|
||||
| `/admin/channels` | 渠道管理 |
|
||||
| `/admin/users` | 用户管理 |
|
||||
| `/admin/redemptions` | 兑换码管理 |
|
||||
| `/admin/models` | 模型管理 |
|
||||
| `/admin/vendors` | 供应商管理 |
|
||||
| `/admin/deployments` | 部署管理 |
|
||||
| `/admin/subscriptions` | 订阅计划管理 |
|
||||
|
||||
### 3.4 超级管理员路由
|
||||
|
||||
| 路由 | 用途 |
|
||||
|------|------|
|
||||
| `/settings/site` | 站点设置 |
|
||||
| `/settings/auth` | 认证设置 |
|
||||
| `/settings/billing` | 计费设置 |
|
||||
| `/settings/content` | 内容设置 |
|
||||
| `/settings/models` | 模型设置 |
|
||||
| `/settings/operations` | 运维设置 |
|
||||
| `/settings/security` | 安全设置 |
|
||||
| `/settings/docs` | 文档管理(新增) |
|
||||
|
||||
## 4. API 定义
|
||||
|
||||
### 4.1 核心类型
|
||||
|
||||
```typescript
|
||||
// 用户
|
||||
interface User {
|
||||
id: number;
|
||||
username: string;
|
||||
display_name: string;
|
||||
email: string;
|
||||
role: number; // 1=user, 10=admin, 100=root
|
||||
status: number;
|
||||
quota: number;
|
||||
used_quota: number;
|
||||
request_count: number;
|
||||
group: string;
|
||||
aff_code: string;
|
||||
inviter_id: number;
|
||||
language: string;
|
||||
access_token: string;
|
||||
created_time: number;
|
||||
}
|
||||
|
||||
// 渠道
|
||||
interface Channel {
|
||||
id: number;
|
||||
type: number;
|
||||
key: string;
|
||||
openai_organization?: string;
|
||||
base_url: string;
|
||||
models: string;
|
||||
model_mapping?: string;
|
||||
group: string;
|
||||
groups: string[];
|
||||
name: string;
|
||||
priority: number;
|
||||
weight: number;
|
||||
status: number;
|
||||
tag?: string;
|
||||
setting?: string;
|
||||
test_time: number;
|
||||
response_time: number;
|
||||
balance: number;
|
||||
balance_updated_time: number;
|
||||
created_time: number;
|
||||
}
|
||||
|
||||
// 令牌
|
||||
interface Token {
|
||||
id: number;
|
||||
user_id: number;
|
||||
key: string;
|
||||
status: number;
|
||||
name: string;
|
||||
created_time: number;
|
||||
accessed_time: number;
|
||||
expired_time: number;
|
||||
remain_quota: number;
|
||||
unlimited_quota: boolean;
|
||||
used_quota: number;
|
||||
models: string;
|
||||
subnet: string;
|
||||
group: string;
|
||||
}
|
||||
|
||||
// 日志
|
||||
interface Log {
|
||||
id: number;
|
||||
user_id: number;
|
||||
created_at: number;
|
||||
type: number;
|
||||
content: string;
|
||||
username: string;
|
||||
token_name: string;
|
||||
model_name: string;
|
||||
quota: number;
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
channel_id: number;
|
||||
token_id: number;
|
||||
group: string;
|
||||
request_id: string;
|
||||
ip: string;
|
||||
detail: string;
|
||||
}
|
||||
|
||||
// 订阅计划
|
||||
interface SubscriptionPlan {
|
||||
id: number;
|
||||
name: string;
|
||||
description: string;
|
||||
price: number;
|
||||
currency: string;
|
||||
duration_days: number;
|
||||
quota: number;
|
||||
models: string;
|
||||
enabled: boolean;
|
||||
sort_order: number;
|
||||
created_time: number;
|
||||
}
|
||||
|
||||
// 文档(新增)
|
||||
interface Document {
|
||||
id: number;
|
||||
title: string;
|
||||
slug: string;
|
||||
content: string; // Markdown
|
||||
category_id: number;
|
||||
category?: DocumentCategory;
|
||||
visibility: 'public' | 'auth' | 'admin';
|
||||
sort_order: number;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
author_id: number;
|
||||
author?: User;
|
||||
versions?: DocumentVersion[];
|
||||
}
|
||||
|
||||
interface DocumentCategory {
|
||||
id: number;
|
||||
name: string;
|
||||
slug: string;
|
||||
parent_id: number | null;
|
||||
children?: DocumentCategory[];
|
||||
sort_order: number;
|
||||
}
|
||||
|
||||
interface DocumentVersion {
|
||||
id: number;
|
||||
document_id: number;
|
||||
content: string;
|
||||
created_at: string;
|
||||
author_id: number;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 新增文档管理 API
|
||||
|
||||
| 端点 | 方法 | 权限 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `/api/docs/categories` | GET | 公开 | 获取分类树 |
|
||||
| `/api/docs/categories` | POST | Admin | 创建分类 |
|
||||
| `/api/docs/categories/:id` | PUT | Admin | 更新分类 |
|
||||
| `/api/docs/categories/:id` | DELETE | Admin | 删除分类 |
|
||||
| `/api/docs/` | GET | 按可见性 | 文档列表(支持搜索) |
|
||||
| `/api/docs/:slug` | GET | 按可见性 | 获取文档详情 |
|
||||
| `/api/docs/` | POST | Admin | 创建文档 |
|
||||
| `/api/docs/:id` | PUT | Admin | 更新文档 |
|
||||
| `/api/docs/:id` | DELETE | Admin | 删除文档 |
|
||||
| `/api/docs/:id/versions` | GET | Admin | 文档版本历史 |
|
||||
|
||||
## 5. 项目目录结构
|
||||
|
||||
```
|
||||
web/daisy/
|
||||
├── index.html
|
||||
├── package.json
|
||||
├── tsconfig.json
|
||||
├── vite.config.ts
|
||||
├── tailwind.config.ts
|
||||
├── public/
|
||||
│ └── manifest.json
|
||||
└── src/
|
||||
├── main.tsx # 入口
|
||||
├── App.tsx # 根组件 + 路由
|
||||
├── vite-env.d.ts
|
||||
├── api/ # API 请求层
|
||||
│ ├── client.ts # Axios 实例 + 拦截器
|
||||
│ ├── auth.ts # 认证 API
|
||||
│ ├── channel.ts # 渠道 API
|
||||
│ ├── token.ts # 令牌 API
|
||||
│ ├── user.ts # 用户 API
|
||||
│ ├── log.ts # 日志 API
|
||||
│ ├── subscription.ts # 订阅 API
|
||||
│ ├── redemption.ts # 兑换码 API
|
||||
│ ├── model.ts # 模型 API
|
||||
│ ├── vendor.ts # 供应商 API
|
||||
│ ├── deployment.ts # 部署 API
|
||||
│ ├── option.ts # 系统设置 API
|
||||
│ ├── payment.ts # 支付 API
|
||||
│ └── doc.ts # 文档 API(新增)
|
||||
├── stores/ # Zustand 状态
|
||||
│ ├── auth.ts # 认证状态
|
||||
│ └── ui.ts # UI 状态(侧边栏/主题)
|
||||
├── hooks/ # 自定义 Hooks
|
||||
│ ├── useAuth.ts
|
||||
│ ├── usePermission.ts
|
||||
│ └── useQuota.ts
|
||||
├── components/ # 通用组件
|
||||
│ ├── layout/
|
||||
│ │ ├── AppLayout.tsx # 主布局
|
||||
│ │ ├── Sidebar.tsx # 侧边导航
|
||||
│ │ ├── Navbar.tsx # 顶部导航
|
||||
│ │ └── Breadcrumb.tsx # 面包屑
|
||||
│ ├── common/
|
||||
│ │ ├── QuotaDisplay.tsx # 额度显示
|
||||
│ │ ├── ModelBadge.tsx # 模型标签
|
||||
│ │ ├── StatusBadge.tsx # 状态标签
|
||||
│ │ ├── SearchInput.tsx # 搜索框
|
||||
│ │ ├── DataTable.tsx # 数据表格
|
||||
│ │ ├── ConfirmDialog.tsx # 确认对话框
|
||||
│ │ └── LoadingSpinner.tsx # 加载动画
|
||||
│ └── charts/
|
||||
│ ├── QuotaChart.tsx # 额度趋势图
|
||||
│ └── StatsChart.tsx # 统计图表
|
||||
├── pages/ # 页面组件
|
||||
│ ├── public/
|
||||
│ │ ├── Home.tsx
|
||||
│ │ ├── Login.tsx
|
||||
│ │ ├── Register.tsx
|
||||
│ │ ├── ForgotPassword.tsx
|
||||
│ │ ├── Pricing.tsx
|
||||
│ │ ├── About.tsx
|
||||
│ │ └── Setup.tsx
|
||||
│ ├── dashboard/
|
||||
│ │ └── Dashboard.tsx
|
||||
│ ├── tokens/
|
||||
│ │ ├── TokenList.tsx
|
||||
│ │ └── TokenForm.tsx
|
||||
│ ├── channels/
|
||||
│ │ ├── ChannelList.tsx
|
||||
│ │ └── ChannelForm.tsx
|
||||
│ ├── users/
|
||||
│ │ ├── UserList.tsx
|
||||
│ │ └── UserForm.tsx
|
||||
│ ├── logs/
|
||||
│ │ ├── LogList.tsx
|
||||
│ │ ├── MidjourneyLog.tsx
|
||||
│ │ └── TaskLog.tsx
|
||||
│ ├── wallet/
|
||||
│ │ └── Wallet.tsx
|
||||
│ ├── subscriptions/
|
||||
│ │ ├── PlanList.tsx
|
||||
│ │ └── MySubscription.tsx
|
||||
│ ├── redemptions/
|
||||
│ │ └── RedemptionList.tsx
|
||||
│ ├── models/
|
||||
│ │ └── ModelList.tsx
|
||||
│ ├── vendors/
|
||||
│ │ └── VendorList.tsx
|
||||
│ ├── deployments/
|
||||
│ │ └── DeploymentList.tsx
|
||||
│ ├── playground/
|
||||
│ │ └── Playground.tsx
|
||||
│ ├── profile/
|
||||
│ │ └── Profile.tsx
|
||||
│ ├── docs/ # 文档中心(新增)
|
||||
│ │ ├── DocCenter.tsx # 文档浏览主页
|
||||
│ │ ├── DocViewer.tsx # 文档阅读页
|
||||
│ │ ├── DocEditor.tsx # 文档编辑页(管理员)
|
||||
│ │ └── DocCategoryManager.tsx # 分类管理(管理员)
|
||||
│ └── settings/
|
||||
│ ├── SiteSettings.tsx
|
||||
│ ├── AuthSettings.tsx
|
||||
│ ├── BillingSettings.tsx
|
||||
│ ├── ContentSettings.tsx
|
||||
│ ├── ModelSettings.tsx
|
||||
│ ├── OperationsSettings.tsx
|
||||
│ ├── SecuritySettings.tsx
|
||||
│ └── DocSettings.tsx # 文档设置(新增)
|
||||
├── i18n/ # 国际化
|
||||
│ ├── index.ts
|
||||
│ └── locales/
|
||||
│ ├── en.json
|
||||
│ └── zh.json
|
||||
├── lib/ # 工具函数
|
||||
│ ├── constants.ts
|
||||
│ ├── utils.ts
|
||||
│ ├── quota.ts
|
||||
│ └── channel-types.ts
|
||||
└── types/ # TypeScript 类型
|
||||
├── api.ts
|
||||
├── channel.ts
|
||||
├── token.ts
|
||||
├── user.ts
|
||||
├── log.ts
|
||||
├── subscription.ts
|
||||
├── doc.ts
|
||||
└── option.ts
|
||||
```
|
||||
|
||||
## 6. 数据模型(新增文档管理)
|
||||
|
||||
```mermaid
|
||||
erDiagram
|
||||
"document_categories" {
|
||||
int id PK
|
||||
string name
|
||||
string slug UK
|
||||
int parent_id FK
|
||||
int sort_order
|
||||
timestamp created_at
|
||||
}
|
||||
"documents" {
|
||||
int id PK
|
||||
string title
|
||||
string slug UK
|
||||
text content
|
||||
int category_id FK
|
||||
string visibility
|
||||
int sort_order
|
||||
int author_id FK
|
||||
timestamp created_at
|
||||
timestamp updated_at
|
||||
}
|
||||
"document_versions" {
|
||||
int id PK
|
||||
int document_id FK
|
||||
text content
|
||||
int author_id FK
|
||||
timestamp created_at
|
||||
}
|
||||
"document_categories" ||--o{ "document_categories" : "parent"
|
||||
"document_categories" ||--o{ "documents" : "has"
|
||||
"documents" ||--o{ "document_versions" : "has"
|
||||
```
|
||||
|
||||
### DDL
|
||||
|
||||
```sql
|
||||
CREATE TABLE document_categories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
slug VARCHAR(100) NOT NULL UNIQUE,
|
||||
parent_id INTEGER REFERENCES document_categories(id) ON DELETE SET NULL,
|
||||
sort_order INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title VARCHAR(200) NOT NULL,
|
||||
slug VARCHAR(200) NOT NULL UNIQUE,
|
||||
content TEXT NOT NULL,
|
||||
category_id INTEGER REFERENCES document_categories(id) ON DELETE SET NULL,
|
||||
visibility VARCHAR(20) DEFAULT 'public' CHECK (visibility IN ('public', 'auth', 'admin')),
|
||||
sort_order INTEGER DEFAULT 0,
|
||||
author_id INTEGER NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE document_versions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
content TEXT NOT NULL,
|
||||
author_id INTEGER NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX idx_documents_slug ON documents(slug);
|
||||
CREATE INDEX idx_documents_category ON documents(category_id);
|
||||
CREATE INDEX idx_documents_visibility ON documents(visibility);
|
||||
CREATE INDEX idx_document_versions_doc ON document_versions(document_id);
|
||||
```
|
||||
+18
-16
@@ -1,22 +1,24 @@
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/default/package.json .
|
||||
COPY web/default/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web/default .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
WORKDIR /build/web
|
||||
COPY web/package.json web/bun.lock ./
|
||||
COPY web/default/package.json ./default/package.json
|
||||
COPY web/classic/package.json ./classic/package.json
|
||||
RUN bun install --frozen-lockfile
|
||||
COPY ./web/default ./default
|
||||
COPY ./VERSION /build/VERSION
|
||||
RUN cd default && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
|
||||
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder-classic
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/classic/package.json .
|
||||
COPY web/classic/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web/classic .
|
||||
COPY ./VERSION .
|
||||
RUN VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
WORKDIR /build/web
|
||||
COPY web/package.json web/bun.lock ./
|
||||
COPY web/default/package.json ./default/package.json
|
||||
COPY web/classic/package.json ./classic/package.json
|
||||
RUN bun install --frozen-lockfile
|
||||
COPY ./web/classic ./classic
|
||||
COPY ./VERSION /build/VERSION
|
||||
RUN cd classic && VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
|
||||
|
||||
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
@@ -32,8 +34,8 @@ ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
COPY --from=builder /build/dist ./web/default/dist
|
||||
COPY --from=builder-classic /build/dist ./web/classic/dist
|
||||
COPY --from=builder /build/web/default/dist ./web/default/dist
|
||||
COPY --from=builder-classic /build/web/classic/dist ./web/classic/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
new-api Notices
|
||||
|
||||
new-api
|
||||
Copyright (c) QuantumNous and contributors.
|
||||
Copyright (c) modelstoken and contributors.
|
||||
|
||||
This project is licensed under the GNU Affero General Public License v3.0.
|
||||
See LICENSE for the full project license terms.
|
||||
@@ -19,7 +19,7 @@ Modified versions that present a user interface must also preserve a visible
|
||||
link to the original project in a prominent about, legal, footer, or
|
||||
attribution location:
|
||||
|
||||
https://github.com/QuantumNous/new-api
|
||||
https://git.viaeon.com/admin/new-api
|
||||
|
||||
Modified versions must not misrepresent the origin of the software and must
|
||||
mark their changes in accordance with AGPLv3 Section 7(c).
|
||||
|
||||
@@ -316,6 +316,7 @@ docker run --name new-api -d --restart always \
|
||||
| `CRYPTO_SECRET` | Encryption secret (required for Redis) | - |
|
||||
| `SQL_DSN` | Database connection string | - |
|
||||
| `REDIS_CONN_STRING` | Redis connection string | - |
|
||||
| `RELAY_IDLE_CONN_TIMEOUT` | Idle keep-alive timeout for relay HTTP clients, seconds. Defaults to Go standard library behavior; set `0` to disable | `90` |
|
||||
| `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` |
|
||||
| `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` |
|
||||
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
|
||||
|
||||
+2
-1
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
var StartTime = time.Now().Unix() // unit: second
|
||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||
var SystemName = "New API"
|
||||
var SystemName = "ModelsToken"
|
||||
var Footer = ""
|
||||
var Logo = ""
|
||||
var TopUpLink = ""
|
||||
@@ -170,6 +170,7 @@ var BatchUpdateInterval int
|
||||
|
||||
var RelayTimeout int // unit is second
|
||||
|
||||
var RelayIdleConnTimeout int // unit is second
|
||||
var RelayMaxIdleConns int
|
||||
var RelayMaxIdleConnsPerHost int
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func checkWriter(writer io.Writer) stringWriter {
|
||||
// W3C Working Draft 29 October 2009
|
||||
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
|
||||
|
||||
var contentType = []string{"text/event-stream"}
|
||||
var writeContentType = []string{"text/event-stream"}
|
||||
var noCache = []string{"no-cache"}
|
||||
|
||||
var fieldReplacer = strings.NewReplacer(
|
||||
@@ -79,7 +79,7 @@ func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
header["Content-Type"] = writeContentType
|
||||
|
||||
if _, exist := header["Cache-Control"]; !exist {
|
||||
header["Cache-Control"] = noCache
|
||||
|
||||
@@ -51,17 +51,21 @@ type themeAwareFileSystem struct {
|
||||
}
|
||||
|
||||
func (t *themeAwareFileSystem) Exists(prefix string, path string) bool {
|
||||
if GetTheme() == "classic" {
|
||||
switch GetTheme() {
|
||||
case "classic":
|
||||
return t.classicFS.Exists(prefix, path)
|
||||
default:
|
||||
return t.defaultFS.Exists(prefix, path)
|
||||
}
|
||||
return t.defaultFS.Exists(prefix, path)
|
||||
}
|
||||
|
||||
func (t *themeAwareFileSystem) Open(name string) (http.File, error) {
|
||||
if GetTheme() == "classic" {
|
||||
switch GetTheme() {
|
||||
case "classic":
|
||||
return t.classicFS.Open(name)
|
||||
default:
|
||||
return t.defaultFS.Open(name)
|
||||
}
|
||||
return t.defaultFS.Open(name)
|
||||
}
|
||||
|
||||
func NewThemeAwareFS(defaultFS, classicFS static.ServeFileSystem) static.ServeFileSystem {
|
||||
|
||||
+19
-1
@@ -110,11 +110,29 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
|
||||
// disk-backed JSON: stream-decode directly from the file to avoid
|
||||
// materializing the entire payload back into a transient []byte
|
||||
// (diskStorage.Bytes() would ReadFull the whole file into the heap).
|
||||
if storage.IsDisk() && strings.HasPrefix(contentType, "application/json") {
|
||||
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||
return seekErr
|
||||
}
|
||||
if err := DecodeJson(storage, v); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||
return seekErr
|
||||
}
|
||||
c.Request.Body = io.NopCloser(storage)
|
||||
return nil
|
||||
}
|
||||
|
||||
requestBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = Unmarshal(requestBody, v)
|
||||
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
|
||||
|
||||
+4
-2
@@ -102,6 +102,7 @@ func InitEnv() {
|
||||
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
|
||||
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
|
||||
RelayIdleConnTimeout = GetEnvOrDefault("RELAY_IDLE_CONN_TIMEOUT", 90)
|
||||
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
|
||||
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
|
||||
|
||||
@@ -111,11 +112,11 @@ func InitEnv() {
|
||||
|
||||
// Initialize rate limit variables
|
||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 360)
|
||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 120)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
|
||||
@@ -135,6 +136,7 @@ func initConstantEnv() {
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 128)
|
||||
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
|
||||
constant.AnonymousRequestBodyLimitKB = GetEnvOrDefault("ANONYMOUS_REQUEST_BODY_LIMIT_KB", 512)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package common
|
||||
|
||||
import "github.com/QuantumNous/new-api/constant"
|
||||
|
||||
const defaultAnonymousRequestBodyLimitKB = 512
|
||||
|
||||
func GetAnonymousRequestBodyLimitBytes() int64 {
|
||||
limitKB := constant.AnonymousRequestBodyLimitKB
|
||||
if limitKB < 0 {
|
||||
limitKB = defaultAnonymousRequestBodyLimitKB
|
||||
}
|
||||
return int64(limitKB) << 10
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -20,6 +21,16 @@ var (
|
||||
maskApiKeyPattern = regexp.MustCompile(`(['"]?)api_key:([^\s'"]+)(['"]?)`)
|
||||
)
|
||||
|
||||
const LocalLogContentLimit = 2048
|
||||
|
||||
// LocalLogPreview limits log-only content unless debug logging is enabled.
|
||||
func LocalLogPreview(content string) string {
|
||||
if DebugEnabled || len(content) <= LocalLogContentLimit {
|
||||
return content
|
||||
}
|
||||
return fmt.Sprintf("%s... [truncated, original_length=%d, limit=%d]", content[:LocalLogContentLimit], len(content), LocalLogContentLimit)
|
||||
}
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
if str == "" {
|
||||
return defaultValue
|
||||
|
||||
@@ -206,4 +206,8 @@ var ChannelSpecialBases = map[string]ChannelSpecialBase{
|
||||
ClaudeBaseURL: "https://ark.cn-beijing.volces.com/api/coding",
|
||||
OpenAIBaseURL: "https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
},
|
||||
"tencent-coding-plan": {
|
||||
ClaudeBaseURL: "https://api.lkeap.cloud.tencent.com/coding",
|
||||
OpenAIBaseURL: "https://api.lkeap.cloud.tencent.com/coding/v3",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
var MaxRequestBodyMB int
|
||||
var AnonymousRequestBodyLimitKB int
|
||||
var AzureDefaultAPIVersion string
|
||||
var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
|
||||
@@ -57,7 +57,24 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
|
||||
return normalized
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
|
||||
func resolveChannelTestUserID(c *gin.Context) (int, error) {
|
||||
if c != nil {
|
||||
if userID := c.GetInt("id"); userID > 0 {
|
||||
return userID, nil
|
||||
}
|
||||
}
|
||||
|
||||
var rootUser model.User
|
||||
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
|
||||
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
|
||||
}
|
||||
if rootUser.Id == 0 {
|
||||
return 0, errors.New("failed to resolve channel test user")
|
||||
}
|
||||
return rootUser.Id, nil
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
|
||||
tik := time.Now()
|
||||
var unsupportedTestChannelTypes = []int{
|
||||
constant.ChannelTypeMidjourney,
|
||||
@@ -143,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
cache, err := model.GetUserCache(testUserID)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
localErr: err,
|
||||
@@ -151,13 +168,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
}
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
c.Set("id", 1)
|
||||
c.Set("id", testUserID)
|
||||
|
||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
group, _ := model.GetUserGroup(testUserID, false)
|
||||
c.Set("group", group)
|
||||
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
@@ -484,7 +501,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
@@ -797,7 +814,7 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel,
|
||||
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "o") {
|
||||
if dto.IsOpenAIReasoningOModel(model) {
|
||||
testRequest.MaxCompletionTokens = lo.ToPtr(uint(16))
|
||||
} else if strings.Contains(model, "thinking") {
|
||||
if !strings.Contains(model, "claude") {
|
||||
@@ -834,8 +851,13 @@ func TestChannel(c *gin.Context) {
|
||||
testModel := c.Query("model")
|
||||
endpointType := c.Query("endpoint_type")
|
||||
isStream, _ := strconv.ParseBool(c.Query("stream"))
|
||||
testUserID, err := resolveChannelTestUserID(c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, testModel, endpointType, isStream)
|
||||
result := testChannel(channel, testUserID, testModel, endpointType, isStream)
|
||||
if result.localErr != nil {
|
||||
resp := gin.H{
|
||||
"success": false,
|
||||
@@ -872,6 +894,10 @@ var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
testUserID, err := resolveChannelTestUserID(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
testAllChannelsLock.Lock()
|
||||
if testAllChannelsRunning {
|
||||
@@ -902,7 +928,7 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||
result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
|
||||
@@ -1218,7 +1218,7 @@ func CopyChannel(c *gin.Context) {
|
||||
}
|
||||
|
||||
// insert
|
||||
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
||||
if err := clone.Insert(); err != nil {
|
||||
common.SysError("failed to clone channel: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
|
||||
return
|
||||
|
||||
@@ -69,3 +69,14 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
|
||||
require.Equal(t, "base", other["matched_tier"])
|
||||
require.NotEmpty(t, other["expr_b64"])
|
||||
}
|
||||
|
||||
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Set("id", 2)
|
||||
|
||||
userID, err := resolveChannelTestUserID(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, userID)
|
||||
}
|
||||
|
||||
@@ -312,7 +312,11 @@ func fetchChannelUpstreamModelIDs(channel *model.Channel) ([]string, error) {
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
default:
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" {
|
||||
url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL)
|
||||
} else {
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
key, _, apiErr := channel.GetNextEnabledKey()
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetCategories 获取文档分类列表(公开)
|
||||
func GetCategories(c *gin.Context) {
|
||||
categories, err := model.GetDocumentCategories()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, categories)
|
||||
}
|
||||
|
||||
// CreateCategory 创建文档分类(管理员)
|
||||
func CreateCategory(c *gin.Context) {
|
||||
var category model.DocumentCategory
|
||||
if err := c.ShouldBindJSON(&category); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if category.Name == "" {
|
||||
common.ApiErrorMsg(c, "分类名称不能为空")
|
||||
return
|
||||
}
|
||||
if category.Slug == "" {
|
||||
common.ApiErrorMsg(c, "分类标识不能为空")
|
||||
return
|
||||
}
|
||||
if err := model.CreateDocumentCategory(&category); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &category)
|
||||
}
|
||||
|
||||
// UpdateCategory 更新文档分类(管理员)
|
||||
func UpdateCategory(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var category model.DocumentCategory
|
||||
if err := c.ShouldBindJSON(&category); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
category.Id = id
|
||||
if err := model.UpdateDocumentCategory(&category); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &category)
|
||||
}
|
||||
|
||||
// DeleteCategory 删除文档分类(管理员)
|
||||
func DeleteCategory(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DeleteDocumentCategory(id); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
// GetDocuments 获取文档列表(公开,根据认证状态过滤可见性)
|
||||
func GetDocuments(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
categoryIdStr := c.Query("category_id")
|
||||
|
||||
var categoryId *int
|
||||
if categoryIdStr != "" {
|
||||
id, err := strconv.Atoi(categoryIdStr)
|
||||
if err == nil {
|
||||
categoryId = &id
|
||||
}
|
||||
}
|
||||
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
// 根据用户认证状态决定可见性过滤
|
||||
visibility := c.Query("visibility")
|
||||
role := c.GetInt("role")
|
||||
|
||||
var documents []*model.Document
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
if role >= common.RoleAdminUser {
|
||||
// 管理员可看所有
|
||||
documents, total, err = model.GetDocuments(keyword, visibility, categoryId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
} else if role >= common.RoleCommonUser {
|
||||
// 普通用户只能看 public 和 auth
|
||||
if visibility == "public" || visibility == "auth" {
|
||||
documents, total, err = model.GetDocuments(keyword, visibility, categoryId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
} else {
|
||||
documents, total, err = model.GetDocumentsByVisibility(keyword, []string{"public", "auth"}, categoryId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
}
|
||||
} else {
|
||||
// 未登录用户只能看 public
|
||||
documents, total, err = model.GetDocuments(keyword, "public", categoryId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(documents)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
// GetDocument 获取单个文档(根据可见性检查权限)
|
||||
func GetDocument(c *gin.Context) {
|
||||
slug := c.Param("slug")
|
||||
doc, err := model.GetDocumentBySlug(slug)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查可见性权限
|
||||
role := c.GetInt("role")
|
||||
switch doc.Visibility {
|
||||
case "admin":
|
||||
if role < common.RoleAdminUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权访问该文档",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "auth":
|
||||
if role < common.RoleCommonUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "请先登录后查看该文档",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, doc)
|
||||
}
|
||||
|
||||
// CreateDocument 创建文档(管理员)
|
||||
func CreateDocument(c *gin.Context) {
|
||||
var doc model.Document
|
||||
if err := c.ShouldBindJSON(&doc); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if doc.Title == "" {
|
||||
common.ApiErrorMsg(c, "文档标题不能为空")
|
||||
return
|
||||
}
|
||||
if doc.Slug == "" {
|
||||
common.ApiErrorMsg(c, "文档标识不能为空")
|
||||
return
|
||||
}
|
||||
if doc.Content == "" {
|
||||
common.ApiErrorMsg(c, "文档内容不能为空")
|
||||
return
|
||||
}
|
||||
if doc.Visibility == "" {
|
||||
doc.Visibility = "public"
|
||||
}
|
||||
doc.AuthorId = c.GetInt("id")
|
||||
if err := model.CreateDocument(&doc); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &doc)
|
||||
}
|
||||
|
||||
// UpdateDocument 更新文档(管理员,自动创建版本记录)
|
||||
func UpdateDocument(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var doc model.Document
|
||||
if err := c.ShouldBindJSON(&doc); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
doc.Id = id
|
||||
|
||||
// 获取旧文档内容,自动创建版本记录
|
||||
oldDoc, err := model.GetDocumentById(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
version := &model.DocumentVersion{
|
||||
DocumentId: oldDoc.Id,
|
||||
Content: oldDoc.Content,
|
||||
AuthorId: oldDoc.AuthorId,
|
||||
}
|
||||
if err := model.CreateDocumentVersion(version); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.UpdateDocument(&doc); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &doc)
|
||||
}
|
||||
|
||||
// DeleteDocument 删除文档(管理员)
|
||||
func DeleteDocument(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DeleteDocument(id); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
// GetDocumentVersions 获取文档版本历史(管理员)
|
||||
func GetDocumentVersions(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
versions, total, err := model.GetDocumentVersions(id, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(versions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
@@ -88,6 +88,7 @@ func GetStatus(c *gin.Context) {
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"register_enabled": common.RegisterEnabled,
|
||||
"password_login_enabled": common.PasswordLoginEnabled,
|
||||
"password_register_enabled": common.PasswordRegisterEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
|
||||
|
||||
+120
-43
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -109,9 +110,102 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
func channelOwnerName(channelType int) string {
|
||||
apiType, success := common.ChannelType2APIType(channelType)
|
||||
if !success {
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
adaptor.Init(&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: channelType,
|
||||
}})
|
||||
if name := strings.TrimSpace(adaptor.GetChannelName()); name != "" {
|
||||
return name
|
||||
}
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
|
||||
func getPreferredModelOwners(modelNames []string, groups []string) map[string]string {
|
||||
channelTypes, err := model.GetPreferredModelOwnerChannelTypes(modelNames, groups)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("GetPreferredModelOwnerChannelTypes error: %v", err))
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
ownerByChannelType := make(map[int]string)
|
||||
owners := make(map[string]string, len(channelTypes))
|
||||
for modelName, channelType := range channelTypes {
|
||||
owner, ok := ownerByChannelType[channelType]
|
||||
if !ok {
|
||||
owner = channelOwnerName(channelType)
|
||||
ownerByChannelType[channelType] = owner
|
||||
}
|
||||
if owner != "" {
|
||||
owners[modelName] = owner
|
||||
}
|
||||
}
|
||||
return owners
|
||||
}
|
||||
|
||||
func buildOpenAIModel(modelName string, ownerByModel map[string]string) dto.OpenAIModels {
|
||||
var oaiModel dto.OpenAIModels
|
||||
if staticModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel = staticModel
|
||||
} else {
|
||||
oaiModel = dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
}
|
||||
}
|
||||
if owner, ok := ownerByModel[modelName]; ok && owner != "" {
|
||||
oaiModel.OwnedBy = owner
|
||||
}
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
return oaiModel
|
||||
}
|
||||
|
||||
type modelListGroups struct {
|
||||
userGroup string
|
||||
tokenGroup string
|
||||
ownerGroups []string
|
||||
}
|
||||
|
||||
func getModelListGroups(c *gin.Context) (modelListGroups, error) {
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
if userGroup == "" && (tokenGroup == "" || tokenGroup == "auto") {
|
||||
var err error
|
||||
userGroup, err = model.GetUserGroup(c.GetInt("id"), false)
|
||||
if err != nil {
|
||||
return modelListGroups{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if tokenGroup == "auto" {
|
||||
return modelListGroups{
|
||||
userGroup: userGroup,
|
||||
tokenGroup: tokenGroup,
|
||||
ownerGroups: service.GetUserAutoGroup(userGroup),
|
||||
}, nil
|
||||
}
|
||||
|
||||
group := userGroup
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
return modelListGroups{
|
||||
userGroup: userGroup,
|
||||
tokenGroup: tokenGroup,
|
||||
ownerGroups: []string{group},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
|
||||
if !acceptUnsetRatioModel {
|
||||
userId := c.GetInt("id")
|
||||
@@ -123,6 +217,16 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
}
|
||||
|
||||
userModelNames := make([]string, 0)
|
||||
groups, err := getModelListGroups(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
ownerGroups := groups.ownerGroups
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
@@ -138,37 +242,12 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
userModelNames = append(userModelNames, allowModel)
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
|
||||
if groups.tokenGroup == "auto" {
|
||||
for _, autoGroup := range ownerGroups {
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
@@ -177,7 +256,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
models = model.GetGroupEnabledModels(ownerGroups[0])
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if !acceptUnsetRatioModel {
|
||||
@@ -185,21 +264,19 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
userModelNames = append(userModelNames, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
ownerByModel := map[string]string{}
|
||||
if len(ownerGroups) > 0 {
|
||||
ownerByModel = getPreferredModelOwners(userModelNames, ownerGroups)
|
||||
}
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0, len(userModelNames))
|
||||
for _, modelName := range userModelNames {
|
||||
userOpenAiModels = append(userOpenAiModels, buildOpenAIModel(modelName, ownerByModel))
|
||||
}
|
||||
|
||||
switch modelType {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChannelOwnerNameUsesAdaptorChannelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
channelType int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "openai",
|
||||
channelType: constant.ChannelTypeOpenAI,
|
||||
expected: "openai",
|
||||
},
|
||||
{
|
||||
name: "codex",
|
||||
channelType: constant.ChannelTypeCodex,
|
||||
expected: "codex",
|
||||
},
|
||||
{
|
||||
name: "openrouter",
|
||||
channelType: constant.ChannelTypeOpenRouter,
|
||||
expected: "openrouter",
|
||||
},
|
||||
{
|
||||
name: "azure fallback",
|
||||
channelType: constant.ChannelTypeAzure,
|
||||
expected: "azure",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, channelOwnerName(tt.channelType))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIModelOverridesOwnedBy(t *testing.T) {
|
||||
modelItem := buildOpenAIModel("gpt-5.4", map[string]string{"gpt-5.4": "openai"})
|
||||
require.Equal(t, "gpt-5.4", modelItem.Id)
|
||||
require.Equal(t, "openai", modelItem.OwnedBy)
|
||||
}
|
||||
|
||||
func TestBuildOpenAIModelFallsBackToCustomForUnknownModels(t *testing.T) {
|
||||
modelItem := buildOpenAIModel("custom-test-model", nil)
|
||||
require.Equal(t, "custom-test-model", modelItem.Id)
|
||||
require.Equal(t, "custom", modelItem.OwnedBy)
|
||||
}
|
||||
|
||||
func TestGetModelListGroupsUsesUserGroupWhenTokenGroupIsEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||
|
||||
groups, err := getModelListGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "default", groups.userGroup)
|
||||
require.Empty(t, groups.tokenGroup)
|
||||
require.Equal(t, []string{"default"}, groups.ownerGroups)
|
||||
}
|
||||
|
||||
func TestGetModelListGroupsUsesExplicitTokenGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenGroup, "vip")
|
||||
|
||||
groups, err := getModelListGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "default", groups.userGroup)
|
||||
require.Equal(t, "vip", groups.tokenGroup)
|
||||
require.Equal(t, []string{"vip"}, groups.ownerGroups)
|
||||
}
|
||||
+2
-23
@@ -7,7 +7,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/console_setting"
|
||||
@@ -29,10 +28,6 @@ var completionRatioMetaOptionKeys = []string{
|
||||
"AudioCompletionRatio",
|
||||
}
|
||||
|
||||
func isPaymentComplianceOptionKey(key string) bool {
|
||||
return strings.HasPrefix(key, "payment_setting.compliance_")
|
||||
}
|
||||
|
||||
func isPositiveOptionValue(value string) bool {
|
||||
intValue, err := strconv.Atoi(strings.TrimSpace(value))
|
||||
if err == nil {
|
||||
@@ -42,15 +37,6 @@ func isPositiveOptionValue(value string) bool {
|
||||
return err == nil && floatValue > 0
|
||||
}
|
||||
|
||||
func isVisiblePublicKeyOption(key string) bool {
|
||||
switch key {
|
||||
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return
|
||||
@@ -95,7 +81,7 @@ func GetOptions(c *gin.Context) {
|
||||
strings.HasSuffix(k, "Key") ||
|
||||
strings.HasSuffix(k, "secret") ||
|
||||
strings.HasSuffix(k, "api_key")
|
||||
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
|
||||
if isSensitiveKey {
|
||||
continue
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
@@ -148,15 +134,8 @@ func UpdateOption(c *gin.Context) {
|
||||
}
|
||||
switch option.Key {
|
||||
case "QuotaForInviter", "QuotaForInvitee":
|
||||
if isPositiveOptionValue(option.Value.(string)) && !operation_setting.IsPaymentComplianceConfirmed() {
|
||||
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
|
||||
return
|
||||
}
|
||||
// no compliance check needed
|
||||
default:
|
||||
if isPaymentComplianceOptionKey(option.Key) {
|
||||
common.ApiErrorMsg(c, "合规确认字段不允许通过通用设置接口修改")
|
||||
return
|
||||
}
|
||||
}
|
||||
switch option.Key {
|
||||
case "GitHubOAuthEnabled":
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PaymentComplianceRequest struct {
|
||||
Confirmed bool `json:"confirmed"`
|
||||
}
|
||||
|
||||
func requirePaymentCompliance(c *gin.Context) bool {
|
||||
if !operation_setting.IsPaymentComplianceConfirmed() {
|
||||
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ConfirmPaymentCompliance(c *gin.Context) {
|
||||
if c.GetBool("use_access_token") {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "This operation requires dashboard session authentication. API access token is not allowed.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req PaymentComplianceRequest
|
||||
if err := common.DecodeJson(c.Request.Body, &req); err != nil {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
return
|
||||
}
|
||||
if !req.Confirmed {
|
||||
common.ApiErrorMsg(c, "请确认合规声明")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
userId := c.GetInt("id")
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
updates := map[string]string{
|
||||
"payment_setting.compliance_confirmed": "true",
|
||||
"payment_setting.compliance_terms_version": operation_setting.CurrentComplianceTermsVersion,
|
||||
"payment_setting.compliance_confirmed_at": strconv.FormatInt(now, 10),
|
||||
"payment_setting.compliance_confirmed_by": strconv.Itoa(userId),
|
||||
"payment_setting.compliance_confirmed_ip": clientIP,
|
||||
}
|
||||
|
||||
for key, value := range updates {
|
||||
if err := model.UpdateOption(key, value); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf(
|
||||
"payment compliance confirmed user_id=%d ip=%s terms_version=%s confirmed_at=%d",
|
||||
userId,
|
||||
clientIP,
|
||||
operation_setting.CurrentComplianceTermsVersion,
|
||||
now,
|
||||
))
|
||||
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"confirmed": true,
|
||||
"terms_version": operation_setting.CurrentComplianceTermsVersion,
|
||||
"confirmed_at": now,
|
||||
"confirmed_by": userId,
|
||||
})
|
||||
}
|
||||
@@ -7,14 +7,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func isPaymentComplianceConfirmed() bool {
|
||||
return operation_setting.IsPaymentComplianceConfirmed()
|
||||
}
|
||||
|
||||
func isStripeTopUpEnabled() bool {
|
||||
if !isPaymentComplianceConfirmed() {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(setting.StripeApiSecret) != "" &&
|
||||
strings.TrimSpace(setting.StripeWebhookSecret) != "" &&
|
||||
strings.TrimSpace(setting.StripePriceId) != ""
|
||||
@@ -29,9 +22,6 @@ func isStripeWebhookEnabled() bool {
|
||||
}
|
||||
|
||||
func isCreemTopUpEnabled() bool {
|
||||
if !isPaymentComplianceConfirmed() {
|
||||
return false
|
||||
}
|
||||
products := strings.TrimSpace(setting.CreemProducts)
|
||||
return strings.TrimSpace(setting.CreemApiKey) != "" &&
|
||||
products != "" &&
|
||||
@@ -47,9 +37,6 @@ func isCreemWebhookEnabled() bool {
|
||||
}
|
||||
|
||||
func isWaffoTopUpEnabled() bool {
|
||||
if !isPaymentComplianceConfirmed() {
|
||||
return false
|
||||
}
|
||||
if !setting.WaffoEnabled {
|
||||
return false
|
||||
}
|
||||
@@ -74,27 +61,13 @@ func isWaffoWebhookEnabled() bool {
|
||||
}
|
||||
|
||||
func isWaffoPancakeTopUpEnabled() bool {
|
||||
if !isPaymentComplianceConfirmed() {
|
||||
return false
|
||||
}
|
||||
if !setting.WaffoPancakeEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
return isWaffoPancakeWebhookConfigured() &&
|
||||
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
|
||||
return strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
|
||||
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
|
||||
}
|
||||
|
||||
func isWaffoPancakeWebhookConfigured() bool {
|
||||
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
|
||||
if setting.WaffoPancakeSandbox {
|
||||
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
|
||||
}
|
||||
|
||||
return currentWebhookKey != ""
|
||||
return isWaffoPancakeTopUpEnabled()
|
||||
}
|
||||
|
||||
func isWaffoPancakeWebhookEnabled() bool {
|
||||
@@ -102,9 +75,6 @@ func isWaffoPancakeWebhookEnabled() bool {
|
||||
}
|
||||
|
||||
func isEpayTopUpEnabled() bool {
|
||||
if !isPaymentComplianceConfirmed() {
|
||||
return false
|
||||
}
|
||||
return isEpayWebhookConfigured() && len(operation_setting.PayMethods) > 0
|
||||
}
|
||||
|
||||
|
||||
@@ -8,21 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func confirmPaymentComplianceForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
paymentSetting := operation_setting.GetPaymentSetting()
|
||||
originalConfirmed := paymentSetting.ComplianceConfirmed
|
||||
originalTermsVersion := paymentSetting.ComplianceTermsVersion
|
||||
t.Cleanup(func() {
|
||||
paymentSetting.ComplianceConfirmed = originalConfirmed
|
||||
paymentSetting.ComplianceTermsVersion = originalTermsVersion
|
||||
})
|
||||
paymentSetting.ComplianceConfirmed = true
|
||||
paymentSetting.ComplianceTermsVersion = operation_setting.CurrentComplianceTermsVersion
|
||||
}
|
||||
|
||||
func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
confirmPaymentComplianceForTest(t)
|
||||
originalAPISecret := setting.StripeApiSecret
|
||||
originalWebhookSecret := setting.StripeWebhookSecret
|
||||
originalPriceID := setting.StripePriceId
|
||||
@@ -45,7 +31,6 @@ func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
confirmPaymentComplianceForTest(t)
|
||||
originalAPIKey := setting.CreemApiKey
|
||||
originalProducts := setting.CreemProducts
|
||||
originalWebhookSecret := setting.CreemWebhookSecret
|
||||
@@ -68,7 +53,6 @@ func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
confirmPaymentComplianceForTest(t)
|
||||
originalEnabled := setting.WaffoEnabled
|
||||
originalSandbox := setting.WaffoSandbox
|
||||
originalAPIKey := setting.WaffoApiKey
|
||||
@@ -113,52 +97,32 @@ func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
confirmPaymentComplianceForTest(t)
|
||||
originalEnabled := setting.WaffoPancakeEnabled
|
||||
originalSandbox := setting.WaffoPancakeSandbox
|
||||
originalMerchantID := setting.WaffoPancakeMerchantID
|
||||
originalPrivateKey := setting.WaffoPancakePrivateKey
|
||||
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
|
||||
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
|
||||
originalStoreID := setting.WaffoPancakeStoreID
|
||||
originalProductID := setting.WaffoPancakeProductID
|
||||
t.Cleanup(func() {
|
||||
setting.WaffoPancakeEnabled = originalEnabled
|
||||
setting.WaffoPancakeSandbox = originalSandbox
|
||||
setting.WaffoPancakeMerchantID = originalMerchantID
|
||||
setting.WaffoPancakePrivateKey = originalPrivateKey
|
||||
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
|
||||
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
|
||||
setting.WaffoPancakeStoreID = originalStoreID
|
||||
setting.WaffoPancakeProductID = originalProductID
|
||||
})
|
||||
|
||||
setting.WaffoPancakeEnabled = true
|
||||
setting.WaffoPancakeSandbox = false
|
||||
setting.WaffoPancakeMerchantID = "merchant"
|
||||
setting.WaffoPancakeMerchantID = ""
|
||||
setting.WaffoPancakePrivateKey = "private"
|
||||
setting.WaffoPancakeStoreID = "store"
|
||||
setting.WaffoPancakeProductID = "product"
|
||||
setting.WaffoPancakeWebhookPublicKey = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeWebhookPublicKey = "public"
|
||||
setting.WaffoPancakeMerchantID = "merchant"
|
||||
require.True(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeEnabled = false
|
||||
setting.WaffoPancakeProductID = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeEnabled = true
|
||||
setting.WaffoPancakeSandbox = true
|
||||
setting.WaffoPancakeWebhookTestKey = ""
|
||||
setting.WaffoPancakeProductID = "product"
|
||||
setting.WaffoPancakePrivateKey = ""
|
||||
require.False(t, isWaffoPancakeWebhookEnabled())
|
||||
|
||||
setting.WaffoPancakeWebhookTestKey = "test_public"
|
||||
require.True(t, isWaffoPancakeWebhookEnabled())
|
||||
}
|
||||
|
||||
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
|
||||
confirmPaymentComplianceForTest(t)
|
||||
originalPayAddress := operation_setting.PayAddress
|
||||
originalEpayID := operation_setting.EpayId
|
||||
originalEpayKey := operation_setting.EpayKey
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -60,11 +59,6 @@ func GetRedemption(c *gin.Context) {
|
||||
}
|
||||
|
||||
func AddRedemption(c *gin.Context) {
|
||||
if !operation_setting.IsPaymentComplianceConfirmed() {
|
||||
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
|
||||
return
|
||||
}
|
||||
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
|
||||
+2
-2
@@ -88,7 +88,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("relay error: %s", common.LocalLogPreview(newAPIError.Error())))
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
@@ -354,7 +354,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, common.LocalLogPreview(err.Error())))
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(err) && channelError.AutoBan {
|
||||
|
||||
+28
-26
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -22,14 +21,13 @@ type BillingPreferenceRequest struct {
|
||||
BillingPreference string `json:"billing_preference"`
|
||||
}
|
||||
|
||||
type SubscriptionBalancePayRequest struct {
|
||||
PlanId int `json:"plan_id"`
|
||||
}
|
||||
|
||||
// ---- User APIs ----
|
||||
|
||||
func GetSubscriptionPlans(c *gin.Context) {
|
||||
if !operation_setting.IsPaymentComplianceConfirmed() {
|
||||
common.ApiSuccess(c, []SubscriptionPlanDTO{})
|
||||
return
|
||||
}
|
||||
|
||||
var plans []model.SubscriptionPlan
|
||||
if err := model.DB.Where("enabled = ?", true).Order("sort_order desc, id desc").Find(&plans).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -37,6 +35,7 @@ func GetSubscriptionPlans(c *gin.Context) {
|
||||
}
|
||||
result := make([]SubscriptionPlanDTO, 0, len(plans))
|
||||
for _, p := range plans {
|
||||
p.NormalizeDefaults()
|
||||
result = append(result, SubscriptionPlanDTO{
|
||||
Plan: p,
|
||||
})
|
||||
@@ -92,6 +91,21 @@ func UpdateSubscriptionPreference(c *gin.Context) {
|
||||
common.ApiSuccess(c, gin.H{"billing_preference": pref})
|
||||
}
|
||||
|
||||
func SubscriptionRequestBalancePay(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
var req SubscriptionBalancePayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.PurchaseSubscriptionWithBalance(userId, req.PlanId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
// ---- Admin APIs ----
|
||||
|
||||
func AdminListSubscriptionPlans(c *gin.Context) {
|
||||
@@ -102,6 +116,7 @@ func AdminListSubscriptionPlans(c *gin.Context) {
|
||||
}
|
||||
result := make([]SubscriptionPlanDTO, 0, len(plans))
|
||||
for _, p := range plans {
|
||||
p.NormalizeDefaults()
|
||||
result = append(result, SubscriptionPlanDTO{
|
||||
Plan: p,
|
||||
})
|
||||
@@ -114,10 +129,6 @@ type AdminUpsertSubscriptionPlanRequest struct {
|
||||
}
|
||||
|
||||
func AdminCreateSubscriptionPlan(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var req AdminUpsertSubscriptionPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
@@ -140,6 +151,9 @@ func AdminCreateSubscriptionPlan(c *gin.Context) {
|
||||
req.Plan.Currency = "USD"
|
||||
}
|
||||
req.Plan.Currency = "USD"
|
||||
if req.Plan.AllowBalancePay == nil {
|
||||
req.Plan.AllowBalancePay = common.GetPointer(true)
|
||||
}
|
||||
if req.Plan.DurationUnit == "" {
|
||||
req.Plan.DurationUnit = model.SubscriptionDurationMonth
|
||||
}
|
||||
@@ -176,10 +190,6 @@ func AdminCreateSubscriptionPlan(c *gin.Context) {
|
||||
}
|
||||
|
||||
func AdminUpdateSubscriptionPlan(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
if id <= 0 {
|
||||
common.ApiErrorMsg(c, "无效的ID")
|
||||
@@ -248,6 +258,7 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
|
||||
"sort_order": req.Plan.SortOrder,
|
||||
"stripe_price_id": req.Plan.StripePriceId,
|
||||
"creem_product_id": req.Plan.CreemProductId,
|
||||
"waffo_pancake_product_id": req.Plan.WaffoPancakeProductId,
|
||||
"max_purchase_per_user": req.Plan.MaxPurchasePerUser,
|
||||
"total_amount": req.Plan.TotalAmount,
|
||||
"upgrade_group": req.Plan.UpgradeGroup,
|
||||
@@ -255,6 +266,9 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
|
||||
"quota_reset_custom_seconds": req.Plan.QuotaResetCustomSeconds,
|
||||
"updated_at": common.GetTimestamp(),
|
||||
}
|
||||
if req.Plan.AllowBalancePay != nil {
|
||||
updateMap["allow_balance_pay"] = *req.Plan.AllowBalancePay
|
||||
}
|
||||
if err := tx.Model(&model.SubscriptionPlan{}).Where("id = ?", id).Updates(updateMap).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -273,10 +287,6 @@ type AdminUpdateSubscriptionPlanStatusRequest struct {
|
||||
}
|
||||
|
||||
func AdminUpdateSubscriptionPlanStatus(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
if id <= 0 {
|
||||
common.ApiErrorMsg(c, "无效的ID")
|
||||
@@ -301,10 +311,6 @@ type AdminBindSubscriptionRequest struct {
|
||||
}
|
||||
|
||||
func AdminBindSubscription(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var req AdminBindSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.UserId <= 0 || req.PlanId <= 0 {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
@@ -344,10 +350,6 @@ type AdminCreateUserSubscriptionRequest struct {
|
||||
|
||||
// AdminCreateUserSubscription creates a new user subscription from a plan (no payment).
|
||||
func AdminCreateUserSubscription(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
userId, _ := strconv.Atoi(c.Param("id"))
|
||||
if userId <= 0 {
|
||||
common.ApiErrorMsg(c, "无效的用户ID")
|
||||
|
||||
@@ -21,10 +21,6 @@ type SubscriptionCreemPayRequest struct {
|
||||
}
|
||||
|
||||
func SubscriptionRequestCreemPay(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var req SubscriptionCreemPayRequest
|
||||
|
||||
// Keep body for debugging consistency (like RequestCreemPay)
|
||||
|
||||
@@ -22,10 +22,6 @@ type SubscriptionEpayPayRequest struct {
|
||||
}
|
||||
|
||||
func SubscriptionRequestEpay(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var req SubscriptionEpayPayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
|
||||
@@ -21,10 +21,6 @@ type SubscriptionStripePayRequest struct {
|
||||
}
|
||||
|
||||
func SubscriptionRequestStripePay(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var req SubscriptionStripePayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
type SubscriptionWaffoPancakePayRequest struct {
|
||||
PlanId int `json:"plan_id"`
|
||||
}
|
||||
|
||||
func SubscriptionRequestWaffoPancakePay(c *gin.Context) {
|
||||
var req SubscriptionWaffoPancakePayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
|
||||
common.ApiErrorMsg(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
plan, err := model.GetSubscriptionPlanById(req.PlanId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if !plan.Enabled {
|
||||
common.ApiErrorMsg(c, "套餐未启用")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(plan.WaffoPancakeProductId) == "" {
|
||||
common.ApiErrorMsg(c, "该套餐未配置 WaffoPancakeProductId")
|
||||
return
|
||||
}
|
||||
// Plan targets its own Pancake product, so we only require credentials
|
||||
// here — not the gateway-level WaffoPancakeProductID.
|
||||
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" {
|
||||
common.ApiErrorMsg(c, "Waffo Pancake 未配置或密钥无效")
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if user == nil {
|
||||
common.ApiErrorMsg(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if plan.MaxPurchasePerUser > 0 {
|
||||
count, err := model.CountUserSubscriptionsByPlan(userId, plan.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if count >= int64(plan.MaxPurchasePerUser) {
|
||||
common.ApiErrorMsg(c, "已达到该套餐购买上限")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WAFFO_PANCAKE_SUB- prefix (vs. wallet's WAFFO_PANCAKE-) drives webhook
|
||||
// dispatch in WaffoPancakeWebhook.
|
||||
tradeNo := fmt.Sprintf("WAFFO_PANCAKE_SUB-%d-%d-%s", userId, time.Now().UnixMilli(), randstr.String(6))
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
expiresInSeconds := 45 * 60
|
||||
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
|
||||
ProductID: plan.WaffoPancakeProductId,
|
||||
BuyerIdentity: service.WaffoPancakeBuyerIdentityFromUserID(user.Id),
|
||||
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
|
||||
Amount: decimal.NewFromFloat(plan.PriceAmount).StringFixed(2),
|
||||
TaxCategory: "saas",
|
||||
},
|
||||
BuyerEmail: getWaffoPancakeBuyerEmail(user),
|
||||
ExpiresInSeconds: &expiresInSeconds,
|
||||
OrderMerchantExternalID: tradeNo,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅结账会话创建失败 user_id=%d plan_id=%d trade_no=%s error=%q", userId, plan.Id, tradeNo, err.Error()))
|
||||
order.Status = common.TopUpStatusFailed
|
||||
_ = order.Update()
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅订单创建成功 user_id=%d plan_id=%d trade_no=%s session_id=%s money=%.2f", userId, plan.Id, tradeNo, session.SessionID, plan.PriceAmount))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"checkout_url": session.CheckoutURL,
|
||||
"session_id": session.SessionID,
|
||||
"expires_at": session.ExpiresAt,
|
||||
"order_id": tradeNo,
|
||||
"token": session.Token,
|
||||
"token_expires_at": session.TokenExpiresAt,
|
||||
},
|
||||
})
|
||||
}
|
||||
+23
-28
@@ -22,13 +22,8 @@ import (
|
||||
)
|
||||
|
||||
func GetTopUpInfo(c *gin.Context) {
|
||||
complianceConfirmed := operation_setting.IsPaymentComplianceConfirmed()
|
||||
|
||||
// 获取支付方式
|
||||
payMethods := operation_setting.PayMethods
|
||||
if !complianceConfirmed {
|
||||
payMethods = []map[string]string{}
|
||||
}
|
||||
|
||||
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||
if isStripeTopUpEnabled() {
|
||||
@@ -52,6 +47,27 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Waffo Pancake displayed above the legacy Waffo gateway.
|
||||
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
|
||||
if enableWaffoPancake {
|
||||
hasWaffoPancake := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == model.PaymentMethodWaffoPancake {
|
||||
hasWaffoPancake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasWaffoPancake {
|
||||
payMethods = append(payMethods, map[string]string{
|
||||
"name": "Waffo Pancake",
|
||||
"type": model.PaymentMethodWaffoPancake,
|
||||
"color": "rgba(var(--semi-orange-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 如果启用了 Waffo 支付,添加到支付方法列表
|
||||
enableWaffo := isWaffoTopUpEnabled()
|
||||
if enableWaffo {
|
||||
@@ -74,35 +90,14 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
|
||||
if enableWaffoPancake {
|
||||
hasWaffoPancake := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == model.PaymentMethodWaffoPancake {
|
||||
hasWaffoPancake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasWaffoPancake {
|
||||
payMethods = append(payMethods, map[string]string{
|
||||
"name": "Waffo Pancake",
|
||||
"type": model.PaymentMethodWaffoPancake,
|
||||
"color": "rgba(var(--semi-orange-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": isEpayTopUpEnabled(),
|
||||
"enable_stripe_topup": isStripeTopUpEnabled(),
|
||||
"enable_creem_topup": isCreemTopUpEnabled(),
|
||||
"enable_waffo_topup": enableWaffo,
|
||||
"enable_waffo_pancake_topup": enableWaffoPancake,
|
||||
"enable_redemption": complianceConfirmed,
|
||||
"payment_compliance_confirmed": complianceConfirmed,
|
||||
"payment_compliance_terms_version": operation_setting.CurrentComplianceTermsVersion,
|
||||
"enable_redemption": true,
|
||||
"payment_compliance_confirmed": true,
|
||||
"waffo_pay_methods": func() interface{} {
|
||||
if enableWaffo {
|
||||
return setting.GetWaffoPayMethods()
|
||||
|
||||
@@ -96,33 +96,257 @@ func getWaffoPancakeBuyerEmail(user *model.User) string {
|
||||
if user != nil && strings.TrimSpace(user.Email) != "" {
|
||||
return user.Email
|
||||
}
|
||||
if user != nil {
|
||||
return fmt.Sprintf("%d@new-api.local", user.Id)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getWaffoPancakeReturnURL() string {
|
||||
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
|
||||
return setting.WaffoPancakeReturnURL
|
||||
// The admin config endpoints below accept typed-but-not-yet-saved creds in
|
||||
// the body and fall back to persisted creds when the body is blank (see
|
||||
// resolveWaffoPancakeAdminCreds). Only SaveWaffoPancake writes to OptionMap.
|
||||
|
||||
type waffoPancakeCredsRequest struct {
|
||||
MerchantID string `json:"merchant_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
}
|
||||
|
||||
type saveWaffoPancakeRequest struct {
|
||||
MerchantID string `json:"merchant_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
StoreID string `json:"store_id"`
|
||||
ProductID string `json:"product_id"`
|
||||
}
|
||||
|
||||
type createWaffoPancakePairRequest struct {
|
||||
MerchantID string `json:"merchant_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
}
|
||||
|
||||
// SaveWaffoPancake atomically persists all five operator-controlled fields.
|
||||
// Catalog / pair endpoints are transient — only this one writes the OptionMap.
|
||||
func SaveWaffoPancake(c *gin.Context) {
|
||||
var req saveWaffoPancakeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
return paymentReturnPath("/console/topup?show_history=true")
|
||||
if err := service.SaveWaffoPancakeConfig(
|
||||
c.Request.Context(),
|
||||
req.MerchantID,
|
||||
req.PrivateKey,
|
||||
req.ReturnURL,
|
||||
req.StoreID,
|
||||
req.ProductID,
|
||||
); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 保存配置失败 store_id=%q product_id=%q error=%q",
|
||||
req.StoreID, req.ProductID, err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "保存配置失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"product_id": setting.WaffoPancakeProductID,
|
||||
"store_id": setting.WaffoPancakeStoreID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// resolveWaffoPancakeAdminCreds prefers body creds (typed-but-not-yet-saved
|
||||
// values, for verification) and falls back to persisted creds when the body
|
||||
// is blank (so returning admins don't have to re-paste the private key,
|
||||
// which is stripped from GET /api/option/).
|
||||
func resolveWaffoPancakeAdminCreds(bodyMerchantID, bodyPrivateKey string) (string, string) {
|
||||
m := strings.TrimSpace(bodyMerchantID)
|
||||
k := strings.TrimSpace(bodyPrivateKey)
|
||||
if m == "" && k == "" {
|
||||
return setting.WaffoPancakeMerchantID, setting.WaffoPancakePrivateKey
|
||||
}
|
||||
return m, k
|
||||
}
|
||||
|
||||
// CreateWaffoPancakePair mints a Store + OnetimeProduct pair in one round-
|
||||
// trip. Surfaces an orphan-store flag when the product half fails so the
|
||||
// frontend can preselect / retry without losing context.
|
||||
func CreateWaffoPancakePair(c *gin.Context) {
|
||||
var req createWaffoPancakePairRequest
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
}
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
|
||||
if merchantID == "" || privateKey == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
|
||||
return
|
||||
}
|
||||
result, err := service.CreateWaffoPancakePrimaryPair(
|
||||
c.Request.Context(), merchantID, privateKey, req.ReturnURL,
|
||||
)
|
||||
if err != nil {
|
||||
orphan := result != nil && result.OrphanStore
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 创建店铺与产品失败 orphan_store=%t store_id=%q error=%q",
|
||||
orphan, func() string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
return result.StoreID
|
||||
}(), err.Error(),
|
||||
))
|
||||
data := gin.H{"error": err.Error()}
|
||||
if orphan {
|
||||
data["store_id"] = result.StoreID
|
||||
data["store_name"] = result.StoreName
|
||||
data["orphan_store"] = true
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": data})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"store_id": result.StoreID,
|
||||
"store_name": result.StoreName,
|
||||
"product_id": result.ProductID,
|
||||
"product_name": result.ProductName,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ListWaffoPancakeCatalog returns the merchant's Stores + OnetimeProducts.
|
||||
// Doubles as a credential probe (a successful 200 proves the resolved creds
|
||||
// authenticate). See resolveWaffoPancakeAdminCreds for credential resolution.
|
||||
func ListWaffoPancakeCatalog(c *gin.Context) {
|
||||
var req waffoPancakeCredsRequest
|
||||
// An empty body means "use persisted creds"; only fail on malformed JSON.
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
}
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds(req.MerchantID, req.PrivateKey)
|
||||
if merchantID == "" || privateKey == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 凭证未配置"})
|
||||
return
|
||||
}
|
||||
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 拉取店铺与产品目录失败 error=%q", err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取目录失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": catalog})
|
||||
}
|
||||
|
||||
type createWaffoPancakeSubscriptionProductRequest struct {
|
||||
Name string `json:"name"`
|
||||
Amount string `json:"amount"`
|
||||
}
|
||||
|
||||
// CreateWaffoPancakeSubscriptionProduct mints an OnetimeProduct (not
|
||||
// SubscriptionProduct — see service.CreateWaffoPancakeProductForPlan)
|
||||
// sized to a plan's `name` + `amount`, using persisted Pancake credentials
|
||||
// + StoreID. Reads from the form, not the plan row, so newly-typed unsaved
|
||||
// plans can mint a product too.
|
||||
func CreateWaffoPancakeSubscriptionProduct(c *gin.Context) {
|
||||
var req createWaffoPancakeSubscriptionProductRequest
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐名称不能为空"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Amount) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "套餐价格不能为空"})
|
||||
return
|
||||
}
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
|
||||
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
|
||||
if merchantID == "" || privateKey == "" || storeID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
|
||||
return
|
||||
}
|
||||
productID, err := service.CreateWaffoPancakeProductForPlan(
|
||||
c.Request.Context(),
|
||||
merchantID,
|
||||
privateKey,
|
||||
storeID,
|
||||
req.Name,
|
||||
req.Amount,
|
||||
setting.WaffoPancakeReturnURL,
|
||||
)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 创建套餐产品失败 store_id=%q name=%q amount=%q error=%q",
|
||||
storeID, req.Name, req.Amount, err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建套餐产品失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"product_id": productID,
|
||||
"product_name": req.Name,
|
||||
"store_id": storeID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ListWaffoPancakeSubscriptionProductOptions returns the OnetimeProducts
|
||||
// in the saved Pancake store, for the subscription-plan dropdown. The name
|
||||
// reflects new-api's plan concept; under the hood it's still OnetimeProducts.
|
||||
func ListWaffoPancakeSubscriptionProductOptions(c *gin.Context) {
|
||||
merchantID, privateKey := resolveWaffoPancakeAdminCreds("", "")
|
||||
storeID := strings.TrimSpace(setting.WaffoPancakeStoreID)
|
||||
if merchantID == "" || privateKey == "" || storeID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 未完成配置,请先在支付设置中完成网关绑定"})
|
||||
return
|
||||
}
|
||||
catalog, err := service.ListWaffoPancakeCatalog(c.Request.Context(), merchantID, privateKey)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake 拉取订阅产品列表失败 store_id=%q error=%q", storeID, err.Error(),
|
||||
))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉取产品列表失败"})
|
||||
return
|
||||
}
|
||||
products := []service.WaffoPancakeCatalogProduct{}
|
||||
for _, store := range catalog.Stores {
|
||||
if store.ID == storeID {
|
||||
products = store.OnetimeProducts
|
||||
break
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"store_id": storeID,
|
||||
"products": products,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func getWaffoPancakeBuyerIdentity(user *model.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
return service.WaffoPancakeBuyerIdentityFromUserID(user.Id)
|
||||
}
|
||||
|
||||
func RequestWaffoPancakePay(c *gin.Context) {
|
||||
if !setting.WaffoPancakeEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
|
||||
return
|
||||
}
|
||||
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
|
||||
if setting.WaffoPancakeSandbox {
|
||||
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
|
||||
}
|
||||
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
|
||||
strings.TrimSpace(currentWebhookKey) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
|
||||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
|
||||
if !isWaffoPancakeTopUpEnabled() {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
|
||||
return
|
||||
}
|
||||
@@ -175,18 +399,15 @@ func RequestWaffoPancakePay(c *gin.Context) {
|
||||
|
||||
expiresInSeconds := 45 * 60
|
||||
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
|
||||
StoreID: setting.WaffoPancakeStoreID,
|
||||
ProductID: setting.WaffoPancakeProductID,
|
||||
ProductType: "onetime",
|
||||
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
|
||||
ProductID: setting.WaffoPancakeProductID,
|
||||
BuyerIdentity: getWaffoPancakeBuyerIdentity(user),
|
||||
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
|
||||
Amount: formatWaffoPancakeAmount(payMoney),
|
||||
TaxIncluded: false,
|
||||
TaxCategory: "saas",
|
||||
},
|
||||
BuyerEmail: getWaffoPancakeBuyerEmail(user),
|
||||
SuccessURL: getWaffoPancakeReturnURL(),
|
||||
ExpiresInSeconds: &expiresInSeconds,
|
||||
BuyerEmail: getWaffoPancakeBuyerEmail(user),
|
||||
ExpiresInSeconds: &expiresInSeconds,
|
||||
OrderMerchantExternalID: tradeNo,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
|
||||
@@ -200,10 +421,12 @@ func RequestWaffoPancakePay(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"checkout_url": session.CheckoutURL,
|
||||
"session_id": session.SessionID,
|
||||
"expires_at": session.ExpiresAt,
|
||||
"order_id": tradeNo,
|
||||
"checkout_url": session.CheckoutURL,
|
||||
"session_id": session.SessionID,
|
||||
"expires_at": session.ExpiresAt,
|
||||
"order_id": tradeNo,
|
||||
"token": session.Token,
|
||||
"token_expires_at": session.TokenExpiresAt,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -215,6 +438,19 @@ func WaffoPancakeWebhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// :env splits test vs prod traffic at the routing layer — operator
|
||||
// registers each URL in the matching webhook slot in Pancake's dashboard.
|
||||
// We then enforce event.mode == expectedEnv to catch mis-registrations.
|
||||
expectedEnv := strings.TrimSpace(c.Param("env"))
|
||||
if expectedEnv != "test" && expectedEnv != "prod" {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 路径环境段无效 env=%q path=%q client_ip=%s",
|
||||
expectedEnv, c.Request.RequestURI, c.ClientIP(),
|
||||
))
|
||||
c.String(http.StatusNotFound, "unknown env")
|
||||
return
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
|
||||
@@ -232,15 +468,57 @@ func WaffoPancakeWebhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.EqualFold(strings.TrimSpace(event.Mode), expectedEnv) {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 环境不匹配 expected=%q actual_mode=%q event_id=%s order_id=%s client_ip=%s",
|
||||
expectedEnv, event.Mode, event.ID, event.Data.OrderID, c.ClientIP(),
|
||||
))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
|
||||
if event.NormalizedEventType() != "order.completed" {
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
// Dispatch by trade_no prefix. OrderMerchantExternalID = our trade_no;
|
||||
// OrderID is Pancake's internal ORD_* (logs only).
|
||||
rawTradeNo := strings.TrimSpace(event.Data.OrderMerchantExternalID)
|
||||
isSubscription := strings.HasPrefix(rawTradeNo, "WAFFO_PANCAKE_SUB-")
|
||||
|
||||
if isSubscription {
|
||||
tradeNo, err := service.ResolveWaffoPancakeSubscriptionTradeNo(event)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 订阅订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
|
||||
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
|
||||
))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
LockOrder(tradeNo)
|
||||
defer UnlockOrder(tradeNo)
|
||||
if err := model.CompleteSubscriptionOrder(tradeNo, string(bodyBytes), model.PaymentProviderWaffoPancake, ""); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
|
||||
c.String(http.StatusInternalServerError, "retry")
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 订阅完成 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
|
||||
// LogError (not LogWarn): covers order-not-found and buyer-identity
|
||||
// mismatch — both warrant human attention. 200 OK so Waffo doesn't
|
||||
// retry a permanently-unresolvable webhook.
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf(
|
||||
"Waffo Pancake webhook 订单解析失败 event_id=%s order_id=%s buyer_identity=%q client_ip=%s error=%q",
|
||||
event.ID, event.Data.OrderID, event.Data.MerchantProvidedBuyerIdentity, c.ClientIP(), err.Error(),
|
||||
))
|
||||
c.String(http.StatusOK, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
+13
-11
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
|
||||
@@ -251,8 +250,20 @@ func GetAllUsers(c *gin.Context) {
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
var role *int
|
||||
if roleStr := c.Query("role"); roleStr != "" {
|
||||
if parsed, err := strconv.Atoi(roleStr); err == nil {
|
||||
role = &parsed
|
||||
}
|
||||
}
|
||||
var status *int
|
||||
if statusStr := c.Query("status"); statusStr != "" {
|
||||
if parsed, err := strconv.Atoi(statusStr); err == nil {
|
||||
status = &parsed
|
||||
}
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
users, total, err := model.SearchUsers(keyword, group, role, status, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
@@ -332,10 +343,6 @@ type TransferAffQuotaRequest struct {
|
||||
}
|
||||
|
||||
func TransferAffQuota(c *gin.Context) {
|
||||
if !requirePaymentCompliance(c) {
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
@@ -1092,11 +1099,6 @@ func getTopUpLock(userID int) *topUpTryLock {
|
||||
}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
if !operation_setting.IsPaymentComplianceConfirmed() {
|
||||
common.ApiErrorI18n(c, i18n.MsgPaymentComplianceRequired)
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
lock := getTopUpLock(id)
|
||||
if !lock.TryLock() {
|
||||
|
||||
+2
-1
@@ -16,7 +16,7 @@ version: '3.4' # For compatibility with older Docker versions
|
||||
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
image: git.viaeon.com/admin/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
@@ -34,6 +34,7 @@ services:
|
||||
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
|
||||
- NODE_NAME=new-api-node-1 # 节点名称,用于审计日志中标识节点身份;多节点/容器部署时建议设置 (Node name used in audit logs; recommended when running multiple instances or in containers)
|
||||
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions)
|
||||
# - RELAY_IDLE_CONN_TIMEOUT=90 # Relay HTTP 客户端空闲连接超时时间,单位秒,默认跟随 Go 标准库,设置为0表示不限制 (Relay HTTP client idle keep-alive timeout in seconds, defaults to Go standard library; set 0 to disable)
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!!)
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
# - GOOGLE_ANALYTICS_ID=G-XXXXXXXXXX # Google Analytics 的测量 ID (Google Analytics Measurement ID)
|
||||
|
||||
+6
-6
@@ -26,11 +26,11 @@ type ImageRequest struct {
|
||||
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
||||
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
Images json.RawMessage `json:"images,omitempty"`
|
||||
Mask json.RawMessage `json:"mask,omitempty"`
|
||||
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Images json.RawMessage `json:"images,omitempty"`
|
||||
Mask json.RawMessage `json:"mask,omitempty"`
|
||||
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
// zhipu 4v
|
||||
WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
|
||||
UserId json.RawMessage `json:"user_id,omitempty"`
|
||||
@@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
|
||||
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
return i.Stream != nil && *i.Stream
|
||||
}
|
||||
|
||||
func (i *ImageRequest) SetModelName(modelName string) {
|
||||
|
||||
+12
-2
@@ -213,12 +213,22 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
return result
|
||||
}
|
||||
|
||||
func IsOpenAIReasoningOModel(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "o1") ||
|
||||
strings.HasPrefix(modelName, "o3") ||
|
||||
strings.HasPrefix(modelName, "o4")
|
||||
}
|
||||
|
||||
func IsOpenAIGPT5Model(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "gpt-5")
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||
if strings.HasPrefix(r.Model, "o") {
|
||||
if IsOpenAIReasoningOModel(r.Model) {
|
||||
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||
return "developer"
|
||||
}
|
||||
} else if strings.HasPrefix(r.Model, "gpt-5") {
|
||||
} else if IsOpenAIGPT5Model(r.Model) {
|
||||
return "developer"
|
||||
}
|
||||
return "system"
|
||||
|
||||
@@ -71,3 +71,27 @@ func TestOpenAIResponsesRequestPreserveExplicitZeroValues(t *testing.T) {
|
||||
require.True(t, gjson.GetBytes(encoded, "stream").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "top_p").Exists())
|
||||
}
|
||||
|
||||
func TestGeneralOpenAIRequestGetSystemRoleName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want string
|
||||
}{
|
||||
{name: "o1 uses developer", model: "o1", want: "developer"},
|
||||
{name: "o3 family uses developer", model: "o3-mini-high", want: "developer"},
|
||||
{name: "o4 family uses developer", model: "o4-mini", want: "developer"},
|
||||
{name: "o1 mini stays system", model: "o1-mini", want: "system"},
|
||||
{name: "o1 preview stays system", model: "o1-preview", want: "system"},
|
||||
{name: "gpt 5 uses developer", model: "gpt-5", want: "developer"},
|
||||
{name: "omni is not o series", model: "omni-moderation-latest", want: "system"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := GeneralOpenAIRequest{Model: tt.model}
|
||||
|
||||
require.Equal(t, tt.want, req.GetSystemRoleName())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+2
-2
@@ -18,10 +18,10 @@
|
||||
"openai",
|
||||
"claude"
|
||||
],
|
||||
"author": "QuantumNous",
|
||||
"author": "modelstoken",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/QuantumNous/new-api"
|
||||
"url": "https://git.viaeon.com/admin/new-api"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cross-env": "^7.0.3",
|
||||
|
||||
@@ -60,6 +60,8 @@ require (
|
||||
gorm.io/gorm v1.25.2
|
||||
)
|
||||
|
||||
require github.com/waffo-com/waffo-pancake-sdk-go v0.3.1
|
||||
|
||||
require (
|
||||
github.com/DmitriyVTitov/size v1.5.0 // indirect
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||
|
||||
@@ -308,6 +308,12 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/waffo-com/waffo-go v1.3.1 h1:NCYD3oQ59DTJj1bwS5T/659LI4h8PuAIW4Qj/w7fKPw=
|
||||
github.com/waffo-com/waffo-go v1.3.1/go.mod h1:IaXVYq6mmYtrLFFsLxPslNwuIZx0mIadWWjhe+eWb0g=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1 h1:YOI7+3zTBlTB7Ou6+ZXnJV2JvW/ag9d7CwE/TxH3Hls=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.1.1/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0 h1:cCSgccM66p7feTtgRqUUGT50tYQOhahsoPXavd+ib1U=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.2.0/go.mod h1:5MBCGH/nqRRA5sHO/lQB/96r4BTAqy8QpWxn53m9htI=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1 h1:ngQSN/oVB35xTwFPLfg++bxPC+SptcF145Mb6c62YCc=
|
||||
github.com/waffo-com/waffo-pancake-sdk-go v0.3.1/go.mod h1:OB2MyFIQaefoPO0FV3J+yu9sDP8RVFQ+sbFsXqGuObc=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
|
||||
@@ -152,7 +152,6 @@ const (
|
||||
MsgPaymentWebhookNotConfig = "payment.webhook_not_configured"
|
||||
MsgPaymentPriceIdNotConfig = "payment.price_id_not_configured"
|
||||
MsgPaymentCreemNotConfig = "payment.creem_not_configured"
|
||||
MsgPaymentComplianceRequired = "payment.compliance_required"
|
||||
)
|
||||
|
||||
// Topup related messages
|
||||
|
||||
@@ -164,7 +164,7 @@ func main() {
|
||||
common.SysLog(fmt.Sprintf("panic detected: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://git.viaeon.com/admin/new-api/issues", err),
|
||||
"type": "new_api_panic",
|
||||
},
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
FRONTEND_DIR = ./web/default
|
||||
FRONTEND_CLASSIC_DIR = ./web/classic
|
||||
BACKEND_DIR = .
|
||||
DEV_FRONTEND_DEFAULT_PORT ?= 5173
|
||||
DEV_FRONTEND_CLASSIC_PORT ?= 5174
|
||||
DEV_COMPOSE_FILE = docker-compose.dev.yml
|
||||
DEV_POSTGRES_SERVICE = postgres
|
||||
DEV_BACKEND_SERVICE = new-api
|
||||
@@ -14,11 +16,13 @@ all: build-all-frontends start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building default frontend..."
|
||||
@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
@cd ./web && bun install --frozen-lockfile
|
||||
@cd $(FRONTEND_DIR) && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
|
||||
build-frontend-classic:
|
||||
@echo "Building classic frontend..."
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun install && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
@cd ./web && bun install --frozen-lockfile
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
|
||||
build-all-frontends: build-frontend build-frontend-classic
|
||||
|
||||
@@ -35,12 +39,35 @@ dev-api-rebuild:
|
||||
@docker compose -f $(DEV_COMPOSE_FILE) up -d --build $(DEV_BACKEND_SERVICE)
|
||||
|
||||
dev-web:
|
||||
@echo "Starting frontend dev server..."
|
||||
@cd $(FRONTEND_DIR) && bun install && bun run dev
|
||||
@echo "Starting both frontend dev servers..."
|
||||
@echo "Default frontend: http://localhost:$(DEV_FRONTEND_DEFAULT_PORT)"
|
||||
@echo "Classic frontend: http://localhost:$(DEV_FRONTEND_CLASSIC_PORT)"
|
||||
@cd ./web && bun install
|
||||
@(cd $(FRONTEND_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_DEFAULT_PORT)) & \
|
||||
default_pid=$$!; \
|
||||
(cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)) & \
|
||||
classic_pid=$$!; \
|
||||
trap 'kill $$default_pid $$classic_pid 2>/dev/null; wait $$default_pid $$classic_pid 2>/dev/null; exit 130' INT TERM; \
|
||||
while kill -0 $$default_pid 2>/dev/null && kill -0 $$classic_pid 2>/dev/null; do \
|
||||
sleep 1; \
|
||||
done; \
|
||||
if ! kill -0 $$default_pid 2>/dev/null; then \
|
||||
wait $$default_pid; \
|
||||
status=$$?; \
|
||||
kill $$classic_pid 2>/dev/null; \
|
||||
wait $$classic_pid 2>/dev/null; \
|
||||
exit $$status; \
|
||||
fi; \
|
||||
wait $$classic_pid; \
|
||||
status=$$?; \
|
||||
kill $$default_pid 2>/dev/null; \
|
||||
wait $$default_pid 2>/dev/null; \
|
||||
exit $$status
|
||||
|
||||
dev-web-classic:
|
||||
@echo "Starting classic frontend dev server..."
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun install && bun run dev
|
||||
@cd ./web && bun install
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun run dev -- --host 0.0.0.0 --port $(DEV_FRONTEND_CLASSIC_PORT)
|
||||
|
||||
dev: dev-api dev-web
|
||||
|
||||
|
||||
@@ -163,6 +163,10 @@ func TryUserAuth() func(c *gin.Context) {
|
||||
if id != nil {
|
||||
c.Set("id", id)
|
||||
}
|
||||
role := session.Get("role")
|
||||
if role != nil {
|
||||
c.Set("role", role)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
@@ -100,14 +102,10 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
|
||||
affinityUsable := false
|
||||
preferred, err := model.CacheGetChannel(preferredChannelID)
|
||||
if err == nil && preferred != nil {
|
||||
if preferred.Status != common.ChannelStatusEnabled {
|
||||
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorAffinityChannelDisabled))
|
||||
return
|
||||
}
|
||||
} else if usingGroup == "auto" {
|
||||
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
|
||||
if usingGroup == "auto" {
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
autoGroups := service.GetUserAutoGroup(userGroup)
|
||||
for _, g := range autoGroups {
|
||||
@@ -115,6 +113,7 @@ func Distribute() func(c *gin.Context) {
|
||||
selectGroup = g
|
||||
common.SetContextKey(c, constant.ContextKeyAutoGroup, g)
|
||||
channel = preferred
|
||||
affinityUsable = true
|
||||
service.MarkChannelAffinityUsed(c, g, preferred.Id)
|
||||
break
|
||||
}
|
||||
@@ -122,9 +121,13 @@ func Distribute() func(c *gin.Context) {
|
||||
} else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) {
|
||||
channel = preferred
|
||||
selectGroup = usingGroup
|
||||
affinityUsable = true
|
||||
service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id)
|
||||
}
|
||||
}
|
||||
if !affinityUsable && !service.ShouldKeepChannelAffinityOnChannelDisabled() {
|
||||
service.ClearCurrentChannelAffinityCache(c)
|
||||
}
|
||||
}
|
||||
|
||||
if channel == nil {
|
||||
@@ -170,6 +173,14 @@ func Distribute() func(c *gin.Context) {
|
||||
// - application/x-www-form-urlencoded
|
||||
// - multipart/form-data
|
||||
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
||||
if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") {
|
||||
modelRequest, err := getModelFromJSONBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()}))
|
||||
}
|
||||
return modelRequest, nil
|
||||
}
|
||||
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
@@ -178,6 +189,50 @@ func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
||||
return &modelRequest, nil
|
||||
}
|
||||
|
||||
func getModelFromJSONBody(c *gin.Context) (*ModelRequest, error) {
|
||||
storage, err := common.GetBodyStorage(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !gjson.ValidBytes(requestBody) {
|
||||
return nil, errors.New("invalid JSON request body")
|
||||
}
|
||||
|
||||
values := gjson.GetManyBytes(requestBody, "model", "group")
|
||||
model, err := getJSONStringValue(values[0], "model")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group, err := getJSONStringValue(values[1], "group")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||
return nil, seekErr
|
||||
}
|
||||
c.Request.Body = io.NopCloser(storage)
|
||||
|
||||
return &ModelRequest{
|
||||
Model: model,
|
||||
Group: group,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getJSONStringValue(result gjson.Result, field string) (string, error) {
|
||||
if !result.Exists() || result.Type == gjson.Null {
|
||||
return "", nil
|
||||
}
|
||||
if result.Type != gjson.String {
|
||||
return "", fmt.Errorf("field %s must be a string", field)
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
var modelRequest ModelRequest
|
||||
shouldSelectChannel := true
|
||||
@@ -244,6 +299,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
shouldSelectChannel = false
|
||||
modelRequest.Model = getTaskOriginModelName(c)
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
@@ -258,6 +314,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
shouldSelectChannel = false
|
||||
modelRequest.Model = getTaskOriginModelName(c)
|
||||
}
|
||||
if _, ok := c.Get("relay_mode"); !ok {
|
||||
c.Set("relay_mode", relayMode)
|
||||
@@ -342,6 +399,31 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
// 修复 #4834: GET /v1/video/generations/:task_id && /v1/video/:task_id 此前不解析 model,
|
||||
// 当 token 启用「可用模型限制」时,下游 modelLimitEnable 校验会因
|
||||
// modelRequest.Model 为空而误报 "This token has no access to model"。
|
||||
// 从已存储的任务记录中回填 OriginModelName 即可让校验走在正确的模型上。
|
||||
func getTaskOriginModelName(c *gin.Context) string {
|
||||
if !common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) {
|
||||
return ""
|
||||
}
|
||||
|
||||
taskId := c.Param("task_id")
|
||||
if taskId == "" {
|
||||
// jimeng adapter
|
||||
taskId = c.GetString("task_id")
|
||||
}
|
||||
if taskId == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
if task, exist, err := model.GetByTaskId(userId, taskId); err == nil && exist && task != nil {
|
||||
return task.Properties.OriginModelName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
|
||||
@@ -17,7 +17,7 @@ func RelayPanicRecover() gin.HandlerFunc {
|
||||
common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://git.viaeon.com/admin/new-api/issues", err),
|
||||
"type": "new_api_panic",
|
||||
},
|
||||
})
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AnonymousRequestBodyLimit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
maxBytes := common.GetAnonymousRequestBodyLimitBytes()
|
||||
if maxBytes <= 0 || c.Request.Body == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
originalBody := c.Request.Body
|
||||
limitedBody, err := readAnonymousRequestBody(originalBody, maxBytes)
|
||||
_ = originalBody.Close()
|
||||
if err != nil {
|
||||
if common.IsRequestBodyTooLargeError(err) {
|
||||
c.AbortWithStatus(http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(limitedBody))
|
||||
c.Request.ContentLength = int64(len(limitedBody))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func readAnonymousRequestBody(body io.Reader, maxBytes int64) ([]byte, error) {
|
||||
data, err := io.ReadAll(io.LimitReader(body, maxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, common.ErrRequestBodyTooLarge
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
+33
-2
@@ -643,13 +643,25 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason
|
||||
if len(keys) == 0 {
|
||||
channel.Status = status
|
||||
} else {
|
||||
var keyIndex int
|
||||
keyIndex := -1
|
||||
for i, key := range keys {
|
||||
if key == usingKey {
|
||||
keyIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if keyIndex < 0 {
|
||||
if usingKey != "" {
|
||||
common.SysLog(fmt.Sprintf("failed to update multi-key status: channel_id=%d, using key not found", channel.Id))
|
||||
return
|
||||
}
|
||||
channel.Status = status
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
return
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
@@ -666,16 +678,31 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason
|
||||
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||||
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||||
}
|
||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||
if !hasEnabledMultiKey(keys, channel.ChannelInfo.MultiKeyStatusList) {
|
||||
channel.Status = common.ChannelStatusAutoDisabled
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = "All keys are disabled"
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
} else if status == common.ChannelStatusEnabled {
|
||||
channel.Status = common.ChannelStatusEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func hasEnabledMultiKey(keys []string, statusList map[int]int) bool {
|
||||
for i := range keys {
|
||||
if statusList == nil {
|
||||
return true
|
||||
}
|
||||
status, ok := statusList[i]
|
||||
if !ok || status == common.ChannelStatusEnabled {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||||
if common.MemoryCacheEnabled {
|
||||
channelStatusLock.Lock()
|
||||
@@ -687,11 +714,15 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
}
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
|
||||
beforeStatus := channelCache.Status
|
||||
pollingLock := GetChannelPollingLock(channelId)
|
||||
pollingLock.Lock()
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||
pollingLock.Unlock()
|
||||
if beforeStatus != channelCache.Status {
|
||||
CacheUpdateChannelStatus(channelId, channelCache.Status)
|
||||
}
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Document struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
Title string `json:"title" gorm:"not null"`
|
||||
Slug string `json:"slug" gorm:"type:varchar(255);uniqueIndex;not null"`
|
||||
Content string `json:"content" gorm:"type:text;not null"`
|
||||
CategoryId *int `json:"category_id" gorm:"index"`
|
||||
Visibility string `json:"visibility" gorm:"default:'public'"` // public, auth, admin
|
||||
SortOrder int `json:"sort_order" gorm:"default:0"`
|
||||
AuthorId int `json:"author_id" gorm:"not null"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
func GetDocuments(keyword string, visibility string, categoryId *int, startIdx int, num int) ([]*Document, int64, error) {
|
||||
query := DB.Model(&Document{})
|
||||
if keyword != "" {
|
||||
like := "%" + keyword + "%"
|
||||
query = query.Where("title LIKE ? OR content LIKE ?", like, like)
|
||||
}
|
||||
if visibility != "" {
|
||||
query = query.Where("visibility = ?", visibility)
|
||||
}
|
||||
if categoryId != nil {
|
||||
query = query.Where("category_id = ?", *categoryId)
|
||||
}
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var documents []*Document
|
||||
if err := query.Order("sort_order ASC, id DESC").Offset(startIdx).Limit(num).Find(&documents).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return documents, total, nil
|
||||
}
|
||||
|
||||
func GetDocumentsByVisibility(keyword string, visibilities []string, categoryId *int, startIdx int, num int) ([]*Document, int64, error) {
|
||||
query := DB.Model(&Document{})
|
||||
if keyword != "" {
|
||||
like := "%" + keyword + "%"
|
||||
query = query.Where("title LIKE ? OR content LIKE ?", like, like)
|
||||
}
|
||||
if len(visibilities) > 0 {
|
||||
query = query.Where("visibility IN ?", visibilities)
|
||||
}
|
||||
if categoryId != nil {
|
||||
query = query.Where("category_id = ?", *categoryId)
|
||||
}
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var documents []*Document
|
||||
if err := query.Order("sort_order ASC, id DESC").Offset(startIdx).Limit(num).Find(&documents).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return documents, total, nil
|
||||
}
|
||||
|
||||
func GetDocumentBySlug(slug string) (*Document, error) {
|
||||
var doc Document
|
||||
err := DB.Where("slug = ?", slug).First(&doc).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &doc, nil
|
||||
}
|
||||
|
||||
func GetDocumentById(id int) (*Document, error) {
|
||||
var doc Document
|
||||
err := DB.First(&doc, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &doc, nil
|
||||
}
|
||||
|
||||
func CreateDocument(doc *Document) error {
|
||||
return DB.Create(doc).Error
|
||||
}
|
||||
|
||||
func UpdateDocument(doc *Document) error {
|
||||
return DB.Model(doc).Select("title", "slug", "content", "category_id", "visibility", "sort_order").Updates(doc).Error
|
||||
}
|
||||
|
||||
func DeleteDocument(id int) error {
|
||||
// Delete associated versions first
|
||||
DB.Where("document_id = ?", id).Delete(&DocumentVersion{})
|
||||
return DB.Delete(&Document{}, id).Error
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type DocumentCategory struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"not null"`
|
||||
Slug string `json:"slug" gorm:"type:varchar(255);uniqueIndex;not null"`
|
||||
ParentId *int `json:"parent_id" gorm:"index"`
|
||||
SortOrder int `json:"sort_order" gorm:"default:0"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
func GetDocumentCategories() ([]*DocumentCategory, error) {
|
||||
var categories []*DocumentCategory
|
||||
err := DB.Order("sort_order ASC, id ASC").Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
func GetDocumentCategoryTree() ([]*DocumentCategory, error) {
|
||||
var categories []*DocumentCategory
|
||||
err := DB.Order("sort_order ASC, id ASC").Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
func CreateDocumentCategory(category *DocumentCategory) error {
|
||||
return DB.Create(category).Error
|
||||
}
|
||||
|
||||
func UpdateDocumentCategory(category *DocumentCategory) error {
|
||||
return DB.Model(category).Select("name", "slug", "parent_id", "sort_order").Updates(category).Error
|
||||
}
|
||||
|
||||
func DeleteDocumentCategory(id int) error {
|
||||
return DB.Delete(&DocumentCategory{}, id).Error
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type DocumentVersion struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
DocumentId int `json:"document_id" gorm:"index;not null"`
|
||||
Content string `json:"content" gorm:"type:text;not null"`
|
||||
AuthorId int `json:"author_id" gorm:"not null"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
func GetDocumentVersions(documentId int, startIdx int, num int) ([]*DocumentVersion, int64, error) {
|
||||
query := DB.Model(&DocumentVersion{}).Where("document_id = ?", documentId)
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var versions []*DocumentVersion
|
||||
if err := query.Order("id DESC").Offset(startIdx).Limit(num).Find(&versions).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return versions, total, nil
|
||||
}
|
||||
|
||||
func CreateDocumentVersion(version *DocumentVersion) error {
|
||||
return DB.Create(version).Error
|
||||
}
|
||||
+65
-49
@@ -17,25 +17,39 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func applyExplicitLogTextFilter(tx *gorm.DB, column string, value string) (*gorm.DB, error) {
|
||||
if value == "" {
|
||||
return tx, nil
|
||||
}
|
||||
if strings.Contains(value, "%") {
|
||||
pattern, err := sanitizeLikePattern(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern), nil
|
||||
}
|
||||
return tx.Where(column+" = ?", value), nil
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:2;index:idx_user_id_id,priority:2"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:1;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
|
||||
UpstreamRequestId string `json:"upstream_request_id,omitempty" gorm:"type:varchar(128);index:idx_logs_upstream_request_id;default:''"`
|
||||
Other string `json:"other"`
|
||||
@@ -146,7 +160,7 @@ func RecordTopupLog(userId int, content string, callerIp string, paymentMethod s
|
||||
|
||||
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, common.LocalLogPreview(content)))
|
||||
username := c.GetString("username")
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
|
||||
@@ -309,9 +323,15 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
tx = LOG_DB.Where("logs.type = ?", logType)
|
||||
}
|
||||
|
||||
tx = applyLogContainsFilter(tx, "logs.model_name", modelName)
|
||||
tx = applyLogContainsFilter(tx, "logs.username", username)
|
||||
tx = applyLogContainsFilter(tx, "logs.token_name", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.username", username); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("logs.token_name = ?", tokenName)
|
||||
}
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
}
|
||||
@@ -334,7 +354,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
err = tx.Order("logs.created_at desc, logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -392,8 +412,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
|
||||
}
|
||||
|
||||
tx = applyLogContainsFilter(tx, "logs.model_name", modelName)
|
||||
tx = applyLogContainsFilter(tx, "logs.token_name", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("logs.token_name = ?", tokenName)
|
||||
}
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
}
|
||||
@@ -430,42 +454,34 @@ type Stat struct {
|
||||
Tpm int `json:"tpm"`
|
||||
}
|
||||
|
||||
func logContainsPattern(input string) (string, bool) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
replacer := strings.NewReplacer("!", "!!", "%", "!%", "_", "!_")
|
||||
return "%" + replacer.Replace(input) + "%", true
|
||||
}
|
||||
|
||||
func applyLogContainsFilter(tx *gorm.DB, column string, value string) *gorm.DB {
|
||||
pattern, ok := logContainsPattern(value)
|
||||
if !ok {
|
||||
return tx
|
||||
}
|
||||
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern)
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
|
||||
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
|
||||
|
||||
// 为rpm和tpm创建单独的查询
|
||||
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
|
||||
|
||||
tx = applyLogContainsFilter(tx, "username", username)
|
||||
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "username", username)
|
||||
tx = applyLogContainsFilter(tx, "token_name", tokenName)
|
||||
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "token_name", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "username", username); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "username", username); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("token_name = ?", tokenName)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||
}
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
tx = applyLogContainsFilter(tx, "model_name", modelName)
|
||||
rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "model_name", modelName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "model_name", modelName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "model_name", modelName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
|
||||
@@ -281,6 +281,9 @@ func migrateDB() error {
|
||||
&CustomOAuthProvider{},
|
||||
&UserOAuthBinding{},
|
||||
&PerfMetric{},
|
||||
&DocumentCategory{},
|
||||
&Document{},
|
||||
&DocumentVersion{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -330,6 +333,9 @@ func migrateDBFast() error {
|
||||
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
|
||||
{&UserOAuthBinding{}, "UserOAuthBinding"},
|
||||
{&PerfMetric{}, "PerfMetric"},
|
||||
{&DocumentCategory{}, "DocumentCategory"},
|
||||
{&Document{}, "Document"},
|
||||
{&DocumentVersion{}, "DocumentVersion"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
@@ -397,8 +403,10 @@ func ensureSubscriptionPlanTableSQLite() error {
|
||||
` + "`custom_seconds`" + ` bigint NOT NULL DEFAULT 0,
|
||||
` + "`enabled`" + ` numeric DEFAULT 1,
|
||||
` + "`sort_order`" + ` integer DEFAULT 0,
|
||||
` + "`allow_balance_pay`" + ` numeric DEFAULT 1,
|
||||
` + "`stripe_price_id`" + ` varchar(128) DEFAULT '',
|
||||
` + "`creem_product_id`" + ` varchar(128) DEFAULT '',
|
||||
` + "`waffo_pancake_product_id`" + ` varchar(128) DEFAULT '',
|
||||
` + "`max_purchase_per_user`" + ` integer DEFAULT 0,
|
||||
` + "`upgrade_group`" + ` varchar(64) DEFAULT '',
|
||||
` + "`total_amount`" + ` bigint NOT NULL DEFAULT 0,
|
||||
@@ -430,8 +438,10 @@ PRIMARY KEY (` + "`id`" + `)
|
||||
{Name: "custom_seconds", DDL: "`custom_seconds` bigint NOT NULL DEFAULT 0"},
|
||||
{Name: "enabled", DDL: "`enabled` numeric DEFAULT 1"},
|
||||
{Name: "sort_order", DDL: "`sort_order` integer DEFAULT 0"},
|
||||
{Name: "allow_balance_pay", DDL: "`allow_balance_pay` numeric DEFAULT 1"},
|
||||
{Name: "stripe_price_id", DDL: "`stripe_price_id` varchar(128) DEFAULT ''"},
|
||||
{Name: "creem_product_id", DDL: "`creem_product_id` varchar(128) DEFAULT ''"},
|
||||
{Name: "waffo_pancake_product_id", DDL: "`waffo_pancake_product_id` varchar(128) DEFAULT ''"},
|
||||
{Name: "max_purchase_per_user", DDL: "`max_purchase_per_user` integer DEFAULT 0"},
|
||||
{Name: "upgrade_group", DDL: "`upgrade_group` varchar(64) DEFAULT ''"},
|
||||
{Name: "total_amount", DDL: "`total_amount` bigint NOT NULL DEFAULT 0"},
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
|
||||
@@ -135,6 +136,62 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeLookupValues(values []string) []string {
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
normalized = append(normalized, value)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
|
||||
result := make(map[string]int)
|
||||
modelNames = normalizeLookupValues(modelNames)
|
||||
if len(modelNames) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type row struct {
|
||||
Model string
|
||||
ChannelType int
|
||||
}
|
||||
var rows []row
|
||||
|
||||
query := DB.Table("abilities").
|
||||
Select("abilities.model as model, channels.type as channel_type").
|
||||
Joins("JOIN channels ON abilities.channel_id = channels.id").
|
||||
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
|
||||
Order("COALESCE(abilities.priority, 0) DESC").
|
||||
Order("abilities.weight DESC").
|
||||
Order("abilities.channel_id ASC")
|
||||
|
||||
groups = normalizeLookupValues(groups)
|
||||
if len(groups) > 0 {
|
||||
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
|
||||
}
|
||||
|
||||
if err := query.Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
if _, ok := result[r.Model]; ok {
|
||||
continue
|
||||
}
|
||||
result[r.Model] = r.ChannelType
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||
var models []*Model
|
||||
db := DB.Model(&Model{})
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func clearPreferredOwnerTables(t *testing.T) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
|
||||
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
|
||||
}
|
||||
|
||||
func insertPreferredOwnerCandidate(
|
||||
t *testing.T,
|
||||
channelID int,
|
||||
modelName string,
|
||||
group string,
|
||||
channelType int,
|
||||
priority int64,
|
||||
weight uint,
|
||||
channelStatus int,
|
||||
abilityEnabled bool,
|
||||
) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Create(&Channel{
|
||||
Id: channelID,
|
||||
Type: channelType,
|
||||
Key: fmt.Sprintf("key-%d", channelID),
|
||||
Status: channelStatus,
|
||||
Name: fmt.Sprintf("channel-%d", channelID),
|
||||
}).Error)
|
||||
require.NoError(t, DB.Create(&Ability{
|
||||
Group: group,
|
||||
Model: modelName,
|
||||
ChannelId: channelID,
|
||||
Enabled: abilityEnabled,
|
||||
Priority: &priority,
|
||||
Weight: weight,
|
||||
}).Error)
|
||||
}
|
||||
|
||||
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
|
||||
const modelName = "gpt-5.4"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T)
|
||||
groups []string
|
||||
expected int
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "openai only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "codex only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "priority wins",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "weight wins when priority is equal",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "channel id stabilizes exact ties",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "group filter excludes other groups",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "disabled candidates are ignored",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearPreferredOwnerTables(t)
|
||||
tt.setup(t)
|
||||
|
||||
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, ok := owners[modelName]
|
||||
require.Equal(t, tt.found, ok)
|
||||
if tt.found {
|
||||
require.Equal(t, tt.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+38
-19
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/setting/performance_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
@@ -106,18 +107,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp)
|
||||
common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString()
|
||||
common.OptionMap["WaffoPancakeEnabled"] = strconv.FormatBool(setting.WaffoPancakeEnabled)
|
||||
common.OptionMap["WaffoPancakeSandbox"] = strconv.FormatBool(setting.WaffoPancakeSandbox)
|
||||
common.OptionMap["WaffoPancakeMerchantID"] = setting.WaffoPancakeMerchantID
|
||||
common.OptionMap["WaffoPancakePrivateKey"] = setting.WaffoPancakePrivateKey
|
||||
common.OptionMap["WaffoPancakeWebhookPublicKey"] = setting.WaffoPancakeWebhookPublicKey
|
||||
common.OptionMap["WaffoPancakeWebhookTestKey"] = setting.WaffoPancakeWebhookTestKey
|
||||
common.OptionMap["WaffoPancakeStoreID"] = setting.WaffoPancakeStoreID
|
||||
common.OptionMap["WaffoPancakeProductID"] = setting.WaffoPancakeProductID
|
||||
common.OptionMap["WaffoPancakeReturnURL"] = setting.WaffoPancakeReturnURL
|
||||
common.OptionMap["WaffoPancakeCurrency"] = setting.WaffoPancakeCurrency
|
||||
common.OptionMap["WaffoPancakeUnitPrice"] = strconv.FormatFloat(setting.WaffoPancakeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["WaffoPancakeMinTopUp"] = strconv.Itoa(setting.WaffoPancakeMinTopUp)
|
||||
common.OptionMap["WaffoPancakeStoreID"] = setting.WaffoPancakeStoreID
|
||||
common.OptionMap["WaffoPancakeProductID"] = setting.WaffoPancakeProductID
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
@@ -222,6 +218,39 @@ func UpdateOption(key string, value string) error {
|
||||
return updateOptionMap(key, value)
|
||||
}
|
||||
|
||||
// UpdateOptionsBulk persists multiple key/value pairs in a single database
|
||||
// transaction, then dispatches them through updateOptionMap in one pass. If
|
||||
// any DB write fails the whole transaction rolls back and no in-memory state
|
||||
// is touched — safe for callers that must commit a set of related options
|
||||
// atomically (e.g. payment gateway binding).
|
||||
func UpdateOptionsBulk(values map[string]string) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
for k, v := range values {
|
||||
option := Option{Key: k}
|
||||
if err := tx.FirstOrCreate(&option, Option{Key: k}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
option.Value = v
|
||||
if err := tx.Save(&option).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range values {
|
||||
if err := updateOptionMap(k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateOptionMap(key string, value string) (err error) {
|
||||
common.OptionMapRWMutex.Lock()
|
||||
defer common.OptionMapRWMutex.Unlock()
|
||||
@@ -419,26 +448,16 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "WaffoMinTopUp":
|
||||
setting.WaffoMinTopUp, _ = strconv.Atoi(value)
|
||||
case "WaffoPancakeEnabled":
|
||||
setting.WaffoPancakeEnabled = value == "true"
|
||||
case "WaffoPancakeSandbox":
|
||||
setting.WaffoPancakeSandbox = value == "true"
|
||||
case "WaffoPancakeMerchantID":
|
||||
setting.WaffoPancakeMerchantID = value
|
||||
case "WaffoPancakePrivateKey":
|
||||
setting.WaffoPancakePrivateKey = value
|
||||
case "WaffoPancakeWebhookPublicKey":
|
||||
setting.WaffoPancakeWebhookPublicKey = value
|
||||
case "WaffoPancakeWebhookTestKey":
|
||||
setting.WaffoPancakeWebhookTestKey = value
|
||||
case "WaffoPancakeReturnURL":
|
||||
setting.WaffoPancakeReturnURL = value
|
||||
case "WaffoPancakeStoreID":
|
||||
setting.WaffoPancakeStoreID = value
|
||||
case "WaffoPancakeProductID":
|
||||
setting.WaffoPancakeProductID = value
|
||||
case "WaffoPancakeReturnURL":
|
||||
setting.WaffoPancakeReturnURL = value
|
||||
case "WaffoPancakeCurrency":
|
||||
setting.WaffoPancakeCurrency = value
|
||||
case "WaffoPancakeUnitPrice":
|
||||
setting.WaffoPancakeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "WaffoPancakeMinTopUp":
|
||||
|
||||
+117
-2
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/pkg/cachex"
|
||||
"github.com/samber/hot"
|
||||
"github.com/shopspring/decimal"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -159,8 +160,11 @@ type SubscriptionPlan struct {
|
||||
Enabled bool `json:"enabled" gorm:"default:true"`
|
||||
SortOrder int `json:"sort_order" gorm:"type:int;default:0"`
|
||||
|
||||
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
|
||||
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
|
||||
AllowBalancePay *bool `json:"allow_balance_pay" gorm:"default:true"`
|
||||
|
||||
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
|
||||
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
|
||||
WaffoPancakeProductId string `json:"waffo_pancake_product_id" gorm:"type:varchar(128);default:''"`
|
||||
|
||||
// Max purchases per user (0 = unlimited)
|
||||
MaxPurchasePerUser int `json:"max_purchase_per_user" gorm:"type:int;default:0"`
|
||||
@@ -191,6 +195,12 @@ func (p *SubscriptionPlan) BeforeUpdate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SubscriptionPlan) NormalizeDefaults() {
|
||||
if p.AllowBalancePay == nil {
|
||||
p.AllowBalancePay = common.GetPointer(true)
|
||||
}
|
||||
}
|
||||
|
||||
// Subscription order (payment -> webhook -> create UserSubscription)
|
||||
type SubscriptionOrder struct {
|
||||
Id int `json:"id"`
|
||||
@@ -358,6 +368,7 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
|
||||
key := subscriptionPlanCacheKey(id)
|
||||
if key != "" {
|
||||
if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
|
||||
cached.NormalizeDefaults()
|
||||
return &cached, nil
|
||||
}
|
||||
}
|
||||
@@ -369,6 +380,7 @@ func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
|
||||
if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.NormalizeDefaults()
|
||||
_ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
|
||||
return &plan, nil
|
||||
}
|
||||
@@ -664,6 +676,109 @@ func AdminBindSubscription(userId int, planId int, sourceNote string) (string, e
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func calcSubscriptionBalanceQuota(priceAmount float64) (int, error) {
|
||||
if priceAmount <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if common.QuotaPerUnit <= 0 {
|
||||
return 0, errors.New("额度单位配置错误")
|
||||
}
|
||||
quota := decimal.NewFromFloat(priceAmount).
|
||||
Mul(decimal.NewFromFloat(common.QuotaPerUnit)).
|
||||
Ceil().
|
||||
IntPart()
|
||||
return int(quota), nil
|
||||
}
|
||||
|
||||
// PurchaseSubscriptionWithBalance creates a subscription by deducting the user's wallet quota.
|
||||
func PurchaseSubscriptionWithBalance(userId int, planId int) error {
|
||||
if userId <= 0 || planId <= 0 {
|
||||
return errors.New("invalid userId or planId")
|
||||
}
|
||||
|
||||
var logPlanTitle string
|
||||
var logMoney float64
|
||||
var chargedQuota int
|
||||
var upgradeGroup string
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
plan, err := getSubscriptionPlanByIdTx(tx, planId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !plan.Enabled {
|
||||
return errors.New("套餐未启用")
|
||||
}
|
||||
if plan.PriceAmount < 0 {
|
||||
return errors.New("套餐价格不能为负数")
|
||||
}
|
||||
if plan.AllowBalancePay != nil && !*plan.AllowBalancePay {
|
||||
return errors.New("该套餐不允许使用余额兑换")
|
||||
}
|
||||
|
||||
requiredQuota, err := calcSubscriptionBalanceQuota(plan.PriceAmount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where("id = ?", userId).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if requiredQuota > 0 && user.Quota < requiredQuota {
|
||||
return errors.New("余额不足")
|
||||
}
|
||||
if requiredQuota > 0 {
|
||||
if err := tx.Model(&User{}).Where("id = ?", userId).
|
||||
Update("quota", gorm.Expr("quota - ?", requiredQuota)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := CreateUserSubscriptionFromPlanTx(tx, userId, plan, PaymentMethodBalance); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := common.GetTimestamp()
|
||||
tradeNo := fmt.Sprintf("SUBBALUSR%dNO%s%d", userId, common.GetRandomString(6), time.Now().UnixNano())
|
||||
order := &SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: PaymentMethodBalance,
|
||||
PaymentProvider: PaymentProviderBalance,
|
||||
Status: common.TopUpStatusSuccess,
|
||||
CreateTime: now,
|
||||
CompleteTime: now,
|
||||
ProviderPayload: fmt.Sprintf("charged_quota=%d", requiredQuota),
|
||||
}
|
||||
if err := tx.Create(order).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logPlanTitle = plan.Title
|
||||
logMoney = plan.PriceAmount
|
||||
chargedQuota = requiredQuota
|
||||
upgradeGroup = strings.TrimSpace(plan.UpgradeGroup)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if chargedQuota > 0 {
|
||||
if err := cacheDecrUserQuota(userId, int64(chargedQuota)); err != nil {
|
||||
common.SysLog("failed to decrease user quota cache after subscription balance purchase: " + err.Error())
|
||||
}
|
||||
}
|
||||
if upgradeGroup != "" {
|
||||
_ = UpdateUserGroupCache(userId, upgradeGroup)
|
||||
}
|
||||
msg := fmt.Sprintf("使用余额购买订阅成功,套餐: %s,支付金额: %.2f,扣除额度: %d", logPlanTitle, logMoney, chargedQuota)
|
||||
RecordLog(userId, LogTypeTopup, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllActiveUserSubscriptions returns all active subscriptions for a user.
|
||||
func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
|
||||
if userId <= 0 {
|
||||
|
||||
@@ -40,6 +40,7 @@ func TestMain(m *testing.M) {
|
||||
&Token{},
|
||||
&Log{},
|
||||
&Channel{},
|
||||
&Ability{},
|
||||
&TopUp{},
|
||||
&SubscriptionPlan{},
|
||||
&SubscriptionOrder{},
|
||||
@@ -60,6 +61,7 @@ func truncateTables(t *testing.T) {
|
||||
DB.Exec("DELETE FROM tokens")
|
||||
DB.Exec("DELETE FROM logs")
|
||||
DB.Exec("DELETE FROM channels")
|
||||
DB.Exec("DELETE FROM abilities")
|
||||
DB.Exec("DELETE FROM top_ups")
|
||||
DB.Exec("DELETE FROM subscription_orders")
|
||||
DB.Exec("DELETE FROM subscription_plans")
|
||||
|
||||
@@ -29,6 +29,7 @@ const (
|
||||
PaymentMethodCreem = "creem"
|
||||
PaymentMethodWaffo = "waffo"
|
||||
PaymentMethodWaffoPancake = "waffo_pancake"
|
||||
PaymentMethodBalance = "balance"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,6 +38,7 @@ const (
|
||||
PaymentProviderCreem = "creem"
|
||||
PaymentProviderWaffo = "waffo"
|
||||
PaymentProviderWaffoPancake = "waffo_pancake"
|
||||
PaymentProviderBalance = "balance"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
+33
-20
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
@@ -225,7 +224,7 @@ func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err err
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
|
||||
func SearchUsers(keyword string, group string, role *int, status *int, startIdx int, num int) ([]*User, int64, error) {
|
||||
var users []*User
|
||||
var total int64
|
||||
var err error
|
||||
@@ -246,28 +245,25 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
||||
|
||||
// 构建搜索条件
|
||||
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
|
||||
likeArgs := []interface{}{"%" + keyword + "%", "%" + keyword + "%", "%" + keyword + "%"}
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果是数字,同时搜索ID和其他字段
|
||||
likeCondition = "id = ? OR " + likeCondition
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
} else {
|
||||
// 非数字关键字,只搜索字符串字段
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
likeArgs = append([]interface{}{keywordInt}, likeArgs...)
|
||||
}
|
||||
|
||||
query = query.Where("("+likeCondition+")", likeArgs...)
|
||||
if group != "" {
|
||||
query = query.Where(commonGroupCol+" = ?", group)
|
||||
}
|
||||
if role != nil {
|
||||
query = query.Where("role = ?", *role)
|
||||
}
|
||||
if status != nil {
|
||||
query = query.Where("status = ?", *status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
@@ -421,7 +417,7 @@ func (user *User) Insert(inviterId int) error {
|
||||
if common.QuotaForNewUser > 0 {
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||
}
|
||||
if inviterId != 0 && operation_setting.IsPaymentComplianceConfirmed() {
|
||||
if inviterId != 0 {
|
||||
if common.QuotaForInvitee > 0 {
|
||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
|
||||
@@ -482,7 +478,7 @@ func (user *User) FinalizeOAuthUserCreation(inviterId int) {
|
||||
if common.QuotaForNewUser > 0 {
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||
}
|
||||
if inviterId != 0 && operation_setting.IsPaymentComplianceConfirmed() {
|
||||
if inviterId != 0 {
|
||||
if common.QuotaForInvitee > 0 {
|
||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
|
||||
@@ -987,6 +983,23 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||
//}
|
||||
}
|
||||
|
||||
func updateUserQuotaUsedQuotaAndRequestCount(id int, quota int, usedQuota int, requestCount int) {
|
||||
if quota == 0 && usedQuota == 0 && requestCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"quota": gorm.Expr("quota + ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota + ?", usedQuota),
|
||||
"request_count": gorm.Expr("request_count + ?", requestCount),
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysLog("failed to batch update user quota, used quota and request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserUsedQuota(id int, quota int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
|
||||
+26
-11
@@ -67,33 +67,48 @@ func batchUpdate() {
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
stores := make([]map[int]int, BatchUpdateTypeCount)
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
store := batchUpdateStores[i]
|
||||
stores[i] = batchUpdateStores[i]
|
||||
batchUpdateStores[i] = make(map[int]int)
|
||||
batchUpdateLocks[i].Unlock()
|
||||
// TODO: maybe we can combine updates with same key?
|
||||
}
|
||||
|
||||
for i, store := range stores {
|
||||
if i == BatchUpdateTypeUserQuota || i == BatchUpdateTypeUsedQuota || i == BatchUpdateTypeRequestCount {
|
||||
continue
|
||||
}
|
||||
for key, value := range store {
|
||||
switch i {
|
||||
case BatchUpdateTypeUserQuota:
|
||||
err := increaseUserQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysLog("failed to batch update user quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeTokenQuota:
|
||||
err := increaseTokenQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysLog("failed to batch update token quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeUsedQuota:
|
||||
updateUserUsedQuota(key, value)
|
||||
case BatchUpdateTypeRequestCount:
|
||||
updateUserRequestCount(key, value)
|
||||
case BatchUpdateTypeChannelUsedQuota:
|
||||
updateChannelUsedQuota(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userQuotaStore := stores[BatchUpdateTypeUserQuota]
|
||||
usedQuotaStore := stores[BatchUpdateTypeUsedQuota]
|
||||
requestCountStore := stores[BatchUpdateTypeRequestCount]
|
||||
|
||||
userIDs := make(map[int]struct{}, len(userQuotaStore)+len(usedQuotaStore)+len(requestCountStore))
|
||||
for key := range userQuotaStore {
|
||||
userIDs[key] = struct{}{}
|
||||
}
|
||||
for key := range usedQuotaStore {
|
||||
userIDs[key] = struct{}{}
|
||||
}
|
||||
for key := range requestCountStore {
|
||||
userIDs[key] = struct{}{}
|
||||
}
|
||||
for key := range userIDs {
|
||||
updateUserQuotaUsedQuotaAndRequestCount(key, userQuotaStore[key], usedQuotaStore[key], requestCountStore[key])
|
||||
}
|
||||
common.SysLog("batch update finished")
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,23 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// applyUpstreamContentLength populates req.ContentLength when the upstream
|
||||
// body is wrapped in a BodyStorage (see relay/common/outbound_body.go).
|
||||
//
|
||||
// net/http.NewRequest only auto-detects ContentLength for *bytes.Reader,
|
||||
// *bytes.Buffer and *strings.Reader. When the body is a type-erased io.Reader
|
||||
// (which is the case for ReaderOnly(BodyStorage)), the Content-Length header
|
||||
// would otherwise be omitted, forcing chunked transfer encoding and breaking
|
||||
// some upstreams that require an explicit Content-Length.
|
||||
func applyUpstreamContentLength(req *http.Request, info *common.RelayInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
if info.UpstreamRequestBodySize > 0 && req.ContentLength <= 0 {
|
||||
req.ContentLength = info.UpstreamRequestBodySize
|
||||
}
|
||||
}
|
||||
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
// multipart/form-data
|
||||
@@ -297,6 +314,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
applyUpstreamContentLength(req, info)
|
||||
headers := req.Header
|
||||
err = a.SetupRequestHeader(c, &headers, info)
|
||||
if err != nil {
|
||||
@@ -326,6 +344,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
applyUpstreamContentLength(req, info)
|
||||
// set form data
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
headers := req.Header
|
||||
@@ -522,6 +541,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
applyUpstreamContentLength(req, info)
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(requestBody), nil
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-7": "anthropic.claude-opus-4-7",
|
||||
"claude-opus-4-8": "anthropic.claude-opus-4-8",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
@@ -97,6 +98,11 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-opus-4-8": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
|
||||
@@ -33,6 +33,13 @@ var ModelList = []string{
|
||||
"claude-opus-4-7-medium",
|
||||
"claude-opus-4-7-low",
|
||||
"claude-opus-4-7-thinking",
|
||||
"claude-opus-4-8",
|
||||
"claude-opus-4-8-max",
|
||||
"claude-opus-4-8-xhigh",
|
||||
"claude-opus-4-8-high",
|
||||
"claude-opus-4-8-medium",
|
||||
"claude-opus-4-8-low",
|
||||
"claude-opus-4-8-thinking",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -154,14 +154,17 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
}
|
||||
|
||||
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
|
||||
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") || strings.HasPrefix(textRequest.Model, "claude-opus-4-7")) {
|
||||
(strings.HasPrefix(textRequest.Model, "claude-opus-4-6") ||
|
||||
strings.HasPrefix(textRequest.Model, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(textRequest.Model, "claude-opus-4-8")) {
|
||||
claudeRequest.Model = baseModel
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "adaptive",
|
||||
}
|
||||
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
|
||||
if strings.HasPrefix(baseModel, "claude-opus-4-7") {
|
||||
// Opus 4.7 rejects non-default temperature/top_p/top_k with 400
|
||||
if strings.HasPrefix(baseModel, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(baseModel, "claude-opus-4-8") {
|
||||
// Opus 4.7/4.8 reject non-default temperature/top_p/top_k with 400
|
||||
// and defaults display to "omitted"; restore the 4.6 visible summary.
|
||||
claudeRequest.Thinking.Display = "summarized"
|
||||
claudeRequest.Temperature = nil
|
||||
@@ -175,8 +178,9 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
|
||||
trimmedModel := strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") {
|
||||
// Opus 4.7 rejects thinking.type="enabled"; use adaptive at high effort.
|
||||
if strings.HasPrefix(trimmedModel, "claude-opus-4-7") ||
|
||||
strings.HasPrefix(trimmedModel, "claude-opus-4-8") {
|
||||
// Opus 4.7/4.8 reject thinking.type="enabled"; use adaptive at high effort.
|
||||
claudeRequest.Thinking = &dto.Thinking{Type: "adaptive", Display: "summarized"}
|
||||
claudeRequest.OutputConfig = json.RawMessage(`{"effort":"high"}`)
|
||||
claudeRequest.Temperature = nil
|
||||
@@ -442,10 +446,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
fcIdx := 0
|
||||
if claudeResponse.Index != nil {
|
||||
fcIdx = *claudeResponse.Index - 1
|
||||
if fcIdx < 0 {
|
||||
fcIdx = 0
|
||||
}
|
||||
fcIdx = *claudeResponse.Index
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if claudeResponse.Type == "message_start" {
|
||||
|
||||
@@ -9,6 +9,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func commonPointer[T any](value T) *T {
|
||||
return &value
|
||||
}
|
||||
|
||||
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
Usage: &dto.Usage{},
|
||||
@@ -310,6 +314,58 @@ func TestRequestOpenAI2ClaudeMessage_IgnoresUnsupportedFileContent(t *testing.T)
|
||||
require.Equal(t, "see attachment", *content[0].Text)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48HighUsesAdaptiveThinking(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-opus-4-8-high",
|
||||
Temperature: commonPointer(0.7),
|
||||
TopP: commonPointer(0.9),
|
||||
TopK: commonPointer(40),
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
|
||||
require.NotNil(t, claudeRequest.Thinking)
|
||||
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
|
||||
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
|
||||
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
|
||||
require.Nil(t, claudeRequest.Temperature)
|
||||
require.Nil(t, claudeRequest.TopP)
|
||||
require.Nil(t, claudeRequest.TopK)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_ClaudeOpus48ThinkingUsesAdaptiveHighEffort(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-opus-4-8-thinking",
|
||||
Temperature: commonPointer(0.7),
|
||||
TopP: commonPointer(0.9),
|
||||
TopK: commonPointer(40),
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-opus-4-8", claudeRequest.Model)
|
||||
require.NotNil(t, claudeRequest.Thinking)
|
||||
require.Equal(t, "adaptive", claudeRequest.Thinking.Type)
|
||||
require.Equal(t, "summarized", claudeRequest.Thinking.Display)
|
||||
require.JSONEq(t, `{"effort":"high"}`, string(claudeRequest.OutputConfig))
|
||||
require.Nil(t, claudeRequest.Temperature)
|
||||
require.Nil(t, claudeRequest.TopP)
|
||||
require.Nil(t, claudeRequest.TopK)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_SupportsPDFFileContent(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
|
||||
@@ -30,7 +30,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -86,7 +85,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
@@ -106,6 +105,9 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
data := scanner.Text()
|
||||
dataChan <- data
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysLog("error reading stream: " + err.Error())
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
@@ -98,7 +98,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
}
|
||||
|
||||
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
id := helper.GetResponseID(c)
|
||||
|
||||
@@ -159,9 +159,14 @@ func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto
|
||||
media := mediaContent.GetImageMedia()
|
||||
var file *DifyFile
|
||||
if media.IsRemoteImage() {
|
||||
file.Type = media.MimeType
|
||||
file.TransferMode = "remote_url"
|
||||
file.URL = media.Url
|
||||
// 修复 #2083: 远程图片分支此前未初始化 file,
|
||||
// 导致 file.Type = ... 触发 nil pointer dereference
|
||||
// 而 panic(500: "invalid memory address or nil pointer dereference")。
|
||||
file = &DifyFile{
|
||||
Type: media.MimeType,
|
||||
TransferMode: "remote_url",
|
||||
URL: media.Url,
|
||||
}
|
||||
} else {
|
||||
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
|
||||
}
|
||||
|
||||
@@ -1079,17 +1079,47 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
FinishReason: constant.FinishReasonStop,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
var texts []string
|
||||
// 使用 strings.Builder 直接累积最终 content,避免:
|
||||
// 1) 每张 inline image 生成一次中间 "" 字符串
|
||||
// 2) 末尾 strings.Join 再分配一份等大缓冲
|
||||
// Gemini 图片返回时 InlineData.Data 可能是数 MB 的 base64,
|
||||
// 上述两份临时分配在高并发下会显著放大堆驻留。
|
||||
var content strings.Builder
|
||||
var inlineGrow int
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
|
||||
}
|
||||
}
|
||||
if inlineGrow > 0 {
|
||||
content.Grow(inlineGrow)
|
||||
}
|
||||
appended := 0
|
||||
writeSep := func() {
|
||||
if appended > 0 {
|
||||
content.WriteByte('\n')
|
||||
}
|
||||
appended++
|
||||
}
|
||||
var toolCalls []dto.ToolCallResponse
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
// 媒体内容
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
writeSep()
|
||||
content.WriteString("
|
||||
content.WriteString(part.InlineData.MimeType)
|
||||
content.WriteString(";base64,")
|
||||
content.WriteString(part.InlineData.Data)
|
||||
content.WriteByte(')')
|
||||
} else {
|
||||
// 其他媒体类型,直接显示链接
|
||||
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
|
||||
writeSep()
|
||||
content.WriteString("[media](data:")
|
||||
content.WriteString(part.InlineData.MimeType)
|
||||
content.WriteString(";base64,")
|
||||
content.WriteString(part.InlineData.Data)
|
||||
content.WriteByte(')')
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
@@ -1100,13 +1130,22 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
choice.Message.ReasoningContent = &part.Text
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
|
||||
writeSep()
|
||||
content.WriteString("```")
|
||||
content.WriteString(part.ExecutableCode.Language)
|
||||
content.WriteByte('\n')
|
||||
content.WriteString(part.ExecutableCode.Code)
|
||||
content.WriteString("\n```")
|
||||
} else if part.CodeExecutionResult != nil {
|
||||
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
|
||||
writeSep()
|
||||
content.WriteString("```output\n")
|
||||
content.WriteString(part.CodeExecutionResult.Output)
|
||||
content.WriteString("\n```")
|
||||
} else {
|
||||
// 过滤掉空行
|
||||
if part.Text != "\n" {
|
||||
texts = append(texts, part.Text)
|
||||
writeSep()
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1115,7 +1154,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
choice.Message.SetToolCalls(toolCalls)
|
||||
isToolCall = true
|
||||
}
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
choice.Message.SetStringContent(content.String())
|
||||
|
||||
}
|
||||
if candidate.FinishReason != nil {
|
||||
@@ -1169,7 +1208,25 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
|
||||
//Role: "assistant",
|
||||
},
|
||||
}
|
||||
var texts []string
|
||||
// 使用 strings.Builder 直接累积 delta content,避免每张 image / 每个
|
||||
// 文本片段都先 `+` 拼出一份临时 string,再 strings.Join 再拷贝一遍。
|
||||
var content strings.Builder
|
||||
var inlineGrow int
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
|
||||
}
|
||||
}
|
||||
if inlineGrow > 0 {
|
||||
content.Grow(inlineGrow)
|
||||
}
|
||||
appended := 0
|
||||
writeSep := func() {
|
||||
if appended > 0 {
|
||||
content.WriteByte('\n')
|
||||
}
|
||||
appended++
|
||||
}
|
||||
isTools := false
|
||||
isThought := false
|
||||
if candidate.FinishReason != nil {
|
||||
@@ -1207,8 +1264,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil {
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
writeSep()
|
||||
content.WriteString("
|
||||
content.WriteString(part.InlineData.MimeType)
|
||||
content.WriteString(";base64,")
|
||||
content.WriteString(part.InlineData.Data)
|
||||
content.WriteByte(')')
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
@@ -1219,23 +1280,33 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
|
||||
|
||||
} else if part.Thought {
|
||||
isThought = true
|
||||
texts = append(texts, part.Text)
|
||||
writeSep()
|
||||
content.WriteString(part.Text)
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
|
||||
writeSep()
|
||||
content.WriteString("```")
|
||||
content.WriteString(part.ExecutableCode.Language)
|
||||
content.WriteByte('\n')
|
||||
content.WriteString(part.ExecutableCode.Code)
|
||||
content.WriteString("\n```\n")
|
||||
} else if part.CodeExecutionResult != nil {
|
||||
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
|
||||
writeSep()
|
||||
content.WriteString("```output\n")
|
||||
content.WriteString(part.CodeExecutionResult.Output)
|
||||
content.WriteString("\n```\n")
|
||||
} else {
|
||||
if part.Text != "\n" {
|
||||
texts = append(texts, part.Text)
|
||||
writeSep()
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if isThought {
|
||||
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
|
||||
choice.Delta.SetReasoningContent(content.String())
|
||||
} else {
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
choice.Delta.SetContentString(content.String())
|
||||
}
|
||||
if isTools {
|
||||
choice.FinishReason = &constant.FinishReasonToolCalls
|
||||
@@ -1339,6 +1410,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
if response.IsToolCall() {
|
||||
finishReason = constant.FinishReasonToolCalls
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
for choiceIdx := range response.Choices {
|
||||
response.Choices[choiceIdx].FinishReason = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for choiceIdx := range response.Choices {
|
||||
choiceKey := response.Choices[choiceIdx].Index
|
||||
for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
|
||||
@@ -1399,7 +1478,9 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
if isStop {
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
|
||||
if info.RelayFormat != types.RelayFormatClaude {
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -1409,6 +1490,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil && !info.ClaudeConvertInfo.Done {
|
||||
response = helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)
|
||||
response.Usage = usage
|
||||
}
|
||||
handleErr := handleFinalStream(c, info, response)
|
||||
if handleErr != nil {
|
||||
common.SysLog("send final response failed: " + handleErr.Error())
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
channelconstant "github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -79,9 +81,23 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request.Temperature != nil && isTemperatureOneOnlyModel(getUpstreamModelName(info, request.Model)) && *request.Temperature != 1.0 {
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func getUpstreamModelName(info *relaycommon.RelayInfo, fallback string) string {
|
||||
if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" {
|
||||
return info.UpstreamModelName
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func isTemperatureOneOnlyModel(model string) bool {
|
||||
return strings.EqualFold(model, "kimi-k2.6")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package moonshot
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConvertOpenAIRequestKimiK26UsesOnlyAllowedTemperature(t *testing.T) {
|
||||
request := &dto.GeneralOpenAIRequest{
|
||||
Model: "kimi-k2.6",
|
||||
Temperature: common.GetPointer[float64](0.7),
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "kimi-k2.6",
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, convertedRequest.Temperature)
|
||||
require.Equal(t, 1.0, *convertedRequest.Temperature)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestKimiK26KeepsOmittedTemperatureOmitted(t *testing.T) {
|
||||
request := &dto.GeneralOpenAIRequest{
|
||||
Model: "kimi-k2.6",
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "kimi-k2.6",
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
|
||||
require.True(t, ok)
|
||||
require.Nil(t, convertedRequest.Temperature)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestOtherMoonshotModelKeepsTemperature(t *testing.T) {
|
||||
request := &dto.GeneralOpenAIRequest{
|
||||
Model: "kimi-k2.5",
|
||||
Temperature: common.GetPointer[float64](0.7),
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "kimi-k2.5",
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, convertedRequest.Temperature)
|
||||
require.Equal(t, 0.7, *convertedRequest.Temperature)
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
@@ -397,7 +397,7 @@ func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback f
|
||||
}
|
||||
|
||||
// 读取流式响应
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
scanner := helper.NewStreamScanner(response.Body)
|
||||
successful := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -70,7 +69,7 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
usage := &dto.Usage{}
|
||||
var model = info.UpstreamModelName
|
||||
var responseId = common.GetUUID()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -163,6 +164,20 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||
return url, nil
|
||||
default:
|
||||
// Handle coding plan special base URLs
|
||||
if specialPlan, ok := constant.ChannelSpecialBases[info.ChannelBaseUrl]; ok && specialPlan.OpenAIBaseURL != "" {
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
return fmt.Sprintf("%s/v1/messages", specialPlan.ClaudeBaseURL), nil
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/embeddings", specialPlan.OpenAIBaseURL), nil
|
||||
case relayconstant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/images/generations", specialPlan.OpenAIBaseURL), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/chat/completions", specialPlan.OpenAIBaseURL), nil
|
||||
}
|
||||
}
|
||||
if (info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini) &&
|
||||
info.RelayMode != relayconstant.RelayModeResponses &&
|
||||
info.RelayMode != relayconstant.RelayModeResponsesCompact {
|
||||
@@ -220,7 +235,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
||||
header.Set("HTTP-Referer", "https://www.newapi.ai")
|
||||
}
|
||||
if header.Get("X-OpenRouter-Title") == "" {
|
||||
header.Set("X-OpenRouter-Title", "New API")
|
||||
header.Set("X-OpenRouter-Title", "ModelsToken")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -310,18 +325,20 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
isOModel := dto.IsOpenAIReasoningOModel(info.UpstreamModelName)
|
||||
isGPT5Model := dto.IsOpenAIGPT5Model(info.UpstreamModelName)
|
||||
if isOModel || isGPT5Model {
|
||||
if lo.FromPtrOr(request.MaxCompletionTokens, uint(0)) == 0 && lo.FromPtrOr(request.MaxTokens, uint(0)) != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") {
|
||||
if isOModel {
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
// gpt-5系列模型适配 归零不再支持的参数
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if isGPT5Model {
|
||||
request.Temperature = nil
|
||||
request.TopP = nil
|
||||
request.LogProbs = nil
|
||||
@@ -437,10 +454,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
// 使用已解析的 multipart 表单,避免重复解析
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
form, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse multipart form: %w", err)
|
||||
}
|
||||
mf = c.Request.MultipartForm
|
||||
c.Request.MultipartForm = form
|
||||
c.Request.PostForm = url.Values(form.Value)
|
||||
mf = form
|
||||
}
|
||||
|
||||
// 写入所有非文件字段
|
||||
@@ -623,7 +643,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
usage, err = OpenaiHandlerWithUsage(c, info, resp)
|
||||
if info.IsStream {
|
||||
usage, err = OpenaiImageStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = OpenaiImageHandler(c, info, resp)
|
||||
}
|
||||
case relayconstant.RelayModeRerank:
|
||||
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||
case relayconstant.RelayModeResponses:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -92,78 +91,28 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
|
||||
func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
processCompletionsStreamResponse(streamResponse, responseTextBuilder)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||||
common.SysLog("error processing stream response: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConvertImageEditRequestMultipart verifies that ConvertImageRequest
|
||||
// re-serializes multipart image edit requests with all fields (including
|
||||
// stream) and the file intact, both when the form was already parsed and when
|
||||
// it must be re-parsed from the reusable body.
|
||||
func TestConvertImageEditRequestMultipart(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
newMultipartContext := func(t *testing.T, prompt string) *gin.Context {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||
require.NoError(t, writer.WriteField("prompt", prompt))
|
||||
require.NoError(t, writer.WriteField("stream", "true"))
|
||||
require.NoError(t, writer.WriteField("partial_images", "3"))
|
||||
part, err := writer.CreateFormFile("image", "input.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write([]byte("fake image"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return c
|
||||
}
|
||||
|
||||
convertAndReplay := func(t *testing.T, c *gin.Context, prompt string) {
|
||||
info := &relaycommon.RelayInfo{
|
||||
RelayMode: relayconstant.RelayModeImagesEdits,
|
||||
}
|
||||
request := dto.ImageRequest{
|
||||
Model: "gpt-image-1",
|
||||
Prompt: prompt,
|
||||
Stream: common.GetPointer(true),
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
|
||||
require.NoError(t, err)
|
||||
convertedBody, ok := converted.(*bytes.Buffer)
|
||||
require.True(t, ok)
|
||||
|
||||
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
|
||||
replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
|
||||
|
||||
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
|
||||
require.Equal(t, prompt, replayedRequest.PostForm.Get("prompt"))
|
||||
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
|
||||
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
|
||||
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
|
||||
|
||||
file, err := replayedRequest.MultipartForm.File["image"][0].Open()
|
||||
require.NoError(t, err)
|
||||
defer file.Close()
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("fake image"), fileBytes)
|
||||
}
|
||||
|
||||
t.Run("with pre-parsed form", func(t *testing.T) {
|
||||
prompt := "edit this image"
|
||||
c := newMultipartContext(t, prompt)
|
||||
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
|
||||
|
||||
convertAndReplay(t, c, prompt)
|
||||
})
|
||||
|
||||
t.Run("re-parses reusable body when form is missing", func(t *testing.T) {
|
||||
prompt := "edit without pre-parsed form"
|
||||
c := newMultipartContext(t, prompt)
|
||||
|
||||
storage, err := common.GetBodyStorage(c)
|
||||
require.NoError(t, err)
|
||||
c.Request.Body = io.NopCloser(storage)
|
||||
c.Request.MultipartForm = nil
|
||||
c.Request.PostForm = nil
|
||||
|
||||
convertAndReplay(t, c, prompt)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newImageTestContext(t *testing.T, body, contentType string, isStream bool) (*gin.Context, *httptest.ResponseRecorder, *http.Response, *relaycommon.RelayInfo) {
|
||||
t.Helper()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: http.Header{"Content-Type": []string{contentType}},
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||
IsStream: isStream,
|
||||
}
|
||||
return c, recorder, resp, info
|
||||
}
|
||||
|
||||
// TestOpenaiImageStreamHandlerForwardsSSEAndUsage covers the core SSE path:
|
||||
// chunks are forwarded with rebuilt event lines, usage is extracted and
|
||||
// normalized (input_tokens -> prompt_tokens with details), and [DONE] is
|
||||
// re-emitted to the client.
|
||||
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
body := strings.Join([]string{
|
||||
`event: image_generation.partial_image`,
|
||||
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
||||
``,
|
||||
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
|
||||
``,
|
||||
`data: [DONE]`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, usage.PromptTokens)
|
||||
require.Equal(t, 4, usage.CompletionTokens)
|
||||
require.Equal(t, 7, usage.TotalTokens)
|
||||
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
||||
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
||||
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
|
||||
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
|
||||
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
|
||||
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
||||
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
// TestOpenaiImageStreamHandlerWrapsJSONResponse covers the non-SSE fallback:
|
||||
// a JSON upstream response is wrapped into pseudo-SSE completed events.
|
||||
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
|
||||
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, usage.PromptTokens)
|
||||
require.Equal(t, 4, usage.CompletionTokens)
|
||||
require.Equal(t, 7, usage.TotalTokens)
|
||||
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
||||
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
||||
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
||||
require.Empty(t, recorder.Header().Get("Content-Length"))
|
||||
require.Contains(t, recorder.Body.String(), `event: image_generation.completed`)
|
||||
require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`)
|
||||
require.Contains(t, recorder.Body.String(), `"b64_json":"final"`)
|
||||
require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`)
|
||||
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
||||
}
|
||||
|
||||
// TestOpenaiImageHandlersReturnJSONError covers JSON error responses for both
|
||||
// entry points: the non-streaming handler and the stream handler's non-SSE
|
||||
// fallback. Neither must leak the error body to the client.
|
||||
func TestOpenaiImageHandlersReturnJSONError(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
||||
|
||||
t.Run("non-streaming handler", func(t *testing.T) {
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "application/json", false)
|
||||
|
||||
usage, err := OpenaiImageHandler(c, info, resp)
|
||||
require.Nil(t, usage)
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||
oaiError := err.ToOpenAIError()
|
||||
require.Equal(t, "content moderation failed", oaiError.Message)
|
||||
require.Equal(t, "upstream_error", oaiError.Type)
|
||||
require.Equal(t, "content_moderation_failed", oaiError.Code)
|
||||
require.Empty(t, recorder.Body.String())
|
||||
})
|
||||
|
||||
t.Run("stream handler JSON fallback", func(t *testing.T) {
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, usage)
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||
require.Equal(t, "content moderation failed", err.ToOpenAIError().Message)
|
||||
require.Empty(t, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
// TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent verifies that an error
|
||||
// event inside the SSE stream is recorded as a soft error while the payload is
|
||||
// still forwarded to the client.
|
||||
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
body := strings.Join([]string{
|
||||
`event: image_generation.partial_image`,
|
||||
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
||||
``,
|
||||
`event: error`,
|
||||
`data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
||||
require.True(t, info.StreamStatus.HasErrors())
|
||||
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
|
||||
// The scanner strips the upstream "event: error" line; the event name is
|
||||
// rebuilt from the JSON "type" field (upstream_error). The error message
|
||||
// is still forwarded in the data: payload (stream ID 77).
|
||||
require.Contains(t, recorder.Body.String(), `event: upstream_error`)
|
||||
require.Contains(t, recorder.Body.String(), `stream ID 77`)
|
||||
}
|
||||
@@ -14,12 +14,9 @@ import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
@@ -119,7 +116,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var responseTextBuilder strings.Builder
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var lastStreamData string
|
||||
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
|
||||
|
||||
@@ -140,7 +136,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil {
|
||||
logger.LogError(c, "error processing stream token data: "+err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -175,11 +174,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
}
|
||||
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
logger.LogError(c, "error processing tokens: "+err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
@@ -296,421 +290,3 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
|
||||
return &simpleResponse.Usage, nil
|
||||
}
|
||||
|
||||
func streamTTSResponse(c *gin.Context, resp *http.Response) {
|
||||
c.Writer.WriteHeaderNow()
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
logger.LogWarn(c, "streaming not supported")
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
logger.LogWarn(c, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := resp.Body.Read(buffer)
|
||||
//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
|
||||
if n > 0 {
|
||||
if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
|
||||
logger.LogError(c, writeErr.Error())
|
||||
break
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
info.IsStream = true
|
||||
clientConn := info.ClientWs
|
||||
targetConn := info.TargetWs
|
||||
|
||||
clientClosed := make(chan struct{})
|
||||
targetClosed := make(chan struct{})
|
||||
sendChan := make(chan []byte, 100)
|
||||
receiveChan := make(chan []byte, 100)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
usage := &dto.RealtimeUsage{}
|
||||
localUsage := &dto.RealtimeUsage{}
|
||||
sumUsage := &dto.RealtimeUsage{}
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||
}
|
||||
close(clientClosed)
|
||||
return
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||
if realtimeEvent.Session != nil {
|
||||
if realtimeEvent.Session.Tools != nil {
|
||||
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = helper.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sendChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||
}
|
||||
close(targetClosed)
|
||||
return
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||
realtimeUsage := realtimeEvent.Response.Usage
|
||||
if realtimeUsage != nil {
|
||||
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||
usage.InputTokens += realtimeUsage.InputTokens
|
||||
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||
err := preConsumeUsage(c, info, usage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
usage = &dto.RealtimeUsage{}
|
||||
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
info.IsFirstRequest = false
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
if realtimeSession != nil {
|
||||
// update audio format
|
||||
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||
}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.OutputTokens += textToken + audioToken
|
||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = helper.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case receiveChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
select {
|
||||
case <-clientClosed:
|
||||
case <-targetClosed:
|
||||
case err := <-errChan:
|
||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||
logger.LogError(c, "realtime error: "+err.Error())
|
||||
case <-c.Done():
|
||||
}
|
||||
|
||||
if usage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, usage, sumUsage)
|
||||
}
|
||||
|
||||
if localUsage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
}
|
||||
|
||||
// check usage total tokens, if 0, use local usage
|
||||
|
||||
return nil, sumUsage
|
||||
}
|
||||
|
||||
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
||||
if usage == nil || totalUsage == nil {
|
||||
return fmt.Errorf("invalid usage pointer")
|
||||
}
|
||||
|
||||
totalUsage.TotalTokens += usage.TotalTokens
|
||||
totalUsage.InputTokens += usage.InputTokens
|
||||
totalUsage.OutputTokens += usage.OutputTokens
|
||||
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
||||
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
||||
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
||||
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
||||
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
||||
// clear usage
|
||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||
return err
|
||||
}
|
||||
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
// Once we've written to the client, we should not return errors anymore
|
||||
// because the upstream has already consumed resources and returned content
|
||||
// We should still perform billing even if parsing fails
|
||||
// format
|
||||
if usageResp.InputTokens > 0 {
|
||||
usageResp.PromptTokens += usageResp.InputTokens
|
||||
}
|
||||
if usageResp.OutputTokens > 0 {
|
||||
usageResp.CompletionTokens += usageResp.OutputTokens
|
||||
}
|
||||
if usageResp.InputTokensDetails != nil {
|
||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
||||
}
|
||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
|
||||
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
||||
if info == nil || usage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch info.ChannelType {
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeMoonshot:
|
||||
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Usage struct {
|
||||
PromptTokensDetails struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
||||
return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.CachedTokens != nil {
|
||||
return *payload.Usage.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.PromptCacheHitTokens != nil {
|
||||
return *payload.Usage.PromptCacheHitTokens, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
||||
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
||||
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Usage struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"usage"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// 遍历choices查找cached_tokens
|
||||
for _, choice := range payload.Choices {
|
||||
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
||||
return *choice.Usage.CachedTokens, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
||||
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Timings struct {
|
||||
CachedTokens *int `json:"cache_n"`
|
||||
} `json:"timings"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Timings.CachedTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *payload.Timings.CachedTokens, true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenaiImageHandler handles non-streaming OpenAI image responses
|
||||
// (generations/edits), returning the parsed usage for billing.
|
||||
func OpenaiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
|
||||
// normalizeOpenAIUsage maps the OpenAI Images usage shape (input_tokens /
|
||||
// output_tokens / input_tokens_details) onto the canonical prompt/completion
|
||||
// fields. It is used only on the OpenAI image relay paths (generations/edits,
|
||||
// streaming and non-streaming): the image API never returns prompt_tokens /
|
||||
// completion_tokens, so the overwrite (=) semantics here are equivalent to the
|
||||
// previous additive (+=) behavior while avoiding any future double-counting if
|
||||
// both field sets are ever populated. Do not reuse this on chat/embedding paths
|
||||
// without revisiting the overwrite semantics.
|
||||
func normalizeOpenAIUsage(usage *dto.Usage) {
|
||||
if usage == nil {
|
||||
return
|
||||
}
|
||||
if usage.InputTokens != 0 {
|
||||
usage.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
if usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens
|
||||
usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens
|
||||
usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens
|
||||
usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
logger.LogError(c, "invalid image stream response")
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return OpenaiImageHandler(c, info, resp)
|
||||
}
|
||||
if !strings.Contains(contentType, "text/event-stream") {
|
||||
return OpenaiImageJSONAsStreamHandler(c, info, resp)
|
||||
}
|
||||
// Reuse the shared streaming engine (helper.StreamScannerHandler) so the
|
||||
// image streaming path gets the same ping keepalive, streaming-timeout
|
||||
// watchdog, client-disconnect detection, panic recovery and goroutine
|
||||
// cleanup as every other relay stream. The scanner delivers only the
|
||||
// "data:" payload, so the SSE "event:" line is rebuilt from the JSON "type"
|
||||
// field (real OpenAI image events keep event == type).
|
||||
usage := &dto.Usage{}
|
||||
var lastStreamData []byte
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
raw := common.StringToByteSlice(data)
|
||||
lastStreamData = raw
|
||||
if isOpenAIImageStreamErrorEvent(raw) {
|
||||
// Record the error as a soft error; the scanner drives the final
|
||||
// EndReason. HasErrors() flags the failure for logging/handling.
|
||||
sr.Error(fmt.Errorf("%s", extractOpenAIImageStreamErrorMessage(raw)))
|
||||
}
|
||||
var usageResp dto.SimpleResponse
|
||||
if err := common.Unmarshal(raw, &usageResp); err == nil {
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
if service.ValidUsage(&usageResp.Usage) {
|
||||
usage = &usageResp.Usage
|
||||
}
|
||||
}
|
||||
writeOpenaiImageStreamChunk(c, raw)
|
||||
})
|
||||
|
||||
// StreamScannerHandler consumes the upstream [DONE]; re-emit it so the
|
||||
// client still receives a terminal data: [DONE].
|
||||
if info != nil && info.StreamStatus != nil && info.StreamStatus.EndReason == relaycommon.StreamEndReasonDone {
|
||||
helper.Done(c)
|
||||
}
|
||||
|
||||
applyUsagePostProcessing(info, usage, lastStreamData)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk:
|
||||
// it emits an "event:" line derived from the JSON "type" field (when present)
|
||||
// followed by the verbatim "data:" payload, mirroring helper.ResponseChunkData.
|
||||
func writeOpenaiImageStreamChunk(c *gin.Context, data []byte) {
|
||||
var payload struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
_ = common.Unmarshal(data, &payload)
|
||||
if eventName := strings.TrimSpace(payload.Type); eventName != "" {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", eventName)})
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(data)})
|
||||
_ = helper.FlushWriter(c)
|
||||
}
|
||||
|
||||
// isOpenAIImageStreamErrorEvent detects upstream error chunks by JSON content
|
||||
// only ("type" of error/upstream_error, or a non-empty "error" field). The SSE
|
||||
// "event:" line is not available here: StreamScannerHandler delivers only the
|
||||
// "data:" payload. A payload carrying just a "message" key is deliberately NOT
|
||||
// treated as an error to avoid false positives.
|
||||
func isOpenAIImageStreamErrorEvent(data []byte) bool {
|
||||
if !json.Valid(data) {
|
||||
return false
|
||||
}
|
||||
var payload struct {
|
||||
Type string `json:"type"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
}
|
||||
if err := common.Unmarshal(data, &payload); err != nil {
|
||||
return false
|
||||
}
|
||||
payloadType := strings.ToLower(strings.TrimSpace(payload.Type))
|
||||
return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0
|
||||
}
|
||||
|
||||
func extractOpenAIImageStreamErrorMessage(data []byte) string {
|
||||
if len(data) == 0 || !json.Valid(data) {
|
||||
return "upstream image stream returned error event"
|
||||
}
|
||||
var payload struct {
|
||||
Message string `json:"message"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
}
|
||||
if err := common.Unmarshal(data, &payload); err != nil {
|
||||
return "upstream image stream returned error event"
|
||||
}
|
||||
if msg := strings.TrimSpace(payload.Message); msg != "" {
|
||||
return msg
|
||||
}
|
||||
if len(payload.Error) > 0 {
|
||||
var nested struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := common.Unmarshal(payload.Error, &nested); err == nil {
|
||||
if msg := strings.TrimSpace(nested.Message); msg != "" {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
return "upstream image stream returned error event"
|
||||
}
|
||||
|
||||
func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var imageResp dto.ImageResponse
|
||||
if err := common.Unmarshal(responseBody, &imageResp); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
_ = common.Unmarshal(responseBody, &usageResp)
|
||||
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
created := imageResp.Created
|
||||
if created == 0 {
|
||||
created = time.Now().Unix()
|
||||
}
|
||||
if info != nil {
|
||||
info.SetFirstResponseTime()
|
||||
}
|
||||
for _, image := range imageResp.Data {
|
||||
payload := map[string]any{
|
||||
"type": "image_generation.completed",
|
||||
"created_at": created,
|
||||
}
|
||||
if image.Url != "" {
|
||||
payload["url"] = image.Url
|
||||
}
|
||||
if image.B64Json != "" {
|
||||
payload["b64_json"] = image.B64Json
|
||||
}
|
||||
if image.RevisedPrompt != "" {
|
||||
payload["revised_prompt"] = image.RevisedPrompt
|
||||
}
|
||||
if service.ValidUsage(&usageResp.Usage) {
|
||||
payload["usage"] = usageResp.Usage
|
||||
}
|
||||
if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||
}
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
}
|
||||
if err := writeOpenaiImageStreamDone(c); err != nil {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||
}
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
if info != nil {
|
||||
info.ReceivedResponseCount += len(imageResp.Data)
|
||||
if info.StreamStatus == nil {
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
}
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
}
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
|
||||
func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error {
|
||||
data, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if eventName != "" {
|
||||
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil {
|
||||
return err
|
||||
}
|
||||
return helper.FlushWriter(c)
|
||||
}
|
||||
|
||||
func writeOpenaiImageStreamDone(c *gin.Context) error {
|
||||
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
return helper.FlushWriter(c)
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
info.IsStream = true
|
||||
clientConn := info.ClientWs
|
||||
targetConn := info.TargetWs
|
||||
|
||||
clientClosed := make(chan struct{})
|
||||
targetClosed := make(chan struct{})
|
||||
sendChan := make(chan []byte, 100)
|
||||
receiveChan := make(chan []byte, 100)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
usage := &dto.RealtimeUsage{}
|
||||
localUsage := &dto.RealtimeUsage{}
|
||||
sumUsage := &dto.RealtimeUsage{}
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||
}
|
||||
close(clientClosed)
|
||||
return
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||
if realtimeEvent.Session != nil {
|
||||
if realtimeEvent.Session.Tools != nil {
|
||||
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = helper.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sendChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||
}
|
||||
close(targetClosed)
|
||||
return
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||
realtimeUsage := realtimeEvent.Response.Usage
|
||||
if realtimeUsage != nil {
|
||||
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||
usage.InputTokens += realtimeUsage.InputTokens
|
||||
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||
err := preConsumeUsage(c, info, usage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
usage = &dto.RealtimeUsage{}
|
||||
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
info.IsFirstRequest = false
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
if realtimeSession != nil {
|
||||
// update audio format
|
||||
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||
}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.OutputTokens += textToken + audioToken
|
||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = helper.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case receiveChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
select {
|
||||
case <-clientClosed:
|
||||
case <-targetClosed:
|
||||
case err := <-errChan:
|
||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||
logger.LogError(c, "realtime error: "+err.Error())
|
||||
case <-c.Done():
|
||||
}
|
||||
|
||||
if usage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, usage, sumUsage)
|
||||
}
|
||||
|
||||
if localUsage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
}
|
||||
|
||||
// check usage total tokens, if 0, use local usage
|
||||
|
||||
return nil, sumUsage
|
||||
}
|
||||
|
||||
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
||||
if usage == nil || totalUsage == nil {
|
||||
return fmt.Errorf("invalid usage pointer")
|
||||
}
|
||||
|
||||
totalUsage.TotalTokens += usage.TotalTokens
|
||||
totalUsage.InputTokens += usage.InputTokens
|
||||
totalUsage.OutputTokens += usage.OutputTokens
|
||||
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
||||
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
||||
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
||||
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
||||
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
||||
// clear usage
|
||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
)
|
||||
|
||||
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
||||
if info == nil || usage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch info.ChannelType {
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeMoonshot:
|
||||
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Usage struct {
|
||||
PromptTokensDetails struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
||||
return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.CachedTokens != nil {
|
||||
return *payload.Usage.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.PromptCacheHitTokens != nil {
|
||||
return *payload.Usage.PromptCacheHitTokens, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
||||
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
||||
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Usage struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"usage"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// 遍历choices查找cached_tokens
|
||||
for _, choice := range payload.Choices {
|
||||
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
||||
return *choice.Usage.CachedTokens, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
||||
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Timings struct {
|
||||
CachedTokens *int `json:"cache_n"`
|
||||
} `json:"timings"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Timings.CachedTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *payload.Timings.CachedTokens, true
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
||||
|
||||
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := helper.NewStreamScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
@@ -45,6 +45,7 @@ var claudeModelMap = map[string]string{
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
"claude-opus-4-7": "claude-opus-4-7",
|
||||
"claude-opus-4-8": "claude-opus-4-8",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user