Compare commits
102 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1baf4a6337 | |||
| 9816ad87e3 | |||
| 50249f581c | |||
| 0193018af6 | |||
| f449e06b9d | |||
| 79527c0ab1 | |||
| 41cd051ea9 | |||
| c04f82bfb5 | |||
| dafc7618c3 | |||
| 22692b3f87 | |||
| d36e892905 | |||
| 3cd1ba4673 | |||
| b7c0f754ad | |||
| a706f00287 | |||
| 7efb1922fe | |||
| 89fe99f3bd | |||
| e5b5331d3b | |||
| 18373c6eac | |||
| 5b47011e08 | |||
| ab99c30884 | |||
| 670abee2f0 | |||
| 8bb9a42f68 | |||
| d22f889e5d | |||
| 3734059da7 | |||
| 26ce873f8b | |||
| e099117c61 | |||
| 310d618a16 | |||
| 20399d3c8f | |||
| 53aeee4ff7 | |||
| 5238f279db | |||
| 5402bf417d | |||
| c766913baf | |||
| 40dc43f44e | |||
| 263b9bc695 | |||
| b2dd4acc9f | |||
| 4e492b26f6 | |||
| 82b750398c | |||
| fbf235d222 | |||
| 62b9aaa520 | |||
| 814a3f5124 | |||
| 22b6b16702 | |||
| 6154b8e3cd | |||
| ff66288e3a | |||
| 926e1781dd | |||
| d4a470a638 | |||
| 9f61407bf0 | |||
| dbf900a531 | |||
| 7399e4721b | |||
| a5e20269dd | |||
| 9ae9040b3c | |||
| 0191a68d4e | |||
| 16221f8279 | |||
| 763c3ff709 | |||
| c667e4706a | |||
| 216b94dac0 | |||
| 49eb533aaf | |||
| 7693edae53 | |||
| ded4a124e2 | |||
| d6982c8182 | |||
| 9ecad90652 | |||
| 929b5060ea | |||
| 755ece2f01 | |||
| f40eb4e5d2 | |||
| 45f65c297b | |||
| 6c074ef897 | |||
| deff59a5be | |||
| 3c516084f8 | |||
| 4d675b4d1f | |||
| 87b426f306 | |||
| 49db5147c3 | |||
| 13122aa0fa | |||
| dcd0911612 | |||
| e904579a5b | |||
| e80d867f38 | |||
| cf86fe5fea | |||
| 42846c692e | |||
| 1911520eba | |||
| 2c3ae32c8e | |||
| 64f41efc47 | |||
| 498199b37d | |||
| ff29900f30 | |||
| eff51857d0 | |||
| e9f8f62796 | |||
| 5fe8e98eeb | |||
| e520977efc | |||
| ed6ff0f267 | |||
| d955a0c080 | |||
| d096a2e5b7 | |||
| d2fb485d34 | |||
| 04f5dd0206 | |||
| ede0ad117b | |||
| 5bb8fe6af5 | |||
| a1a92c1918 | |||
| a4d1ed6da5 | |||
| 669e596ff7 | |||
| 1daeac42ef | |||
| e70bfa2d57 | |||
| b09337e6ed | |||
| bd09b47ef4 | |||
| d595ef4990 | |||
| 2270f63c00 | |||
| 202a433f86 |
@@ -1,12 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
@@ -27,9 +27,10 @@ jobs:
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Check out (shallow)
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -46,16 +47,16 @@ jobs:
|
||||
run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -63,14 +64,15 @@ jobs:
|
||||
|
||||
- name: Extract metadata (labels)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ env.GHCR_REPOSITORY }}
|
||||
|
||||
- name: Build & push single-arch (to both registries)
|
||||
uses: docker/build-push-action@v6
|
||||
id: build
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -83,8 +85,25 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@398d4b0eeef1380460a10c8013a76f728fb906ac # v3
|
||||
|
||||
- name: Sign image with cosign
|
||||
run: |
|
||||
cosign sign --yes calciumion/new-api@${{ steps.build.outputs.digest }}
|
||||
cosign sign --yes ghcr.io/${{ env.GHCR_REPOSITORY }}@${{ steps.build.outputs.digest }}
|
||||
|
||||
- name: Output digest
|
||||
run: |
|
||||
echo "### Docker Image Digest (${{ matrix.arch }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:alpha-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "ghcr.io/${{ env.GHCR_REPOSITORY }}:alpha-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.build.outputs.digest }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
create_manifests:
|
||||
name: Create multi-arch manifests (Docker Hub + GHCR)
|
||||
@@ -95,7 +114,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out (shallow)
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -110,7 +129,7 @@ jobs:
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -130,7 +149,7 @@ jobs:
|
||||
calciumion/new-api:${VERSION}-arm64
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -149,3 +168,12 @@ jobs:
|
||||
-t ghcr.io/${GHCR_REPOSITORY}:${VERSION} \
|
||||
ghcr.io/${GHCR_REPOSITORY}:${VERSION}-amd64 \
|
||||
ghcr.io/${GHCR_REPOSITORY}:${VERSION}-arm64
|
||||
|
||||
- name: Output manifest digest
|
||||
run: |
|
||||
echo "### Multi-arch Manifest Digests" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect calciumion/new-api:alpha >> $GITHUB_STEP_SUMMARY
|
||||
echo "---" >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect ghcr.io/${GHCR_REPOSITORY}:alpha >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -30,10 +30,11 @@ jobs:
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: ${{ github.event_name == 'workflow_dispatch' && 0 || 1 }}
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
@@ -59,16 +60,16 @@ jobs:
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@v3
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
@@ -76,14 +77,15 @@ jobs:
|
||||
|
||||
- name: Extract metadata (labels)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
# ghcr.io/${{ env.GHCR_REPOSITORY }}
|
||||
|
||||
- name: Build & push single-arch (to both registries)
|
||||
uses: docker/build-push-action@v6
|
||||
id: build
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -96,8 +98,22 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@398d4b0eeef1380460a10c8013a76f728fb906ac # v3
|
||||
|
||||
- name: Sign image with cosign
|
||||
run: cosign sign --yes calciumion/new-api@${{ steps.build.outputs.digest }}
|
||||
|
||||
- name: Output digest
|
||||
run: |
|
||||
echo "### Docker Image Digest (${{ matrix.arch }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:${{ env.TAG }}-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.build.outputs.digest }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
create_manifests:
|
||||
name: Create multi-arch manifests (Docker Hub)
|
||||
@@ -117,7 +133,7 @@ jobs:
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -136,9 +152,16 @@ jobs:
|
||||
calciumion/new-api:latest-amd64 \
|
||||
calciumion/new-api:latest-arm64
|
||||
|
||||
- name: Output manifest digest
|
||||
run: |
|
||||
echo "### Multi-arch Manifest" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect calciumion/new-api:${TAG} >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# ---- GHCR ----
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@v3
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
|
||||
@@ -19,14 +19,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend (amd64)
|
||||
@@ -50,12 +50,16 @@ jobs:
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
|
||||
- name: Generate checksums
|
||||
run: sha256sum new-api-* > checksums-linux.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
new-api-*
|
||||
checksums-linux.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -64,14 +68,14 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -84,18 +88,23 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-X 'new-api/common.Version=$VERSION'" -o new-api-macos-$VERSION
|
||||
- name: Generate checksums
|
||||
run: shasum -a 256 new-api-macos-* > checksums-macos.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-macos-*
|
||||
files: |
|
||||
new-api-macos-*
|
||||
checksums-macos.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -107,14 +116,14 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -126,17 +135,22 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION'" -o new-api-$VERSION.exe
|
||||
- name: Generate checksums
|
||||
run: sha256sum new-api-*.exe > checksums-windows.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-*.exe
|
||||
files: |
|
||||
new-api-*.exe
|
||||
checksums-windows.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
+3
-3
@@ -1,4 +1,4 @@
|
||||
FROM oven/bun:latest AS builder
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
@@ -8,7 +8,7 @@ COPY ./web .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
|
||||
FROM golang:alpine AS builder2
|
||||
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
|
||||
ARG TARGETOS
|
||||
@@ -25,7 +25,7 @@ COPY . .
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates tzdata libasan8 wget \
|
||||
|
||||
+1
-1
@@ -383,7 +383,7 @@ docker run --name new-api -d --restart always \
|
||||
2. 在应用商店搜索 **New-API**
|
||||
3. 一键安装
|
||||
|
||||
📖 [图文教程](./docs/BT.md)
|
||||
📖 [图文教程](./docs/installation/BT.md)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
+11
-8
@@ -70,17 +70,20 @@
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank">
|
||||
<img src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="80" />
|
||||
</a>
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://github.com/iOfficeAI/AionUi/" target="_blank">
|
||||
<img src="./docs/images/aionui.png" alt="Aion UI" height="80" />
|
||||
</a><!--
|
||||
--><a href="https://bda.pku.edu.cn/" target="_blank">
|
||||
<img src="./docs/images/pku.png" alt="北京大學" height="80" />
|
||||
</a>
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank">
|
||||
<img src="./docs/images/ucloud.png" alt="UCloud 優刻得" height="80" />
|
||||
</a>
|
||||
<a href="https://www.aliyun.com/" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://www.aliyun.com/" target="_blank">
|
||||
<img src="./docs/images/aliyun.png" alt="阿里雲" height="80" />
|
||||
</a>
|
||||
<a href="https://io.net/" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://io.net/" target="_blank">
|
||||
<img src="./docs/images/io-net.png" alt="IO.NET" height="80" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -177,6 +177,7 @@ var (
|
||||
DownloadRateLimitDuration int64 = 60
|
||||
|
||||
// Per-user search rate limit (applies after authentication, keyed by user ID)
|
||||
SearchRateLimitEnable = true
|
||||
SearchRateLimitNum = 10
|
||||
SearchRateLimitDuration int64 = 60
|
||||
)
|
||||
@@ -211,5 +212,6 @@ const (
|
||||
const (
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusFailed = "failed"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
|
||||
@@ -229,6 +229,7 @@ func init() {
|
||||
// Default implementation that returns the key as-is
|
||||
// This will be replaced by i18n.T during i18n initialization
|
||||
TranslateMessage = func(c *gin.Context, key string, args ...map[string]any) string {
|
||||
c.Header("X-Translate-id", "d5e7afdfc7f03414b941f9c1e7096be9966510e7")
|
||||
return key
|
||||
}
|
||||
}
|
||||
|
||||
+5
-1
@@ -120,6 +120,10 @@ func InitEnv() {
|
||||
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
|
||||
CriticalRateLimitNum = GetEnvOrDefault("CRITICAL_RATE_LIMIT", 20)
|
||||
CriticalRateLimitDuration = int64(GetEnvOrDefault("CRITICAL_RATE_LIMIT_DURATION", 20*60))
|
||||
|
||||
SearchRateLimitEnable = GetEnvOrDefaultBool("SEARCH_RATE_LIMIT_ENABLE", true)
|
||||
SearchRateLimitNum = GetEnvOrDefault("SEARCH_RATE_LIMIT", 10)
|
||||
SearchRateLimitDuration = int64(GetEnvOrDefault("SEARCH_RATE_LIMIT_DURATION", 60))
|
||||
initConstantEnv()
|
||||
}
|
||||
|
||||
@@ -127,7 +131,7 @@ func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 64)
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 128)
|
||||
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
|
||||
+15
-8
@@ -3,53 +3,60 @@ package common
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LogWriterMu protects concurrent access to gin.DefaultWriter/gin.DefaultErrorWriter
|
||||
// during log file rotation. Acquire RLock when reading/writing through the writers,
|
||||
// acquire Lock when swapping writers and closing old files.
|
||||
var LogWriterMu sync.RWMutex
|
||||
|
||||
func SysLog(s string) {
|
||||
t := time.Now()
|
||||
LogWriterMu.RLock()
|
||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
LogWriterMu.RUnlock()
|
||||
}
|
||||
|
||||
func SysError(s string) {
|
||||
t := time.Now()
|
||||
LogWriterMu.RLock()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
LogWriterMu.RUnlock()
|
||||
}
|
||||
|
||||
func FatalLog(v ...any) {
|
||||
t := time.Now()
|
||||
LogWriterMu.RLock()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||
LogWriterMu.RUnlock()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func LogStartupSuccess(startTime time.Time, port string) {
|
||||
|
||||
duration := time.Since(startTime)
|
||||
durationMs := duration.Milliseconds()
|
||||
|
||||
// Get network IPs
|
||||
networkIps := GetNetworkIps()
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
LogWriterMu.RLock()
|
||||
defer LogWriterMu.RUnlock()
|
||||
|
||||
// Print the main success message
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
|
||||
// Skip fancy startup message in container environments
|
||||
if !IsRunningInContainer() {
|
||||
// Print local URL
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
|
||||
}
|
||||
|
||||
// Print network URLs
|
||||
for _, ip := range networkIps {
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
|
||||
}
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
// WaffoPayMethod defines the display and API parameter mapping for Waffo payment methods.
|
||||
type WaffoPayMethod struct {
|
||||
Name string `json:"name"` // Frontend display name
|
||||
Icon string `json:"icon"` // Frontend icon identifier: credit-card, apple, google
|
||||
PayMethodType string `json:"payMethodType"` // Waffo API PayMethodType, can be comma-separated
|
||||
PayMethodName string `json:"payMethodName"` // Waffo API PayMethodName, empty means auto-select by Waffo checkout
|
||||
}
|
||||
|
||||
// DefaultWaffoPayMethods is the default list of supported payment methods.
|
||||
var DefaultWaffoPayMethods = []WaffoPayMethod{
|
||||
{Name: "Card", Icon: "/pay-card.png", PayMethodType: "CREDITCARD,DEBITCARD", PayMethodName: ""},
|
||||
{Name: "Apple Pay", Icon: "/pay-apple.png", PayMethodType: "APPLEPAY", PayMethodName: "APPLEPAY"},
|
||||
{Name: "Google Pay", Icon: "/pay-google.png", PayMethodType: "GOOGLEPAY", PayMethodName: "GOOGLEPAY"},
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -169,10 +170,7 @@ func collectPendingUpstreamModelChangesFromModels(
|
||||
upstreamSet[modelName] = struct{}{}
|
||||
}
|
||||
|
||||
ignoredSet := make(map[string]struct{})
|
||||
for _, modelName := range normalizeModelNames(ignoredModels) {
|
||||
ignoredSet[modelName] = struct{}{}
|
||||
}
|
||||
normalizedIgnoredModels := normalizeModelNames(ignoredModels)
|
||||
|
||||
redirectSourceSet := make(map[string]struct{}, len(modelMapping))
|
||||
redirectTargetSet := make(map[string]struct{}, len(modelMapping))
|
||||
@@ -193,7 +191,13 @@ func collectPendingUpstreamModelChangesFromModels(
|
||||
if _, ok := coveredUpstreamSet[modelName]; ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := ignoredSet[modelName]; ok {
|
||||
if lo.ContainsBy(normalizedIgnoredModels, func(ignoredModel string) bool {
|
||||
if regexBody, ok := strings.CutPrefix(ignoredModel, "regex:"); ok {
|
||||
matched, err := regexp.MatchString(strings.TrimSpace(regexBody), modelName)
|
||||
return err == nil && matched
|
||||
}
|
||||
return ignoredModel == modelName
|
||||
}) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -111,6 +111,18 @@ func TestCollectPendingUpstreamModelChangesFromModels_WithModelMapping(t *testin
|
||||
require.Equal(t, []string{"stale-model"}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestCollectPendingUpstreamModelChangesFromModels_WithIgnoredRegexPatterns(t *testing.T) {
|
||||
pendingAddModels, pendingRemoveModels := collectPendingUpstreamModelChangesFromModels(
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "claude-3-5-sonnet", "sora-video", "gpt-4.1"},
|
||||
[]string{"regex:^sora-.*$", "gpt-4.1"},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, []string{"claude-3-5-sonnet"}, pendingAddModels)
|
||||
require.Equal(t, []string{}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestBuildUpstreamModelUpdateTaskNotificationContent_OmitOverflowDetails(t *testing.T) {
|
||||
channelSummaries := make([]upstreamModelUpdateChannelSummary, 0, 12)
|
||||
for i := 0; i < 12; i++ {
|
||||
|
||||
+14
-21
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
@@ -116,7 +117,6 @@ func GetStatus(c *gin.Context) {
|
||||
"user_agreement_enabled": legalSetting.UserAgreement != "",
|
||||
"privacy_policy_enabled": legalSetting.PrivacyPolicy != "",
|
||||
"checkin_enabled": operation_setting.GetCheckinSetting().Enabled,
|
||||
"_qn": "new-api",
|
||||
}
|
||||
|
||||
// 根据启用状态注入可选内容
|
||||
@@ -308,31 +308,24 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if !model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该邮箱地址未注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
if model.IsEmailAlreadyTaken(email) {
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("failed to send password reset email to %s: %s", email, err.Error()))
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type PasswordResetRequest struct {
|
||||
|
||||
+3
-1
@@ -190,7 +190,9 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
}
|
||||
}
|
||||
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, gin.H{
|
||||
"action": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
// findOrCreateOAuthUser finds existing user or creates new user
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -169,6 +175,183 @@ func ForceGC(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// LogFileInfo 日志文件信息
|
||||
type LogFileInfo struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
ModTime time.Time `json:"mod_time"`
|
||||
}
|
||||
|
||||
// LogFilesResponse 日志文件列表响应
|
||||
type LogFilesResponse struct {
|
||||
LogDir string `json:"log_dir"`
|
||||
Enabled bool `json:"enabled"`
|
||||
FileCount int `json:"file_count"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
OldestTime *time.Time `json:"oldest_time,omitempty"`
|
||||
NewestTime *time.Time `json:"newest_time,omitempty"`
|
||||
Files []LogFileInfo `json:"files"`
|
||||
}
|
||||
|
||||
// getLogFiles 读取日志目录中的日志文件列表
|
||||
func getLogFiles() ([]LogFileInfo, error) {
|
||||
if *common.LogDir == "" {
|
||||
return nil, nil
|
||||
}
|
||||
entries, err := os.ReadDir(*common.LogDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var files []LogFileInfo
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "oneapi-") || !strings.HasSuffix(name, ".log") {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
files = append(files, LogFileInfo{
|
||||
Name: name,
|
||||
Size: info.Size(),
|
||||
ModTime: info.ModTime(),
|
||||
})
|
||||
}
|
||||
// 按文件名降序排列(最新在前)
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Name > files[j].Name
|
||||
})
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// GetLogFiles 获取日志文件列表
|
||||
func GetLogFiles(c *gin.Context) {
|
||||
if *common.LogDir == "" {
|
||||
common.ApiSuccess(c, LogFilesResponse{Enabled: false})
|
||||
return
|
||||
}
|
||||
files, err := getLogFiles()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var totalSize int64
|
||||
var oldest, newest time.Time
|
||||
for i, f := range files {
|
||||
totalSize += f.Size
|
||||
if i == 0 || f.ModTime.Before(oldest) {
|
||||
oldest = f.ModTime
|
||||
}
|
||||
if i == 0 || f.ModTime.After(newest) {
|
||||
newest = f.ModTime
|
||||
}
|
||||
}
|
||||
resp := LogFilesResponse{
|
||||
LogDir: *common.LogDir,
|
||||
Enabled: true,
|
||||
FileCount: len(files),
|
||||
TotalSize: totalSize,
|
||||
Files: files,
|
||||
}
|
||||
if len(files) > 0 {
|
||||
resp.OldestTime = &oldest
|
||||
resp.NewestTime = &newest
|
||||
}
|
||||
common.ApiSuccess(c, resp)
|
||||
}
|
||||
|
||||
// CleanupLogFiles 清理过期日志文件
|
||||
func CleanupLogFiles(c *gin.Context) {
|
||||
mode := c.Query("mode")
|
||||
valueStr := c.Query("value")
|
||||
if mode != "by_count" && mode != "by_days" {
|
||||
common.ApiErrorMsg(c, "invalid mode, must be by_count or by_days")
|
||||
return
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil || value < 1 {
|
||||
common.ApiErrorMsg(c, "invalid value, must be a positive integer")
|
||||
return
|
||||
}
|
||||
if *common.LogDir == "" {
|
||||
common.ApiErrorMsg(c, "log directory not configured")
|
||||
return
|
||||
}
|
||||
|
||||
files, err := getLogFiles()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
activeLogPath := logger.GetCurrentLogPath()
|
||||
var toDelete []LogFileInfo
|
||||
|
||||
switch mode {
|
||||
case "by_count":
|
||||
// files 已按名称降序(最新在前),保留前 value 个
|
||||
for i, f := range files {
|
||||
if i < value {
|
||||
continue
|
||||
}
|
||||
fullPath := filepath.Join(*common.LogDir, f.Name)
|
||||
if fullPath == activeLogPath {
|
||||
continue
|
||||
}
|
||||
toDelete = append(toDelete, f)
|
||||
}
|
||||
case "by_days":
|
||||
cutoff := time.Now().AddDate(0, 0, -value)
|
||||
for _, f := range files {
|
||||
if f.ModTime.Before(cutoff) {
|
||||
fullPath := filepath.Join(*common.LogDir, f.Name)
|
||||
if fullPath == activeLogPath {
|
||||
continue
|
||||
}
|
||||
toDelete = append(toDelete, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var deletedCount int
|
||||
var freedBytes int64
|
||||
var failedFiles []string
|
||||
for _, f := range toDelete {
|
||||
fullPath := filepath.Join(*common.LogDir, f.Name)
|
||||
if err := os.Remove(fullPath); err != nil {
|
||||
failedFiles = append(failedFiles, f.Name)
|
||||
continue
|
||||
}
|
||||
deletedCount++
|
||||
freedBytes += f.Size
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"deleted_count": deletedCount,
|
||||
"freed_bytes": freedBytes,
|
||||
"failed_files": failedFiles,
|
||||
}
|
||||
|
||||
if len(failedFiles) > 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("部分文件删除失败(%d/%d)", len(failedFiles), len(toDelete)),
|
||||
"data": result,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
// getDiskCacheInfo 获取磁盘缓存目录信息
|
||||
func getDiskCacheInfo() DiskCacheInfo {
|
||||
// 使用统一的缓存目录
|
||||
|
||||
@@ -46,7 +46,7 @@ func GetPricing(c *gin.Context) {
|
||||
"usable_group": usableGroup,
|
||||
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||
"auto_groups": service.GetUserAutoGroup(group),
|
||||
"_": "a42d372ccf0b5dd13ecf71203521f9d2",
|
||||
"pricing_version": "a42d372ccf0b5dd13ecf71203521f9d2",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+68
-14
@@ -48,14 +48,52 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果启用了 Waffo 支付,添加到支付方法列表
|
||||
enableWaffo := setting.WaffoEnabled &&
|
||||
((!setting.WaffoSandbox &&
|
||||
setting.WaffoApiKey != "" &&
|
||||
setting.WaffoPrivateKey != "" &&
|
||||
setting.WaffoPublicCert != "") ||
|
||||
(setting.WaffoSandbox &&
|
||||
setting.WaffoSandboxApiKey != "" &&
|
||||
setting.WaffoSandboxPrivateKey != "" &&
|
||||
setting.WaffoSandboxPublicCert != ""))
|
||||
if enableWaffo {
|
||||
hasWaffo := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == "waffo" {
|
||||
hasWaffo = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasWaffo {
|
||||
waffoMethod := map[string]string{
|
||||
"name": "Waffo (Global Payment)",
|
||||
"type": "waffo",
|
||||
"color": "rgba(var(--semi-blue-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoMinTopUp),
|
||||
}
|
||||
payMethods = append(payMethods, waffoMethod)
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"enable_creem_topup": setting.CreemApiKey != "" && setting.CreemProducts != "[]",
|
||||
"creem_products": setting.CreemProducts,
|
||||
"enable_waffo_topup": enableWaffo,
|
||||
"waffo_pay_methods": func() interface{} {
|
||||
if enableWaffo {
|
||||
return setting.GetWaffoPayMethods()
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
"creem_products": setting.CreemProducts,
|
||||
"pay_methods": payMethods,
|
||||
"min_topup": operation_setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"waffo_min_topup": setting.WaffoMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
}
|
||||
@@ -204,27 +242,42 @@ func RequestEpay(c *gin.Context) {
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// refCountedMutex 带引用计数的互斥锁,确保最后一个使用者才从 map 中删除
|
||||
type refCountedMutex struct {
|
||||
mu sync.Mutex
|
||||
refCount int
|
||||
}
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
createLock.Lock()
|
||||
var rcm *refCountedMutex
|
||||
if v, ok := orderLocks.Load(tradeNo); ok {
|
||||
rcm = v.(*refCountedMutex)
|
||||
} else {
|
||||
rcm = &refCountedMutex{}
|
||||
orderLocks.Store(tradeNo, rcm)
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
rcm.refCount++
|
||||
createLock.Unlock()
|
||||
rcm.mu.Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
v, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
rcm := v.(*refCountedMutex)
|
||||
rcm.mu.Unlock()
|
||||
|
||||
createLock.Lock()
|
||||
rcm.refCount--
|
||||
if rcm.refCount == 0 {
|
||||
orderLocks.Delete(tradeNo)
|
||||
}
|
||||
createLock.Unlock()
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
@@ -410,3 +463,4 @@ func AdminCompleteTopUp(c *gin.Context) {
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/thanhpk/randstr"
|
||||
waffo "github.com/waffo-com/waffo-go"
|
||||
"github.com/waffo-com/waffo-go/config"
|
||||
"github.com/waffo-com/waffo-go/core"
|
||||
"github.com/waffo-com/waffo-go/types/order"
|
||||
)
|
||||
|
||||
func getWaffoSDK() (*waffo.Waffo, error) {
|
||||
env := config.Sandbox
|
||||
apiKey := setting.WaffoSandboxApiKey
|
||||
privateKey := setting.WaffoSandboxPrivateKey
|
||||
publicKey := setting.WaffoSandboxPublicCert
|
||||
if !setting.WaffoSandbox {
|
||||
env = config.Production
|
||||
apiKey = setting.WaffoApiKey
|
||||
privateKey = setting.WaffoPrivateKey
|
||||
publicKey = setting.WaffoPublicCert
|
||||
}
|
||||
builder := config.NewConfigBuilder().
|
||||
APIKey(apiKey).
|
||||
PrivateKey(privateKey).
|
||||
WaffoPublicKey(publicKey).
|
||||
Environment(env)
|
||||
if setting.WaffoMerchantId != "" {
|
||||
builder = builder.MerchantID(setting.WaffoMerchantId)
|
||||
}
|
||||
cfg, err := builder.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return waffo.New(cfg), nil
|
||||
}
|
||||
|
||||
func getWaffoUserEmail(user *model.User) string {
|
||||
return fmt.Sprintf("%d@examples.com", user.Id)
|
||||
}
|
||||
|
||||
func getWaffoCurrency() string {
|
||||
if setting.WaffoCurrency != "" {
|
||||
return setting.WaffoCurrency
|
||||
}
|
||||
return "USD"
|
||||
}
|
||||
|
||||
// zeroDecimalCurrencies 零小数位币种,金额不能带小数点
|
||||
var zeroDecimalCurrencies = map[string]bool{
|
||||
"IDR": true, "JPY": true, "KRW": true, "VND": true,
|
||||
}
|
||||
|
||||
func formatWaffoAmount(amount float64, currency string) string {
|
||||
if zeroDecimalCurrencies[currency] {
|
||||
return fmt.Sprintf("%.0f", amount)
|
||||
}
|
||||
return fmt.Sprintf("%.2f", amount)
|
||||
}
|
||||
|
||||
// getWaffoPayMoney converts the user-facing amount to USD for Waffo payment.
|
||||
// Waffo only accepts USD, so this function handles the conversion from different
|
||||
// display types (USD/CNY/TOKENS) to the actual USD amount to charge.
|
||||
func getWaffoPayMoney(amount float64, group string) float64 {
|
||||
originalAmount := amount
|
||||
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
return amount * setting.WaffoUnitPrice * topupGroupRatio * discount
|
||||
}
|
||||
|
||||
type WaffoPayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PayMethodIndex *int `json:"pay_method_index"` // 服务端支付方式列表的索引,nil 表示由 Waffo 自动选择
|
||||
PayMethodType string `json:"pay_method_type"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
|
||||
PayMethodName string `json:"pay_method_name"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
|
||||
}
|
||||
|
||||
// RequestWaffoPay 创建 Waffo 支付订单
|
||||
func RequestWaffoPay(c *gin.Context) {
|
||||
if !setting.WaffoEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "Waffo 支付未启用"})
|
||||
return
|
||||
}
|
||||
|
||||
var req WaffoPayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
waffoMinTopup := int64(setting.WaffoMinTopUp)
|
||||
if req.Amount < waffoMinTopup {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil || user == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 从服务端配置查找支付方式,客户端只传索引或旧字段
|
||||
var resolvedPayMethodType, resolvedPayMethodName string
|
||||
methods := setting.GetWaffoPayMethods()
|
||||
if req.PayMethodIndex != nil {
|
||||
// 新协议:按索引查找
|
||||
idx := *req.PayMethodIndex
|
||||
if idx < 0 || idx >= len(methods) {
|
||||
log.Printf("Waffo 无效的支付方式索引: %d, UserId=%d, 可用范围: [0, %d)", idx, id, len(methods))
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
return
|
||||
}
|
||||
resolvedPayMethodType = methods[idx].PayMethodType
|
||||
resolvedPayMethodName = methods[idx].PayMethodName
|
||||
} else if req.PayMethodType != "" {
|
||||
// 兼容旧前端:验证客户端传的值在服务端列表中
|
||||
valid := false
|
||||
for _, m := range methods {
|
||||
if m.PayMethodType == req.PayMethodType && m.PayMethodName == req.PayMethodName {
|
||||
valid = true
|
||||
resolvedPayMethodType = m.PayMethodType
|
||||
resolvedPayMethodName = m.PayMethodName
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
log.Printf("Waffo 无效的支付方式: PayMethodType=%s, PayMethodName=%s, UserId=%d", req.PayMethodType, req.PayMethodName, id)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
return
|
||||
}
|
||||
}
|
||||
// resolvedPayMethodType/Name 为空时,Waffo 自动选择支付方式
|
||||
|
||||
group, _ := model.GetUserGroup(id, true)
|
||||
payMoney := getWaffoPayMoney(float64(req.Amount), group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成唯一订单号,paymentRequestId 与 merchantOrderId 保持一致,简化追踪
|
||||
merchantOrderId := fmt.Sprintf("WAFFO-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
|
||||
paymentRequestId := merchantOrderId
|
||||
|
||||
// Token 模式下归一化 Amount(存等价美元/CNY 数量,避免 RechargeWaffo 双重放大)
|
||||
amount := req.Amount
|
||||
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
|
||||
amount = int64(float64(req.Amount) / common.QuotaPerUnit)
|
||||
if amount < 1 {
|
||||
amount = 1
|
||||
}
|
||||
}
|
||||
|
||||
// 创建本地订单
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: "waffo",
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
log.Printf("Waffo 创建本地订单失败: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
sdk, err := getWaffoSDK()
|
||||
if err != nil {
|
||||
log.Printf("Waffo SDK 初始化失败: %v", err)
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付配置错误"})
|
||||
return
|
||||
}
|
||||
|
||||
callbackAddr := service.GetCallbackAddress()
|
||||
notifyUrl := callbackAddr + "/api/waffo/webhook"
|
||||
if setting.WaffoNotifyUrl != "" {
|
||||
notifyUrl = setting.WaffoNotifyUrl
|
||||
}
|
||||
returnUrl := system_setting.ServerAddress + "/console/topup?show_history=true"
|
||||
if setting.WaffoReturnUrl != "" {
|
||||
returnUrl = setting.WaffoReturnUrl
|
||||
}
|
||||
|
||||
currency := getWaffoCurrency()
|
||||
createParams := &order.CreateOrderParams{
|
||||
PaymentRequestID: paymentRequestId,
|
||||
MerchantOrderID: merchantOrderId,
|
||||
OrderAmount: formatWaffoAmount(payMoney, currency),
|
||||
OrderCurrency: currency,
|
||||
OrderDescription: fmt.Sprintf("Recharge %d credits", req.Amount),
|
||||
OrderRequestedAt: time.Now().UTC().Format("2006-01-02T15:04:05.000Z"),
|
||||
NotifyURL: notifyUrl,
|
||||
MerchantInfo: &order.MerchantInfo{
|
||||
MerchantID: setting.WaffoMerchantId,
|
||||
},
|
||||
UserInfo: &order.UserInfo{
|
||||
UserID: strconv.Itoa(user.Id),
|
||||
UserEmail: getWaffoUserEmail(user),
|
||||
UserTerminal: "WEB",
|
||||
},
|
||||
PaymentInfo: &order.PaymentInfo{
|
||||
ProductName: "ONE_TIME_PAYMENT",
|
||||
PayMethodType: resolvedPayMethodType,
|
||||
PayMethodName: resolvedPayMethodName,
|
||||
},
|
||||
SuccessRedirectURL: returnUrl,
|
||||
FailedRedirectURL: returnUrl,
|
||||
}
|
||||
resp, err := sdk.Order().Create(c.Request.Context(), createParams, nil)
|
||||
if err != nil {
|
||||
log.Printf("Waffo 创建订单失败: %v", err)
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
if !resp.IsSuccess() {
|
||||
log.Printf("Waffo 创建订单业务失败: [%s] %s, 完整响应: %+v", resp.Code, resp.Message, resp)
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
orderData := resp.GetData()
|
||||
log.Printf("Waffo 订单创建成功 - 用户: %d, 订单: %s, 金额: %.2f", id, merchantOrderId, payMoney)
|
||||
|
||||
paymentUrl := orderData.FetchRedirectURL()
|
||||
if paymentUrl == "" {
|
||||
paymentUrl = orderData.OrderAction
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payment_url": paymentUrl,
|
||||
"order_id": merchantOrderId,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// webhookPayloadWithSubInfo 扩展 PAYMENT_NOTIFICATION,包含 SDK 未定义的 subscriptionInfo 字段
|
||||
type webhookPayloadWithSubInfo struct {
|
||||
EventType string `json:"eventType"`
|
||||
Result struct {
|
||||
core.PaymentNotificationResult
|
||||
SubscriptionInfo *webhookSubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
} `json:"result"`
|
||||
}
|
||||
|
||||
type webhookSubscriptionInfo struct {
|
||||
Period string `json:"period,omitempty"`
|
||||
MerchantRequest string `json:"merchantRequest,omitempty"`
|
||||
SubscriptionID string `json:"subscriptionId,omitempty"`
|
||||
SubscriptionRequest string `json:"subscriptionRequest,omitempty"`
|
||||
}
|
||||
|
||||
// WaffoWebhook 处理 Waffo 回调通知(支付/退款/订阅)
|
||||
func WaffoWebhook(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("Waffo Webhook 读取 body 失败: %v", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sdk, err := getWaffoSDK()
|
||||
if err != nil {
|
||||
log.Printf("Waffo Webhook SDK 初始化失败: %v", err)
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
wh := sdk.Webhook()
|
||||
bodyStr := string(bodyBytes)
|
||||
signature := c.GetHeader("X-SIGNATURE")
|
||||
|
||||
// 验证请求签名
|
||||
if !wh.VerifySignature(bodyStr, signature) {
|
||||
log.Printf("Waffo webhook 签名验证失败")
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var event core.WebhookEvent
|
||||
if err := common.Unmarshal(bodyBytes, &event); err != nil {
|
||||
log.Printf("Waffo Webhook 解析失败: %v", err)
|
||||
sendWaffoWebhookResponse(c, wh, false, "invalid payload")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.EventType {
|
||||
case core.EventPayment:
|
||||
// 解析为扩展类型,区分普通支付和订阅支付
|
||||
var payload webhookPayloadWithSubInfo
|
||||
if err := common.Unmarshal(bodyBytes, &payload); err != nil {
|
||||
sendWaffoWebhookResponse(c, wh, false, "invalid payment payload")
|
||||
return
|
||||
}
|
||||
log.Printf("Waffo Webhook - EventType: %s, MerchantOrderId: %s, OrderStatus: %s",
|
||||
event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus)
|
||||
handleWaffoPayment(c, wh, &payload.Result.PaymentNotificationResult)
|
||||
default:
|
||||
log.Printf("Waffo Webhook 未知事件: %s", event.EventType)
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaffoPayment 处理支付完成通知
|
||||
func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.PaymentNotificationResult) {
|
||||
if result.OrderStatus != "PAY_SUCCESS" {
|
||||
log.Printf("Waffo 订单状态非成功: %s, 订单: %s", result.OrderStatus, result.MerchantOrderID)
|
||||
// 终态失败订单标记为 failed,避免永远停在 pending
|
||||
if result.MerchantOrderID != "" {
|
||||
if topUp := model.GetTopUpByTradeNo(result.MerchantOrderID); topUp != nil &&
|
||||
topUp.Status == common.TopUpStatusPending {
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
}
|
||||
}
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
return
|
||||
}
|
||||
|
||||
merchantOrderId := result.MerchantOrderID
|
||||
|
||||
LockOrder(merchantOrderId)
|
||||
defer UnlockOrder(merchantOrderId)
|
||||
|
||||
if err := model.RechargeWaffo(merchantOrderId); err != nil {
|
||||
log.Printf("Waffo 充值处理失败: %v, 订单: %s", err, merchantOrderId)
|
||||
sendWaffoWebhookResponse(c, wh, false, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Waffo 充值成功 - 订单: %s", merchantOrderId)
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
}
|
||||
|
||||
// sendWaffoWebhookResponse 发送签名响应
|
||||
func sendWaffoWebhookResponse(c *gin.Context, wh *core.WebhookHandler, success bool, msg string) {
|
||||
var body, sig string
|
||||
if success {
|
||||
body, sig = wh.BuildSuccessResponse()
|
||||
} else {
|
||||
body, sig = wh.BuildFailedResponse(msg)
|
||||
}
|
||||
c.Header("X-SIGNATURE", sig)
|
||||
c.Data(http.StatusOK, "application/json", []byte(body))
|
||||
}
|
||||
+12
-2
@@ -925,9 +925,19 @@ func ManageUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type emailBindRequest struct {
|
||||
Email string `json:"email"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
func EmailBind(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
code := c.Query("code")
|
||||
var req emailBindRequest
|
||||
if err := common.DecodeJson(c.Request.Body, &req); err != nil {
|
||||
common.ApiError(c, errors.New("invalid request body"))
|
||||
return
|
||||
}
|
||||
email := req.Email
|
||||
code := req.Code
|
||||
if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) {
|
||||
common.ApiErrorI18n(c, i18n.MsgUserVerificationCodeError)
|
||||
return
|
||||
|
||||
@@ -10,10 +10,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -127,6 +129,13 @@ func VideoProxy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(videoURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Video URL blocked for task %s: %v", taskID, err))
|
||||
videoProxyError(c, http.StatusForbidden, "server_error", fmt.Sprintf("request blocked: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
req.URL, err = url.Parse(videoURL)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
|
||||
|
||||
+15
-2
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -25,7 +26,7 @@ func getWeChatIdByCode(code string) (string, error) {
|
||||
if code == "" {
|
||||
return "", errors.New("无效的参数")
|
||||
}
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, url.QueryEscape(code)), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -121,6 +122,10 @@ func WeChatAuth(c *gin.Context) {
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
type wechatBindRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
func WeChatBind(c *gin.Context) {
|
||||
if !common.WeChatAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -129,7 +134,15 @@ func WeChatBind(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
var req wechatBindRequest
|
||||
if err := common.DecodeJson(c.Request.Body, &req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的请求",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := req.Code
|
||||
wechatId, err := getWeChatIdByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
+150
-2
@@ -1,3 +1,151 @@
|
||||
密钥为环境变量SESSION_SECRET
|
||||
# 宝塔面板部署教程
|
||||
|
||||
本文档提供使用宝塔面板 Docker 功能部署 New API 的图文教程。
|
||||
|
||||
> 📖 官方文档:[宝塔面板部署](https://docs.newapi.pro/zh/docs/installation/deployment-methods/bt-docker-installation)
|
||||
|
||||
***
|
||||
|
||||
## 前置要求
|
||||
|
||||
| 项目 | 要求 |
|
||||
| ----- | ---------------------------------- |
|
||||
| 宝塔面板 | ≥ 9.2.0 版本 |
|
||||
| 推荐系统 | CentOS 7+、Ubuntu 18.04+、Debian 10+ |
|
||||
| 服务器配置 | 至少 1 核 2G 内存 |
|
||||
|
||||
***
|
||||
|
||||
## 步骤一:安装宝塔面板
|
||||
|
||||
1. 前往 [宝塔面板官网](https://www.bt.cn/new/download.html) 下载适合您系统的安装脚本
|
||||
2. 运行安装脚本安装宝塔面板
|
||||
3. 安装完成后,使用提供的地址、用户名和密码登录宝塔面板
|
||||
|
||||
***
|
||||
|
||||
## 步骤二:安装 Docker
|
||||
|
||||
1. 登录宝塔面板后,在左侧菜单栏找到并点击 **Docker**
|
||||
2. 首次进入会提示安装 Docker 服务,点击 **立即安装**
|
||||
3. 按照提示完成 Docker 服务的安装
|
||||
|
||||
***
|
||||
|
||||
## 步骤三:安装 New API
|
||||
|
||||
### 方法一:使用宝塔应用商店(推荐)
|
||||
|
||||
1. 在宝塔面板 Docker 功能中,点击 **应用商店**
|
||||
2. 搜索并找到 **New-API**
|
||||
3. 点击 **安装**
|
||||
4. 配置以下基本选项:
|
||||
- **容器名称**:可自定义,默认为 `new-api`
|
||||
- **端口映射**:默认为 `3000:3000`
|
||||
- **环境变量**:
|
||||
- `SESSION_SECRET`:会话密钥(**必填**,多机部署时必须一致)
|
||||
- `CRYPTO_SECRET`:加密密钥(使用 Redis 时必填)
|
||||
5. 点击 **确认** 开始安装
|
||||
6. 等待安装完成后,访问 `http://您的服务器IP:3000` 即可使用
|
||||
|
||||
### 方法二:使用 Docker Compose
|
||||
|
||||
1. 在宝塔面板中创建网站目录,如 `/www/wwwroot/new-api`
|
||||
2. 创建 `docker-compose.yml` 文件:
|
||||
|
||||
```yaml
|
||||
version: '3'
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
environment:
|
||||
- SESSION_SECRET=your_session_secret_here # 请修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
```
|
||||
|
||||
1. 在终端中进入目录并启动:
|
||||
|
||||
```bash
|
||||
cd /www/wwwroot/new-api
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 必要环境变量
|
||||
|
||||
| 变量名 | 说明 | 是否必填 |
|
||||
| ------------------- | ------------------ | ------ |
|
||||
| `SESSION_SECRET` | 会话密钥,多机部署必须一致 | **必填** |
|
||||
| `CRYPTO_SECRET` | 加密密钥,使用 Redis 时必填 | 条件必填 |
|
||||
| `SQL_DSN` | 数据库连接字符串(使用外部数据库时) | 可选 |
|
||||
| `REDIS_CONN_STRING` | Redis 连接字符串 | 可选 |
|
||||
|
||||
### 生成随机密钥
|
||||
|
||||
```bash
|
||||
# 生成 SESSION_SECRET
|
||||
openssl rand -hex 16
|
||||
|
||||
# 或使用 Linux 命令
|
||||
head -c 16 /dev/urandom | xxd -p
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q1:无法访问 3000 端口?
|
||||
|
||||
1. 检查服务器防火墙是否开放 3000 端口
|
||||
2. 在宝塔面板 **安全** 中放行 3000 端口
|
||||
3. 检查云服务器安全组是否开放端口
|
||||
|
||||
### Q2:登录后提示会话失效?
|
||||
|
||||
确保设置了 `SESSION_SECRET` 环境变量,且值不为空。
|
||||
|
||||
### Q3:数据如何持久化?
|
||||
|
||||
使用 Docker 卷映射数据目录:
|
||||
|
||||
```yaml
|
||||
volumes:
|
||||
- ./data:/data
|
||||
```
|
||||
|
||||
### Q4:如何更新版本?
|
||||
|
||||
```bash
|
||||
# 拉取最新镜像
|
||||
docker pull calciumion/new-api:latest
|
||||
|
||||
# 重启容器
|
||||
docker-compose down && docker-compose up -d
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 相关链接
|
||||
|
||||
- [官方文档](https://docs.newapi.pro/zh/docs/installation)
|
||||
- [环境变量配置](https://docs.newapi.pro/zh/docs/installation/config-maintenance/environment-variables)
|
||||
- [常见问题](https://docs.newapi.pro/zh/docs/support/faq)
|
||||
- [GitHub 仓库](https://github.com/QuantumNous/new-api)
|
||||
|
||||
***
|
||||
|
||||
## 截图示例
|
||||
|
||||

|
||||
|
||||
> ⚠️ 注意:密钥为环境变量 `SESSION_SECRET`,请务必设置!
|
||||
|
||||

|
||||
|
||||
+5
-6
@@ -148,15 +148,14 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
}
|
||||
|
||||
// not support token count for dalle
|
||||
n := uint(1)
|
||||
if i.N != nil {
|
||||
n = *i.N
|
||||
}
|
||||
// n is NOT included here; it is handled via OtherRatio("n") in
|
||||
// image_handler.go (default) or channel adaptors (actual count).
|
||||
// Including n here caused double-counting for channels that also
|
||||
// set OtherRatio("n") (e.g. Ali/Bailian).
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: i.Prompt,
|
||||
MaxTokens: 1584,
|
||||
ImagePriceRatio: sizeRatio * qualityRatio * float64(n),
|
||||
ImagePriceRatio: sizeRatio * qualityRatio,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -393,7 +393,7 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
|
||||
|
||||
type MessageImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
MimeType string
|
||||
}
|
||||
|
||||
|
||||
@@ -220,10 +220,12 @@ type CompletionsStreamResponse struct {
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||
UsageSemantic string `json:"usage_semantic,omitempty"`
|
||||
UsageSource string `json:"usage_source,omitempty"`
|
||||
|
||||
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
||||
@@ -251,7 +253,7 @@ type OpenAIVideoResponse struct {
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
CachedCreationTokens int `json:"-"`
|
||||
CachedCreationTokens int `json:"cached_creation_tokens,omitempty"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
|
||||
+3
-3
@@ -3948,9 +3948,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
@@ -46,13 +46,14 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
github.com/waffo-com/waffo-go v1.3.1
|
||||
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/image v0.38.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
golang.org/x/text v0.35.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
@@ -120,7 +121,6 @@ require (
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/samber/go-singleflightx v0.3.2 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/DmitriyVTitov/size v1.5.0 h1:/PzqxYrOyOUX1BXj6J9OuVRVGe+66VL4D9FlUaW515g=
|
||||
@@ -10,34 +12,18 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0 h1:TDKR8ACRw7G+GFaQlhoy6biu+8q6ZtSddQCy9avMdMI=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0/go.mod h1:XlhOh5Ax/lesqN4aZCUgj9vVJed5VoXYHHFYGAlJEwU=
|
||||
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
|
||||
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
@@ -58,7 +44,6 @@ github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gE
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -132,12 +117,13 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU=
|
||||
github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -186,8 +172,6 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
|
||||
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
@@ -245,7 +229,6 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -262,8 +245,9 @@ github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoG
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/samber/go-singleflightx v0.3.2 h1:jXbUU0fvis8Fdv4HGONboX5WdEZcYLoBEcKiE+ITCyQ=
|
||||
github.com/samber/go-singleflightx v0.3.2/go.mod h1:X2BR+oheHIYc73PvxRMlcASg6KYYTQyUYpdVU7t/ux4=
|
||||
github.com/samber/hot v0.11.0 h1:JhV9hk8SmZIqB0To8OyCzPubvszkuoSXWx/7FCEGO+Q=
|
||||
@@ -320,6 +304,8 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/waffo-com/waffo-go v1.3.1 h1:NCYD3oQ59DTJj1bwS5T/659LI4h8PuAIW4Qj/w7fKPw=
|
||||
github.com/waffo-com/waffo-go v1.3.1/go.mod h1:IaXVYq6mmYtrLFFsLxPslNwuIZx0mIadWWjhe+eWb0g=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
@@ -330,6 +316,8 @@ github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFi
|
||||
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.21.0 h1:iTC9o7+wP6cPWpDWkivCvQFGAHDQ59SrSxsLPcnkArw=
|
||||
golang.org/x/arch v0.21.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
@@ -337,18 +325,16 @@ golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/image v0.38.0 h1:5l+q+Y9JDC7mBOMjo4/aPhMDcxEptsX+Tt3GgRQRPuE=
|
||||
golang.org/x/image v0.38.0/go.mod h1:/3f6vaXC+6CEanU4KJxbcUZyEePbyKbaLoDOe4ehFYY=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -367,19 +353,14 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
+26
-4
@@ -29,6 +29,15 @@ const maxLogCount = 1000000
|
||||
var logCount int
|
||||
var setupLogLock sync.Mutex
|
||||
var setupLogWorking bool
|
||||
var currentLogPath string
|
||||
var currentLogPathMu sync.RWMutex
|
||||
var currentLogFile *os.File
|
||||
|
||||
func GetCurrentLogPath() string {
|
||||
currentLogPathMu.RLock()
|
||||
defer currentLogPathMu.RUnlock()
|
||||
return currentLogPath
|
||||
}
|
||||
|
||||
func SetupLogger() {
|
||||
defer func() {
|
||||
@@ -48,8 +57,19 @@ func SetupLogger() {
|
||||
if err != nil {
|
||||
log.Fatal("failed to open log file")
|
||||
}
|
||||
currentLogPathMu.Lock()
|
||||
oldFile := currentLogFile
|
||||
currentLogPath = logPath
|
||||
currentLogFile = fd
|
||||
currentLogPathMu.Unlock()
|
||||
|
||||
common.LogWriterMu.Lock()
|
||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||
if oldFile != nil {
|
||||
_ = oldFile.Close()
|
||||
}
|
||||
common.LogWriterMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,16 +95,18 @@ func LogDebug(ctx context.Context, msg string, args ...any) {
|
||||
}
|
||||
|
||||
func logHelper(ctx context.Context, level string, msg string) {
|
||||
writer := gin.DefaultErrorWriter
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(common.RequestIdKey)
|
||||
if id == nil {
|
||||
id = "SYSTEM"
|
||||
}
|
||||
now := time.Now()
|
||||
common.LogWriterMu.RLock()
|
||||
writer := gin.DefaultErrorWriter
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||
common.LogWriterMu.RUnlock()
|
||||
logCount++ // we don't need accurate count, so no lock here
|
||||
if logCount > maxLogCount && !setupLogWorking {
|
||||
logCount = 0
|
||||
|
||||
@@ -101,8 +101,13 @@ func Distribute() func(c *gin.Context) {
|
||||
|
||||
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
|
||||
preferred, err := model.CacheGetChannel(preferredChannelID)
|
||||
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
|
||||
if usingGroup == "auto" {
|
||||
if err == nil && preferred != nil {
|
||||
if preferred.Status != common.ChannelStatusEnabled {
|
||||
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
|
||||
return
|
||||
}
|
||||
} else if usingGroup == "auto" {
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
autoGroups := service.GetUserAutoGroup(userGroup)
|
||||
for _, g := range autoGroups {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -48,17 +48,23 @@ func checkSystemPerformance() *types.NewAPIError {
|
||||
|
||||
// 检查 CPU
|
||||
if config.CPUThreshold > 0 && int(status.CPUUsage) > config.CPUThreshold {
|
||||
return types.NewErrorWithStatusCode(errors.New("system cpu overloaded"), "system_cpu_overloaded", http.StatusServiceUnavailable)
|
||||
return types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("system cpu overloaded (current: %.1f%%, threshold: %d%%)", status.CPUUsage, config.CPUThreshold),
|
||||
"system_cpu_overloaded", http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
// 检查内存
|
||||
if config.MemoryThreshold > 0 && int(status.MemoryUsage) > config.MemoryThreshold {
|
||||
return types.NewErrorWithStatusCode(errors.New("system memory overloaded"), "system_memory_overloaded", http.StatusServiceUnavailable)
|
||||
return types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("system memory overloaded (current: %.1f%%, threshold: %d%%)", status.MemoryUsage, config.MemoryThreshold),
|
||||
"system_memory_overloaded", http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
// 检查磁盘
|
||||
if config.DiskThreshold > 0 && int(status.DiskUsage) > config.DiskThreshold {
|
||||
return types.NewErrorWithStatusCode(errors.New("system disk overloaded"), "system_disk_overloaded", http.StatusServiceUnavailable)
|
||||
return types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("system disk overloaded (current: %.1f%%, threshold: %d%%)", status.DiskUsage, config.DiskThreshold),
|
||||
"system_disk_overloaded", http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -196,7 +196,10 @@ func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key
|
||||
}
|
||||
|
||||
// SearchRateLimit returns a per-user rate limiter for search endpoints.
|
||||
// 10 requests per 60 seconds per user (by user ID, not IP).
|
||||
// Configurable via SEARCH_RATE_LIMIT_ENABLE / SEARCH_RATE_LIMIT / SEARCH_RATE_LIMIT_DURATION.
|
||||
func SearchRateLimit() func(c *gin.Context) {
|
||||
if !common.SearchRateLimitEnable {
|
||||
return defNext
|
||||
}
|
||||
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
|
||||
}
|
||||
|
||||
@@ -2,14 +2,25 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var _bp = func() string {
|
||||
if bi, ok := debug.ReadBuildInfo(); ok && bi.Main.Path != "" {
|
||||
h := sha256.Sum256([]byte(bi.Main.Path))
|
||||
return hex.EncodeToString(h[:4])
|
||||
}
|
||||
return common.GetRandomString(8)
|
||||
}()
|
||||
|
||||
func RequestId() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
id := common.GetTimeString() + common.GetRandomString(8)
|
||||
id := common.GetTimeString() + _bp + common.GetRandomString(8)
|
||||
c.Set(common.RequestIdKey, id)
|
||||
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
+2
-1
@@ -58,7 +58,8 @@ func formatUserLogs(logs []*Log, startIdx int) {
|
||||
if otherMap != nil {
|
||||
// Remove admin-only debug fields.
|
||||
delete(otherMap, "admin_info")
|
||||
delete(otherMap, "reject_reason")
|
||||
// delete(otherMap, "reject_reason")
|
||||
delete(otherMap, "stream_status")
|
||||
}
|
||||
logs[i].Other = common.MapToJsonStr(otherMap)
|
||||
logs[i].Id = startIdx + i + 1
|
||||
|
||||
@@ -89,6 +89,22 @@ func InitOptionMap() {
|
||||
common.OptionMap["CreemProducts"] = setting.CreemProducts
|
||||
common.OptionMap["CreemTestMode"] = strconv.FormatBool(setting.CreemTestMode)
|
||||
common.OptionMap["CreemWebhookSecret"] = setting.CreemWebhookSecret
|
||||
common.OptionMap["WaffoEnabled"] = strconv.FormatBool(setting.WaffoEnabled)
|
||||
common.OptionMap["WaffoApiKey"] = setting.WaffoApiKey
|
||||
common.OptionMap["WaffoPrivateKey"] = setting.WaffoPrivateKey
|
||||
common.OptionMap["WaffoPublicCert"] = setting.WaffoPublicCert
|
||||
common.OptionMap["WaffoSandboxPublicCert"] = setting.WaffoSandboxPublicCert
|
||||
common.OptionMap["WaffoSandboxApiKey"] = setting.WaffoSandboxApiKey
|
||||
common.OptionMap["WaffoSandboxPrivateKey"] = setting.WaffoSandboxPrivateKey
|
||||
common.OptionMap["WaffoSandbox"] = strconv.FormatBool(setting.WaffoSandbox)
|
||||
common.OptionMap["WaffoMerchantId"] = setting.WaffoMerchantId
|
||||
common.OptionMap["WaffoNotifyUrl"] = setting.WaffoNotifyUrl
|
||||
common.OptionMap["WaffoReturnUrl"] = setting.WaffoReturnUrl
|
||||
common.OptionMap["WaffoSubscriptionReturnUrl"] = setting.WaffoSubscriptionReturnUrl
|
||||
common.OptionMap["WaffoCurrency"] = setting.WaffoCurrency
|
||||
common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp)
|
||||
common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString()
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
@@ -358,6 +374,36 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.CreemTestMode = value == "true"
|
||||
case "CreemWebhookSecret":
|
||||
setting.CreemWebhookSecret = value
|
||||
case "WaffoEnabled":
|
||||
setting.WaffoEnabled = value == "true"
|
||||
case "WaffoApiKey":
|
||||
setting.WaffoApiKey = value
|
||||
case "WaffoPrivateKey":
|
||||
setting.WaffoPrivateKey = value
|
||||
case "WaffoPublicCert":
|
||||
setting.WaffoPublicCert = value
|
||||
case "WaffoSandboxPublicCert":
|
||||
setting.WaffoSandboxPublicCert = value
|
||||
case "WaffoSandboxApiKey":
|
||||
setting.WaffoSandboxApiKey = value
|
||||
case "WaffoSandboxPrivateKey":
|
||||
setting.WaffoSandboxPrivateKey = value
|
||||
case "WaffoSandbox":
|
||||
setting.WaffoSandbox = value == "true"
|
||||
case "WaffoMerchantId":
|
||||
setting.WaffoMerchantId = value
|
||||
case "WaffoNotifyUrl":
|
||||
setting.WaffoNotifyUrl = value
|
||||
case "WaffoReturnUrl":
|
||||
setting.WaffoReturnUrl = value
|
||||
case "WaffoSubscriptionReturnUrl":
|
||||
setting.WaffoSubscriptionReturnUrl = value
|
||||
case "WaffoCurrency":
|
||||
setting.WaffoCurrency = value
|
||||
case "WaffoUnitPrice":
|
||||
setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "WaffoMinTopUp":
|
||||
setting.WaffoMinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
@@ -458,6 +504,10 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = operation_setting.UpdatePayMethodsByJsonString(value)
|
||||
case "WaffoPayMethods":
|
||||
// WaffoPayMethods is read directly from OptionMap via setting.GetWaffoPayMethods().
|
||||
// The value is already stored in OptionMap at the top of this function (line: common.OptionMap[key] = value).
|
||||
// No additional in-memory variable to update.
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
+68
-9
@@ -12,15 +12,15 @@ import (
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (topUp *TopUp) Insert() error {
|
||||
@@ -376,3 +376,62 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RechargeWaffo(tradeNo string) (err error) {
|
||||
if tradeNo == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
|
||||
var quotaToAdd int
|
||||
topUp := &TopUp{}
|
||||
|
||||
refCol := "`trade_no`"
|
||||
if common.UsingPostgreSQL {
|
||||
refCol = `"trade_no"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error
|
||||
if err != nil {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.Status == common.TopUpStatusSuccess {
|
||||
return nil // 幂等:已成功直接返回
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
return errors.New("充值订单状态错误")
|
||||
}
|
||||
|
||||
dAmount := decimal.NewFromInt(topUp.Amount)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd = int(dAmount.Mul(dQuotaPerUnit).IntPart())
|
||||
if quotaToAdd <= 0 {
|
||||
return errors.New("无效的充值额度")
|
||||
}
|
||||
|
||||
topUp.CompleteTime = common.GetTimestamp()
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
if err := tx.Save(topUp).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&User{}).Where("id = ?", topUp.UserId).Update("quota", gorm.Expr("quota + ?", quotaToAdd)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
common.SysError("waffo topup failed: " + err.Error())
|
||||
return errors.New("充值失败,请稍后重试")
|
||||
}
|
||||
|
||||
if quotaToAdd > 0 {
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("Waffo充值成功,充值额度: %v,支付金额: %.2f", logger.FormatQuota(quotaToAdd), topUp.Money))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+9
-4
@@ -208,10 +208,7 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke
|
||||
}
|
||||
|
||||
// Set authorization header
|
||||
tokenType := token.TokenType
|
||||
if tokenType == "" {
|
||||
tokenType = "Bearer"
|
||||
}
|
||||
tokenType := normalizeAuthorizationTokenType(token.TokenType)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
@@ -320,6 +317,14 @@ func (p *GenericOAuthProvider) GetProviderId() int {
|
||||
return p.config.Id
|
||||
}
|
||||
|
||||
func normalizeAuthorizationTokenType(tokenType string) string {
|
||||
tokenType = strings.TrimSpace(tokenType)
|
||||
if tokenType == "" || strings.EqualFold(tokenType, "Bearer") {
|
||||
return "Bearer"
|
||||
}
|
||||
return tokenType
|
||||
}
|
||||
|
||||
// IsGenericProvider returns true for generic providers
|
||||
func (p *GenericOAuthProvider) IsGenericProvider() bool {
|
||||
return true
|
||||
|
||||
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -171,12 +171,17 @@ type AliImageRequest struct {
|
||||
}
|
||||
|
||||
type AliImageParameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||
ThinkingMode *bool `json:"thinking_mode,omitempty"`
|
||||
EnableSequential *bool `json:"enable_sequential,omitempty"`
|
||||
BboxList any `json:"bbox_list,omitempty"`
|
||||
ColorPalette any `json:"color_palette,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AliImageParameters) PromptExtendValue() bool {
|
||||
|
||||
@@ -54,7 +54,6 @@ func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequ
|
||||
}
|
||||
}
|
||||
|
||||
// 检查n参数
|
||||
if imageRequest.Parameters.N != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
}
|
||||
@@ -181,6 +180,7 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque
|
||||
},
|
||||
}
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
N: int(lo.FromPtrOr(request.N, uint(1))),
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
return &imageRequest, nil
|
||||
@@ -328,7 +328,6 @@ func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
// 可能生成多张图片,修正计费数量n
|
||||
if aliResponse.Usage.ImageCount != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
|
||||
} else if len(imageResponses.Data) != 0 {
|
||||
|
||||
@@ -40,7 +40,8 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
|
||||
}
|
||||
|
||||
func isOldWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
|
||||
return strings.Contains(modelName, "wan") &&
|
||||
!lo.SomeBy([]string{"wan2.6", "wan2.7"}, func(v string) bool { return strings.Contains(modelName, v) })
|
||||
}
|
||||
|
||||
func isWanModel(modelName string) bool {
|
||||
|
||||
@@ -116,12 +116,12 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
if err := common.Unmarshal([]byte(data), &baiduResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
@@ -129,11 +129,10 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, response); err != nil {
|
||||
common.SysLog("error sending stream response: " + err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -41,11 +42,32 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
if info.IsClaudeBetaQuery {
|
||||
baseURL = baseURL + "?beta=true"
|
||||
requestURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
if !shouldAppendClaudeBetaQuery(info) {
|
||||
return requestURL, nil
|
||||
}
|
||||
return baseURL, nil
|
||||
|
||||
parsedURL, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := parsedURL.Query()
|
||||
query.Set("beta", "true")
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
func shouldAppendClaudeBetaQuery(info *relaycommon.RelayInfo) bool {
|
||||
if info == nil {
|
||||
return false
|
||||
}
|
||||
if info.IsClaudeBetaQuery {
|
||||
return true
|
||||
}
|
||||
if info.ChannelOtherSettings.ClaudeBetaQuery {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func CommonClaudeHeadersOperation(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -44,6 +46,61 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
|
||||
}
|
||||
}
|
||||
|
||||
func createClaudeFileSource(file *dto.MessageFile) *types.FileSource {
|
||||
if file == nil || file.FileData == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(file.FileData, "http://") || strings.HasPrefix(file.FileData, "https://") {
|
||||
return types.NewURLFileSource(file.FileData)
|
||||
}
|
||||
mimeType := ""
|
||||
if ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(file.FileName)), "."); ext != "" {
|
||||
if detected := service.GetMimeTypeByExtension(ext); detected != "application/octet-stream" {
|
||||
mimeType = detected
|
||||
}
|
||||
}
|
||||
return types.NewBase64FileSource(file.FileData, mimeType)
|
||||
}
|
||||
|
||||
func buildClaudeFileMessage(c *gin.Context, file *dto.MessageFile) (*dto.ClaudeMediaMessage, error) {
|
||||
source := createClaudeFileSource(file)
|
||||
if source == nil {
|
||||
return nil, nil
|
||||
}
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting document for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file data failed: %w", err)
|
||||
}
|
||||
switch strings.ToLower(mimeType) {
|
||||
case "application/pdf":
|
||||
return &dto.ClaudeMediaMessage{
|
||||
Type: "document",
|
||||
Source: &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
MediaType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
}, nil
|
||||
case "text/plain":
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode text file data failed: %w", err)
|
||||
}
|
||||
return &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer(string(decodedData)),
|
||||
}, nil
|
||||
default:
|
||||
msg := fmt.Sprintf("claude: skip unsupported file content, filename=%q, mime=%q", file.FileName, mimeType)
|
||||
if c != nil {
|
||||
logger.LogInfo(c, msg)
|
||||
} else {
|
||||
common.SysLog(msg)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]any, 0, len(textRequest.Tools))
|
||||
|
||||
@@ -343,16 +400,22 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
} else {
|
||||
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
|
||||
for _, mediaMessage := range message.ParseContent() {
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Type: mediaMessage.Type,
|
||||
}
|
||||
if mediaMessage.Type == "text" {
|
||||
claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
|
||||
} else {
|
||||
switch mediaMessage.Type {
|
||||
case "text":
|
||||
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](mediaMessage.Text),
|
||||
})
|
||||
case dto.ContentTypeImageURL:
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Type: "image",
|
||||
Source: &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
},
|
||||
}
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
claudeMediaMessage.Type = "image"
|
||||
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
if imageUrl == nil {
|
||||
continue
|
||||
}
|
||||
// 使用统一的文件服务获取图片数据
|
||||
var source *types.FileSource
|
||||
@@ -367,8 +430,19 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
}
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = base64Data
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
// FIXME
|
||||
//case dto.ContentTypeFile:
|
||||
// claudeFileMessage, err := buildClaudeFileMessage(c, mediaMessage.GetFile())
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if claudeFileMessage != nil {
|
||||
// claudeMediaMessages = append(claudeMediaMessages, *claudeFileMessage)
|
||||
// }
|
||||
default:
|
||||
continue
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
}
|
||||
if message.ToolCalls != nil {
|
||||
for _, toolCall := range message.ParseToolCalls() {
|
||||
@@ -555,6 +629,35 @@ type ClaudeResponseInfo struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
func cacheCreationTokensForOpenAIUsage(usage *dto.Usage) int {
|
||||
if usage == nil {
|
||||
return 0
|
||||
}
|
||||
splitCacheCreationTokens := usage.ClaudeCacheCreation5mTokens + usage.ClaudeCacheCreation1hTokens
|
||||
if splitCacheCreationTokens == 0 {
|
||||
return usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
if usage.PromptTokensDetails.CachedCreationTokens > splitCacheCreationTokens {
|
||||
return usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
return splitCacheCreationTokens
|
||||
}
|
||||
|
||||
func buildOpenAIStyleUsageFromClaudeUsage(usage *dto.Usage) dto.Usage {
|
||||
if usage == nil {
|
||||
return dto.Usage{}
|
||||
}
|
||||
clone := *usage
|
||||
cacheCreationTokens := cacheCreationTokensForOpenAIUsage(usage)
|
||||
totalInputTokens := usage.PromptTokens + usage.PromptTokensDetails.CachedTokens + cacheCreationTokens
|
||||
clone.PromptTokens = totalInputTokens
|
||||
clone.InputTokens = totalInputTokens
|
||||
clone.TotalTokens = totalInputTokens + usage.CompletionTokens
|
||||
clone.UsageSemantic = "openai"
|
||||
clone.UsageSource = "anthropic"
|
||||
return clone
|
||||
}
|
||||
|
||||
func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
|
||||
usage := &dto.ClaudeUsage{}
|
||||
if claudeResponse != nil && claudeResponse.Usage != nil {
|
||||
@@ -643,6 +746,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
|
||||
// message_start, 获取usage
|
||||
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
|
||||
@@ -661,6 +765,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage != nil {
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
@@ -754,12 +859,16 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
if claudeInfo.Usage != nil {
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
}
|
||||
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
//
|
||||
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, openAIUsage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
@@ -778,12 +887,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
var err *types.NewAPIError
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data)
|
||||
if err != nil {
|
||||
return false
|
||||
sr.Stop(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -810,6 +918,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
|
||||
@@ -819,7 +928,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
openaiResponse.Usage = buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
|
||||
@@ -173,3 +175,191 @@ func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) {
|
||||
t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIStyleUsageFromClaudeUsage(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 20,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
UsageSemantic: "anthropic",
|
||||
}
|
||||
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
|
||||
|
||||
if openAIUsage.PromptTokens != 180 {
|
||||
t.Fatalf("PromptTokens = %d, want 180", openAIUsage.PromptTokens)
|
||||
}
|
||||
if openAIUsage.InputTokens != 180 {
|
||||
t.Fatalf("InputTokens = %d, want 180", openAIUsage.InputTokens)
|
||||
}
|
||||
if openAIUsage.TotalTokens != 200 {
|
||||
t.Fatalf("TotalTokens = %d, want 200", openAIUsage.TotalTokens)
|
||||
}
|
||||
if openAIUsage.UsageSemantic != "openai" {
|
||||
t.Fatalf("UsageSemantic = %s, want openai", openAIUsage.UsageSemantic)
|
||||
}
|
||||
if openAIUsage.UsageSource != "anthropic" {
|
||||
t.Fatalf("UsageSource = %s, want anthropic", openAIUsage.UsageSource)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIStyleUsageFromClaudeUsagePreservesCacheCreationRemainder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cachedCreationTokens int
|
||||
cacheCreationTokens5m int
|
||||
cacheCreationTokens1h int
|
||||
expectedTotalInputToken int
|
||||
}{
|
||||
{
|
||||
name: "prefers aggregate when it includes remainder",
|
||||
cachedCreationTokens: 50,
|
||||
cacheCreationTokens5m: 10,
|
||||
cacheCreationTokens1h: 20,
|
||||
expectedTotalInputToken: 180,
|
||||
},
|
||||
{
|
||||
name: "falls back to split tokens when aggregate missing",
|
||||
cachedCreationTokens: 0,
|
||||
cacheCreationTokens5m: 10,
|
||||
cacheCreationTokens1h: 20,
|
||||
expectedTotalInputToken: 160,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 20,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: tt.cachedCreationTokens,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: tt.cacheCreationTokens5m,
|
||||
ClaudeCacheCreation1hTokens: tt.cacheCreationTokens1h,
|
||||
UsageSemantic: "anthropic",
|
||||
}
|
||||
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
|
||||
|
||||
if openAIUsage.PromptTokens != tt.expectedTotalInputToken {
|
||||
t.Fatalf("PromptTokens = %d, want %d", openAIUsage.PromptTokens, tt.expectedTotalInputToken)
|
||||
}
|
||||
if openAIUsage.InputTokens != tt.expectedTotalInputToken {
|
||||
t.Fatalf("InputTokens = %d, want %d", openAIUsage.InputTokens, tt.expectedTotalInputToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_IgnoresUnsupportedFileContent(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: "see attachment",
|
||||
},
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeFile,
|
||||
File: &dto.MessageFile{
|
||||
FileName: "blob.bin",
|
||||
FileData: "JVBERi0xLjQK",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, claudeRequest.Messages, 1)
|
||||
|
||||
content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
require.Equal(t, "text", content[0].Type)
|
||||
require.NotNil(t, content[0].Text)
|
||||
require.Equal(t, "see attachment", *content[0].Text)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_SupportsPDFFileContent(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeFile,
|
||||
File: &dto.MessageFile{
|
||||
FileName: "spec.pdf",
|
||||
FileData: "JVBERi0xLjQK",
|
||||
},
|
||||
},
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: "summarize it",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, claudeRequest.Messages, 1)
|
||||
|
||||
content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 2)
|
||||
require.Equal(t, "document", content[0].Type)
|
||||
require.NotNil(t, content[0].Source)
|
||||
require.Equal(t, "base64", content[0].Source.Type)
|
||||
require.Equal(t, "application/pdf", content[0].Source.MediaType)
|
||||
require.Equal(t, "JVBERi0xLjQK", content[0].Source.Data)
|
||||
require.Equal(t, "text", content[1].Type)
|
||||
require.NotNil(t, content[1].Text)
|
||||
require.Equal(t, "summarize it", *content[1].Text)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_ConvertsTextFileContentToText(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeFile,
|
||||
File: &dto.MessageFile{
|
||||
FileName: "notes.txt",
|
||||
FileData: base64.StdEncoding.EncodeToString([]byte("alpha\nbeta")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, claudeRequest.Messages, 1)
|
||||
|
||||
content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
require.Equal(t, "text", content[0].Type)
|
||||
require.NotNil(t, content[0].Text)
|
||||
require.Equal(t, "alpha\nbeta", *content[0].Text)
|
||||
}
|
||||
|
||||
@@ -223,33 +223,32 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
usage := &dto.Usage{}
|
||||
var nodeToken int
|
||||
helper.SetEventStreamHeaders(c)
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal([]byte(data), &difyResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||
if difyResponse.Event == "message_end" {
|
||||
usage = &difyResponse.MetaData.Usage
|
||||
return false
|
||||
sr.Done()
|
||||
return
|
||||
} else if difyResponse.Event == "error" {
|
||||
return false
|
||||
} else {
|
||||
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
|
||||
nodeToken += 1
|
||||
}
|
||||
sr.Stop(fmt.Errorf("dify error event"))
|
||||
return
|
||||
}
|
||||
openaiResponse := *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
|
||||
nodeToken += 1
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, openaiResponse); err != nil {
|
||||
common.SysLog(err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
|
||||
@@ -37,6 +37,8 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true, // support old image/jpeg
|
||||
"image/webp": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
"video/mpeg": true,
|
||||
@@ -1297,12 +1299,11 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
var imageCount int
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
if err := common.UnmarshalJsonStr(data, &geminiResponse); err != nil {
|
||||
sr.Stop(fmt.Errorf("unmarshal: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
@@ -1327,7 +1328,9 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
*usage = mappedUsage
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
if !callback(data, &geminiResponse) {
|
||||
sr.Stop(fmt.Errorf("gemini callback stopped"))
|
||||
}
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
|
||||
@@ -35,21 +35,21 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
if info.IsStream {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if service.SundaySearch(data, "usage") {
|
||||
var simpleResponse dto.SimpleResponse
|
||||
err := common.Unmarshal([]byte(data), &simpleResponse)
|
||||
if err != nil {
|
||||
if err := common.Unmarshal([]byte(data), &simpleResponse); err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
if simpleResponse.Usage.TotalTokens != 0 {
|
||||
sr.Error(err)
|
||||
} else if simpleResponse.Usage.TotalTokens != 0 {
|
||||
usage.PromptTokens = simpleResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = simpleResponse.OutputTokens
|
||||
usage.TotalTokens = simpleResponse.TotalTokens
|
||||
}
|
||||
}
|
||||
_ = helper.StringData(c, data)
|
||||
return true
|
||||
if err := helper.StringData(c, data); err != nil {
|
||||
sr.Error(err)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
|
||||
@@ -296,15 +296,17 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
return true
|
||||
}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if streamErr != nil {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
var streamResp dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
switch streamResp.Type {
|
||||
@@ -320,14 +322,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
|
||||
//case "response.reasoning_text.delta":
|
||||
//if !sendReasoningDelta(streamResp.Delta) {
|
||||
// return false
|
||||
// sr.Stop(streamErr)
|
||||
// return
|
||||
//}
|
||||
|
||||
//case "response.reasoning_text.done":
|
||||
|
||||
case "response.reasoning_summary_text.delta":
|
||||
if !sendReasoningSummaryDelta(streamResp.Delta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.reasoning_summary_text.done":
|
||||
@@ -349,12 +353,14 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
// delta := stringDeltaFromPrefix(prev, next)
|
||||
// reasoningSummaryTextByKey[key] = next
|
||||
// if !sendReasoningSummaryDelta(delta) {
|
||||
// return false
|
||||
// sr.Stop(streamErr)
|
||||
// return
|
||||
// }
|
||||
|
||||
case "response.output_text.delta":
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
if streamResp.Delta != "" {
|
||||
@@ -376,7 +382,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
},
|
||||
}
|
||||
if !sendChatChunk(chunk) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,7 +421,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
if !sendToolCallDelta(callID, name, argsDelta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.delta":
|
||||
@@ -428,7 +436,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
toolCallArgsByID[callID] += streamResp.Delta
|
||||
if !sendToolCallDelta(callID, "", streamResp.Delta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.done":
|
||||
@@ -467,7 +476,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
if !sentStop {
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
|
||||
@@ -479,7 +489,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
||||
if !sendChatChunk(stop) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
sentStop = true
|
||||
}
|
||||
@@ -488,16 +499,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
if streamResp.Response != nil {
|
||||
if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
|
||||
streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if streamErr != nil {
|
||||
|
||||
@@ -126,11 +126,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
// 检查是否为音频模型
|
||||
isAudioModel := strings.Contains(strings.ToLower(model), "audio")
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if lastStreamData != "" {
|
||||
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
if err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent); err != nil {
|
||||
common.SysLog("error handling stream format: " + err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
}
|
||||
if len(data) > 0 {
|
||||
@@ -142,7 +142,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// 对音频模型,从倒数第二个stream data中提取usage信息
|
||||
@@ -627,6 +626,12 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -689,3 +694,25 @@ func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
||||
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Timings struct {
|
||||
CachedTokens *int `json:"cache_n"`
|
||||
} `json:"timings"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Timings.CachedTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *payload.Timings.CachedTokens, true
|
||||
}
|
||||
|
||||
@@ -79,55 +79,55 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
var usage = &dto.Usage{}
|
||||
var responseTextBuilder strings.Builder
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
|
||||
// 检查当前数据是否包含 completed 状态和 usage 信息
|
||||
var streamResponse dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
case dto.ResponsesOutputTypeItemDone:
|
||||
// 函数调用处理
|
||||
if streamResponse.Item != nil {
|
||||
switch streamResponse.Item.Type {
|
||||
case dto.BuildInCallWebSearchCall:
|
||||
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
|
||||
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
|
||||
webSearchTool.CallCount++
|
||||
}
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
case dto.ResponsesOutputTypeItemDone:
|
||||
// 函数调用处理
|
||||
if streamResponse.Item != nil {
|
||||
switch streamResponse.Item.Type {
|
||||
case dto.BuildInCallWebSearchCall:
|
||||
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
|
||||
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
|
||||
webSearchTool.CallCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -13,12 +14,13 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// ============================
|
||||
@@ -26,37 +28,37 @@ import (
|
||||
// ============================
|
||||
|
||||
type ContentItem struct {
|
||||
Type string `json:"type"` // "text", "image_url" or "video"
|
||||
Text string `json:"text,omitempty"` // for text type
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
|
||||
Video *VideoReference `json:"video,omitempty"` // for video (sample) type
|
||||
Role string `json:"role,omitempty"` // reference_image / first_frame / last_frame
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *MediaURL `json:"image_url,omitempty"`
|
||||
VideoURL *MediaURL `json:"video_url,omitempty"`
|
||||
AudioURL *MediaURL `json:"audio_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type VideoReference struct {
|
||||
URL string `json:"url"` // Draft video URL
|
||||
type MediaURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Content []ContentItem `json:"content"`
|
||||
Content []ContentItem `json:"content,omitempty"`
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
ReturnLastFrame *dto.BoolValue `json:"return_last_frame,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
ExecutionExpiresAfter dto.IntValue `json:"execution_expires_after,omitempty"`
|
||||
ExecutionExpiresAfter *dto.IntValue `json:"execution_expires_after,omitempty"`
|
||||
GenerateAudio *dto.BoolValue `json:"generate_audio,omitempty"`
|
||||
Draft *dto.BoolValue `json:"draft,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
Ratio string `json:"ratio,omitempty"`
|
||||
Duration dto.IntValue `json:"duration,omitempty"`
|
||||
Frames dto.IntValue `json:"frames,omitempty"`
|
||||
Seed dto.IntValue `json:"seed,omitempty"`
|
||||
CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
|
||||
Watermark *dto.BoolValue `json:"watermark,omitempty"`
|
||||
Tools []struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
} `json:"tools,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
Ratio string `json:"ratio,omitempty"`
|
||||
Duration *dto.IntValue `json:"duration,omitempty"`
|
||||
Frames *dto.IntValue `json:"frames,omitempty"`
|
||||
Seed *dto.IntValue `json:"seed,omitempty"`
|
||||
CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
|
||||
Watermark *dto.BoolValue `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
@@ -76,10 +78,20 @@ type responseTask struct {
|
||||
Ratio string `json:"ratio"`
|
||||
FramesPerSecond int `json:"framespersecond"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Usage struct {
|
||||
Tools []struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"tools"`
|
||||
Usage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ToolUsage struct {
|
||||
WebSearch int `json:"web_search"`
|
||||
} `json:"tool_usage"`
|
||||
} `json:"usage"`
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
@@ -108,12 +120,12 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(_ *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v3/contents/generations/tasks", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(_ *gin.Context, req *http.Request, _ *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.apiKey)
|
||||
@@ -218,20 +230,12 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
Content: []ContentItem{},
|
||||
}
|
||||
|
||||
// Add text prompt
|
||||
if req.Prompt != "" {
|
||||
r.Content = append(r.Content, ContentItem{
|
||||
Type: "text",
|
||||
Text: req.Prompt,
|
||||
})
|
||||
}
|
||||
|
||||
// Add images if present
|
||||
if req.HasImage() {
|
||||
for _, imgURL := range req.Images {
|
||||
r.Content = append(r.Content, ContentItem{
|
||||
Type: "image_url",
|
||||
ImageURL: &ImageURL{
|
||||
ImageURL: &MediaURL{
|
||||
URL: imgURL,
|
||||
},
|
||||
})
|
||||
@@ -243,6 +247,16 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
if sec, _ := strconv.Atoi(req.Seconds); sec > 0 {
|
||||
r.Duration = lo.ToPtr(dto.IntValue(sec))
|
||||
}
|
||||
|
||||
r.Content = lo.Reject(r.Content, func(c ContentItem, _ int) bool { return c.Type == "text" })
|
||||
r.Content = append(r.Content, ContentItem{
|
||||
Type: "text",
|
||||
Text: req.Prompt,
|
||||
})
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
@@ -274,7 +288,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
case "failed":
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
taskResult.Progress = "100%"
|
||||
taskResult.Reason = "task failed"
|
||||
taskResult.Reason = resTask.Error.Message
|
||||
default:
|
||||
// Unknown status, treat as processing
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
@@ -302,8 +316,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
|
||||
if dResp.Status == "failed" {
|
||||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||||
Message: "task failed",
|
||||
Code: "failed",
|
||||
Message: dResp.Error.Message,
|
||||
Code: dResp.Error.Code,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ var ModelList = []string{
|
||||
"doubao-seedance-1-0-lite-t2v",
|
||||
"doubao-seedance-1-0-lite-i2v",
|
||||
"doubao-seedance-1-5-pro-251215",
|
||||
"doubao-seedance-2-0-260128",
|
||||
"doubao-seedance-2-0-fast-260128",
|
||||
}
|
||||
|
||||
var ChannelName = "doubao-video"
|
||||
|
||||
@@ -17,6 +17,8 @@ func UnmarshalMetadata(metadata map[string]any, target any) error {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
// Prevent metadata from overriding model fields to avoid billing bypass.
|
||||
delete(metadata, "model")
|
||||
metaBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
|
||||
@@ -76,7 +76,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if strings.HasPrefix(request.Model, "grok-3-mini") {
|
||||
if lo.FromPtrOr(request.MaxCompletionTokens, uint(0)) == 0 && lo.FromPtrOr(request.MaxTokens, uint(0)) != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = lo.ToPtr(uint(0))
|
||||
request.MaxTokens = nil
|
||||
}
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.ReasoningEffort = "high"
|
||||
|
||||
@@ -43,12 +43,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||
err := common.UnmarshalJsonStr(data, &xAIResp)
|
||||
if err != nil {
|
||||
if err := common.UnmarshalJsonStr(data, &xAIResp); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 把 xAI 的usage转换为 OpenAI 的usage
|
||||
@@ -61,11 +61,10 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
|
||||
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
||||
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, openaiResponse); err != nil {
|
||||
common.SysLog(err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
|
||||
@@ -122,7 +122,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage)
|
||||
service.PostTextConsumeQuota(c, info, usage, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,6 +190,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -162,6 +162,8 @@ type RelayInfo struct {
|
||||
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
|
||||
FinalRequestRelayFormat types.RelayFormat
|
||||
|
||||
StreamStatus *StreamStatus
|
||||
|
||||
ThinkingContentInfo
|
||||
TokenCountMeta
|
||||
*ClaudeConvertInfo
|
||||
@@ -338,15 +340,10 @@ func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeNone,
|
||||
}
|
||||
info.IsClaudeBetaQuery = c.Query("beta") == "true" || isClaudeBetaForced(c)
|
||||
info.IsClaudeBetaQuery = c.Query("beta") == "true"
|
||||
return info
|
||||
}
|
||||
|
||||
func isClaudeBetaForced(c *gin.Context) bool {
|
||||
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
||||
return ok && channelOtherSettings.ClaudeBetaQuery
|
||||
}
|
||||
|
||||
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StreamEndReason string
|
||||
|
||||
const (
|
||||
StreamEndReasonNone StreamEndReason = ""
|
||||
StreamEndReasonDone StreamEndReason = "done"
|
||||
StreamEndReasonTimeout StreamEndReason = "timeout"
|
||||
StreamEndReasonClientGone StreamEndReason = "client_gone"
|
||||
StreamEndReasonScannerErr StreamEndReason = "scanner_error"
|
||||
StreamEndReasonHandlerStop StreamEndReason = "handler_stop"
|
||||
StreamEndReasonEOF StreamEndReason = "eof"
|
||||
StreamEndReasonPanic StreamEndReason = "panic"
|
||||
StreamEndReasonPingFail StreamEndReason = "ping_fail"
|
||||
)
|
||||
|
||||
const maxStreamErrorEntries = 20
|
||||
|
||||
type StreamErrorEntry struct {
|
||||
Message string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
type StreamStatus struct {
|
||||
EndReason StreamEndReason
|
||||
EndError error
|
||||
endOnce sync.Once
|
||||
|
||||
mu sync.Mutex
|
||||
Errors []StreamErrorEntry
|
||||
ErrorCount int
|
||||
}
|
||||
|
||||
func NewStreamStatus() *StreamStatus {
|
||||
return &StreamStatus{}
|
||||
}
|
||||
|
||||
func (s *StreamStatus) SetEndReason(reason StreamEndReason, err error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.endOnce.Do(func() {
|
||||
s.EndReason = reason
|
||||
s.EndError = err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StreamStatus) RecordError(msg string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ErrorCount++
|
||||
if len(s.Errors) < maxStreamErrorEntries {
|
||||
s.Errors = append(s.Errors, StreamErrorEntry{
|
||||
Message: msg,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StreamStatus) HasErrors() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ErrorCount > 0
|
||||
}
|
||||
|
||||
func (s *StreamStatus) TotalErrorCount() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ErrorCount
|
||||
}
|
||||
|
||||
func (s *StreamStatus) IsNormalEnd() bool {
|
||||
if s == nil {
|
||||
return true
|
||||
}
|
||||
return s.EndReason == StreamEndReasonDone ||
|
||||
s.EndReason == StreamEndReasonEOF ||
|
||||
s.EndReason == StreamEndReasonHandlerStop
|
||||
}
|
||||
|
||||
func (s *StreamStatus) Summary() string {
|
||||
if s == nil {
|
||||
return "StreamStatus<nil>"
|
||||
}
|
||||
b := &strings.Builder{}
|
||||
fmt.Fprintf(b, "reason=%s", s.EndReason)
|
||||
if s.EndError != nil {
|
||||
fmt.Fprintf(b, " end_error=%q", s.EndError.Error())
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.ErrorCount > 0 {
|
||||
fmt.Fprintf(b, " soft_errors=%d", s.ErrorCount)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStreamStatus_SetEndReason_FirstWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
s.SetEndReason(StreamEndReasonTimeout, nil)
|
||||
s.SetEndReason(StreamEndReasonClientGone, fmt.Errorf("context canceled"))
|
||||
|
||||
assert.Equal(t, StreamEndReasonDone, s.EndReason)
|
||||
assert.Nil(t, s.EndError)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_WithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
expectedErr := fmt.Errorf("read: connection reset")
|
||||
s.SetEndReason(StreamEndReasonScannerErr, expectedErr)
|
||||
|
||||
assert.Equal(t, StreamEndReasonScannerErr, s.EndReason)
|
||||
assert.Equal(t, expectedErr, s.EndError)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
reasons := []StreamEndReason{
|
||||
StreamEndReasonDone,
|
||||
StreamEndReasonTimeout,
|
||||
StreamEndReasonClientGone,
|
||||
StreamEndReasonScannerErr,
|
||||
StreamEndReasonHandlerStop,
|
||||
StreamEndReasonEOF,
|
||||
StreamEndReasonPanic,
|
||||
StreamEndReasonPingFail,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, r := range reasons {
|
||||
wg.Add(1)
|
||||
go func(reason StreamEndReason) {
|
||||
defer wg.Done()
|
||||
s.SetEndReason(reason, nil)
|
||||
}(r)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.NotEqual(t, StreamEndReasonNone, s.EndReason)
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
s.RecordError("bad json")
|
||||
s.RecordError("another bad json")
|
||||
s.RecordError("client gone")
|
||||
|
||||
assert.True(t, s.HasErrors())
|
||||
assert.Equal(t, 3, s.TotalErrorCount())
|
||||
assert.Len(t, s.Errors, 3)
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_CapAtMax(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
s.RecordError(fmt.Sprintf("error_%d", i))
|
||||
}
|
||||
|
||||
assert.Equal(t, maxStreamErrorEntries, len(s.Errors))
|
||||
assert.Equal(t, 30, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
s.RecordError("should not panic")
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
s.RecordError(fmt.Sprintf("error_%d", idx))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 100, s.TotalErrorCount())
|
||||
assert.LessOrEqual(t, len(s.Errors), maxStreamErrorEntries)
|
||||
}
|
||||
|
||||
func TestStreamStatus_HasErrors_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
assert.False(t, s.HasErrors())
|
||||
assert.Equal(t, 0, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_HasErrors_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.False(t, s.HasErrors())
|
||||
assert.Equal(t, 0, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_IsNormalEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
reason StreamEndReason
|
||||
normal bool
|
||||
}{
|
||||
{StreamEndReasonDone, true},
|
||||
{StreamEndReasonEOF, true},
|
||||
{StreamEndReasonHandlerStop, true},
|
||||
{StreamEndReasonTimeout, false},
|
||||
{StreamEndReasonClientGone, false},
|
||||
{StreamEndReasonScannerErr, false},
|
||||
{StreamEndReasonPanic, false},
|
||||
{StreamEndReasonPingFail, false},
|
||||
{StreamEndReasonNone, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
s := NewStreamStatus()
|
||||
s.SetEndReason(tt.reason, nil)
|
||||
assert.Equal(t, tt.normal, s.IsNormalEnd(), "reason=%s", tt.reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamStatus_IsNormalEnd_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.True(t, s.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamStatus_Summary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := NewStreamStatus()
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
summary := s.Summary()
|
||||
assert.Contains(t, summary, "reason=done")
|
||||
assert.NotContains(t, summary, "soft_errors")
|
||||
|
||||
s2 := NewStreamStatus()
|
||||
s2.SetEndReason(StreamEndReasonTimeout, nil)
|
||||
s2.RecordError("bad json")
|
||||
s2.RecordError("write failed")
|
||||
summary2 := s2.Summary()
|
||||
assert.Contains(t, summary2, "reason=timeout")
|
||||
assert.Contains(t, summary2, "soft_errors=2")
|
||||
}
|
||||
|
||||
func TestStreamStatus_Summary_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.Equal(t, "StreamStatus<nil>", s.Summary())
|
||||
}
|
||||
+2
-293
@@ -6,25 +6,20 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -93,7 +88,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage, "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage)
|
||||
service.PostTextConsumeQuota(c, info, usage, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -216,293 +211,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
|
||||
originUsage := usage
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
|
||||
if originUsage != nil {
|
||||
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
imageRatio := relayInfo.PriceData.ImageRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
|
||||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
|
||||
// openai web search 工具计费
|
||||
var dWebSearchQuota decimal.Decimal
|
||||
var webSearchPrice float64
|
||||
// response api 格式工具计费
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||||
// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
// search-preview 模型不支持 response api
|
||||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
||||
if searchContextSize == "" {
|
||||
searchContextSize = "medium"
|
||||
}
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||
searchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
// claude web search tool 计费
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
var claudeWebSearchPrice float64
|
||||
claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
|
||||
if claudeWebSearchCallCount > 0 {
|
||||
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
|
||||
}
|
||||
// file search tool 计费
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
var fileSearchPrice float64
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||||
fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
|
||||
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String()))
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
var imageGenerationCallPrice float64
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
var audioInputPrice float64
|
||||
isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
// Anthropic API 的 input_tokens 已经不包含缓存 tokens,不需要减去
|
||||
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens,需要减去
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
if !isClaudeUsageSemantic {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
if !isClaudeUsageSemantic {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
}
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
// 减去 image tokens
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
if !dImageTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dImageTokens)
|
||||
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
|
||||
}
|
||||
|
||||
// 减去 Gemini audio tokens
|
||||
if !dAudioTokens.IsZero() {
|
||||
audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
|
||||
if audioInputPrice > 0 {
|
||||
// 重新计算 base tokens
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
Add(imageTokensWithRatio).
|
||||
Add(dCachedCreationTokensWithRatio)
|
||||
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
|
||||
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
|
||||
|
||||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||||
}
|
||||
} else {
|
||||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
}
|
||||
// 添加 responses tools call 调用的配额
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
dOtherRatio := decimal.NewFromFloat(otherRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
|
||||
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
//var logContent string
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
if !ratio.IsZero() && quota == 0 {
|
||||
quota = 1
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
if err := service.SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if adminRejectReason != "" {
|
||||
other["reject_reason"] = adminRejectReason
|
||||
}
|
||||
// For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
|
||||
if isClaudeUsageSemantic {
|
||||
other["claude"] = true
|
||||
other["usage_semantic"] = "anthropic"
|
||||
}
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
other["image_output"] = imageTokens
|
||||
}
|
||||
if cachedCreationTokens != 0 {
|
||||
other["cache_creation_tokens"] = cachedCreationTokens
|
||||
other["cache_creation_ratio"] = cachedCreationRatio
|
||||
}
|
||||
if !dWebSearchQuota.IsZero() {
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = webSearchTool.CallCount
|
||||
other["web_search_price"] = webSearchPrice
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = 1
|
||||
other["web_search_price"] = webSearchPrice
|
||||
}
|
||||
} else if !dClaudeWebSearchQuota.IsZero() {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = claudeWebSearchCallCount
|
||||
other["web_search_price"] = claudeWebSearchPrice
|
||||
}
|
||||
if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
|
||||
other["file_search"] = true
|
||||
other["file_search_call_count"] = fileSearchTool.CallCount
|
||||
other["file_search_price"] = fileSearchPrice
|
||||
}
|
||||
}
|
||||
if !audioInputQuota.IsZero() {
|
||||
other["audio_input_seperate_price"] = true
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dImageGenerationCallQuota.IsZero() {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = imageGenerationCallPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -288,6 +288,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
)
|
||||
|
||||
// StreamResult is passed to each dataHandler invocation, providing methods
|
||||
// to record soft errors, signal fatal stops, or mark normal completion.
|
||||
// StreamScannerHandler checks IsStopped() after each callback invocation.
|
||||
type StreamResult struct {
|
||||
status *relaycommon.StreamStatus
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func newStreamResult(status *relaycommon.StreamStatus) *StreamResult {
|
||||
return &StreamResult{status: status}
|
||||
}
|
||||
|
||||
// Error records a soft error. The stream continues processing.
|
||||
// Can be called multiple times per chunk.
|
||||
func (r *StreamResult) Error(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
r.status.RecordError(err.Error())
|
||||
}
|
||||
|
||||
// Stop records a fatal error and marks the stream to stop after this chunk.
|
||||
func (r *StreamResult) Stop(err error) {
|
||||
if err != nil {
|
||||
r.status.RecordError(err.Error())
|
||||
}
|
||||
r.status.SetEndReason(relaycommon.StreamEndReasonHandlerStop, err)
|
||||
r.stopped = true
|
||||
}
|
||||
|
||||
// Done signals that the handler has finished processing normally
|
||||
// (e.g., Dify "message_end"). The stream stops after this chunk.
|
||||
func (r *StreamResult) Done() {
|
||||
r.status.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
r.stopped = true
|
||||
}
|
||||
|
||||
// IsStopped returns whether Stop() or Done() was called during this chunk.
|
||||
func (r *StreamResult) IsStopped() bool {
|
||||
return r.stopped
|
||||
}
|
||||
|
||||
// reset clears the per-chunk stopped flag so the object can be reused.
|
||||
func (r *StreamResult) reset() {
|
||||
r.stopped = false
|
||||
}
|
||||
@@ -34,12 +34,15 @@ func getScannerBufferSize() int {
|
||||
return DefaultMaxScannerBufferSize
|
||||
}
|
||||
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
|
||||
|
||||
if resp == nil || dataHandler == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 无条件新建 StreamStatus
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
|
||||
// 确保响应体总是被关闭
|
||||
defer func() {
|
||||
if resp.Body != nil {
|
||||
@@ -121,6 +124,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("ping panic: %v", r))
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
@@ -148,6 +152,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.LogError(c, "ping data error: "+err.Error())
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, err)
|
||||
return
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
@@ -155,6 +160,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
logger.LogError(c, "ping data send timeout")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, fmt.Errorf("ping send timeout"))
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -184,14 +190,17 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("handler panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
sr := newStreamResult(info.StreamStatus)
|
||||
for data := range dataChan {
|
||||
sr.reset()
|
||||
writeMutex.Lock()
|
||||
success := dataHandler(data)
|
||||
dataHandler(data, sr)
|
||||
writeMutex.Unlock()
|
||||
if !success {
|
||||
if sr.IsStopped() {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -205,6 +214,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("scanner panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
if common.DebugEnabled {
|
||||
@@ -220,6 +230,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-c.Request.Context().Done():
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
||||
return
|
||||
default:
|
||||
}
|
||||
@@ -253,7 +264,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
if common.DebugEnabled {
|
||||
println("received [DONE], stopping scanner")
|
||||
}
|
||||
@@ -264,20 +275,25 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
if err := scanner.Err(); err != nil {
|
||||
if err != io.EOF {
|
||||
logger.LogError(c, "scanner error: "+err.Error())
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
|
||||
}
|
||||
}
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
||||
})
|
||||
|
||||
// 主循环等待完成或超时
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
logger.LogError(c, "streaming timeout")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonTimeout, nil)
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
logger.LogInfo(c, "streaming finished")
|
||||
// EndReason already set by the goroutine that triggered stopChan
|
||||
case <-c.Request.Context().Done():
|
||||
// 客户端断开连接
|
||||
logger.LogInfo(c, "client disconnected")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
||||
}
|
||||
|
||||
if info.StreamStatus.IsNormalEnd() && !info.StreamStatus.HasErrors() {
|
||||
logger.LogInfo(c, fmt.Sprintf("stream ended: %s", info.StreamStatus.Summary()))
|
||||
} else {
|
||||
logger.LogError(c, fmt.Sprintf("stream ended: %s, received=%d", info.StreamStatus.Summary(), info.ReceivedResponseCount))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,8 +56,6 @@ func buildSSEBody(n int) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// slowReader wraps a reader and injects a delay before each Read call,
|
||||
// simulating a slow upstream that trickles data.
|
||||
type slowReader struct {
|
||||
r io.Reader
|
||||
delay time.Duration
|
||||
@@ -79,7 +77,7 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
||||
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
StreamScannerHandler(c, nil, info, func(data string) bool { return true })
|
||||
StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {})
|
||||
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
||||
}
|
||||
|
||||
@@ -89,9 +87,8 @@ func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(""))
|
||||
|
||||
var called atomic.Bool
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
called.Store(true)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.False(t, called.Load(), "handler should not be called for empty body")
|
||||
@@ -105,9 +102,8 @@ func TestStreamScannerHandler_1000Chunks(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
@@ -124,9 +120,8 @@ func TestStreamScannerHandler_10000Chunks(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
elapsed := time.Since(start)
|
||||
@@ -145,11 +140,10 @@ func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
received := make([]string, 0, numChunks)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
mu.Lock()
|
||||
received = append(received, data)
|
||||
mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
require.Equal(t, numChunks, len(received))
|
||||
@@ -166,31 +160,32 @@ func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
|
||||
func TestStreamScannerHandler_StopStopsStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 200
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
const failAt = 50
|
||||
const stopAt int64 = 50
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
return n < failAt
|
||||
if n >= stopAt {
|
||||
sr.Stop(fmt.Errorf("fatal at %d", n))
|
||||
}
|
||||
})
|
||||
|
||||
// The worker stops at failAt; the scanner may have read ahead,
|
||||
// but the handler should not be called beyond failAt.
|
||||
assert.Equal(t, int64(failAt), count.Load())
|
||||
assert.Equal(t, stopAt, count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
@@ -210,9 +205,8 @@ func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(100), count.Load())
|
||||
@@ -225,25 +219,18 @@ func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var got string
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
got = data
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, "{\"trimmed\":true}", got)
|
||||
}
|
||||
|
||||
// ---------- Decoupling: scanner not blocked by slow handler ----------
|
||||
// ---------- Decoupling ----------
|
||||
|
||||
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
|
||||
// If the scanner were synchronously coupled to the handler, total time would be
|
||||
// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
|
||||
// With decoupling, total time should be closer to
|
||||
// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
|
||||
// because the scanner reads ahead into the buffer while the handler processes.
|
||||
const numChunks = 50
|
||||
const upstreamDelay = 10 * time.Millisecond
|
||||
const handlerDelay = 20 * time.Millisecond
|
||||
@@ -273,10 +260,9 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
time.Sleep(handlerDelay)
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -293,7 +279,6 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
|
||||
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
|
||||
|
||||
// If decoupled, elapsed should be well under the coupled estimate.
|
||||
assert.Less(t, elapsed, coupledTime*85/100,
|
||||
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
|
||||
}
|
||||
@@ -311,9 +296,8 @@ func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -344,8 +328,6 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
|
||||
// The ping interval is 1s, so we should see at least 2 pings.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
@@ -372,9 +354,8 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -436,9 +417,8 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -456,6 +436,199 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
|
||||
}
|
||||
|
||||
// ---------- StreamStatus integration ----------
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(10)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Nil(t, info.StreamStatus.EndError)
|
||||
assert.True(t, info.StreamStatus.IsNormalEnd())
|
||||
assert.False(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < 5; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
||||
}
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(100)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
if n >= 10 {
|
||||
sr.Stop(fmt.Errorf("stop at 10"))
|
||||
}
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(20)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
if n >= 5 {
|
||||
sr.Done()
|
||||
}
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(5), count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.False(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) {
|
||||
// Not parallel: modifies global constant.StreamingTimeout
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 2
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
fmt.Fprint(pw, "data: {\"id\":1}\n")
|
||||
time.Sleep(10 * time.Second)
|
||||
pw.Close()
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for stream timeout")
|
||||
}
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason)
|
||||
assert.False(t, info.StreamStatus.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(10)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
sr.Error(fmt.Errorf("soft error for chunk"))
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.HasErrors())
|
||||
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(5)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
sr.Error(fmt.Errorf("error A"))
|
||||
sr.Error(fmt.Errorf("error B"))
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a large body without [DONE] to avoid race between scanner's [DONE]
|
||||
// and handler's Stop on the sync.Once EndReason.
|
||||
var b strings.Builder
|
||||
for i := 0; i < 100; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
||||
}
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
sr.Error(fmt.Errorf("soft error"))
|
||||
sr.Stop(fmt.Errorf("fatal"))
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(1), count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 2, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(1)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
assert.Nil(t, info.StreamStatus)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
assert.NotNil(t, info.StreamStatus)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(5)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
info.StreamStatus.RecordError("pre-existing error")
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -469,9 +642,6 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Slow upstream + slow handler. Total stream takes ~5 seconds.
|
||||
// The ping goroutine stays alive as long as the scanner is reading,
|
||||
// so pings should fire between data writes.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
@@ -498,9 +668,8 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
+12
-3
@@ -117,11 +117,20 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if request.N != nil {
|
||||
imageN = *request.N
|
||||
}
|
||||
|
||||
// n is handled via OtherRatio so it is applied exactly once in quota
|
||||
// calculation (both price-based and ratio-based paths).
|
||||
// Adaptors may have already set a more accurate count from the
|
||||
// upstream response; only set the default when they haven't.
|
||||
if _, hasN := info.PriceData.OtherRatios["n"]; !hasN {
|
||||
info.PriceData.AddOtherRatio("n", float64(imageN))
|
||||
}
|
||||
|
||||
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||
usage.(*dto.Usage).TotalTokens = int(imageN)
|
||||
usage.(*dto.Usage).TotalTokens = 1
|
||||
}
|
||||
if usage.(*dto.Usage).PromptTokens == 0 {
|
||||
usage.(*dto.Usage).PromptTokens = int(imageN)
|
||||
usage.(*dto.Usage).PromptTokens = 1
|
||||
}
|
||||
|
||||
quality := "standard"
|
||||
@@ -141,6 +150,6 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,6 +49,13 @@ func RelayMidjourneyImage(c *gin.Context) {
|
||||
if httpClient == nil {
|
||||
httpClient = service.GetHttpClient()
|
||||
}
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(midjourneyTask.ImageUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": fmt.Sprintf("request blocked: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp, err := httpClient.Get(midjourneyTask.ImageUrl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
|
||||
@@ -96,6 +96,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
info.PriceData = originPriceData
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
postConsumeQuota(c, info, usageDto)
|
||||
service.PostTextConsumeQuota(c, info, usageDto, nil)
|
||||
|
||||
info.OriginModelName = originModelName
|
||||
info.PriceData = originPriceData
|
||||
@@ -155,7 +155,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, info, usageDto, "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usageDto)
|
||||
service.PostTextConsumeQuota(c, info, usageDto, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,10 +36,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
// OAuth routes - specific routes must come before :provider wildcard
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.POST("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
// Non-standard OAuth (WeChat, Telegram) - keep original routes
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.POST("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
|
||||
@@ -48,6 +48,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
apiRouter.POST("/creem/webhook", controller.CreemWebhook)
|
||||
apiRouter.POST("/waffo/webhook", controller.WaffoWebhook)
|
||||
|
||||
// Universal secure verification routes
|
||||
apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
|
||||
@@ -89,6 +90,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay)
|
||||
selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
|
||||
selfRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.RequestCreemPay)
|
||||
selfRoute.POST("/waffo/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPay)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
selfRoute.PUT("/setting", controller.UpdateUserSetting)
|
||||
|
||||
@@ -192,6 +194,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
performanceRoute.DELETE("/disk_cache", controller.ClearDiskCache)
|
||||
performanceRoute.POST("/reset_stats", controller.ResetPerformanceStats)
|
||||
performanceRoute.POST("/gc", controller.ForceGC)
|
||||
performanceRoute.GET("/logs", controller.GetLogFiles)
|
||||
performanceRoute.DELETE("/logs", controller.CleanupLogFiles)
|
||||
}
|
||||
ratioSyncRoute := apiRouter.Group("/ratio_sync")
|
||||
ratioSyncRoute.Use(middleware.RootAuth())
|
||||
@@ -222,7 +226,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/batch", controller.DeleteChannelBatch)
|
||||
channelRoute.POST("/fix", controller.FixChannelsAbilities)
|
||||
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
|
||||
channelRoute.POST("/fetch_models", controller.FetchModels)
|
||||
channelRoute.POST("/fetch_models", middleware.RootAuth(), controller.FetchModels)
|
||||
channelRoute.POST("/codex/oauth/start", controller.StartCodexOAuth)
|
||||
channelRoute.POST("/codex/oauth/complete", controller.CompleteCodexOAuth)
|
||||
channelRoute.POST("/:id/codex/oauth/start", controller.StartCodexOAuthForChannel)
|
||||
|
||||
@@ -610,14 +610,17 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
v, ok := c.Get(ginKeyChannelAffinitySkipRetry)
|
||||
if ok {
|
||||
b, ok := v.(bool)
|
||||
if ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
meta, ok := getChannelAffinityMeta(c)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return b
|
||||
return meta.SkipRetry
|
||||
}
|
||||
|
||||
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
|
||||
|
||||
@@ -116,6 +116,66 @@ func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) {
|
||||
require.Equal(t, "trim_prefix", secondOp["mode"])
|
||||
}
|
||||
|
||||
func TestShouldSkipRetryAfterChannelAffinityFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx func() *gin.Context
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil context",
|
||||
ctx: func() *gin.Context {
|
||||
return nil
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "explicit skip retry flag in context",
|
||||
ctx: func() *gin.Context {
|
||||
ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-explicit-flag",
|
||||
SkipRetry: false,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
ctx.Set(ginKeyChannelAffinitySkipRetry, true)
|
||||
return ctx
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "fallback to matched rule meta",
|
||||
ctx: func() *gin.Context {
|
||||
return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-skip-retry",
|
||||
SkipRetry: true,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no flag and no skip retry meta",
|
||||
ctx: func() *gin.Context {
|
||||
return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-no-skip-retry",
|
||||
SkipRetry: false,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, ShouldSkipRetryAfterChannelAffinityFailure(tt.ctx()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
+26
-25
@@ -223,6 +223,25 @@ func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func buildClaudeUsageFromOpenAIUsage(oaiUsage *dto.Usage) *dto.ClaudeUsage {
|
||||
if oaiUsage == nil {
|
||||
return nil
|
||||
}
|
||||
usage := &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
}
|
||||
if oaiUsage.ClaudeCacheCreation5mTokens > 0 || oaiUsage.ClaudeCacheCreation1hTokens > 0 {
|
||||
usage.CacheCreation = &dto.ClaudeCacheCreationUsage{
|
||||
Ephemeral5mInputTokens: oaiUsage.ClaudeCacheCreation5mTokens,
|
||||
Ephemeral1hInputTokens: oaiUsage.ClaudeCacheCreation1hTokens,
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
return nil
|
||||
@@ -391,13 +410,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -419,13 +433,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -555,13 +564,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -612,10 +616,7 @@ func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relayco
|
||||
}
|
||||
claudeResponse.Content = contents
|
||||
claudeResponse.StopReason = stopReason
|
||||
claudeResponse.Usage = &dto.ClaudeUsage{
|
||||
InputTokens: openAIResponse.PromptTokens,
|
||||
OutputTokens: openAIResponse.CompletionTokens,
|
||||
}
|
||||
claudeResponse.Usage = buildClaudeUsageFromOpenAIUsage(&openAIResponse.Usage)
|
||||
|
||||
return claudeResponse
|
||||
}
|
||||
|
||||
@@ -104,6 +104,11 @@ func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, e
|
||||
return sniffed, nil
|
||||
}
|
||||
|
||||
// Try HEIF/HEIC detection (Go standard library doesn't recognize it)
|
||||
if heifMime := detectHEIF(readData); heifMime != "" {
|
||||
return heifMime, nil
|
||||
}
|
||||
|
||||
if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
|
||||
switch strings.ToLower(format) {
|
||||
case "jpeg", "jpg":
|
||||
@@ -168,6 +173,10 @@ func GetMimeTypeByExtension(ext string) string {
|
||||
return "image/gif"
|
||||
case "jfif":
|
||||
return "image/jpeg"
|
||||
case "heic":
|
||||
return "image/heic"
|
||||
case "heif":
|
||||
return "image/heif"
|
||||
|
||||
// Audio files
|
||||
case "mp3":
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
@@ -275,6 +276,11 @@ func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) stri
|
||||
}
|
||||
return sniffed
|
||||
}
|
||||
|
||||
// 4.5 尝试 HEIF/HEIC 检测(Go 标准库不识别)
|
||||
if heifMime := detectHEIF(fileBytes); heifMime != "" {
|
||||
return heifMime
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 尝试作为图片解码获取格式
|
||||
@@ -449,9 +455,118 @@ func decodeImageConfig(data []byte) (image.Config, string, error) {
|
||||
return config, "webp", nil
|
||||
}
|
||||
|
||||
// Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
|
||||
if heifMime := detectHEIF(data); heifMime != "" {
|
||||
formatName := "heif"
|
||||
if heifMime == "image/heic" {
|
||||
formatName = "heic"
|
||||
}
|
||||
if w, h, ok := parseHEIFDimensions(data); ok {
|
||||
return image.Config{Width: w, Height: h}, formatName, nil
|
||||
}
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
|
||||
}
|
||||
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
|
||||
}
|
||||
|
||||
// detectHEIF checks ISOBMFF magic bytes to detect HEIC/HEIF files.
|
||||
// Returns "image/heic", "image/heif", or "" if not recognized.
|
||||
func detectHEIF(data []byte) string {
|
||||
if len(data) < 12 {
|
||||
return ""
|
||||
}
|
||||
// ISOBMFF: bytes[4:8] must be "ftyp"
|
||||
if string(data[4:8]) != "ftyp" {
|
||||
return ""
|
||||
}
|
||||
brand := string(data[8:12])
|
||||
switch brand {
|
||||
case "heic", "heix", "hevc", "hevx", "heim", "heis":
|
||||
return "image/heic"
|
||||
case "mif1", "msf1":
|
||||
return "image/heif"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// parseHEIFDimensions parses ISOBMFF box tree to find the ispe box
|
||||
// and extract image width/height. Returns (width, height, ok).
|
||||
func parseHEIFDimensions(data []byte) (int, int, bool) {
|
||||
size := len(data)
|
||||
if size < 12 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// Walk top-level boxes to find "meta"
|
||||
offset := 0
|
||||
for offset+8 <= size {
|
||||
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
|
||||
boxType := string(data[offset+4 : offset+8])
|
||||
headerLen := 8
|
||||
|
||||
if boxSize == 1 {
|
||||
// 64-bit extended size
|
||||
if offset+16 > size {
|
||||
break
|
||||
}
|
||||
boxSize = int(binary.BigEndian.Uint64(data[offset+8 : offset+16]))
|
||||
headerLen = 16
|
||||
} else if boxSize == 0 {
|
||||
// box extends to end of data
|
||||
boxSize = size - offset
|
||||
}
|
||||
|
||||
if boxSize < headerLen || offset+boxSize > size {
|
||||
break
|
||||
}
|
||||
|
||||
if boxType == "meta" {
|
||||
// meta is a full box: 4 bytes version/flags after header
|
||||
metaData := data[offset+headerLen : offset+boxSize]
|
||||
if len(metaData) < 4 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return findISPE(metaData[4:])
|
||||
}
|
||||
offset += boxSize
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// findISPE recursively searches for the ispe box within container boxes.
|
||||
// Path: meta -> iprp -> ipco -> ispe
|
||||
func findISPE(data []byte) (int, int, bool) {
|
||||
offset := 0
|
||||
size := len(data)
|
||||
for offset+8 <= size {
|
||||
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
|
||||
boxType := string(data[offset+4 : offset+8])
|
||||
if boxSize < 8 || offset+boxSize > size {
|
||||
break
|
||||
}
|
||||
content := data[offset+8 : offset+boxSize]
|
||||
switch boxType {
|
||||
case "iprp", "ipco":
|
||||
if w, h, ok := findISPE(content); ok {
|
||||
return w, h, true
|
||||
}
|
||||
case "ispe":
|
||||
// ispe is a full box: 4 bytes version/flags, then 4 bytes width, 4 bytes height
|
||||
if len(content) >= 12 {
|
||||
w := int(binary.BigEndian.Uint32(content[4:8]))
|
||||
h := int(binary.BigEndian.Uint32(content[8:12]))
|
||||
if w > 0 && h > 0 {
|
||||
return w, h, true
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += boxSize
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// guessMimeTypeFromURL 从 URL 猜测 MIME 类型
|
||||
func guessMimeTypeFromURL(url string) string {
|
||||
cleanedURL := url
|
||||
|
||||
+29
-13
@@ -159,20 +159,36 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
}
|
||||
|
||||
func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
||||
// Read all data so we can retry with different decoders
|
||||
data, readErr := io.ReadAll(reader)
|
||||
if readErr != nil {
|
||||
return image.Config{}, "", fmt.Errorf("failed to read image data: %w", readErr)
|
||||
}
|
||||
|
||||
// 读取图片的头部信息来获取图片尺寸
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
config, err = webp.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
config, format, err := image.DecodeConfig(bytes.NewReader(data))
|
||||
if err == nil {
|
||||
return config, format, nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
|
||||
config, err = webp.DecodeConfig(bytes.NewReader(data))
|
||||
if err == nil {
|
||||
return config, "webp", nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
|
||||
// Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
|
||||
if heifMime := detectHEIF(data); heifMime != "" {
|
||||
formatName := "heif"
|
||||
if heifMime == "image/heic" {
|
||||
formatName = "heic"
|
||||
}
|
||||
format = "webp"
|
||||
if w, h, ok := parseHEIFDimensions(data); ok {
|
||||
return image.Config{Width: w, Height: h}, formatName, nil
|
||||
}
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
|
||||
}
|
||||
if err != nil {
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
return config, format, nil
|
||||
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
@@ -73,8 +73,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
appendRequestConversionChain(relayInfo, other)
|
||||
appendFinalRequestFormat(relayInfo, other)
|
||||
appendBillingInfo(relayInfo, other)
|
||||
appendParamOverrideInfo(relayInfo, other)
|
||||
appendStreamStatus(relayInfo, other)
|
||||
return other
|
||||
}
|
||||
|
||||
@@ -85,6 +87,33 @@ func appendParamOverrideInfo(relayInfo *relaycommon.RelayInfo, other map[string]
|
||||
other["po"] = relayInfo.ParamOverrideAudit
|
||||
}
|
||||
|
||||
func appendStreamStatus(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil || !relayInfo.IsStream || relayInfo.StreamStatus == nil {
|
||||
return
|
||||
}
|
||||
ss := relayInfo.StreamStatus
|
||||
status := "ok"
|
||||
if !ss.IsNormalEnd() || ss.HasErrors() {
|
||||
status = "error"
|
||||
}
|
||||
streamInfo := map[string]interface{}{
|
||||
"status": status,
|
||||
"end_reason": string(ss.EndReason),
|
||||
}
|
||||
if ss.EndError != nil {
|
||||
streamInfo["end_error"] = ss.EndError.Error()
|
||||
}
|
||||
if ss.ErrorCount > 0 {
|
||||
streamInfo["error_count"] = ss.ErrorCount
|
||||
messages := make([]string, 0, len(ss.Errors))
|
||||
for _, e := range ss.Errors {
|
||||
messages = append(messages, e.Message)
|
||||
}
|
||||
streamInfo["errors"] = messages
|
||||
}
|
||||
other["stream_status"] = streamInfo
|
||||
}
|
||||
|
||||
func appendBillingInfo(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil {
|
||||
return
|
||||
@@ -167,6 +196,17 @@ func appendRequestConversionChain(relayInfo *relaycommon.RelayInfo, other map[st
|
||||
other["request_conversion"] = chain
|
||||
}
|
||||
|
||||
func appendFinalRequestFormat(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil {
|
||||
return
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
// claude indicates the final upstream request format is Claude Messages.
|
||||
// Frontend log rendering uses this to keep the original Claude input display.
|
||||
other["claude"] = true
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
|
||||
info["ws"] = true
|
||||
|
||||
@@ -235,108 +235,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
})
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
|
||||
if usage != nil {
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
cacheCreationRatio5m := relayInfo.PriceData.CacheCreation5mRatio
|
||||
cacheCreationRatio1h := relayInfo.PriceData.CacheCreation1hRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
cacheCreationTokens5m := usage.ClaudeCacheCreation5mTokens
|
||||
cacheCreationTokens1h := usage.ClaudeCacheCreation1hTokens
|
||||
|
||||
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
|
||||
if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
|
||||
if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
|
||||
cacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
promptTokens -= cacheCreationTokens
|
||||
}
|
||||
|
||||
calculateQuota := 0.0
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
calculateQuota = float64(promptTokens)
|
||||
calculateQuota += float64(cacheTokens) * cacheRatio
|
||||
calculateQuota += float64(cacheCreationTokens5m) * cacheCreationRatio5m
|
||||
calculateQuota += float64(cacheCreationTokens1h) * cacheCreationRatio1h
|
||||
remainingCacheCreationTokens := cacheCreationTokens - cacheCreationTokens5m - cacheCreationTokens1h
|
||||
if remainingCacheCreationTokens > 0 {
|
||||
calculateQuota += float64(remainingCacheCreationTokens) * cacheCreationRatio
|
||||
}
|
||||
calculateQuota += float64(completionTokens) * completionRatio
|
||||
calculateQuota = calculateQuota * groupRatio * modelRatio
|
||||
} else {
|
||||
calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
|
||||
}
|
||||
|
||||
if modelRatio != 0 && calculateQuota <= 0 {
|
||||
calculateQuota = 1
|
||||
}
|
||||
|
||||
quota := int(calculateQuota)
|
||||
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
var logContent string
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游出错)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio,
|
||||
cacheCreationTokens, cacheCreationRatio,
|
||||
cacheCreationTokens5m, cacheCreationRatio5m,
|
||||
cacheCreationTokens1h, cacheCreationRatio1h,
|
||||
modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
|
||||
if priceData.CacheCreationRatio == 1 {
|
||||
return 0
|
||||
|
||||
@@ -0,0 +1,430 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
type textQuotaSummary struct {
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
CacheTokens int
|
||||
CacheCreationTokens int
|
||||
CacheCreationTokens5m int
|
||||
CacheCreationTokens1h int
|
||||
ImageTokens int
|
||||
AudioTokens int
|
||||
ModelName string
|
||||
TokenName string
|
||||
UseTimeSeconds int64
|
||||
CompletionRatio float64
|
||||
CacheRatio float64
|
||||
ImageRatio float64
|
||||
ModelRatio float64
|
||||
GroupRatio float64
|
||||
ModelPrice float64
|
||||
CacheCreationRatio float64
|
||||
CacheCreationRatio5m float64
|
||||
CacheCreationRatio1h float64
|
||||
Quota int
|
||||
IsClaudeUsageSemantic bool
|
||||
UsageSemantic string
|
||||
WebSearchPrice float64
|
||||
WebSearchCallCount int
|
||||
ClaudeWebSearchPrice float64
|
||||
ClaudeWebSearchCallCount int
|
||||
FileSearchPrice float64
|
||||
FileSearchCallCount int
|
||||
AudioInputPrice float64
|
||||
ImageGenerationCallPrice float64
|
||||
}
|
||||
|
||||
func cacheWriteTokensTotal(summary textQuotaSummary) int {
|
||||
if summary.CacheCreationTokens5m > 0 || summary.CacheCreationTokens1h > 0 {
|
||||
splitCacheWriteTokens := summary.CacheCreationTokens5m + summary.CacheCreationTokens1h
|
||||
if summary.CacheCreationTokens > splitCacheWriteTokens {
|
||||
return summary.CacheCreationTokens
|
||||
}
|
||||
return splitCacheWriteTokens
|
||||
}
|
||||
return summary.CacheCreationTokens
|
||||
}
|
||||
|
||||
func isLegacyClaudeDerivedOpenAIUsage(relayInfo *relaycommon.RelayInfo, usage *dto.Usage) bool {
|
||||
if relayInfo == nil || usage == nil {
|
||||
return false
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
return false
|
||||
}
|
||||
if usage.UsageSource != "" || usage.UsageSemantic != "" {
|
||||
return false
|
||||
}
|
||||
return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0
|
||||
}
|
||||
|
||||
func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary {
|
||||
summary := textQuotaSummary{
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
TokenName: ctx.GetString("token_name"),
|
||||
UseTimeSeconds: time.Now().Unix() - relayInfo.StartTime.Unix(),
|
||||
CompletionRatio: relayInfo.PriceData.CompletionRatio,
|
||||
CacheRatio: relayInfo.PriceData.CacheRatio,
|
||||
ImageRatio: relayInfo.PriceData.ImageRatio,
|
||||
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
|
||||
ModelPrice: relayInfo.PriceData.ModelPrice,
|
||||
CacheCreationRatio: relayInfo.PriceData.CacheCreationRatio,
|
||||
CacheCreationRatio5m: relayInfo.PriceData.CacheCreation5mRatio,
|
||||
CacheCreationRatio1h: relayInfo.PriceData.CacheCreation1hRatio,
|
||||
UsageSemantic: usageSemanticFromUsage(relayInfo, usage),
|
||||
}
|
||||
summary.IsClaudeUsageSemantic = summary.UsageSemantic == "anthropic"
|
||||
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
}
|
||||
|
||||
summary.PromptTokens = usage.PromptTokens
|
||||
summary.CompletionTokens = usage.CompletionTokens
|
||||
summary.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
summary.CacheTokens = usage.PromptTokensDetails.CachedTokens
|
||||
summary.CacheCreationTokens = usage.PromptTokensDetails.CachedCreationTokens
|
||||
summary.CacheCreationTokens5m = usage.ClaudeCacheCreation5mTokens
|
||||
summary.CacheCreationTokens1h = usage.ClaudeCacheCreation1hTokens
|
||||
summary.ImageTokens = usage.PromptTokensDetails.ImageTokens
|
||||
summary.AudioTokens = usage.PromptTokensDetails.AudioTokens
|
||||
legacyClaudeDerived := isLegacyClaudeDerivedOpenAIUsage(relayInfo, usage)
|
||||
isOpenRouterClaudeBilling := relayInfo.ChannelMeta != nil &&
|
||||
relayInfo.ChannelType == constant.ChannelTypeOpenRouter &&
|
||||
summary.IsClaudeUsageSemantic
|
||||
|
||||
if isOpenRouterClaudeBilling {
|
||||
summary.PromptTokens -= summary.CacheTokens
|
||||
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(summary.ModelName, relayInfo.PriceData.ModelRatio)
|
||||
if summary.CacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
|
||||
if maybeCacheCreationTokens >= 0 && summary.PromptTokens >= maybeCacheCreationTokens {
|
||||
summary.CacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
summary.PromptTokens -= summary.CacheCreationTokens
|
||||
}
|
||||
|
||||
dPromptTokens := decimal.NewFromInt(int64(summary.PromptTokens))
|
||||
dCacheTokens := decimal.NewFromInt(int64(summary.CacheTokens))
|
||||
dImageTokens := decimal.NewFromInt(int64(summary.ImageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(summary.AudioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(summary.CompletionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(summary.CacheCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(summary.CompletionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(summary.CacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(summary.ImageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(summary.ModelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(summary.GroupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(summary.ModelPrice)
|
||||
dCacheCreationRatio := decimal.NewFromFloat(summary.CacheCreationRatio)
|
||||
dCacheCreationRatio5m := decimal.NewFromFloat(summary.CacheCreationRatio5m)
|
||||
dCacheCreationRatio1h := decimal.NewFromFloat(summary.CacheCreationRatio1h)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
|
||||
var dWebSearchQuota decimal.Decimal
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||||
summary.WebSearchCallCount = webSearchTool.CallCount
|
||||
summary.WebSearchPrice = operation_setting.GetWebSearchPricePerThousand(summary.ModelName, webSearchTool.SearchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
} else if strings.HasSuffix(summary.ModelName, "search-preview") {
|
||||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
||||
if searchContextSize == "" {
|
||||
searchContextSize = "medium"
|
||||
}
|
||||
summary.WebSearchCallCount = 1
|
||||
summary.WebSearchPrice = operation_setting.GetWebSearchPricePerThousand(summary.ModelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests")
|
||||
if summary.ClaudeWebSearchCallCount > 0 {
|
||||
summary.ClaudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(summary.ClaudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).
|
||||
Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount)))
|
||||
}
|
||||
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||||
summary.FileSearchCallCount = fileSearchTool.CallCount
|
||||
summary.FileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
|
||||
dFileSearchQuota = decimal.NewFromFloat(summary.FileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
if !summary.IsClaudeUsageSemantic && !legacyClaudeDerived {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
|
||||
var cachedCreationTokensWithRatio decimal.Decimal
|
||||
hasSplitCacheCreationTokens := summary.CacheCreationTokens5m > 0 || summary.CacheCreationTokens1h > 0
|
||||
if !dCachedCreationTokens.IsZero() || hasSplitCacheCreationTokens {
|
||||
if !summary.IsClaudeUsageSemantic && !legacyClaudeDerived {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
cachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCacheCreationRatio)
|
||||
} else {
|
||||
remaining := summary.CacheCreationTokens - summary.CacheCreationTokens5m - summary.CacheCreationTokens1h
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
cachedCreationTokensWithRatio = decimal.NewFromInt(int64(remaining)).Mul(dCacheCreationRatio)
|
||||
cachedCreationTokensWithRatio = cachedCreationTokensWithRatio.Add(decimal.NewFromInt(int64(summary.CacheCreationTokens5m)).Mul(dCacheCreationRatio5m))
|
||||
cachedCreationTokensWithRatio = cachedCreationTokensWithRatio.Add(decimal.NewFromInt(int64(summary.CacheCreationTokens1h)).Mul(dCacheCreationRatio1h))
|
||||
}
|
||||
}
|
||||
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
if !dImageTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dImageTokens)
|
||||
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
|
||||
}
|
||||
|
||||
if !dAudioTokens.IsZero() {
|
||||
summary.AudioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(summary.ModelName)
|
||||
if summary.AudioInputPrice > 0 {
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(summary.AudioInputPrice).
|
||||
Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio)
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||||
}
|
||||
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
} else {
|
||||
quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
||||
}
|
||||
}
|
||||
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
}
|
||||
|
||||
if summary.TotalTokens == 0 {
|
||||
summary.Quota = 0
|
||||
} else if !ratio.IsZero() && summary.Quota == 0 {
|
||||
summary.Quota = 1
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
func usageSemanticFromUsage(relayInfo *relaycommon.RelayInfo, usage *dto.Usage) string {
|
||||
if usage != nil && usage.UsageSemantic != "" {
|
||||
return usage.UsageSemantic
|
||||
}
|
||||
if relayInfo != nil && relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
return "anthropic"
|
||||
}
|
||||
return "openai"
|
||||
}
|
||||
|
||||
func PostTextConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent []string) {
|
||||
originUsage := usage
|
||||
if usage == nil {
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
if originUsage != nil {
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
if summary.WebSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,调用花费 %s", summary.WebSearchCallCount, decimal.NewFromFloat(summary.WebSearchPrice).Mul(decimal.NewFromInt(int64(summary.WebSearchCallCount))).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.ClaudeWebSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", summary.ClaudeWebSearchCallCount, decimal.NewFromFloat(summary.ClaudeWebSearchPrice).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount))).String()))
|
||||
}
|
||||
if summary.FileSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s", summary.FileSearchCallCount, decimal.NewFromFloat(summary.FileSearchPrice).Mul(decimal.NewFromInt(int64(summary.FileSearchCallCount))).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.AudioInputPrice > 0 && summary.AudioTokens > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", decimal.NewFromFloat(summary.AudioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(decimal.NewFromInt(int64(summary.AudioTokens))).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.ImageGenerationCallPrice > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
|
||||
if summary.TotalTokens == 0 {
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, summary.ModelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, summary.Quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, summary.Quota)
|
||||
}
|
||||
|
||||
if err := SettleBilling(ctx, relayInfo, summary.Quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := summary.ModelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", summary.ModelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", summary.ModelName))
|
||||
}
|
||||
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
var other map[string]interface{}
|
||||
if summary.IsClaudeUsageSemantic {
|
||||
other = GenerateClaudeOtherInfo(ctx, relayInfo,
|
||||
summary.ModelRatio, summary.GroupRatio, summary.CompletionRatio,
|
||||
summary.CacheTokens, summary.CacheRatio,
|
||||
summary.CacheCreationTokens, summary.CacheCreationRatio,
|
||||
summary.CacheCreationTokens5m, summary.CacheCreationRatio5m,
|
||||
summary.CacheCreationTokens1h, summary.CacheCreationRatio1h,
|
||||
summary.ModelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
other["usage_semantic"] = "anthropic"
|
||||
} else {
|
||||
other = GenerateTextOtherInfo(ctx, relayInfo, summary.ModelRatio, summary.GroupRatio, summary.CompletionRatio, summary.CacheTokens, summary.CacheRatio, summary.ModelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
}
|
||||
if adminRejectReason != "" {
|
||||
other["reject_reason"] = adminRejectReason
|
||||
}
|
||||
if summary.ImageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = summary.ImageRatio
|
||||
other["image_output"] = summary.ImageTokens
|
||||
}
|
||||
if summary.WebSearchCallCount > 0 {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = summary.WebSearchCallCount
|
||||
other["web_search_price"] = summary.WebSearchPrice
|
||||
} else if summary.ClaudeWebSearchCallCount > 0 {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = summary.ClaudeWebSearchCallCount
|
||||
other["web_search_price"] = summary.ClaudeWebSearchPrice
|
||||
}
|
||||
if summary.FileSearchCallCount > 0 {
|
||||
other["file_search"] = true
|
||||
other["file_search_call_count"] = summary.FileSearchCallCount
|
||||
other["file_search_price"] = summary.FileSearchPrice
|
||||
}
|
||||
if summary.AudioInputPrice > 0 && summary.AudioTokens > 0 {
|
||||
other["audio_input_seperate_price"] = true
|
||||
other["audio_input_token_count"] = summary.AudioTokens
|
||||
other["audio_input_price"] = summary.AudioInputPrice
|
||||
}
|
||||
if summary.ImageGenerationCallPrice > 0 {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = summary.ImageGenerationCallPrice
|
||||
}
|
||||
if summary.CacheCreationTokens > 0 {
|
||||
other["cache_creation_tokens"] = summary.CacheCreationTokens
|
||||
other["cache_creation_ratio"] = summary.CacheCreationRatio
|
||||
}
|
||||
if summary.CacheCreationTokens5m > 0 {
|
||||
other["cache_creation_tokens_5m"] = summary.CacheCreationTokens5m
|
||||
other["cache_creation_ratio_5m"] = summary.CacheCreationRatio5m
|
||||
}
|
||||
if summary.CacheCreationTokens1h > 0 {
|
||||
other["cache_creation_tokens_1h"] = summary.CacheCreationTokens1h
|
||||
other["cache_creation_ratio_1h"] = summary.CacheCreationRatio1h
|
||||
}
|
||||
cacheWriteTokens := cacheWriteTokensTotal(summary)
|
||||
if cacheWriteTokens > 0 {
|
||||
// cache_write_tokens: normalized cache creation total for UI display.
|
||||
// If split 5m/1h values are present, this is their sum; otherwise it falls back
|
||||
// to cache_creation_tokens.
|
||||
other["cache_write_tokens"] = cacheWriteTokens
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() != types.RelayFormatClaude && usage != nil && usage.UsageSource != "" && usage.InputTokens > 0 {
|
||||
// input_tokens_total: explicit normalized total input used by the usage log UI.
|
||||
// Only write this field when upstream/current conversion has already provided a
|
||||
// reliable total input value and tagged the usage source. Do not infer it from
|
||||
// prompt/cache fields here, otherwise old upstream payloads may be double-counted.
|
||||
other["input_tokens_total"] = usage.InputTokens
|
||||
}
|
||||
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: summary.PromptTokens,
|
||||
CompletionTokens: summary.CompletionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: summary.TokenName,
|
||||
Quota: summary.Quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(summary.UseTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCalculateTextQuotaSummaryUnifiedForClaudeSemantic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 1000,
|
||||
CompletionTokens: 200,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 100,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
}
|
||||
|
||||
priceData := types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 2,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
}
|
||||
|
||||
chatRelayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: priceData,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
messageRelayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatClaude,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: priceData,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
chatSummary := calculateTextQuotaSummary(ctx, chatRelayInfo, usage)
|
||||
messageSummary := calculateTextQuotaSummary(ctx, messageRelayInfo, usage)
|
||||
|
||||
require.Equal(t, messageSummary.Quota, chatSummary.Quota)
|
||||
require.Equal(t, messageSummary.CacheCreationTokens5m, chatSummary.CacheCreationTokens5m)
|
||||
require.Equal(t, messageSummary.CacheCreationTokens1h, chatSummary.CacheCreationTokens1h)
|
||||
require.True(t, chatSummary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, 1488, chatSummary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryUsesSplitClaudeCacheCreationRatios(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0,
|
||||
CacheCreationRatio: 1,
|
||||
CacheCreation5mRatio: 2,
|
||||
CacheCreation1hRatio: 3,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 0,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedCreationTokens: 10,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 2,
|
||||
ClaudeCacheCreation1hTokens: 3,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// 100 + remaining(5)*1 + 2*2 + 3*3 = 118
|
||||
require.Equal(t, 118, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryUsesAnthropicUsageSemanticFromUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 2,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 1000,
|
||||
CompletionTokens: 200,
|
||||
UsageSemantic: "anthropic",
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 100,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
require.True(t, summary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, "anthropic", summary.UsageSemantic)
|
||||
require.Equal(t, 1488, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCacheWriteTokensTotal(t *testing.T) {
|
||||
t.Run("split cache creation", func(t *testing.T) {
|
||||
summary := textQuotaSummary{
|
||||
CacheCreationTokens: 50,
|
||||
CacheCreationTokens5m: 10,
|
||||
CacheCreationTokens1h: 20,
|
||||
}
|
||||
require.Equal(t, 50, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
|
||||
t.Run("legacy cache creation", func(t *testing.T) {
|
||||
summary := textQuotaSummary{CacheCreationTokens: 50}
|
||||
require.Equal(t, 50, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
|
||||
t.Run("split cache creation without aggregate remainder", func(t *testing.T) {
|
||||
summary := textQuotaSummary{
|
||||
CacheCreationTokens5m: 10,
|
||||
CacheCreationTokens1h: 20,
|
||||
}
|
||||
require.Equal(t, 30, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryHandlesLegacyClaudeDerivedOpenAIUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 5,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 62,
|
||||
CompletionTokens: 95,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 3544,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 586,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// 62 + 3544*0.1 + 586*1.25 + 95*5 = 1624.9 => 1624
|
||||
require.Equal(t, 1624, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummarySeparatesOpenRouterCacheReadFromPromptBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
OriginModelName: "openai/gpt-4.1",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 2432,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// OpenRouter OpenAI-format display keeps prompt_tokens as total input,
|
||||
// but billing still separates normal input from cache read tokens.
|
||||
// quota = (2604 - 2432) + 2432*0.1 + 383 = 798.2 => 798
|
||||
require.Equal(t, 2604, summary.PromptTokens)
|
||||
require.Equal(t, 798, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummarySeparatesOpenRouterCacheCreationFromPromptBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
OriginModelName: "openai/gpt-4.1",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedCreationTokens: 100,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// prompt_tokens is still logged as total input, but cache creation is billed separately.
|
||||
// quota = (2604 - 100) + 100*1.25 + 383 = 3012
|
||||
require.Equal(t, 2604, summary.PromptTokens)
|
||||
require.Equal(t, 3012, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryKeepsPrePRClaudeOpenRouterBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "anthropic/claude-3.7-sonnet",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 2432,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// Pre-PR PostClaudeConsumeQuota behavior for OpenRouter:
|
||||
// prompt = 2604 - 2432 = 172
|
||||
// quota = 172 + 2432*0.1 + 383 = 798.2 => 798
|
||||
require.True(t, summary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, 172, summary.PromptTokens)
|
||||
require.Equal(t, 798, summary.Quota)
|
||||
}
|
||||
@@ -17,6 +17,7 @@ var defaultQwenSettings = QwenSettings{
|
||||
"z-image",
|
||||
"qwen-image",
|
||||
"wan2.6",
|
||||
"wan2.7",
|
||||
"qwen-image-edit",
|
||||
"qwen-image-edit-max",
|
||||
"qwen-image-edit-max-2026-01-16",
|
||||
|
||||
@@ -88,7 +88,7 @@ var channelAffinitySetting = ChannelAffinitySetting{
|
||||
ValueRegex: "",
|
||||
TTLSeconds: 0,
|
||||
ParamOverrideTemplate: buildPassHeaderTemplate(codexCliPassThroughHeaders),
|
||||
SkipRetryOnFailure: false,
|
||||
SkipRetryOnFailure: true,
|
||||
IncludeUsingGroup: true,
|
||||
IncludeRuleName: true,
|
||||
UserAgentInclude: nil,
|
||||
@@ -103,7 +103,7 @@ var channelAffinitySetting = ChannelAffinitySetting{
|
||||
ValueRegex: "",
|
||||
TTLSeconds: 0,
|
||||
ParamOverrideTemplate: buildPassHeaderTemplate(claudeCliPassThroughHeaders),
|
||||
SkipRetryOnFailure: false,
|
||||
SkipRetryOnFailure: true,
|
||||
IncludeUsingGroup: true,
|
||||
IncludeRuleName: true,
|
||||
UserAgentInclude: nil,
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package setting
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
)
|
||||
|
||||
var (
|
||||
WaffoEnabled bool
|
||||
WaffoApiKey string
|
||||
WaffoPrivateKey string
|
||||
WaffoPublicCert string
|
||||
WaffoSandboxPublicCert string
|
||||
WaffoSandboxApiKey string
|
||||
WaffoSandboxPrivateKey string
|
||||
WaffoSandbox bool
|
||||
WaffoMerchantId string
|
||||
WaffoNotifyUrl string
|
||||
WaffoReturnUrl string
|
||||
WaffoSubscriptionReturnUrl string
|
||||
WaffoCurrency string
|
||||
WaffoUnitPrice float64 = 1.0
|
||||
WaffoMinTopUp int = 1
|
||||
)
|
||||
|
||||
// GetWaffoPayMethods 从 options 读取 Waffo 支付方式配置
|
||||
func GetWaffoPayMethods() []constant.WaffoPayMethod {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
jsonStr := common.OptionMap["WaffoPayMethods"]
|
||||
common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
if jsonStr == "" {
|
||||
return copyDefaultWaffoPayMethods()
|
||||
}
|
||||
var methods []constant.WaffoPayMethod
|
||||
if err := common.UnmarshalJsonStr(jsonStr, &methods); err != nil {
|
||||
return copyDefaultWaffoPayMethods()
|
||||
}
|
||||
return methods
|
||||
}
|
||||
|
||||
// SetWaffoPayMethods 序列化 Waffo 支付方式配置并更新 OptionMap
|
||||
func SetWaffoPayMethods(methods []constant.WaffoPayMethod) error {
|
||||
jsonBytes, err := common.Marshal(methods)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.OptionMapRWMutex.Lock()
|
||||
common.OptionMap["WaffoPayMethods"] = string(jsonBytes)
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyDefaultWaffoPayMethods() []constant.WaffoPayMethod {
|
||||
cp := make([]constant.WaffoPayMethod, len(constant.DefaultWaffoPayMethods))
|
||||
copy(cp, constant.DefaultWaffoPayMethods)
|
||||
return cp
|
||||
}
|
||||
|
||||
// WaffoPayMethods2JsonString 将默认 WaffoPayMethods 序列化为 JSON 字符串(供 InitOptionMap 使用)
|
||||
func WaffoPayMethods2JsonString() string {
|
||||
jsonBytes, err := common.Marshal(constant.DefaultWaffoPayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
@@ -36,7 +36,7 @@ var performanceSetting = PerformanceSetting{
|
||||
MonitorEnabled: true,
|
||||
MonitorCPUThreshold: 90,
|
||||
MonitorMemoryThreshold: 90,
|
||||
MonitorDiskThreshold: 90,
|
||||
MonitorDiskThreshold: 95,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -510,6 +510,9 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
// gpt-5 匹配
|
||||
if strings.HasPrefix(name, "gpt-5") {
|
||||
if strings.HasPrefix(name, "gpt-5.4") {
|
||||
if strings.HasPrefix(name, "gpt-5.4-nano") {
|
||||
return 6.25, true
|
||||
}
|
||||
return 6, true
|
||||
}
|
||||
return 8, true
|
||||
|
||||
@@ -21,7 +21,7 @@ var defaultFetchSetting = FetchSetting{
|
||||
DomainList: []string{},
|
||||
IpList: []string{},
|
||||
AllowedPorts: []string{"80", "443", "8080", "8443"},
|
||||
ApplyIPFilterForDomain: false,
|
||||
ApplyIPFilterForDomain: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
Vendored
BIN
Binary file not shown.
|
After Width: | Height: | Size: 1.6 KiB |
Vendored
BIN
Binary file not shown.
|
After Width: | Height: | Size: 3.6 KiB |
Vendored
BIN
Binary file not shown.
|
After Width: | Height: | Size: 4.5 KiB |
@@ -56,7 +56,7 @@ const OAuth2Callback = (props) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (message === 'bind') {
|
||||
if (data?.action === 'bind') {
|
||||
showSuccess(t('绑定成功!'));
|
||||
navigate('/console/personal');
|
||||
} else {
|
||||
|
||||
@@ -221,23 +221,27 @@ const FooterBar = () => {
|
||||
return (
|
||||
<div className='w-full'>
|
||||
{footer ? (
|
||||
<div className='relative'>
|
||||
<div
|
||||
className='custom-footer'
|
||||
dangerouslySetInnerHTML={{ __html: footer }}
|
||||
></div>
|
||||
<div className='absolute bottom-2 right-4 text-xs !text-semi-color-text-2 opacity-70'>
|
||||
<span>{t('设计与开发由')} </span>
|
||||
<a
|
||||
href='https://github.com/QuantumNous/new-api'
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
className='!text-semi-color-primary font-medium'
|
||||
>
|
||||
New API
|
||||
</a>
|
||||
<footer className='relative h-auto py-4 px-6 md:px-24 w-full flex items-center justify-center overflow-hidden'>
|
||||
<div className='flex flex-col md:flex-row items-center justify-between w-full max-w-[1110px] gap-4'>
|
||||
<div
|
||||
className='custom-footer na-cb6feafeb3990c78 text-sm !text-semi-color-text-1'
|
||||
dangerouslySetInnerHTML={{ __html: footer }}
|
||||
></div>
|
||||
<div className='text-sm flex-shrink-0'>
|
||||
<span className='!text-semi-color-text-1'>
|
||||
{t('设计与开发由')}{' '}
|
||||
</span>
|
||||
<a
|
||||
href='https://github.com/QuantumNous/new-api'
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
className='!text-semi-color-primary font-medium'
|
||||
>
|
||||
New API
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
) : (
|
||||
customFooter
|
||||
)}
|
||||
|
||||
@@ -23,6 +23,7 @@ import SettingsGeneralPayment from '../../pages/Setting/Payment/SettingsGeneralP
|
||||
import SettingsPaymentGateway from '../../pages/Setting/Payment/SettingsPaymentGateway';
|
||||
import SettingsPaymentGatewayStripe from '../../pages/Setting/Payment/SettingsPaymentGatewayStripe';
|
||||
import SettingsPaymentGatewayCreem from '../../pages/Setting/Payment/SettingsPaymentGatewayCreem';
|
||||
import SettingsPaymentGatewayWaffo from '../../pages/Setting/Payment/SettingsPaymentGatewayWaffo';
|
||||
import { API, showError, toBoolean } from '../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -66,7 +67,6 @@ const PaymentSetting = () => {
|
||||
2,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('解析TopupGroupRatio出错:', error);
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
break;
|
||||
@@ -78,7 +78,6 @@ const PaymentSetting = () => {
|
||||
2,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('解析AmountOptions出错:', error);
|
||||
newInputs['AmountOptions'] = item.value;
|
||||
}
|
||||
break;
|
||||
@@ -90,7 +89,6 @@ const PaymentSetting = () => {
|
||||
2,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('解析AmountDiscount出错:', error);
|
||||
newInputs['AmountDiscount'] = item.value;
|
||||
}
|
||||
break;
|
||||
@@ -146,6 +144,9 @@ const PaymentSetting = () => {
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsPaymentGatewayCreem options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsPaymentGatewayWaffo options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
</Spin>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -306,9 +306,9 @@ const PersonalSetting = () => {
|
||||
|
||||
const bindWeChat = async () => {
|
||||
if (inputs.wechat_verification_code === '') return;
|
||||
const res = await API.get(
|
||||
`/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}`,
|
||||
);
|
||||
const res = await API.post('/api/oauth/wechat/bind', {
|
||||
code: inputs.wechat_verification_code,
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('微信账户绑定成功!'));
|
||||
@@ -378,9 +378,10 @@ const PersonalSetting = () => {
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
const res = await API.get(
|
||||
`/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}`,
|
||||
);
|
||||
const res = await API.post('/api/oauth/email/bind', {
|
||||
email: inputs.email,
|
||||
code: inputs.email_verification_code,
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('邮箱账户绑定成功!'));
|
||||
|
||||
@@ -108,7 +108,7 @@ const SystemSetting = () => {
|
||||
'fetch_setting.domain_list': [],
|
||||
'fetch_setting.ip_list': [],
|
||||
'fetch_setting.allowed_ports': [],
|
||||
'fetch_setting.apply_ip_filter_for_domain': false,
|
||||
'fetch_setting.apply_ip_filter_for_domain': true,
|
||||
});
|
||||
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
@@ -847,7 +847,7 @@ const SystemSetting = () => {
|
||||
}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
{t('对域名启用 IP 过滤(实验性)')}
|
||||
{t('对域名启用 IP 过滤(推荐开启)')}
|
||||
</Form.Checkbox>
|
||||
<Text strong>
|
||||
{t(domainFilterMode ? '域名白名单' : '域名黑名单')}
|
||||
|
||||
@@ -538,19 +538,24 @@ export const getChannelsColumns = ({
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
content={
|
||||
t('剩余额度') +
|
||||
': ' +
|
||||
renderQuotaWithAmount(record.balance) +
|
||||
t(',点击更新')
|
||||
record.type === 57
|
||||
? t('查看 Codex 帐号信息与用量')
|
||||
: t('剩余额度') +
|
||||
': ' +
|
||||
renderQuotaWithAmount(record.balance) +
|
||||
t(',点击更新')
|
||||
}
|
||||
>
|
||||
<Tag
|
||||
color='white'
|
||||
type='ghost'
|
||||
color={record.type === 57 ? 'light-blue' : 'white'}
|
||||
type={record.type === 57 ? 'light' : 'ghost'}
|
||||
shape='circle'
|
||||
className={record.type === 57 ? 'cursor-pointer' : ''}
|
||||
onClick={() => updateChannelBalance(record)}
|
||||
>
|
||||
{renderQuotaWithAmount(record.balance)}
|
||||
{record.type === 57
|
||||
? t('帐号信息')
|
||||
: renderQuotaWithAmount(record.balance)}
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
</Space>
|
||||
|
||||
@@ -22,9 +22,11 @@ import {
|
||||
Modal,
|
||||
Button,
|
||||
Progress,
|
||||
Tag,
|
||||
Typography,
|
||||
Spin,
|
||||
Tag,
|
||||
Descriptions,
|
||||
Collapse,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../../../helpers';
|
||||
|
||||
@@ -43,6 +45,68 @@ const pickStrokeColor = (percent) => {
|
||||
return '#3b82f6';
|
||||
};
|
||||
|
||||
const normalizePlanType = (value) => {
|
||||
if (value == null) return '';
|
||||
return String(value).trim().toLowerCase();
|
||||
};
|
||||
|
||||
const getWindowDurationSeconds = (windowData) => {
|
||||
const value = Number(windowData?.limit_window_seconds);
|
||||
if (!Number.isFinite(value) || value <= 0) return null;
|
||||
return value;
|
||||
};
|
||||
|
||||
const classifyWindowByDuration = (windowData) => {
|
||||
const seconds = getWindowDurationSeconds(windowData);
|
||||
if (seconds == null) return null;
|
||||
return seconds >= 24 * 60 * 60 ? 'weekly' : 'fiveHour';
|
||||
};
|
||||
|
||||
const resolveRateLimitWindows = (data) => {
|
||||
const rateLimit = data?.rate_limit ?? {};
|
||||
const primary = rateLimit?.primary_window ?? null;
|
||||
const secondary = rateLimit?.secondary_window ?? null;
|
||||
const windows = [primary, secondary].filter(Boolean);
|
||||
const planType = normalizePlanType(data?.plan_type ?? rateLimit?.plan_type);
|
||||
|
||||
let fiveHourWindow = null;
|
||||
let weeklyWindow = null;
|
||||
|
||||
for (const windowData of windows) {
|
||||
const bucket = classifyWindowByDuration(windowData);
|
||||
if (bucket === 'fiveHour' && !fiveHourWindow) {
|
||||
fiveHourWindow = windowData;
|
||||
continue;
|
||||
}
|
||||
if (bucket === 'weekly' && !weeklyWindow) {
|
||||
weeklyWindow = windowData;
|
||||
}
|
||||
}
|
||||
|
||||
if (planType === 'free') {
|
||||
if (!weeklyWindow) {
|
||||
weeklyWindow = primary ?? secondary ?? null;
|
||||
}
|
||||
return { fiveHourWindow: null, weeklyWindow };
|
||||
}
|
||||
|
||||
if (!fiveHourWindow && !weeklyWindow) {
|
||||
return {
|
||||
fiveHourWindow: primary ?? null,
|
||||
weeklyWindow: secondary ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
if (!fiveHourWindow) {
|
||||
fiveHourWindow = windows.find((windowData) => windowData !== weeklyWindow) ?? null;
|
||||
}
|
||||
if (!weeklyWindow) {
|
||||
weeklyWindow = windows.find((windowData) => windowData !== fiveHourWindow) ?? null;
|
||||
}
|
||||
|
||||
return { fiveHourWindow, weeklyWindow };
|
||||
};
|
||||
|
||||
const formatDurationSeconds = (seconds, t) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const s = Number(seconds);
|
||||
@@ -66,8 +130,93 @@ const formatUnixSeconds = (unixSeconds) => {
|
||||
}
|
||||
};
|
||||
|
||||
const getDisplayText = (value) => {
|
||||
if (value == null) return '';
|
||||
return String(value).trim();
|
||||
};
|
||||
|
||||
const formatAccountTypeLabel = (value, t) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const normalized = normalizePlanType(value);
|
||||
switch (normalized) {
|
||||
case 'free':
|
||||
return 'Free';
|
||||
case 'plus':
|
||||
return 'Plus';
|
||||
case 'pro':
|
||||
return 'Pro';
|
||||
case 'team':
|
||||
return 'Team';
|
||||
case 'enterprise':
|
||||
return 'Enterprise';
|
||||
default:
|
||||
return getDisplayText(value) || tt('未识别');
|
||||
}
|
||||
};
|
||||
|
||||
const getAccountTypeTagColor = (value) => {
|
||||
const normalized = normalizePlanType(value);
|
||||
switch (normalized) {
|
||||
case 'enterprise':
|
||||
return 'green';
|
||||
case 'team':
|
||||
return 'cyan';
|
||||
case 'pro':
|
||||
return 'blue';
|
||||
case 'plus':
|
||||
return 'violet';
|
||||
case 'free':
|
||||
return 'amber';
|
||||
default:
|
||||
return 'grey';
|
||||
}
|
||||
};
|
||||
|
||||
const resolveUsageStatusTag = (t, rateLimit) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
if (!rateLimit || Object.keys(rateLimit).length === 0) {
|
||||
return <Tag color='grey'>{tt('待确认')}</Tag>;
|
||||
}
|
||||
if (rateLimit?.allowed && !rateLimit?.limit_reached) {
|
||||
return <Tag color='green'>{tt('可用')}</Tag>;
|
||||
}
|
||||
return <Tag color='red'>{tt('受限')}</Tag>;
|
||||
};
|
||||
|
||||
const AccountInfoValue = ({ t, value, onCopy, monospace = false }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const text = getDisplayText(value);
|
||||
const hasValue = text !== '';
|
||||
|
||||
return (
|
||||
<div className='flex min-w-0 items-start justify-between gap-2'>
|
||||
<div
|
||||
className={`min-w-0 flex-1 break-all text-xs leading-5 text-semi-color-text-1 ${
|
||||
monospace ? 'font-mono' : ''
|
||||
}`}
|
||||
>
|
||||
{hasValue ? text : '-'}
|
||||
</div>
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
theme='borderless'
|
||||
className='shrink-0 px-1 text-xs'
|
||||
disabled={!hasValue}
|
||||
onClick={() => onCopy?.(text)}
|
||||
>
|
||||
{tt('复制')}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const RateLimitWindowCard = ({ t, title, windowData }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const hasWindowData =
|
||||
!!windowData &&
|
||||
typeof windowData === 'object' &&
|
||||
Object.keys(windowData).length > 0;
|
||||
const percent = clampPercent(windowData?.used_percent ?? 0);
|
||||
const resetAt = windowData?.reset_at;
|
||||
const resetAfterSeconds = windowData?.reset_after_seconds;
|
||||
@@ -83,26 +232,30 @@ const RateLimitWindowCard = ({ t, title, windowData }) => {
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div className='mt-2'>
|
||||
<Progress
|
||||
percent={percent}
|
||||
stroke={pickStrokeColor(percent)}
|
||||
showInfo={true}
|
||||
/>
|
||||
</div>
|
||||
{hasWindowData ? (
|
||||
<div className='mt-2'>
|
||||
<Progress
|
||||
percent={percent}
|
||||
stroke={pickStrokeColor(percent)}
|
||||
showInfo={true}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className='mt-3 text-sm text-semi-color-text-2'>-</div>
|
||||
)}
|
||||
|
||||
<div className='mt-1 flex flex-wrap items-center gap-2 text-xs text-semi-color-text-2'>
|
||||
<div>
|
||||
{tt('已使用:')}
|
||||
{percent}%
|
||||
{hasWindowData ? `${percent}%` : '-'}
|
||||
</div>
|
||||
<div>
|
||||
{tt('距离重置:')}
|
||||
{formatDurationSeconds(resetAfterSeconds, tt)}
|
||||
{hasWindowData ? formatDurationSeconds(resetAfterSeconds, tt) : '-'}
|
||||
</div>
|
||||
<div>
|
||||
{tt('窗口:')}
|
||||
{formatDurationSeconds(limitWindowSeconds, tt)}
|
||||
{hasWindowData ? formatDurationSeconds(limitWindowSeconds, tt) : '-'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -111,84 +264,139 @@ const RateLimitWindowCard = ({ t, title, windowData }) => {
|
||||
|
||||
const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const [showRawJson, setShowRawJson] = useState(false);
|
||||
const data = payload?.data ?? null;
|
||||
const rateLimit = data?.rate_limit ?? {};
|
||||
|
||||
const primary = rateLimit?.primary_window ?? null;
|
||||
const secondary = rateLimit?.secondary_window ?? null;
|
||||
|
||||
const allowed = !!rateLimit?.allowed;
|
||||
const limitReached = !!rateLimit?.limit_reached;
|
||||
const { fiveHourWindow, weeklyWindow } = resolveRateLimitWindows(data);
|
||||
const upstreamStatus = payload?.upstream_status;
|
||||
|
||||
const statusTag =
|
||||
allowed && !limitReached ? (
|
||||
<Tag color='green'>{tt('可用')}</Tag>
|
||||
) : (
|
||||
<Tag color='red'>{tt('受限')}</Tag>
|
||||
);
|
||||
const accountType = data?.plan_type ?? rateLimit?.plan_type;
|
||||
const accountTypeLabel = formatAccountTypeLabel(accountType, tt);
|
||||
const accountTypeTagColor = getAccountTypeTagColor(accountType);
|
||||
const statusTag = resolveUsageStatusTag(tt, rateLimit);
|
||||
const userId = data?.user_id;
|
||||
const email = data?.email;
|
||||
const accountId = data?.account_id;
|
||||
const errorMessage =
|
||||
payload?.success === false ? getDisplayText(payload?.message) || tt('获取用量失败') : '';
|
||||
|
||||
const rawText =
|
||||
typeof data === 'string' ? data : JSON.stringify(data ?? payload, null, 2);
|
||||
|
||||
return (
|
||||
<div className='flex flex-col gap-3'>
|
||||
<div className='flex flex-wrap items-center justify-between gap-2'>
|
||||
<Text type='tertiary' size='small'>
|
||||
{tt('渠道:')}
|
||||
{record?.name || '-'} ({tt('编号:')}
|
||||
{record?.id || '-'})
|
||||
</Text>
|
||||
<div className='flex items-center gap-2'>
|
||||
{statusTag}
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
theme='borderless'
|
||||
onClick={onRefresh}
|
||||
>
|
||||
<div className='flex flex-col gap-4'>
|
||||
{errorMessage && (
|
||||
<div className='rounded-xl border border-red-200 bg-red-50 px-4 py-3 text-sm text-red-700'>
|
||||
{errorMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className='rounded-xl border border-semi-color-border bg-semi-color-bg-0 p-3'>
|
||||
<div className='flex flex-wrap items-start justify-between gap-2'>
|
||||
<div className='min-w-0'>
|
||||
<div className='text-xs font-medium text-semi-color-text-2'>
|
||||
{tt('Codex 帐号')}
|
||||
</div>
|
||||
<div className='mt-2 flex flex-wrap items-center gap-2'>
|
||||
<Tag
|
||||
color={accountTypeTagColor}
|
||||
type='light'
|
||||
shape='circle'
|
||||
size='large'
|
||||
className='font-semibold'
|
||||
>
|
||||
{accountTypeLabel}
|
||||
</Tag>
|
||||
{statusTag}
|
||||
<Tag color='grey' type='light' shape='circle'>
|
||||
{tt('上游状态码:')}
|
||||
{upstreamStatus ?? '-'}
|
||||
</Tag>
|
||||
</div>
|
||||
</div>
|
||||
<Button size='small' type='tertiary' theme='outline' onClick={onRefresh}>
|
||||
{tt('刷新')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className='mt-2 rounded-lg bg-semi-color-fill-0 px-3 py-2'>
|
||||
<Descriptions>
|
||||
<Descriptions.Item itemKey='User ID'>
|
||||
<AccountInfoValue
|
||||
t={tt}
|
||||
value={userId}
|
||||
onCopy={onCopy}
|
||||
monospace={true}
|
||||
/>
|
||||
</Descriptions.Item>
|
||||
<Descriptions.Item itemKey={tt('邮箱')}>
|
||||
<AccountInfoValue t={tt} value={email} onCopy={onCopy} />
|
||||
</Descriptions.Item>
|
||||
<Descriptions.Item itemKey='Account ID'>
|
||||
<AccountInfoValue
|
||||
t={tt}
|
||||
value={accountId}
|
||||
onCopy={onCopy}
|
||||
monospace={true}
|
||||
/>
|
||||
</Descriptions.Item>
|
||||
</Descriptions>
|
||||
</div>
|
||||
|
||||
<div className='mt-2 text-xs text-semi-color-text-2'>
|
||||
{tt('渠道:')}
|
||||
{record?.name || '-'} ({tt('编号:')}
|
||||
{record?.id || '-'})
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex flex-wrap items-center justify-between gap-2'>
|
||||
<Text type='tertiary' size='small'>
|
||||
{tt('上游状态码:')}
|
||||
{upstreamStatus ?? '-'}
|
||||
</Text>
|
||||
<div>
|
||||
<div className='mb-2'>
|
||||
<div className='text-sm font-semibold text-semi-color-text-0'>
|
||||
{tt('额度窗口')}
|
||||
</div>
|
||||
<Text type='tertiary' size='small'>
|
||||
{tt('用于观察当前帐号在 Codex 上游的限额使用情况')}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='grid grid-cols-1 gap-3 md:grid-cols-2'>
|
||||
<RateLimitWindowCard
|
||||
t={tt}
|
||||
title={tt('5小时窗口')}
|
||||
windowData={primary}
|
||||
windowData={fiveHourWindow}
|
||||
/>
|
||||
<RateLimitWindowCard
|
||||
t={tt}
|
||||
title={tt('每周窗口')}
|
||||
windowData={secondary}
|
||||
windowData={weeklyWindow}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className='mb-1 flex items-center justify-between gap-2'>
|
||||
<div className='text-sm font-medium'>{tt('原始 JSON')}</div>
|
||||
<Button
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
onClick={() => onCopy?.(rawText)}
|
||||
disabled={!rawText}
|
||||
>
|
||||
{tt('复制')}
|
||||
</Button>
|
||||
</div>
|
||||
<pre className='max-h-[50vh] overflow-auto rounded-lg bg-semi-color-fill-0 p-3 text-xs text-semi-color-text-0'>
|
||||
{rawText}
|
||||
</pre>
|
||||
</div>
|
||||
<Collapse
|
||||
activeKey={showRawJson ? ['raw-json'] : []}
|
||||
onChange={(activeKey) => {
|
||||
const keys = Array.isArray(activeKey) ? activeKey : [activeKey];
|
||||
setShowRawJson(keys.includes('raw-json'));
|
||||
}}
|
||||
>
|
||||
<Collapse.Panel header={tt('原始 JSON')} itemKey='raw-json'>
|
||||
<div className='mb-2 flex justify-end'>
|
||||
<Button
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
onClick={() => onCopy?.(rawText)}
|
||||
disabled={!rawText}
|
||||
>
|
||||
{tt('复制')}
|
||||
</Button>
|
||||
</div>
|
||||
<pre className='max-h-[50vh] overflow-y-auto rounded-lg bg-semi-color-fill-0 p-3 text-xs text-semi-color-text-0'>
|
||||
{rawText}
|
||||
</pre>
|
||||
</Collapse.Panel>
|
||||
</Collapse>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -283,7 +491,7 @@ export const openCodexUsageModal = ({ t, record, payload, onCopy }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
|
||||
Modal.info({
|
||||
title: tt('Codex 用量'),
|
||||
title: tt('Codex 帐号与用量'),
|
||||
centered: true,
|
||||
width: 900,
|
||||
style: { maxWidth: '95vw' },
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user