Compare commits
90 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dc83c4af31 | |||
| 960bf9c49e | |||
| 12a48c620e | |||
| eacc245bad | |||
| 03758a4a85 | |||
| 8fc0eb78e2 | |||
| 427fb7eaf6 | |||
| 4cd0e3651d | |||
| 677d02f2ab | |||
| c583382af7 | |||
| b1950655e7 | |||
| 22cbbcab4c | |||
| 873067f7d1 | |||
| 82c2008d2c | |||
| 495e4f5e17 | |||
| 23fde25b15 | |||
| ac90d9f185 | |||
| bb5b9eaca2 | |||
| c9611c493f | |||
| 1baf4a6337 | |||
| 9816ad87e3 | |||
| 50249f581c | |||
| 0193018af6 | |||
| f449e06b9d | |||
| 79527c0ab1 | |||
| 41cd051ea9 | |||
| c04f82bfb5 | |||
| dafc7618c3 | |||
| 22692b3f87 | |||
| d36e892905 | |||
| 3cd1ba4673 | |||
| b7c0f754ad | |||
| a706f00287 | |||
| 7efb1922fe | |||
| 89fe99f3bd | |||
| e5b5331d3b | |||
| 18373c6eac | |||
| 5b47011e08 | |||
| 314ae40820 | |||
| ab99c30884 | |||
| 670abee2f0 | |||
| 8bb9a42f68 | |||
| d22f889e5d | |||
| 3734059da7 | |||
| 26ce873f8b | |||
| e099117c61 | |||
| 310d618a16 | |||
| 20399d3c8f | |||
| 53aeee4ff7 | |||
| 5238f279db | |||
| 5402bf417d | |||
| c766913baf | |||
| 40dc43f44e | |||
| 263b9bc695 | |||
| b2dd4acc9f | |||
| 4e492b26f6 | |||
| 82b750398c | |||
| fbf235d222 | |||
| 62b9aaa520 | |||
| 814a3f5124 | |||
| 22b6b16702 | |||
| 6154b8e3cd | |||
| ff66288e3a | |||
| 926e1781dd | |||
| d4a470a638 | |||
| 9f61407bf0 | |||
| dbf900a531 | |||
| 7399e4721b | |||
| a5e20269dd | |||
| 9ae9040b3c | |||
| 0191a68d4e | |||
| 16221f8279 | |||
| 763c3ff709 | |||
| c667e4706a | |||
| 216b94dac0 | |||
| 49eb533aaf | |||
| 7693edae53 | |||
| ded4a124e2 | |||
| d6982c8182 | |||
| 9ecad90652 | |||
| 929b5060ea | |||
| 755ece2f01 | |||
| f40eb4e5d2 | |||
| 45f65c297b | |||
| 6c074ef897 | |||
| eff51857d0 | |||
| e9f8f62796 | |||
| 5fe8e98eeb | |||
| e520977efc | |||
| b09337e6ed |
@@ -27,9 +27,10 @@ jobs:
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Check out (shallow)
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -46,16 +47,16 @@ jobs:
|
||||
run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -63,14 +64,15 @@ jobs:
|
||||
|
||||
- name: Extract metadata (labels)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ env.GHCR_REPOSITORY }}
|
||||
|
||||
- name: Build & push single-arch (to both registries)
|
||||
uses: docker/build-push-action@v6
|
||||
id: build
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -83,8 +85,25 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@398d4b0eeef1380460a10c8013a76f728fb906ac # v3
|
||||
|
||||
- name: Sign image with cosign
|
||||
run: |
|
||||
cosign sign --yes calciumion/new-api@${{ steps.build.outputs.digest }}
|
||||
cosign sign --yes ghcr.io/${{ env.GHCR_REPOSITORY }}@${{ steps.build.outputs.digest }}
|
||||
|
||||
- name: Output digest
|
||||
run: |
|
||||
echo "### Docker Image Digest (${{ matrix.arch }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:alpha-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "ghcr.io/${{ env.GHCR_REPOSITORY }}:alpha-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.build.outputs.digest }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
create_manifests:
|
||||
name: Create multi-arch manifests (Docker Hub + GHCR)
|
||||
@@ -95,7 +114,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out (shallow)
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -110,7 +129,7 @@ jobs:
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -130,7 +149,7 @@ jobs:
|
||||
calciumion/new-api:${VERSION}-arm64
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -149,3 +168,12 @@ jobs:
|
||||
-t ghcr.io/${GHCR_REPOSITORY}:${VERSION} \
|
||||
ghcr.io/${GHCR_REPOSITORY}:${VERSION}-amd64 \
|
||||
ghcr.io/${GHCR_REPOSITORY}:${VERSION}-arm64
|
||||
|
||||
- name: Output manifest digest
|
||||
run: |
|
||||
echo "### Multi-arch Manifest Digests" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect calciumion/new-api:alpha >> $GITHUB_STEP_SUMMARY
|
||||
echo "---" >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect ghcr.io/${GHCR_REPOSITORY}:alpha >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -30,10 +30,11 @@ jobs:
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: ${{ github.event_name == 'workflow_dispatch' && 0 || 1 }}
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
@@ -59,16 +60,16 @@ jobs:
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@v3
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
@@ -76,14 +77,15 @@ jobs:
|
||||
|
||||
- name: Extract metadata (labels)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
# ghcr.io/${{ env.GHCR_REPOSITORY }}
|
||||
|
||||
- name: Build & push single-arch (to both registries)
|
||||
uses: docker/build-push-action@v6
|
||||
id: build
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -96,8 +98,22 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@398d4b0eeef1380460a10c8013a76f728fb906ac # v3
|
||||
|
||||
- name: Sign image with cosign
|
||||
run: cosign sign --yes calciumion/new-api@${{ steps.build.outputs.digest }}
|
||||
|
||||
- name: Output digest
|
||||
run: |
|
||||
echo "### Docker Image Digest (${{ matrix.arch }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:${{ env.TAG }}-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.build.outputs.digest }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
create_manifests:
|
||||
name: Create multi-arch manifests (Docker Hub)
|
||||
@@ -117,7 +133,7 @@ jobs:
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -136,9 +152,16 @@ jobs:
|
||||
calciumion/new-api:latest-amd64 \
|
||||
calciumion/new-api:latest-arm64
|
||||
|
||||
- name: Output manifest digest
|
||||
run: |
|
||||
echo "### Multi-arch Manifest" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect calciumion/new-api:${TAG} >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# ---- GHCR ----
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@v3
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
|
||||
@@ -19,14 +19,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend (amd64)
|
||||
@@ -50,12 +50,16 @@ jobs:
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
|
||||
- name: Generate checksums
|
||||
run: sha256sum new-api-* > checksums-linux.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
new-api-*
|
||||
checksums-linux.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -64,14 +68,14 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -84,18 +88,23 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-X 'new-api/common.Version=$VERSION'" -o new-api-macos-$VERSION
|
||||
- name: Generate checksums
|
||||
run: shasum -a 256 new-api-macos-* > checksums-macos.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-macos-*
|
||||
files: |
|
||||
new-api-macos-*
|
||||
checksums-macos.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -107,14 +116,14 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -126,17 +135,22 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION'" -o new-api-$VERSION.exe
|
||||
- name: Generate checksums
|
||||
run: sha256sum new-api-*.exe > checksums-windows.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-*.exe
|
||||
files: |
|
||||
new-api-*.exe
|
||||
checksums-windows.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
+3
-3
@@ -1,4 +1,4 @@
|
||||
FROM oven/bun:latest AS builder
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
@@ -8,7 +8,7 @@ COPY ./web .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
|
||||
FROM golang:alpine AS builder2
|
||||
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
|
||||
ARG TARGETOS
|
||||
@@ -25,7 +25,7 @@ COPY . .
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates tzdata libasan8 wget \
|
||||
|
||||
+11
-8
@@ -70,17 +70,20 @@
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank">
|
||||
<img src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="80" />
|
||||
</a>
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://github.com/iOfficeAI/AionUi/" target="_blank">
|
||||
<img src="./docs/images/aionui.png" alt="Aion UI" height="80" />
|
||||
</a><!--
|
||||
--><a href="https://bda.pku.edu.cn/" target="_blank">
|
||||
<img src="./docs/images/pku.png" alt="北京大學" height="80" />
|
||||
</a>
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank">
|
||||
<img src="./docs/images/ucloud.png" alt="UCloud 優刻得" height="80" />
|
||||
</a>
|
||||
<a href="https://www.aliyun.com/" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://www.aliyun.com/" target="_blank">
|
||||
<img src="./docs/images/aliyun.png" alt="阿里雲" height="80" />
|
||||
</a>
|
||||
<a href="https://io.net/" target="_blank">
|
||||
</a><!--
|
||||
--><a href="https://io.net/" target="_blank">
|
||||
<img src="./docs/images/io-net.png" alt="IO.NET" height="80" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -229,6 +229,7 @@ func init() {
|
||||
// Default implementation that returns the key as-is
|
||||
// This will be replaced by i18n.T during i18n initialization
|
||||
TranslateMessage = func(c *gin.Context, key string, args ...map[string]any) string {
|
||||
c.Header("X-Translate-id", "d5e7afdfc7f03414b941f9c1e7096be9966510e7")
|
||||
return key
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -169,10 +170,7 @@ func collectPendingUpstreamModelChangesFromModels(
|
||||
upstreamSet[modelName] = struct{}{}
|
||||
}
|
||||
|
||||
ignoredSet := make(map[string]struct{})
|
||||
for _, modelName := range normalizeModelNames(ignoredModels) {
|
||||
ignoredSet[modelName] = struct{}{}
|
||||
}
|
||||
normalizedIgnoredModels := normalizeModelNames(ignoredModels)
|
||||
|
||||
redirectSourceSet := make(map[string]struct{}, len(modelMapping))
|
||||
redirectTargetSet := make(map[string]struct{}, len(modelMapping))
|
||||
@@ -193,7 +191,13 @@ func collectPendingUpstreamModelChangesFromModels(
|
||||
if _, ok := coveredUpstreamSet[modelName]; ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := ignoredSet[modelName]; ok {
|
||||
if lo.ContainsBy(normalizedIgnoredModels, func(ignoredModel string) bool {
|
||||
if regexBody, ok := strings.CutPrefix(ignoredModel, "regex:"); ok {
|
||||
matched, err := regexp.MatchString(strings.TrimSpace(regexBody), modelName)
|
||||
return err == nil && matched
|
||||
}
|
||||
return ignoredModel == modelName
|
||||
}) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -111,6 +111,18 @@ func TestCollectPendingUpstreamModelChangesFromModels_WithModelMapping(t *testin
|
||||
require.Equal(t, []string{"stale-model"}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestCollectPendingUpstreamModelChangesFromModels_WithIgnoredRegexPatterns(t *testing.T) {
|
||||
pendingAddModels, pendingRemoveModels := collectPendingUpstreamModelChangesFromModels(
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "claude-3-5-sonnet", "sora-video", "gpt-4.1"},
|
||||
[]string{"regex:^sora-.*$", "gpt-4.1"},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, []string{"claude-3-5-sonnet"}, pendingAddModels)
|
||||
require.Equal(t, []string{}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestBuildUpstreamModelUpdateTaskNotificationContent_OmitOverflowDetails(t *testing.T) {
|
||||
channelSummaries := make([]upstreamModelUpdateChannelSummary, 0, 12)
|
||||
for i := 0; i < 12; i++ {
|
||||
|
||||
+14
-21
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
@@ -116,7 +117,6 @@ func GetStatus(c *gin.Context) {
|
||||
"user_agreement_enabled": legalSetting.UserAgreement != "",
|
||||
"privacy_policy_enabled": legalSetting.PrivacyPolicy != "",
|
||||
"checkin_enabled": operation_setting.GetCheckinSetting().Enabled,
|
||||
"_qn": "new-api",
|
||||
}
|
||||
|
||||
// 根据启用状态注入可选内容
|
||||
@@ -308,31 +308,24 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if !model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该邮箱地址未注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
if model.IsEmailAlreadyTaken(email) {
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("failed to send password reset email to %s: %s", email, err.Error()))
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type PasswordResetRequest struct {
|
||||
|
||||
+3
-1
@@ -190,7 +190,9 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
}
|
||||
}
|
||||
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, gin.H{
|
||||
"action": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
// findOrCreateOAuthUser finds existing user or creates new user
|
||||
|
||||
@@ -46,7 +46,7 @@ func GetPricing(c *gin.Context) {
|
||||
"usable_group": usableGroup,
|
||||
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||
"auto_groups": service.GetUserAutoGroup(group),
|
||||
"_": "a42d372ccf0b5dd13ecf71203521f9d2",
|
||||
"pricing_version": "a42d372ccf0b5dd13ecf71203521f9d2",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -581,7 +581,7 @@ func RelayTask(c *gin.Context) {
|
||||
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||
OtherRatios: relayInfo.PriceData.OtherRatios,
|
||||
OriginModelName: relayInfo.OriginModelName,
|
||||
PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
|
||||
PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName) || relayInfo.PriceData.UsePrice,
|
||||
}
|
||||
task.Quota = result.Quota
|
||||
task.Data = result.TaskData
|
||||
|
||||
@@ -334,3 +334,26 @@ func DeleteTokenBatch(c *gin.Context) {
|
||||
"data": count,
|
||||
})
|
||||
}
|
||||
|
||||
func GetTokenKeysBatch(c *gin.Context) {
|
||||
tokenBatch := TokenBatch{}
|
||||
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
|
||||
return
|
||||
}
|
||||
if len(tokenBatch.Ids) > 100 {
|
||||
common.ApiErrorI18n(c, i18n.MsgBatchTooMany, map[string]any{"Max": 100})
|
||||
return
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
tokens, err := model.GetTokenKeysByIds(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
keysMap := make(map[int]string)
|
||||
for _, t := range tokens {
|
||||
keysMap[t.Id] = t.GetFullKey()
|
||||
}
|
||||
common.ApiSuccess(c, gin.H{"keys": keysMap})
|
||||
}
|
||||
|
||||
+12
-2
@@ -925,9 +925,19 @@ func ManageUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type emailBindRequest struct {
|
||||
Email string `json:"email"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
func EmailBind(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
code := c.Query("code")
|
||||
var req emailBindRequest
|
||||
if err := common.DecodeJson(c.Request.Body, &req); err != nil {
|
||||
common.ApiError(c, errors.New("invalid request body"))
|
||||
return
|
||||
}
|
||||
email := req.Email
|
||||
code := req.Code
|
||||
if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) {
|
||||
common.ApiErrorI18n(c, i18n.MsgUserVerificationCodeError)
|
||||
return
|
||||
|
||||
@@ -10,10 +10,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -127,6 +129,13 @@ func VideoProxy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(videoURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Video URL blocked for task %s: %v", taskID, err))
|
||||
videoProxyError(c, http.StatusForbidden, "server_error", fmt.Sprintf("request blocked: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
req.URL, err = url.Parse(videoURL)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
|
||||
|
||||
+15
-2
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -25,7 +26,7 @@ func getWeChatIdByCode(code string) (string, error) {
|
||||
if code == "" {
|
||||
return "", errors.New("无效的参数")
|
||||
}
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, url.QueryEscape(code)), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -121,6 +122,10 @@ func WeChatAuth(c *gin.Context) {
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
type wechatBindRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
func WeChatBind(c *gin.Context) {
|
||||
if !common.WeChatAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -129,7 +134,15 @@ func WeChatBind(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
var req wechatBindRequest
|
||||
if err := common.DecodeJson(c.Request.Body, &req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的请求",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := req.Code
|
||||
wechatId, err := getWeChatIdByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
+24
-30
@@ -98,6 +98,20 @@ func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
|
||||
return mediaContent
|
||||
}
|
||||
|
||||
func (m *ClaudeMediaMessage) ToFileSource() types.FileSource {
|
||||
if m.Source == nil {
|
||||
return nil
|
||||
}
|
||||
data := m.Source.Url
|
||||
if data == "" {
|
||||
data = common.Interface2String(m.Source.Data)
|
||||
}
|
||||
if data == "" {
|
||||
return nil
|
||||
}
|
||||
return types.NewFileSourceFromData(data, m.Source.MediaType)
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
@@ -223,14 +237,6 @@ type OutputConfigForEffort struct {
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// createClaudeFileSource 根据数据内容创建正确类型的 FileSource
|
||||
func createClaudeFileSource(data string) *types.FileSource {
|
||||
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
|
||||
return types.NewURLFileSource(data)
|
||||
}
|
||||
return types.NewBase64FileSource(data, "")
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
maxTokens := 0
|
||||
if c.MaxTokens != nil {
|
||||
@@ -258,17 +264,11 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
case "text":
|
||||
texts = append(texts, media.GetText())
|
||||
case "image":
|
||||
if media.Source != nil {
|
||||
data := media.Source.Url
|
||||
if data == "" {
|
||||
data = common.Interface2String(media.Source.Data)
|
||||
}
|
||||
if data != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Source: createClaudeFileSource(data),
|
||||
})
|
||||
}
|
||||
if source := media.ToFileSource(); source != nil {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,17 +293,11 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
case "text":
|
||||
texts = append(texts, media.GetText())
|
||||
case "image":
|
||||
if media.Source != nil {
|
||||
data := media.Source.Url
|
||||
if data == "" {
|
||||
data = common.Interface2String(media.Source.Data)
|
||||
}
|
||||
if data != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Source: createClaudeFileSource(data),
|
||||
})
|
||||
}
|
||||
if source := media.ToFileSource(); source != nil {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
case "tool_use":
|
||||
if media.Name != "" {
|
||||
|
||||
+13
-11
@@ -64,14 +64,6 @@ type LatLng struct {
|
||||
Longitude *float64 `json:"longitude,omitempty"`
|
||||
}
|
||||
|
||||
// createGeminiFileSource 根据数据内容创建正确类型的 FileSource
|
||||
func createGeminiFileSource(data string, mimeType string) *types.FileSource {
|
||||
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
|
||||
return types.NewURLFileSource(data)
|
||||
}
|
||||
return types.NewBase64FileSource(data, mimeType)
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
var files []*types.FileMeta = make([]*types.FileMeta, 0)
|
||||
|
||||
@@ -87,9 +79,8 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
if source := part.InlineData.ToFileSource(); source != nil {
|
||||
mimeType := part.InlineData.MimeType
|
||||
source := createGeminiFileSource(part.InlineData.Data, mimeType)
|
||||
var fileType types.FileType
|
||||
if strings.HasPrefix(mimeType, "image/") {
|
||||
fileType = types.FileTypeImage
|
||||
@@ -103,7 +94,6 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
files = append(files, &types.FileMeta{
|
||||
FileType: fileType,
|
||||
Source: source,
|
||||
MimeType: mimeType,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -121,6 +111,11 @@ func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
|
||||
if c.Query("alt") == "sse" {
|
||||
return true
|
||||
}
|
||||
// Native Gemini API uses URL action to indicate streaming:
|
||||
// /v1beta/models/{model}:streamGenerateContent
|
||||
if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -210,6 +205,13 @@ type GeminiInlineData struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
func (d *GeminiInlineData) ToFileSource() types.FileSource {
|
||||
if d == nil || d.Data == "" {
|
||||
return nil
|
||||
}
|
||||
return types.NewFileSourceFromData(d.Data, d.MimeType)
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
|
||||
func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiInlineData // Use type alias to avoid recursion
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGeminiChatRequest_IsStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
query string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "streamGenerateContent without alt=sse",
|
||||
path: "/v1beta/models/gemini-2.0-flash:streamGenerateContent",
|
||||
query: "key=sk-xxx",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "streamGenerateContent with alt=sse",
|
||||
path: "/v1beta/models/gemini-2.0-flash:streamGenerateContent",
|
||||
query: "alt=sse&key=sk-xxx",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "generateContent without alt=sse",
|
||||
path: "/v1beta/models/gemini-2.0-flash:generateContent",
|
||||
query: "key=sk-xxx",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "generateContent with alt=sse",
|
||||
path: "/v1beta/models/gemini-2.0-flash:generateContent",
|
||||
query: "alt=sse",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GenerateContent capitalized",
|
||||
path: "/v1beta/models/gemini-2.0-flash:GenerateContent",
|
||||
query: "key=sk-xxx",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "embedding path",
|
||||
path: "/v1beta/models/gemini-2.0-flash:embedContent",
|
||||
query: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
url := tt.path
|
||||
if tt.query != "" {
|
||||
url += "?" + tt.query
|
||||
}
|
||||
c.Request, _ = http.NewRequest("POST", url, nil)
|
||||
|
||||
req := &GeminiChatRequest{}
|
||||
assert.Equal(t, tt.expected, req.IsStream(c))
|
||||
})
|
||||
}
|
||||
}
|
||||
+5
-6
@@ -148,15 +148,14 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
}
|
||||
|
||||
// not support token count for dalle
|
||||
n := uint(1)
|
||||
if i.N != nil {
|
||||
n = *i.N
|
||||
}
|
||||
// n is NOT included here; it is handled via OtherRatio("n") in
|
||||
// image_handler.go (default) or channel adaptors (actual count).
|
||||
// Including n here caused double-counting for channels that also
|
||||
// set OtherRatio("n") (e.g. Ali/Bailian).
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: i.Prompt,
|
||||
MaxTokens: 1584,
|
||||
ImagePriceRatio: sizeRatio * qualityRatio * float64(n),
|
||||
ImagePriceRatio: sizeRatio * qualityRatio,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+54
-48
@@ -108,14 +108,6 @@ type GeneralOpenAIRequest struct {
|
||||
ReasoningSplit json.RawMessage `json:"reasoning_split,omitempty"`
|
||||
}
|
||||
|
||||
// createFileSource 根据数据内容创建正确类型的 FileSource
|
||||
func createFileSource(data string) *types.FileSource {
|
||||
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
|
||||
return types.NewURLFileSource(data)
|
||||
}
|
||||
return types.NewBase64FileSource(data, "")
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
var tokenCountMeta types.TokenCountMeta
|
||||
var texts = make([]string, 0)
|
||||
@@ -159,44 +151,24 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == ContentTypeImageURL {
|
||||
imageUrl := m.GetImageMedia()
|
||||
if imageUrl != nil && imageUrl.Url != "" {
|
||||
source := createFileSource(imageUrl.Url)
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Source: source,
|
||||
Detail: imageUrl.Detail,
|
||||
})
|
||||
source := m.ToFileSource()
|
||||
if source != nil {
|
||||
meta := &types.FileMeta{Source: source}
|
||||
switch m.Type {
|
||||
case ContentTypeImageURL:
|
||||
meta.FileType = types.FileTypeImage
|
||||
if img := m.GetImageMedia(); img != nil {
|
||||
meta.Detail = img.Detail
|
||||
}
|
||||
case ContentTypeInputAudio:
|
||||
meta.FileType = types.FileTypeAudio
|
||||
case ContentTypeFile:
|
||||
meta.FileType = types.FileTypeFile
|
||||
case ContentTypeVideoUrl:
|
||||
meta.FileType = types.FileTypeVideo
|
||||
}
|
||||
} else if m.Type == ContentTypeInputAudio {
|
||||
inputAudio := m.GetInputAudio()
|
||||
if inputAudio != nil && inputAudio.Data != "" {
|
||||
source := createFileSource(inputAudio.Data)
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeAudio,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
} else if m.Type == ContentTypeFile {
|
||||
file := m.GetFile()
|
||||
if file != nil && file.FileData != "" {
|
||||
source := createFileSource(file.FileData)
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeFile,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
} else if m.Type == ContentTypeVideoUrl {
|
||||
videoUrl := m.GetVideoUrl()
|
||||
if videoUrl != nil && videoUrl.Url != "" {
|
||||
source := createFileSource(videoUrl.Url)
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeVideo,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
fileMeta = append(fileMeta, meta)
|
||||
} else if m.Type == ContentTypeText {
|
||||
texts = append(texts, m.Text)
|
||||
}
|
||||
}
|
||||
@@ -391,9 +363,43 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MediaContent) ToFileSource() types.FileSource {
|
||||
switch m.Type {
|
||||
case ContentTypeImageURL:
|
||||
img := m.GetImageMedia()
|
||||
if img == nil || img.Url == "" {
|
||||
return nil
|
||||
}
|
||||
return types.NewFileSourceFromData(img.Url, img.MimeType)
|
||||
case ContentTypeInputAudio:
|
||||
audio := m.GetInputAudio()
|
||||
if audio == nil || audio.Data == "" {
|
||||
return nil
|
||||
}
|
||||
mimeType := ""
|
||||
if audio.Format != "" {
|
||||
mimeType = "audio/" + audio.Format
|
||||
}
|
||||
return types.NewFileSourceFromData(audio.Data, mimeType)
|
||||
case ContentTypeFile:
|
||||
file := m.GetFile()
|
||||
if file == nil || file.FileData == "" {
|
||||
return nil
|
||||
}
|
||||
return types.NewFileSourceFromData(file.FileData, "")
|
||||
case ContentTypeVideoUrl:
|
||||
video := m.GetVideoUrl()
|
||||
if video == nil || video.Url == "" {
|
||||
return nil
|
||||
}
|
||||
return types.NewFileSourceFromData(video.Url, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type MessageImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
MimeType string
|
||||
}
|
||||
|
||||
@@ -865,7 +871,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
if input.ImageUrl != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Source: createFileSource(input.ImageUrl),
|
||||
Source: types.NewFileSourceFromData(input.ImageUrl, ""),
|
||||
Detail: input.Detail,
|
||||
})
|
||||
}
|
||||
@@ -873,7 +879,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
if input.FileUrl != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeFile,
|
||||
Source: createFileSource(input.FileUrl),
|
||||
Source: types.NewFileSourceFromData(input.FileUrl, ""),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -220,10 +220,12 @@ type CompletionsStreamResponse struct {
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||
UsageSemantic string `json:"usage_semantic,omitempty"`
|
||||
UsageSource string `json:"usage_source,omitempty"`
|
||||
|
||||
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
||||
@@ -251,7 +253,7 @@ type OpenAIVideoResponse struct {
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
CachedCreationTokens int `json:"-"`
|
||||
CachedCreationTokens int `json:"cached_creation_tokens,omitempty"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
|
||||
+13
-13
@@ -9,7 +9,7 @@
|
||||
"version": "1.0.0",
|
||||
"devDependencies": {
|
||||
"cross-env": "^7.0.3",
|
||||
"electron": "35.7.5",
|
||||
"electron": "39.8.5",
|
||||
"electron-builder": "^26.7.0"
|
||||
}
|
||||
},
|
||||
@@ -777,9 +777,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@xmldom/xmldom": {
|
||||
"version": "0.8.11",
|
||||
"resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.8.11.tgz",
|
||||
"integrity": "sha512-cQzWCtO6C8TQiYl1ruKNn2U6Ao4o4WBBcbL61yJl84x+j5sOWWFU9X7DpND8XZG3daDppSsigMdfAIl2upQBRw==",
|
||||
"version": "0.8.12",
|
||||
"resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.8.12.tgz",
|
||||
"integrity": "sha512-9k/gHF6n/pAi/9tqr3m3aqkuiNosYTurLLUtc7xQ9sxB/wm7WPygCv8GYa6mS0fLJEHhqMC1ATYhz++U/lRHqg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
@@ -2145,9 +2145,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/electron": {
|
||||
"version": "35.7.5",
|
||||
"resolved": "https://registry.npmjs.org/electron/-/electron-35.7.5.tgz",
|
||||
"integrity": "sha512-dnL+JvLraKZl7iusXTVTGYs10TKfzUi30uEDTqsmTm0guN9V2tbOjTzyIZbh9n3ygUjgEYyo+igAwMRXIi3IPw==",
|
||||
"version": "39.8.5",
|
||||
"resolved": "https://registry.npmjs.org/electron/-/electron-39.8.5.tgz",
|
||||
"integrity": "sha512-q6+LiQIcTadSyvtPgLDQkCtVA9jQJXQVMrQcctfOJILh6OFMN+UJJLRkuUTy8CZDYeCIBn1ZycqsL1dAXugxZA==",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
@@ -3279,9 +3279,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.23",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz",
|
||||
"integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==",
|
||||
"version": "4.18.1",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.18.1.tgz",
|
||||
"integrity": "sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
@@ -3948,9 +3948,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
Vendored
+1
-1
@@ -25,7 +25,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"cross-env": "^7.0.3",
|
||||
"electron": "35.7.5",
|
||||
"electron": "39.8.5",
|
||||
"electron-builder": "^26.7.0"
|
||||
},
|
||||
"build": {
|
||||
|
||||
@@ -49,11 +49,11 @@ require (
|
||||
github.com/waffo-com/waffo-go v1.3.1
|
||||
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/image v0.38.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
golang.org/x/text v0.35.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
|
||||
@@ -325,16 +325,16 @@ golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/image v0.38.0 h1:5l+q+Y9JDC7mBOMjo4/aPhMDcxEptsX+Tt3GgRQRPuE=
|
||||
golang.org/x/image v0.38.0/go.mod h1:/3f6vaXC+6CEanU4KJxbcUZyEePbyKbaLoDOe4ehFYY=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -353,11 +353,11 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
|
||||
@@ -25,6 +25,7 @@ const (
|
||||
MsgDeleteFailed = "common.delete_failed"
|
||||
MsgAlreadyExists = "common.already_exists"
|
||||
MsgNameCannotBeEmpty = "common.name_cannot_be_empty"
|
||||
MsgBatchTooMany = "common.batch_too_many"
|
||||
)
|
||||
|
||||
// Token related messages
|
||||
|
||||
@@ -21,6 +21,7 @@ common.delete_success: "Deletion successful"
|
||||
common.delete_failed: "Deletion failed"
|
||||
common.already_exists: "Already exists"
|
||||
common.name_cannot_be_empty: "Name cannot be empty"
|
||||
common.batch_too_many: "Too many items in batch request, maximum is {{.Max}}"
|
||||
|
||||
# Token messages
|
||||
token.name_too_long: "Token name is too long"
|
||||
|
||||
@@ -22,6 +22,7 @@ common.delete_success: "删除成功"
|
||||
common.delete_failed: "删除失败"
|
||||
common.already_exists: "已存在"
|
||||
common.name_cannot_be_empty: "名称不能为空"
|
||||
common.batch_too_many: "批量请求数量过多,最多 {{.Max}} 条"
|
||||
|
||||
# Token messages
|
||||
token.name_too_long: "令牌名称过长"
|
||||
|
||||
@@ -22,6 +22,7 @@ common.delete_success: "刪除成功"
|
||||
common.delete_failed: "刪除失敗"
|
||||
common.already_exists: "已存在"
|
||||
common.name_cannot_be_empty: "名稱不能為空"
|
||||
common.batch_too_many: "批次請求數量過多,最多 {{.Max}} 條"
|
||||
|
||||
# Token messages
|
||||
token.name_too_long: "令牌名稱過長"
|
||||
|
||||
@@ -101,8 +101,13 @@ func Distribute() func(c *gin.Context) {
|
||||
|
||||
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
|
||||
preferred, err := model.CacheGetChannel(preferredChannelID)
|
||||
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
|
||||
if usingGroup == "auto" {
|
||||
if err == nil && preferred != nil {
|
||||
if preferred.Status != common.ChannelStatusEnabled {
|
||||
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
|
||||
return
|
||||
}
|
||||
} else if usingGroup == "auto" {
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
autoGroups := service.GetUserAutoGroup(userGroup)
|
||||
for _, g := range autoGroups {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -48,17 +48,23 @@ func checkSystemPerformance() *types.NewAPIError {
|
||||
|
||||
// 检查 CPU
|
||||
if config.CPUThreshold > 0 && int(status.CPUUsage) > config.CPUThreshold {
|
||||
return types.NewErrorWithStatusCode(errors.New("system cpu overloaded"), "system_cpu_overloaded", http.StatusServiceUnavailable)
|
||||
return types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("system cpu overloaded (current: %.1f%%, threshold: %d%%)", status.CPUUsage, config.CPUThreshold),
|
||||
"system_cpu_overloaded", http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
// 检查内存
|
||||
if config.MemoryThreshold > 0 && int(status.MemoryUsage) > config.MemoryThreshold {
|
||||
return types.NewErrorWithStatusCode(errors.New("system memory overloaded"), "system_memory_overloaded", http.StatusServiceUnavailable)
|
||||
return types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("system memory overloaded (current: %.1f%%, threshold: %d%%)", status.MemoryUsage, config.MemoryThreshold),
|
||||
"system_memory_overloaded", http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
// 检查磁盘
|
||||
if config.DiskThreshold > 0 && int(status.DiskUsage) > config.DiskThreshold {
|
||||
return types.NewErrorWithStatusCode(errors.New("system disk overloaded"), "system_disk_overloaded", http.StatusServiceUnavailable)
|
||||
return types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("system disk overloaded (current: %.1f%%, threshold: %d%%)", status.DiskUsage, config.DiskThreshold),
|
||||
"system_disk_overloaded", http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,14 +2,25 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var _bp = func() string {
|
||||
if bi, ok := debug.ReadBuildInfo(); ok && bi.Main.Path != "" {
|
||||
h := sha256.Sum256([]byte(bi.Main.Path))
|
||||
return hex.EncodeToString(h[:4])
|
||||
}
|
||||
return common.GetRandomString(8)
|
||||
}()
|
||||
|
||||
func RequestId() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
id := common.GetTimeString() + common.GetRandomString(8)
|
||||
id := common.GetTimeString() + _bp + common.GetRandomString(8)
|
||||
c.Set(common.RequestIdKey, id)
|
||||
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
+2
-1
@@ -58,7 +58,8 @@ func formatUserLogs(logs []*Log, startIdx int) {
|
||||
if otherMap != nil {
|
||||
// Remove admin-only debug fields.
|
||||
delete(otherMap, "admin_info")
|
||||
delete(otherMap, "reject_reason")
|
||||
// delete(otherMap, "reject_reason")
|
||||
delete(otherMap, "stream_status")
|
||||
}
|
||||
logs[i].Other = common.MapToJsonStr(otherMap)
|
||||
logs[i].Id = startIdx + i + 1
|
||||
|
||||
@@ -481,3 +481,11 @@ func BatchDeleteTokens(ids []int, userId int) (int, error) {
|
||||
|
||||
return len(tokens), nil
|
||||
}
|
||||
|
||||
func GetTokenKeysByIds(ids []int, userId int) ([]Token, error) {
|
||||
var tokens []Token
|
||||
err := DB.Select("id", commonKeyCol).
|
||||
Where("user_id = ? AND id IN (?)", userId, ids).
|
||||
Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -171,12 +171,17 @@ type AliImageRequest struct {
|
||||
}
|
||||
|
||||
type AliImageParameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||
ThinkingMode *bool `json:"thinking_mode,omitempty"`
|
||||
EnableSequential *bool `json:"enable_sequential,omitempty"`
|
||||
BboxList any `json:"bbox_list,omitempty"`
|
||||
ColorPalette any `json:"color_palette,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AliImageParameters) PromptExtendValue() bool {
|
||||
|
||||
@@ -54,7 +54,6 @@ func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequ
|
||||
}
|
||||
}
|
||||
|
||||
// 检查n参数
|
||||
if imageRequest.Parameters.N != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
}
|
||||
@@ -181,6 +180,7 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque
|
||||
},
|
||||
}
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
N: int(lo.FromPtrOr(request.N, uint(1))),
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
return &imageRequest, nil
|
||||
@@ -328,7 +328,6 @@ func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
// 可能生成多张图片,修正计费数量n
|
||||
if aliResponse.Usage.ImageCount != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
|
||||
} else if len(imageResponses.Data) != 0 {
|
||||
|
||||
@@ -40,7 +40,8 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
|
||||
}
|
||||
|
||||
func isOldWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
|
||||
return strings.Contains(modelName, "wan") &&
|
||||
!lo.SomeBy([]string{"wan2.6", "wan2.7"}, func(v string) bool { return strings.Contains(modelName, v) })
|
||||
}
|
||||
|
||||
func isWanModel(modelName string) bool {
|
||||
|
||||
@@ -116,12 +116,12 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
if err := common.Unmarshal([]byte(data), &baiduResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
@@ -129,11 +129,10 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, response); err != nil {
|
||||
common.SysLog("error sending stream response: " + err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -41,11 +42,32 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
if info.IsClaudeBetaQuery {
|
||||
baseURL = baseURL + "?beta=true"
|
||||
requestURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
if !shouldAppendClaudeBetaQuery(info) {
|
||||
return requestURL, nil
|
||||
}
|
||||
return baseURL, nil
|
||||
|
||||
parsedURL, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := parsedURL.Query()
|
||||
query.Set("beta", "true")
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
func shouldAppendClaudeBetaQuery(info *relaycommon.RelayInfo) bool {
|
||||
if info == nil {
|
||||
return false
|
||||
}
|
||||
if info.IsClaudeBetaQuery {
|
||||
return true
|
||||
}
|
||||
if info.ChannelOtherSettings.ClaudeBetaQuery {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func CommonClaudeHeadersOperation(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) {
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestBuildMessageDeltaPatchUsage(t *testing.T) {
|
||||
require.EqualValues(t, 50, usage.CacheCreationInputTokens)
|
||||
require.EqualValues(t, 53, usage.OutputTokens)
|
||||
require.NotNil(t, usage.CacheCreation)
|
||||
require.EqualValues(t, 10, usage.CacheCreation.Ephemeral5mInputTokens)
|
||||
require.EqualValues(t, 30, usage.CacheCreation.Ephemeral5mInputTokens)
|
||||
require.EqualValues(t, 20, usage.CacheCreation.Ephemeral1hInputTokens)
|
||||
})
|
||||
|
||||
@@ -108,4 +108,22 @@ func TestBuildMessageDeltaPatchUsage(t *testing.T) {
|
||||
require.EqualValues(t, 7, usage.CacheReadInputTokens)
|
||||
require.EqualValues(t, 6, usage.CacheCreationInputTokens)
|
||||
})
|
||||
|
||||
t.Run("default aggregate cache creation to 5m when split missing", func(t *testing.T) {
|
||||
claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{
|
||||
OutputTokens: 53,
|
||||
CacheCreationInputTokens: 50,
|
||||
}}
|
||||
claudeInfo := &ClaudeResponseInfo{Usage: &dto.Usage{
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
}}
|
||||
|
||||
usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, usage.CacheCreation)
|
||||
require.EqualValues(t, 50, usage.CacheCreation.Ephemeral5mInputTokens)
|
||||
require.EqualValues(t, 0, usage.CacheCreation.Ephemeral1hInputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
|
||||
// 解析 UserLocation JSON
|
||||
var userLocationMap map[string]interface{}
|
||||
if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
|
||||
if err := common.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
|
||||
// 检查是否有 approximate 字段
|
||||
if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
|
||||
if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
|
||||
@@ -177,7 +177,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
claudeRequest.TopP = common.GetPointer[float64](0)
|
||||
claudeRequest.TopP = nil
|
||||
claudeRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
if !model_setting.ShouldPreserveThinkingSuffix(textRequest.Model) {
|
||||
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
@@ -343,33 +343,39 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
} else {
|
||||
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
|
||||
for _, mediaMessage := range message.ParseContent() {
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Type: mediaMessage.Type,
|
||||
}
|
||||
if mediaMessage.Type == "text" {
|
||||
claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
|
||||
} else {
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
claudeMediaMessage.Type = "image"
|
||||
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
}
|
||||
// 使用统一的文件服务获取图片数据
|
||||
var source *types.FileSource
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
source = types.NewURLFileSource(imageUrl.Url)
|
||||
} else {
|
||||
source = types.NewBase64FileSource(imageUrl.Url, "")
|
||||
switch mediaMessage.Type {
|
||||
case "text":
|
||||
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](mediaMessage.Text),
|
||||
})
|
||||
default:
|
||||
source := mediaMessage.ToFileSource()
|
||||
if source == nil {
|
||||
continue
|
||||
}
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file data failed: %s", err.Error())
|
||||
}
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Source: &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
},
|
||||
}
|
||||
if strings.HasPrefix(mimeType, "application/pdf") {
|
||||
claudeMediaMessage.Type = "document"
|
||||
} else {
|
||||
claudeMediaMessage.Type = "image"
|
||||
}
|
||||
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = base64Data
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
continue
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
}
|
||||
|
||||
if message.ToolCalls != nil {
|
||||
for _, toolCall := range message.ParseToolCalls() {
|
||||
inputObj := make(map[string]any)
|
||||
@@ -555,6 +561,40 @@ type ClaudeResponseInfo struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
func cacheCreationTokensForOpenAIUsage(usage *dto.Usage) int {
|
||||
if usage == nil {
|
||||
return 0
|
||||
}
|
||||
splitCacheCreationTokens := usage.ClaudeCacheCreation5mTokens + usage.ClaudeCacheCreation1hTokens
|
||||
if splitCacheCreationTokens == 0 {
|
||||
return usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
if usage.PromptTokensDetails.CachedCreationTokens > splitCacheCreationTokens {
|
||||
return usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
return splitCacheCreationTokens
|
||||
}
|
||||
|
||||
func buildOpenAIStyleUsageFromClaudeUsage(usage *dto.Usage) dto.Usage {
|
||||
if usage == nil {
|
||||
return dto.Usage{}
|
||||
}
|
||||
clone := *usage
|
||||
clone.ClaudeCacheCreation5mTokens, clone.ClaudeCacheCreation1hTokens = service.NormalizeCacheCreationSplit(
|
||||
usage.PromptTokensDetails.CachedCreationTokens,
|
||||
usage.ClaudeCacheCreation5mTokens,
|
||||
usage.ClaudeCacheCreation1hTokens,
|
||||
)
|
||||
cacheCreationTokens := cacheCreationTokensForOpenAIUsage(usage)
|
||||
totalInputTokens := usage.PromptTokens + usage.PromptTokensDetails.CachedTokens + cacheCreationTokens
|
||||
clone.PromptTokens = totalInputTokens
|
||||
clone.InputTokens = totalInputTokens
|
||||
clone.TotalTokens = totalInputTokens + usage.CompletionTokens
|
||||
clone.UsageSemantic = "openai"
|
||||
clone.UsageSource = "anthropic"
|
||||
return clone
|
||||
}
|
||||
|
||||
func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
|
||||
usage := &dto.ClaudeUsage{}
|
||||
if claudeResponse != nil && claudeResponse.Usage != nil {
|
||||
@@ -574,11 +614,26 @@ func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo
|
||||
if usage.CacheCreationInputTokens == 0 && claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens > 0 {
|
||||
usage.CacheCreationInputTokens = claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
if usage.CacheCreation == nil && (claudeInfo.Usage.ClaudeCacheCreation5mTokens > 0 || claudeInfo.Usage.ClaudeCacheCreation1hTokens > 0) {
|
||||
usage.CacheCreation = &dto.ClaudeCacheCreationUsage{
|
||||
Ephemeral5mInputTokens: claudeInfo.Usage.ClaudeCacheCreation5mTokens,
|
||||
Ephemeral1hInputTokens: claudeInfo.Usage.ClaudeCacheCreation1hTokens,
|
||||
}
|
||||
cacheCreation5m := 0
|
||||
cacheCreation1h := 0
|
||||
if usage.CacheCreation != nil {
|
||||
cacheCreation5m = usage.CacheCreation.Ephemeral5mInputTokens
|
||||
cacheCreation1h = usage.CacheCreation.Ephemeral1hInputTokens
|
||||
} else {
|
||||
cacheCreation5m = claudeInfo.Usage.ClaudeCacheCreation5mTokens
|
||||
cacheCreation1h = claudeInfo.Usage.ClaudeCacheCreation1hTokens
|
||||
}
|
||||
cacheCreation5m, cacheCreation1h = service.NormalizeCacheCreationSplit(
|
||||
usage.CacheCreationInputTokens,
|
||||
cacheCreation5m,
|
||||
cacheCreation1h,
|
||||
)
|
||||
if usage.CacheCreation == nil && (cacheCreation5m > 0 || cacheCreation1h > 0) {
|
||||
usage.CacheCreation = &dto.ClaudeCacheCreationUsage{}
|
||||
}
|
||||
if usage.CacheCreation != nil {
|
||||
usage.CacheCreation.Ephemeral5mInputTokens = cacheCreation5m
|
||||
usage.CacheCreation.Ephemeral1hInputTokens = cacheCreation1h
|
||||
}
|
||||
return usage
|
||||
}
|
||||
@@ -643,6 +698,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
|
||||
// message_start, 获取usage
|
||||
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
|
||||
@@ -661,6 +717,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage != nil {
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
@@ -754,12 +811,16 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
if claudeInfo.Usage != nil {
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
}
|
||||
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
//
|
||||
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, openAIUsage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
@@ -778,12 +839,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
var err *types.NewAPIError
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data)
|
||||
if err != nil {
|
||||
return false
|
||||
sr.Stop(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -810,6 +870,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
|
||||
@@ -819,7 +880,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
openaiResponse.Usage = buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
|
||||
@@ -173,3 +175,208 @@ func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) {
|
||||
t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIStyleUsageFromClaudeUsage(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 20,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
UsageSemantic: "anthropic",
|
||||
}
|
||||
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
|
||||
|
||||
if openAIUsage.PromptTokens != 180 {
|
||||
t.Fatalf("PromptTokens = %d, want 180", openAIUsage.PromptTokens)
|
||||
}
|
||||
if openAIUsage.InputTokens != 180 {
|
||||
t.Fatalf("InputTokens = %d, want 180", openAIUsage.InputTokens)
|
||||
}
|
||||
if openAIUsage.TotalTokens != 200 {
|
||||
t.Fatalf("TotalTokens = %d, want 200", openAIUsage.TotalTokens)
|
||||
}
|
||||
if openAIUsage.UsageSemantic != "openai" {
|
||||
t.Fatalf("UsageSemantic = %s, want openai", openAIUsage.UsageSemantic)
|
||||
}
|
||||
if openAIUsage.UsageSource != "anthropic" {
|
||||
t.Fatalf("UsageSource = %s, want anthropic", openAIUsage.UsageSource)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIStyleUsageFromClaudeUsagePreservesCacheCreationRemainder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cachedCreationTokens int
|
||||
cacheCreationTokens5m int
|
||||
cacheCreationTokens1h int
|
||||
expectedTotalInputToken int
|
||||
}{
|
||||
{
|
||||
name: "prefers aggregate when it includes remainder",
|
||||
cachedCreationTokens: 50,
|
||||
cacheCreationTokens5m: 10,
|
||||
cacheCreationTokens1h: 20,
|
||||
expectedTotalInputToken: 180,
|
||||
},
|
||||
{
|
||||
name: "falls back to split tokens when aggregate missing",
|
||||
cachedCreationTokens: 0,
|
||||
cacheCreationTokens5m: 10,
|
||||
cacheCreationTokens1h: 20,
|
||||
expectedTotalInputToken: 160,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 20,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: tt.cachedCreationTokens,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: tt.cacheCreationTokens5m,
|
||||
ClaudeCacheCreation1hTokens: tt.cacheCreationTokens1h,
|
||||
UsageSemantic: "anthropic",
|
||||
}
|
||||
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
|
||||
|
||||
if openAIUsage.PromptTokens != tt.expectedTotalInputToken {
|
||||
t.Fatalf("PromptTokens = %d, want %d", openAIUsage.PromptTokens, tt.expectedTotalInputToken)
|
||||
}
|
||||
if openAIUsage.InputTokens != tt.expectedTotalInputToken {
|
||||
t.Fatalf("InputTokens = %d, want %d", openAIUsage.InputTokens, tt.expectedTotalInputToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIStyleUsageFromClaudeUsageDefaultsAggregateCacheCreationTo5m(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 20,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
UsageSemantic: "anthropic",
|
||||
}
|
||||
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
|
||||
|
||||
require.Equal(t, 50, openAIUsage.ClaudeCacheCreation5mTokens)
|
||||
require.Equal(t, 0, openAIUsage.ClaudeCacheCreation1hTokens)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_IgnoresUnsupportedFileContent(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: "see attachment",
|
||||
},
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeFile,
|
||||
File: &dto.MessageFile{
|
||||
FileName: "blob.bin",
|
||||
FileData: "JVBERi0xLjQK",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, claudeRequest.Messages, 1)
|
||||
|
||||
content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
require.Equal(t, "text", content[0].Type)
|
||||
require.NotNil(t, content[0].Text)
|
||||
require.Equal(t, "see attachment", *content[0].Text)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_SupportsPDFFileContent(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeFile,
|
||||
File: &dto.MessageFile{
|
||||
FileName: "spec.pdf",
|
||||
FileData: "JVBERi0xLjQK",
|
||||
},
|
||||
},
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: "summarize it",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, claudeRequest.Messages, 1)
|
||||
|
||||
content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 2)
|
||||
require.Equal(t, "document", content[0].Type)
|
||||
require.NotNil(t, content[0].Source)
|
||||
require.Equal(t, "base64", content[0].Source.Type)
|
||||
require.Equal(t, "application/pdf", content[0].Source.MediaType)
|
||||
require.Equal(t, "JVBERi0xLjQK", content[0].Source.Data)
|
||||
require.Equal(t, "text", content[1].Type)
|
||||
require.NotNil(t, content[1].Text)
|
||||
require.Equal(t, "summarize it", *content[1].Text)
|
||||
}
|
||||
|
||||
func TestRequestOpenAI2ClaudeMessage_ConvertsTextFileContentToText(t *testing.T) {
|
||||
request := dto.GeneralOpenAIRequest{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeFile,
|
||||
File: &dto.MessageFile{
|
||||
FileName: "notes.txt",
|
||||
FileData: base64.StdEncoding.EncodeToString([]byte("alpha\nbeta")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, claudeRequest.Messages, 1)
|
||||
|
||||
content, ok := claudeRequest.Messages[0].Content.([]dto.ClaudeMediaMessage)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
require.Equal(t, "text", content[0].Type)
|
||||
require.NotNil(t, content[0].Text)
|
||||
require.Equal(t, "alpha\nbeta", *content[0].Text)
|
||||
}
|
||||
|
||||
@@ -223,33 +223,32 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
usage := &dto.Usage{}
|
||||
var nodeToken int
|
||||
helper.SetEventStreamHeaders(c)
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal([]byte(data), &difyResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||
if difyResponse.Event == "message_end" {
|
||||
usage = &difyResponse.MetaData.Usage
|
||||
return false
|
||||
sr.Done()
|
||||
return
|
||||
} else if difyResponse.Event == "error" {
|
||||
return false
|
||||
} else {
|
||||
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
|
||||
nodeToken += 1
|
||||
}
|
||||
sr.Stop(fmt.Errorf("dify error event"))
|
||||
return
|
||||
}
|
||||
openaiResponse := *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
|
||||
nodeToken += 1
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, openaiResponse); err != nil {
|
||||
common.SysLog(err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
|
||||
@@ -37,6 +37,8 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true, // support old image/jpeg
|
||||
"image/webp": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
"video/mpeg": true,
|
||||
@@ -583,14 +585,10 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
Text: part.Text,
|
||||
})
|
||||
}
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
// 使用统一的文件服务获取图片数据
|
||||
var source *types.FileSource
|
||||
imageUrl := part.GetImageMedia().Url
|
||||
if strings.HasPrefix(imageUrl, "http") {
|
||||
source = types.NewURLFileSource(imageUrl)
|
||||
} else {
|
||||
source = types.NewBase64FileSource(imageUrl, "")
|
||||
} else {
|
||||
source := part.ToFileSource()
|
||||
if source == nil {
|
||||
continue
|
||||
}
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
|
||||
if err != nil {
|
||||
@@ -602,36 +600,6 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeFile {
|
||||
if part.GetFile().FileId != "" {
|
||||
return nil, fmt.Errorf("only base64 file is supported in gemini")
|
||||
}
|
||||
fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeInputAudio {
|
||||
if part.GetInputAudio().Data == "" {
|
||||
return nil, fmt.Errorf("only base64 audio is supported in gemini")
|
||||
}
|
||||
audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
|
||||
base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
@@ -1297,12 +1265,11 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
var imageCount int
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
if err := common.UnmarshalJsonStr(data, &geminiResponse); err != nil {
|
||||
sr.Stop(fmt.Errorf("unmarshal: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
@@ -1327,7 +1294,9 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
*usage = mappedUsage
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
if !callback(data, &geminiResponse) {
|
||||
sr.Stop(fmt.Errorf("gemini callback stopped"))
|
||||
}
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
|
||||
@@ -98,15 +98,8 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
|
||||
parts := m.ParseContent()
|
||||
for _, part := range parts {
|
||||
if part.Type == dto.ContentTypeImageURL {
|
||||
img := part.GetImageMedia()
|
||||
if img != nil && img.Url != "" {
|
||||
// 使用统一的文件服务获取图片数据
|
||||
var source *types.FileSource
|
||||
if strings.HasPrefix(img.Url, "http") {
|
||||
source = types.NewURLFileSource(img.Url)
|
||||
} else {
|
||||
source = types.NewBase64FileSource(img.Url, "")
|
||||
}
|
||||
source := part.ToFileSource()
|
||||
if source != nil {
|
||||
base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -35,21 +35,21 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
if info.IsStream {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if service.SundaySearch(data, "usage") {
|
||||
var simpleResponse dto.SimpleResponse
|
||||
err := common.Unmarshal([]byte(data), &simpleResponse)
|
||||
if err != nil {
|
||||
if err := common.Unmarshal([]byte(data), &simpleResponse); err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
if simpleResponse.Usage.TotalTokens != 0 {
|
||||
sr.Error(err)
|
||||
} else if simpleResponse.Usage.TotalTokens != 0 {
|
||||
usage.PromptTokens = simpleResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = simpleResponse.OutputTokens
|
||||
usage.TotalTokens = simpleResponse.TotalTokens
|
||||
}
|
||||
}
|
||||
_ = helper.StringData(c, data)
|
||||
return true
|
||||
if err := helper.StringData(c, data); err != nil {
|
||||
sr.Error(err)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
|
||||
@@ -296,15 +296,17 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
return true
|
||||
}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if streamErr != nil {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
var streamResp dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
switch streamResp.Type {
|
||||
@@ -320,14 +322,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
|
||||
//case "response.reasoning_text.delta":
|
||||
//if !sendReasoningDelta(streamResp.Delta) {
|
||||
// return false
|
||||
// sr.Stop(streamErr)
|
||||
// return
|
||||
//}
|
||||
|
||||
//case "response.reasoning_text.done":
|
||||
|
||||
case "response.reasoning_summary_text.delta":
|
||||
if !sendReasoningSummaryDelta(streamResp.Delta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.reasoning_summary_text.done":
|
||||
@@ -349,12 +353,14 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
// delta := stringDeltaFromPrefix(prev, next)
|
||||
// reasoningSummaryTextByKey[key] = next
|
||||
// if !sendReasoningSummaryDelta(delta) {
|
||||
// return false
|
||||
// sr.Stop(streamErr)
|
||||
// return
|
||||
// }
|
||||
|
||||
case "response.output_text.delta":
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
if streamResp.Delta != "" {
|
||||
@@ -376,7 +382,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
},
|
||||
}
|
||||
if !sendChatChunk(chunk) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,7 +421,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
if !sendToolCallDelta(callID, name, argsDelta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.delta":
|
||||
@@ -428,7 +436,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
toolCallArgsByID[callID] += streamResp.Delta
|
||||
if !sendToolCallDelta(callID, "", streamResp.Delta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.done":
|
||||
@@ -467,7 +476,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
if !sentStop {
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
|
||||
@@ -479,7 +489,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
||||
if !sendChatChunk(stop) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
sentStop = true
|
||||
}
|
||||
@@ -488,16 +499,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
if streamResp.Response != nil {
|
||||
if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
|
||||
streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if streamErr != nil {
|
||||
|
||||
@@ -126,11 +126,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
// 检查是否为音频模型
|
||||
isAudioModel := strings.Contains(strings.ToLower(model), "audio")
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if lastStreamData != "" {
|
||||
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
if err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent); err != nil {
|
||||
common.SysLog("error handling stream format: " + err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
}
|
||||
if len(data) > 0 {
|
||||
@@ -142,7 +142,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// 对音频模型,从倒数第二个stream data中提取usage信息
|
||||
|
||||
@@ -79,55 +79,55 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
var usage = &dto.Usage{}
|
||||
var responseTextBuilder strings.Builder
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
|
||||
// 检查当前数据是否包含 completed 状态和 usage 信息
|
||||
var streamResponse dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
case dto.ResponsesOutputTypeItemDone:
|
||||
// 函数调用处理
|
||||
if streamResponse.Item != nil {
|
||||
switch streamResponse.Item.Type {
|
||||
case dto.BuildInCallWebSearchCall:
|
||||
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
|
||||
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
|
||||
webSearchTool.CallCount++
|
||||
}
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
case dto.ResponsesOutputTypeItemDone:
|
||||
// 函数调用处理
|
||||
if streamResponse.Item != nil {
|
||||
switch streamResponse.Item.Type {
|
||||
case dto.BuildInCallWebSearchCall:
|
||||
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
|
||||
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
|
||||
webSearchTool.CallCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -13,12 +14,13 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// ============================
|
||||
@@ -26,37 +28,37 @@ import (
|
||||
// ============================
|
||||
|
||||
type ContentItem struct {
|
||||
Type string `json:"type"` // "text", "image_url" or "video"
|
||||
Text string `json:"text,omitempty"` // for text type
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
|
||||
Video *VideoReference `json:"video,omitempty"` // for video (sample) type
|
||||
Role string `json:"role,omitempty"` // reference_image / first_frame / last_frame
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *MediaURL `json:"image_url,omitempty"`
|
||||
VideoURL *MediaURL `json:"video_url,omitempty"`
|
||||
AudioURL *MediaURL `json:"audio_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type VideoReference struct {
|
||||
URL string `json:"url"` // Draft video URL
|
||||
type MediaURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Content []ContentItem `json:"content"`
|
||||
Content []ContentItem `json:"content,omitempty"`
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
ReturnLastFrame *dto.BoolValue `json:"return_last_frame,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
ExecutionExpiresAfter dto.IntValue `json:"execution_expires_after,omitempty"`
|
||||
ExecutionExpiresAfter *dto.IntValue `json:"execution_expires_after,omitempty"`
|
||||
GenerateAudio *dto.BoolValue `json:"generate_audio,omitempty"`
|
||||
Draft *dto.BoolValue `json:"draft,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
Ratio string `json:"ratio,omitempty"`
|
||||
Duration dto.IntValue `json:"duration,omitempty"`
|
||||
Frames dto.IntValue `json:"frames,omitempty"`
|
||||
Seed dto.IntValue `json:"seed,omitempty"`
|
||||
CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
|
||||
Watermark *dto.BoolValue `json:"watermark,omitempty"`
|
||||
Tools []struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
} `json:"tools,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
Ratio string `json:"ratio,omitempty"`
|
||||
Duration *dto.IntValue `json:"duration,omitempty"`
|
||||
Frames *dto.IntValue `json:"frames,omitempty"`
|
||||
Seed *dto.IntValue `json:"seed,omitempty"`
|
||||
CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
|
||||
Watermark *dto.BoolValue `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
@@ -76,10 +78,20 @@ type responseTask struct {
|
||||
Ratio string `json:"ratio"`
|
||||
FramesPerSecond int `json:"framespersecond"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Usage struct {
|
||||
Tools []struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"tools"`
|
||||
Usage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ToolUsage struct {
|
||||
WebSearch int `json:"web_search"`
|
||||
} `json:"tool_usage"`
|
||||
} `json:"usage"`
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
@@ -108,18 +120,61 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(_ *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v3/contents/generations/tasks", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(_ *gin.Context, req *http.Request, _ *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.apiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// EstimateBilling 检测请求 metadata 中是否包含视频输入,返回视频折扣 OtherRatio。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||
req, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if hasVideoInMetadata(req.Metadata) {
|
||||
if ratio, ok := GetVideoInputRatio(info.OriginModelName); ok {
|
||||
return map[string]float64{"video_input": ratio}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasVideoInMetadata 直接检查 metadata 的 content 数组是否包含 video_url 条目,
|
||||
// 避免构建完整的上游 requestPayload。
|
||||
func hasVideoInMetadata(metadata map[string]interface{}) bool {
|
||||
if metadata == nil {
|
||||
return false
|
||||
}
|
||||
contentRaw, ok := metadata["content"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
contentSlice, ok := contentRaw.([]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range contentSlice {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if itemMap["type"] == "video_url" {
|
||||
return true
|
||||
}
|
||||
if _, has := itemMap["video_url"]; has {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Doubao specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
req, err := relaycommon.GetTaskRequest(c)
|
||||
@@ -218,20 +273,12 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
Content: []ContentItem{},
|
||||
}
|
||||
|
||||
// Add text prompt
|
||||
if req.Prompt != "" {
|
||||
r.Content = append(r.Content, ContentItem{
|
||||
Type: "text",
|
||||
Text: req.Prompt,
|
||||
})
|
||||
}
|
||||
|
||||
// Add images if present
|
||||
if req.HasImage() {
|
||||
for _, imgURL := range req.Images {
|
||||
r.Content = append(r.Content, ContentItem{
|
||||
Type: "image_url",
|
||||
ImageURL: &ImageURL{
|
||||
ImageURL: &MediaURL{
|
||||
URL: imgURL,
|
||||
},
|
||||
})
|
||||
@@ -243,6 +290,16 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
if sec, _ := strconv.Atoi(req.Seconds); sec > 0 {
|
||||
r.Duration = lo.ToPtr(dto.IntValue(sec))
|
||||
}
|
||||
|
||||
r.Content = lo.Reject(r.Content, func(c ContentItem, _ int) bool { return c.Type == "text" })
|
||||
r.Content = append(r.Content, ContentItem{
|
||||
Type: "text",
|
||||
Text: req.Prompt,
|
||||
})
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
@@ -274,7 +331,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
case "failed":
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
taskResult.Progress = "100%"
|
||||
taskResult.Reason = "task failed"
|
||||
taskResult.Reason = resTask.Error.Message
|
||||
default:
|
||||
// Unknown status, treat as processing
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
@@ -302,8 +359,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
|
||||
if dResp.Status == "failed" {
|
||||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||||
Message: "task failed",
|
||||
Code: "failed",
|
||||
Message: dResp.Error.Message,
|
||||
Code: dResp.Error.Code,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,21 @@ var ModelList = []string{
|
||||
"doubao-seedance-1-0-lite-t2v",
|
||||
"doubao-seedance-1-0-lite-i2v",
|
||||
"doubao-seedance-1-5-pro-251215",
|
||||
"doubao-seedance-2-0-260128",
|
||||
"doubao-seedance-2-0-fast-260128",
|
||||
}
|
||||
|
||||
var ChannelName = "doubao-video"
|
||||
|
||||
// videoInputRatioMap 视频输入折扣比率(含视频单价 / 不含视频单价)。
|
||||
// 管理员应将 ModelRatio 设置为"不含视频"的较高费率,
|
||||
// 系统在检测到视频输入时自动乘以此折扣。
|
||||
var videoInputRatioMap = map[string]float64{
|
||||
"doubao-seedance-2-0-260128": 28.0 / 46.0, // ~0.6087
|
||||
"doubao-seedance-2-0-fast-260128": 22.0 / 37.0, // ~0.5946
|
||||
}
|
||||
|
||||
func GetVideoInputRatio(modelName string) (float64, bool) {
|
||||
r, ok := videoInputRatioMap[modelName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ func UnmarshalMetadata(metadata map[string]any, target any) error {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
// Prevent metadata from overriding model fields to avoid billing bypass.
|
||||
delete(metadata, "model")
|
||||
metaBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
|
||||
@@ -76,7 +76,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if strings.HasPrefix(request.Model, "grok-3-mini") {
|
||||
if lo.FromPtrOr(request.MaxCompletionTokens, uint(0)) == 0 && lo.FromPtrOr(request.MaxTokens, uint(0)) != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = lo.ToPtr(uint(0))
|
||||
request.MaxTokens = nil
|
||||
}
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.ReasoningEffort = "high"
|
||||
|
||||
@@ -43,12 +43,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||
err := common.UnmarshalJsonStr(data, &xAIResp)
|
||||
if err != nil {
|
||||
if err := common.UnmarshalJsonStr(data, &xAIResp); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 把 xAI 的usage转换为 OpenAI 的usage
|
||||
@@ -61,11 +61,10 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
|
||||
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
||||
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, openaiResponse); err != nil {
|
||||
common.SysLog(err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
|
||||
@@ -122,7 +122,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage)
|
||||
service.PostTextConsumeQuota(c, info, usage, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,6 +190,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -162,6 +162,8 @@ type RelayInfo struct {
|
||||
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
|
||||
FinalRequestRelayFormat types.RelayFormat
|
||||
|
||||
StreamStatus *StreamStatus
|
||||
|
||||
ThinkingContentInfo
|
||||
TokenCountMeta
|
||||
*ClaudeConvertInfo
|
||||
@@ -338,15 +340,10 @@ func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeNone,
|
||||
}
|
||||
info.IsClaudeBetaQuery = c.Query("beta") == "true" || isClaudeBetaForced(c)
|
||||
info.IsClaudeBetaQuery = c.Query("beta") == "true"
|
||||
return info
|
||||
}
|
||||
|
||||
func isClaudeBetaForced(c *gin.Context) bool {
|
||||
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
||||
return ok && channelOtherSettings.ClaudeBetaQuery
|
||||
}
|
||||
|
||||
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StreamEndReason string
|
||||
|
||||
const (
|
||||
StreamEndReasonNone StreamEndReason = ""
|
||||
StreamEndReasonDone StreamEndReason = "done"
|
||||
StreamEndReasonTimeout StreamEndReason = "timeout"
|
||||
StreamEndReasonClientGone StreamEndReason = "client_gone"
|
||||
StreamEndReasonScannerErr StreamEndReason = "scanner_error"
|
||||
StreamEndReasonHandlerStop StreamEndReason = "handler_stop"
|
||||
StreamEndReasonEOF StreamEndReason = "eof"
|
||||
StreamEndReasonPanic StreamEndReason = "panic"
|
||||
StreamEndReasonPingFail StreamEndReason = "ping_fail"
|
||||
)
|
||||
|
||||
const maxStreamErrorEntries = 20
|
||||
|
||||
type StreamErrorEntry struct {
|
||||
Message string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
type StreamStatus struct {
|
||||
EndReason StreamEndReason
|
||||
EndError error
|
||||
endOnce sync.Once
|
||||
|
||||
mu sync.Mutex
|
||||
Errors []StreamErrorEntry
|
||||
ErrorCount int
|
||||
}
|
||||
|
||||
func NewStreamStatus() *StreamStatus {
|
||||
return &StreamStatus{}
|
||||
}
|
||||
|
||||
func (s *StreamStatus) SetEndReason(reason StreamEndReason, err error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.endOnce.Do(func() {
|
||||
s.EndReason = reason
|
||||
s.EndError = err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StreamStatus) RecordError(msg string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ErrorCount++
|
||||
if len(s.Errors) < maxStreamErrorEntries {
|
||||
s.Errors = append(s.Errors, StreamErrorEntry{
|
||||
Message: msg,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StreamStatus) HasErrors() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ErrorCount > 0
|
||||
}
|
||||
|
||||
func (s *StreamStatus) TotalErrorCount() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ErrorCount
|
||||
}
|
||||
|
||||
func (s *StreamStatus) IsNormalEnd() bool {
|
||||
if s == nil {
|
||||
return true
|
||||
}
|
||||
return s.EndReason == StreamEndReasonDone ||
|
||||
s.EndReason == StreamEndReasonEOF ||
|
||||
s.EndReason == StreamEndReasonHandlerStop
|
||||
}
|
||||
|
||||
func (s *StreamStatus) Summary() string {
|
||||
if s == nil {
|
||||
return "StreamStatus<nil>"
|
||||
}
|
||||
b := &strings.Builder{}
|
||||
fmt.Fprintf(b, "reason=%s", s.EndReason)
|
||||
if s.EndError != nil {
|
||||
fmt.Fprintf(b, " end_error=%q", s.EndError.Error())
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.ErrorCount > 0 {
|
||||
fmt.Fprintf(b, " soft_errors=%d", s.ErrorCount)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStreamStatus_SetEndReason_FirstWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
s.SetEndReason(StreamEndReasonTimeout, nil)
|
||||
s.SetEndReason(StreamEndReasonClientGone, fmt.Errorf("context canceled"))
|
||||
|
||||
assert.Equal(t, StreamEndReasonDone, s.EndReason)
|
||||
assert.Nil(t, s.EndError)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_WithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
expectedErr := fmt.Errorf("read: connection reset")
|
||||
s.SetEndReason(StreamEndReasonScannerErr, expectedErr)
|
||||
|
||||
assert.Equal(t, StreamEndReasonScannerErr, s.EndReason)
|
||||
assert.Equal(t, expectedErr, s.EndError)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
reasons := []StreamEndReason{
|
||||
StreamEndReasonDone,
|
||||
StreamEndReasonTimeout,
|
||||
StreamEndReasonClientGone,
|
||||
StreamEndReasonScannerErr,
|
||||
StreamEndReasonHandlerStop,
|
||||
StreamEndReasonEOF,
|
||||
StreamEndReasonPanic,
|
||||
StreamEndReasonPingFail,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, r := range reasons {
|
||||
wg.Add(1)
|
||||
go func(reason StreamEndReason) {
|
||||
defer wg.Done()
|
||||
s.SetEndReason(reason, nil)
|
||||
}(r)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.NotEqual(t, StreamEndReasonNone, s.EndReason)
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
s.RecordError("bad json")
|
||||
s.RecordError("another bad json")
|
||||
s.RecordError("client gone")
|
||||
|
||||
assert.True(t, s.HasErrors())
|
||||
assert.Equal(t, 3, s.TotalErrorCount())
|
||||
assert.Len(t, s.Errors, 3)
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_CapAtMax(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
s.RecordError(fmt.Sprintf("error_%d", i))
|
||||
}
|
||||
|
||||
assert.Equal(t, maxStreamErrorEntries, len(s.Errors))
|
||||
assert.Equal(t, 30, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
s.RecordError("should not panic")
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
s.RecordError(fmt.Sprintf("error_%d", idx))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 100, s.TotalErrorCount())
|
||||
assert.LessOrEqual(t, len(s.Errors), maxStreamErrorEntries)
|
||||
}
|
||||
|
||||
func TestStreamStatus_HasErrors_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
assert.False(t, s.HasErrors())
|
||||
assert.Equal(t, 0, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_HasErrors_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.False(t, s.HasErrors())
|
||||
assert.Equal(t, 0, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_IsNormalEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
reason StreamEndReason
|
||||
normal bool
|
||||
}{
|
||||
{StreamEndReasonDone, true},
|
||||
{StreamEndReasonEOF, true},
|
||||
{StreamEndReasonHandlerStop, true},
|
||||
{StreamEndReasonTimeout, false},
|
||||
{StreamEndReasonClientGone, false},
|
||||
{StreamEndReasonScannerErr, false},
|
||||
{StreamEndReasonPanic, false},
|
||||
{StreamEndReasonPingFail, false},
|
||||
{StreamEndReasonNone, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
s := NewStreamStatus()
|
||||
s.SetEndReason(tt.reason, nil)
|
||||
assert.Equal(t, tt.normal, s.IsNormalEnd(), "reason=%s", tt.reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamStatus_IsNormalEnd_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.True(t, s.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamStatus_Summary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := NewStreamStatus()
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
summary := s.Summary()
|
||||
assert.Contains(t, summary, "reason=done")
|
||||
assert.NotContains(t, summary, "soft_errors")
|
||||
|
||||
s2 := NewStreamStatus()
|
||||
s2.SetEndReason(StreamEndReasonTimeout, nil)
|
||||
s2.RecordError("bad json")
|
||||
s2.RecordError("write failed")
|
||||
summary2 := s2.Summary()
|
||||
assert.Contains(t, summary2, "reason=timeout")
|
||||
assert.Contains(t, summary2, "soft_errors=2")
|
||||
}
|
||||
|
||||
func TestStreamStatus_Summary_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.Equal(t, "StreamStatus<nil>", s.Summary())
|
||||
}
|
||||
+2
-293
@@ -6,25 +6,20 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -93,7 +88,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage, "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage)
|
||||
service.PostTextConsumeQuota(c, info, usage, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -216,293 +211,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
|
||||
originUsage := usage
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
|
||||
if originUsage != nil {
|
||||
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
imageRatio := relayInfo.PriceData.ImageRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
|
||||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
|
||||
// openai web search 工具计费
|
||||
var dWebSearchQuota decimal.Decimal
|
||||
var webSearchPrice float64
|
||||
// response api 格式工具计费
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||||
// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
// search-preview 模型不支持 response api
|
||||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
||||
if searchContextSize == "" {
|
||||
searchContextSize = "medium"
|
||||
}
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||
searchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
// claude web search tool 计费
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
var claudeWebSearchPrice float64
|
||||
claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
|
||||
if claudeWebSearchCallCount > 0 {
|
||||
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
|
||||
}
|
||||
// file search tool 计费
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
var fileSearchPrice float64
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||||
fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
|
||||
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String()))
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
var imageGenerationCallPrice float64
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
var audioInputPrice float64
|
||||
isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
// Anthropic API 的 input_tokens 已经不包含缓存 tokens,不需要减去
|
||||
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens,需要减去
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
if !isClaudeUsageSemantic {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
if !isClaudeUsageSemantic {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
}
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
// 减去 image tokens
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
if !dImageTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dImageTokens)
|
||||
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
|
||||
}
|
||||
|
||||
// 减去 Gemini audio tokens
|
||||
if !dAudioTokens.IsZero() {
|
||||
audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
|
||||
if audioInputPrice > 0 {
|
||||
// 重新计算 base tokens
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
Add(imageTokensWithRatio).
|
||||
Add(dCachedCreationTokensWithRatio)
|
||||
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
|
||||
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
|
||||
|
||||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||||
}
|
||||
} else {
|
||||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
}
|
||||
// 添加 responses tools call 调用的配额
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
dOtherRatio := decimal.NewFromFloat(otherRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
|
||||
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
//var logContent string
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
if !ratio.IsZero() && quota == 0 {
|
||||
quota = 1
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
if err := service.SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if adminRejectReason != "" {
|
||||
other["reject_reason"] = adminRejectReason
|
||||
}
|
||||
// For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
|
||||
if isClaudeUsageSemantic {
|
||||
other["claude"] = true
|
||||
other["usage_semantic"] = "anthropic"
|
||||
}
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
other["image_output"] = imageTokens
|
||||
}
|
||||
if cachedCreationTokens != 0 {
|
||||
other["cache_creation_tokens"] = cachedCreationTokens
|
||||
other["cache_creation_ratio"] = cachedCreationRatio
|
||||
}
|
||||
if !dWebSearchQuota.IsZero() {
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = webSearchTool.CallCount
|
||||
other["web_search_price"] = webSearchPrice
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = 1
|
||||
other["web_search_price"] = webSearchPrice
|
||||
}
|
||||
} else if !dClaudeWebSearchQuota.IsZero() {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = claudeWebSearchCallCount
|
||||
other["web_search_price"] = claudeWebSearchPrice
|
||||
}
|
||||
if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
|
||||
other["file_search"] = true
|
||||
other["file_search_call_count"] = fileSearchTool.CallCount
|
||||
other["file_search_price"] = fileSearchPrice
|
||||
}
|
||||
}
|
||||
if !audioInputQuota.IsZero() {
|
||||
other["audio_input_seperate_price"] = true
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dImageGenerationCallQuota.IsZero() {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = imageGenerationCallPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -288,6 +288,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
+29
-15
@@ -139,21 +139,23 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
||||
// ModelPriceHelperPerCall 按次/按量计费的 PriceHelper (MJ、Task)
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types.PriceData, error) {
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
||||
// 如果没有配置价格,检查模型倍率配置
|
||||
if !success {
|
||||
usePrice := success
|
||||
var modelRatio float64
|
||||
|
||||
// 没有配置费用,也要使用默认费用,否则按费率计费模型无法使用
|
||||
if !success {
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[info.OriginModelName]
|
||||
if ok {
|
||||
modelPrice = defaultPrice
|
||||
usePrice = true
|
||||
} else {
|
||||
// 没有配置倍率也不接受没配置,那就返回错误
|
||||
_, ratioSuccess, matchName := ratio_setting.GetModelRatio(info.OriginModelName)
|
||||
var ratioSuccess bool
|
||||
var matchName string
|
||||
modelRatio, ratioSuccess, matchName = ratio_setting.GetModelRatio(info.OriginModelName)
|
||||
acceptUnsetRatio := false
|
||||
if info.UserSetting.AcceptUnsetRatioModel {
|
||||
acceptUnsetRatio = true
|
||||
@@ -161,25 +163,37 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types
|
||||
if !ratioSuccess && !acceptUnsetRatio {
|
||||
return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
|
||||
}
|
||||
// 未配置价格但配置了倍率,使用默认预扣价格
|
||||
modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
|
||||
}
|
||||
|
||||
}
|
||||
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
|
||||
// 免费模型检测(与 ModelPriceHelper 对齐)
|
||||
var quota int
|
||||
freeModel := false
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
|
||||
quota = 0
|
||||
freeModel = true
|
||||
|
||||
if usePrice {
|
||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
|
||||
quota = 0
|
||||
freeModel = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 按量计费:以模型倍率的一半作为预扣额度
|
||||
quota = int(modelRatio / 2 * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
modelPrice = -1
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
if groupRatioInfo.GroupRatio == 0 || modelRatio == 0 {
|
||||
quota = 0
|
||||
freeModel = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
priceData := types.PriceData{
|
||||
FreeModel: freeModel,
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
UsePrice: usePrice,
|
||||
Quota: quota,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
)
|
||||
|
||||
// StreamResult is passed to each dataHandler invocation, providing methods
|
||||
// to record soft errors, signal fatal stops, or mark normal completion.
|
||||
// StreamScannerHandler checks IsStopped() after each callback invocation.
|
||||
type StreamResult struct {
|
||||
status *relaycommon.StreamStatus
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func newStreamResult(status *relaycommon.StreamStatus) *StreamResult {
|
||||
return &StreamResult{status: status}
|
||||
}
|
||||
|
||||
// Error records a soft error. The stream continues processing.
|
||||
// Can be called multiple times per chunk.
|
||||
func (r *StreamResult) Error(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
r.status.RecordError(err.Error())
|
||||
}
|
||||
|
||||
// Stop records a fatal error and marks the stream to stop after this chunk.
|
||||
func (r *StreamResult) Stop(err error) {
|
||||
if err != nil {
|
||||
r.status.RecordError(err.Error())
|
||||
}
|
||||
r.status.SetEndReason(relaycommon.StreamEndReasonHandlerStop, err)
|
||||
r.stopped = true
|
||||
}
|
||||
|
||||
// Done signals that the handler has finished processing normally
|
||||
// (e.g., Dify "message_end"). The stream stops after this chunk.
|
||||
func (r *StreamResult) Done() {
|
||||
r.status.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
r.stopped = true
|
||||
}
|
||||
|
||||
// IsStopped returns whether Stop() or Done() was called during this chunk.
|
||||
func (r *StreamResult) IsStopped() bool {
|
||||
return r.stopped
|
||||
}
|
||||
|
||||
// reset clears the per-chunk stopped flag so the object can be reused.
|
||||
func (r *StreamResult) reset() {
|
||||
r.stopped = false
|
||||
}
|
||||
@@ -34,12 +34,15 @@ func getScannerBufferSize() int {
|
||||
return DefaultMaxScannerBufferSize
|
||||
}
|
||||
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
|
||||
|
||||
if resp == nil || dataHandler == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 无条件新建 StreamStatus
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
|
||||
// 确保响应体总是被关闭
|
||||
defer func() {
|
||||
if resp.Body != nil {
|
||||
@@ -121,6 +124,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("ping panic: %v", r))
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
@@ -148,6 +152,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.LogError(c, "ping data error: "+err.Error())
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, err)
|
||||
return
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
@@ -155,6 +160,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
logger.LogError(c, "ping data send timeout")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, fmt.Errorf("ping send timeout"))
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -184,14 +190,17 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("handler panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
sr := newStreamResult(info.StreamStatus)
|
||||
for data := range dataChan {
|
||||
sr.reset()
|
||||
writeMutex.Lock()
|
||||
success := dataHandler(data)
|
||||
dataHandler(data, sr)
|
||||
writeMutex.Unlock()
|
||||
if !success {
|
||||
if sr.IsStopped() {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -205,6 +214,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("scanner panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
if common.DebugEnabled {
|
||||
@@ -220,6 +230,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-c.Request.Context().Done():
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
||||
return
|
||||
default:
|
||||
}
|
||||
@@ -253,7 +264,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
if common.DebugEnabled {
|
||||
println("received [DONE], stopping scanner")
|
||||
}
|
||||
@@ -264,20 +275,25 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
if err := scanner.Err(); err != nil {
|
||||
if err != io.EOF {
|
||||
logger.LogError(c, "scanner error: "+err.Error())
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
|
||||
}
|
||||
}
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
||||
})
|
||||
|
||||
// 主循环等待完成或超时
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
logger.LogError(c, "streaming timeout")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonTimeout, nil)
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
logger.LogInfo(c, "streaming finished")
|
||||
// EndReason already set by the goroutine that triggered stopChan
|
||||
case <-c.Request.Context().Done():
|
||||
// 客户端断开连接
|
||||
logger.LogInfo(c, "client disconnected")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
||||
}
|
||||
|
||||
if info.StreamStatus.IsNormalEnd() && !info.StreamStatus.HasErrors() {
|
||||
logger.LogInfo(c, fmt.Sprintf("stream ended: %s", info.StreamStatus.Summary()))
|
||||
} else {
|
||||
logger.LogError(c, fmt.Sprintf("stream ended: %s, received=%d", info.StreamStatus.Summary(), info.ReceivedResponseCount))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,8 +56,6 @@ func buildSSEBody(n int) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// slowReader wraps a reader and injects a delay before each Read call,
|
||||
// simulating a slow upstream that trickles data.
|
||||
type slowReader struct {
|
||||
r io.Reader
|
||||
delay time.Duration
|
||||
@@ -79,7 +77,7 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
||||
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
StreamScannerHandler(c, nil, info, func(data string) bool { return true })
|
||||
StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {})
|
||||
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
||||
}
|
||||
|
||||
@@ -89,9 +87,8 @@ func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(""))
|
||||
|
||||
var called atomic.Bool
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
called.Store(true)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.False(t, called.Load(), "handler should not be called for empty body")
|
||||
@@ -105,9 +102,8 @@ func TestStreamScannerHandler_1000Chunks(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
@@ -124,9 +120,8 @@ func TestStreamScannerHandler_10000Chunks(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
elapsed := time.Since(start)
|
||||
@@ -145,11 +140,10 @@ func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
received := make([]string, 0, numChunks)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
mu.Lock()
|
||||
received = append(received, data)
|
||||
mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
require.Equal(t, numChunks, len(received))
|
||||
@@ -166,31 +160,32 @@ func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
|
||||
func TestStreamScannerHandler_StopStopsStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 200
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
const failAt = 50
|
||||
const stopAt int64 = 50
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
return n < failAt
|
||||
if n >= stopAt {
|
||||
sr.Stop(fmt.Errorf("fatal at %d", n))
|
||||
}
|
||||
})
|
||||
|
||||
// The worker stops at failAt; the scanner may have read ahead,
|
||||
// but the handler should not be called beyond failAt.
|
||||
assert.Equal(t, int64(failAt), count.Load())
|
||||
assert.Equal(t, stopAt, count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
@@ -210,9 +205,8 @@ func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(100), count.Load())
|
||||
@@ -225,25 +219,18 @@ func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var got string
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
got = data
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, "{\"trimmed\":true}", got)
|
||||
}
|
||||
|
||||
// ---------- Decoupling: scanner not blocked by slow handler ----------
|
||||
// ---------- Decoupling ----------
|
||||
|
||||
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
|
||||
// If the scanner were synchronously coupled to the handler, total time would be
|
||||
// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
|
||||
// With decoupling, total time should be closer to
|
||||
// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
|
||||
// because the scanner reads ahead into the buffer while the handler processes.
|
||||
const numChunks = 50
|
||||
const upstreamDelay = 10 * time.Millisecond
|
||||
const handlerDelay = 20 * time.Millisecond
|
||||
@@ -273,10 +260,9 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
time.Sleep(handlerDelay)
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -293,7 +279,6 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
|
||||
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
|
||||
|
||||
// If decoupled, elapsed should be well under the coupled estimate.
|
||||
assert.Less(t, elapsed, coupledTime*85/100,
|
||||
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
|
||||
}
|
||||
@@ -311,9 +296,8 @@ func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -344,8 +328,6 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
|
||||
// The ping interval is 1s, so we should see at least 2 pings.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
@@ -372,9 +354,8 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -436,9 +417,8 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -456,6 +436,199 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
|
||||
}
|
||||
|
||||
// ---------- StreamStatus integration ----------
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(10)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Nil(t, info.StreamStatus.EndError)
|
||||
assert.True(t, info.StreamStatus.IsNormalEnd())
|
||||
assert.False(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < 5; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
||||
}
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(100)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
if n >= 10 {
|
||||
sr.Stop(fmt.Errorf("stop at 10"))
|
||||
}
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(20)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
if n >= 5 {
|
||||
sr.Done()
|
||||
}
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(5), count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.False(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) {
|
||||
// Not parallel: modifies global constant.StreamingTimeout
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 2
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
fmt.Fprint(pw, "data: {\"id\":1}\n")
|
||||
time.Sleep(10 * time.Second)
|
||||
pw.Close()
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for stream timeout")
|
||||
}
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason)
|
||||
assert.False(t, info.StreamStatus.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(10)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
sr.Error(fmt.Errorf("soft error for chunk"))
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.HasErrors())
|
||||
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(5)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
sr.Error(fmt.Errorf("error A"))
|
||||
sr.Error(fmt.Errorf("error B"))
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a large body without [DONE] to avoid race between scanner's [DONE]
|
||||
// and handler's Stop on the sync.Once EndReason.
|
||||
var b strings.Builder
|
||||
for i := 0; i < 100; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
||||
}
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
sr.Error(fmt.Errorf("soft error"))
|
||||
sr.Stop(fmt.Errorf("fatal"))
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(1), count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 2, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(1)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
assert.Nil(t, info.StreamStatus)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
assert.NotNil(t, info.StreamStatus)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(5)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
info.StreamStatus.RecordError("pre-existing error")
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -469,9 +642,6 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Slow upstream + slow handler. Total stream takes ~5 seconds.
|
||||
// The ping goroutine stays alive as long as the scanner is reading,
|
||||
// so pings should fire between data writes.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
@@ -498,9 +668,8 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
+12
-3
@@ -117,11 +117,20 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if request.N != nil {
|
||||
imageN = *request.N
|
||||
}
|
||||
|
||||
// n is handled via OtherRatio so it is applied exactly once in quota
|
||||
// calculation (both price-based and ratio-based paths).
|
||||
// Adaptors may have already set a more accurate count from the
|
||||
// upstream response; only set the default when they haven't.
|
||||
if _, hasN := info.PriceData.OtherRatios["n"]; !hasN {
|
||||
info.PriceData.AddOtherRatio("n", float64(imageN))
|
||||
}
|
||||
|
||||
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||
usage.(*dto.Usage).TotalTokens = int(imageN)
|
||||
usage.(*dto.Usage).TotalTokens = 1
|
||||
}
|
||||
if usage.(*dto.Usage).PromptTokens == 0 {
|
||||
usage.(*dto.Usage).PromptTokens = int(imageN)
|
||||
usage.(*dto.Usage).PromptTokens = 1
|
||||
}
|
||||
|
||||
quality := "standard"
|
||||
@@ -141,6 +150,6 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,6 +49,13 @@ func RelayMidjourneyImage(c *gin.Context) {
|
||||
if httpClient == nil {
|
||||
httpClient = service.GetHttpClient()
|
||||
}
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(midjourneyTask.ImageUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": fmt.Sprintf("request blocked: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp, err := httpClient.Get(midjourneyTask.ImageUrl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
|
||||
@@ -96,6 +96,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
info.PriceData = originPriceData
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
postConsumeQuota(c, info, usageDto)
|
||||
service.PostTextConsumeQuota(c, info, usageDto, nil)
|
||||
|
||||
info.OriginModelName = originModelName
|
||||
info.PriceData = originPriceData
|
||||
@@ -155,7 +155,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, info, usageDto, "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usageDto)
|
||||
service.PostTextConsumeQuota(c, info, usageDto, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,10 +36,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
// OAuth routes - specific routes must come before :provider wildcard
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.POST("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
// Non-standard OAuth (WeChat, Telegram) - keep original routes
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.POST("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
|
||||
@@ -226,7 +226,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/batch", controller.DeleteChannelBatch)
|
||||
channelRoute.POST("/fix", controller.FixChannelsAbilities)
|
||||
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
|
||||
channelRoute.POST("/fetch_models", controller.FetchModels)
|
||||
channelRoute.POST("/fetch_models", middleware.RootAuth(), controller.FetchModels)
|
||||
channelRoute.POST("/codex/oauth/start", controller.StartCodexOAuth)
|
||||
channelRoute.POST("/codex/oauth/complete", controller.CompleteCodexOAuth)
|
||||
channelRoute.POST("/:id/codex/oauth/start", controller.StartCodexOAuthForChannel)
|
||||
@@ -257,6 +257,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
tokenRoute.PUT("/", controller.UpdateToken)
|
||||
tokenRoute.DELETE("/:id", controller.DeleteToken)
|
||||
tokenRoute.POST("/batch", controller.DeleteTokenBatch)
|
||||
tokenRoute.POST("/batch/keys", middleware.CriticalRateLimit(), middleware.DisableCache(), controller.GetTokenKeysBatch)
|
||||
}
|
||||
|
||||
usageRoute := apiRouter.Group("/usage")
|
||||
|
||||
@@ -610,14 +610,17 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
v, ok := c.Get(ginKeyChannelAffinitySkipRetry)
|
||||
if ok {
|
||||
b, ok := v.(bool)
|
||||
if ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
meta, ok := getChannelAffinityMeta(c)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return b
|
||||
return meta.SkipRetry
|
||||
}
|
||||
|
||||
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
|
||||
|
||||
@@ -116,6 +116,66 @@ func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) {
|
||||
require.Equal(t, "trim_prefix", secondOp["mode"])
|
||||
}
|
||||
|
||||
func TestShouldSkipRetryAfterChannelAffinityFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx func() *gin.Context
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil context",
|
||||
ctx: func() *gin.Context {
|
||||
return nil
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "explicit skip retry flag in context",
|
||||
ctx: func() *gin.Context {
|
||||
ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-explicit-flag",
|
||||
SkipRetry: false,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
ctx.Set(ginKeyChannelAffinitySkipRetry, true)
|
||||
return ctx
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "fallback to matched rule meta",
|
||||
ctx: func() *gin.Context {
|
||||
return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-skip-retry",
|
||||
SkipRetry: true,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no flag and no skip retry meta",
|
||||
ctx: func() *gin.Context {
|
||||
return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-no-skip-retry",
|
||||
SkipRetry: false,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, ShouldSkipRetryAfterChannelAffinityFailure(tt.ctx()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
+58
-35
@@ -223,6 +223,35 @@ func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func buildClaudeUsageFromOpenAIUsage(oaiUsage *dto.Usage) *dto.ClaudeUsage {
|
||||
if oaiUsage == nil {
|
||||
return nil
|
||||
}
|
||||
cacheCreation5m, cacheCreation1h := NormalizeCacheCreationSplit(
|
||||
oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
oaiUsage.ClaudeCacheCreation5mTokens,
|
||||
oaiUsage.ClaudeCacheCreation1hTokens,
|
||||
)
|
||||
usage := &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
}
|
||||
if cacheCreation5m > 0 || cacheCreation1h > 0 {
|
||||
usage.CacheCreation = &dto.ClaudeCacheCreationUsage{
|
||||
Ephemeral5mInputTokens: cacheCreation5m,
|
||||
Ephemeral1hInputTokens: cacheCreation1h,
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func NormalizeCacheCreationSplit(totalTokens int, tokens5m int, tokens1h int) (int, int) {
|
||||
remainder := lo.Max([]int{totalTokens - tokens5m - tokens1h, 0})
|
||||
return tokens5m + remainder, tokens1h
|
||||
}
|
||||
|
||||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
return nil
|
||||
@@ -391,13 +420,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -412,28 +436,28 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
|
||||
if len(openAIResponse.Choices) == 0 {
|
||||
// no choices
|
||||
// 可能为非标准的 OpenAI 响应,判断是否已经完成
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
// Some OpenAI-compatible upstreams end with a usage-only SSE chunk.
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
stopOpenBlocks()
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
})
|
||||
stopReason := stopReasonOpenAI2Claude(info.FinishReason)
|
||||
if stopReason == "" {
|
||||
stopReason = "end_turn"
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReason),
|
||||
},
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_stop",
|
||||
})
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
}
|
||||
return claudeResponses
|
||||
} else {
|
||||
@@ -441,6 +465,13 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
doneChunk := chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != ""
|
||||
if doneChunk {
|
||||
info.FinishReason = *chosenChoice.FinishReason
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
// Some upstreams emit finish_reason first, then send a final usage-only chunk.
|
||||
// Defer closing until usage is available so the final message_delta carries it.
|
||||
return claudeResponses
|
||||
}
|
||||
}
|
||||
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
@@ -555,13 +586,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -612,10 +638,7 @@ func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relayco
|
||||
}
|
||||
claudeResponse.Content = contents
|
||||
claudeResponse.StopReason = stopReason
|
||||
claudeResponse.Usage = &dto.ClaudeUsage{
|
||||
InputTokens: openAIResponse.PromptTokens,
|
||||
OutputTokens: openAIResponse.CompletionTokens,
|
||||
}
|
||||
claudeResponse.Usage = buildClaudeUsageFromOpenAIUsage(&openAIResponse.Usage)
|
||||
|
||||
return claudeResponse
|
||||
}
|
||||
|
||||
@@ -104,6 +104,11 @@ func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, e
|
||||
return sniffed, nil
|
||||
}
|
||||
|
||||
// Try HEIF/HEIC detection (Go standard library doesn't recognize it)
|
||||
if heifMime := detectHEIF(readData); heifMime != "" {
|
||||
return heifMime, nil
|
||||
}
|
||||
|
||||
if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
|
||||
switch strings.ToLower(format) {
|
||||
case "jpeg", "jpg":
|
||||
@@ -168,6 +173,10 @@ func GetMimeTypeByExtension(ext string) string {
|
||||
return "image/gif"
|
||||
case "jfif":
|
||||
return "image/jpeg"
|
||||
case "heic":
|
||||
return "image/heic"
|
||||
case "heif":
|
||||
return "image/heif"
|
||||
|
||||
// Audio files
|
||||
case "mp3":
|
||||
|
||||
+168
-32
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
@@ -24,14 +25,26 @@ import (
|
||||
// FileService 统一的文件处理服务
|
||||
// 提供文件下载、解码、缓存等功能的统一入口
|
||||
|
||||
// getContextCacheKey 生成 context 缓存的 key
|
||||
// getContextCacheKey 生成 URL context 缓存的 key
|
||||
func getContextCacheKey(url string) string {
|
||||
return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
|
||||
}
|
||||
|
||||
// getBase64ContextCacheKey 生成 base64 context 缓存的 key
|
||||
// 使用 length + MIME + 前 128 字符作为输入,避免对整个 base64 数据做 hash
|
||||
func getBase64ContextCacheKey(data string, mimeType string) string {
|
||||
keyMaterial := fmt.Sprintf("%d:%s:", len(data), mimeType)
|
||||
if len(data) > 128 {
|
||||
keyMaterial += data[:128]
|
||||
} else {
|
||||
keyMaterial += data
|
||||
}
|
||||
return fmt.Sprintf("b64_cache_%s", common.GenerateHMAC(keyMaterial))
|
||||
}
|
||||
|
||||
// LoadFileSource 加载文件源数据
|
||||
// 这是统一的入口,会自动处理缓存和不同的来源类型
|
||||
func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
|
||||
func LoadFileSource(c *gin.Context, source types.FileSource, reason ...string) (*types.CachedFileData, error) {
|
||||
if source == nil {
|
||||
return nil, fmt.Errorf("file source is nil")
|
||||
}
|
||||
@@ -42,7 +55,6 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
|
||||
|
||||
// 1. 快速检查内部缓存
|
||||
if source.HasCache() {
|
||||
// 即使命中内部缓存,也要确保注册到清理列表(如果尚未注册)
|
||||
if c != nil {
|
||||
registerSourceForCleanup(c, source)
|
||||
}
|
||||
@@ -61,39 +73,49 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
|
||||
return source.GetCache(), nil
|
||||
}
|
||||
|
||||
// 4. 如果是 URL,检查 Context 缓存
|
||||
var contextKey string
|
||||
if source.IsURL() && c != nil {
|
||||
contextKey = getContextCacheKey(source.URL)
|
||||
if cachedData, exists := c.Get(contextKey); exists {
|
||||
data := cachedData.(*types.CachedFileData)
|
||||
source.SetCache(data)
|
||||
registerSourceForCleanup(c, source)
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 执行加载逻辑
|
||||
// 4. 根据来源类型加载(含 URL context 缓存查找)
|
||||
var cachedData *types.CachedFileData
|
||||
var contextKey string
|
||||
var err error
|
||||
|
||||
if source.IsURL() {
|
||||
cachedData, err = loadFromURL(c, source.URL, reason...)
|
||||
} else {
|
||||
cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
|
||||
switch s := source.(type) {
|
||||
case *types.URLSource:
|
||||
if c != nil {
|
||||
contextKey = getContextCacheKey(s.URL)
|
||||
if cached, exists := c.Get(contextKey); exists {
|
||||
data := cached.(*types.CachedFileData)
|
||||
source.SetCache(data)
|
||||
registerSourceForCleanup(c, source)
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
cachedData, err = loadFromURL(c, s.URL, reason...)
|
||||
case *types.Base64Source:
|
||||
if c != nil {
|
||||
contextKey = getBase64ContextCacheKey(s.Base64Data, s.MimeType)
|
||||
if cached, exists := c.Get(contextKey); exists {
|
||||
data := cached.(*types.CachedFileData)
|
||||
source.SetCache(data)
|
||||
registerSourceForCleanup(c, source)
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
cachedData, err = loadFromBase64(s.Base64Data, s.MimeType)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported file source type: %T", source)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 6. 设置缓存
|
||||
// 5. 设置缓存
|
||||
source.SetCache(cachedData)
|
||||
if contextKey != "" && c != nil {
|
||||
c.Set(contextKey, cachedData)
|
||||
}
|
||||
|
||||
// 7. 注册到 context 以便请求结束时自动清理
|
||||
// 6. 注册到 context 以便请求结束时自动清理
|
||||
if c != nil {
|
||||
registerSourceForCleanup(c, source)
|
||||
}
|
||||
@@ -102,15 +124,15 @@ func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string)
|
||||
}
|
||||
|
||||
// registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
|
||||
func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
|
||||
func registerSourceForCleanup(c *gin.Context, source types.FileSource) {
|
||||
if source.IsRegistered() {
|
||||
return
|
||||
}
|
||||
|
||||
key := string(constant.ContextKeyFileSourcesToCleanup)
|
||||
var sources []*types.FileSource
|
||||
var sources []types.FileSource
|
||||
if existing, exists := c.Get(key); exists {
|
||||
sources = existing.([]*types.FileSource)
|
||||
sources = existing.([]types.FileSource)
|
||||
}
|
||||
sources = append(sources, source)
|
||||
c.Set(key, sources)
|
||||
@@ -122,12 +144,12 @@ func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
|
||||
func CleanupFileSources(c *gin.Context) {
|
||||
key := string(constant.ContextKeyFileSourcesToCleanup)
|
||||
if sources, exists := c.Get(key); exists {
|
||||
for _, source := range sources.([]*types.FileSource) {
|
||||
for _, source := range sources.([]types.FileSource) {
|
||||
if cache := source.GetCache(); cache != nil {
|
||||
cache.Close()
|
||||
}
|
||||
}
|
||||
c.Set(key, nil) // 清除引用
|
||||
c.Set(key, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,6 +297,11 @@ func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) stri
|
||||
}
|
||||
return sniffed
|
||||
}
|
||||
|
||||
// 4.5 尝试 HEIF/HEIC 检测(Go 标准库不识别)
|
||||
if heifMime := detectHEIF(fileBytes); heifMime != "" {
|
||||
return heifMime
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 尝试作为图片解码获取格式
|
||||
@@ -357,7 +384,7 @@ func loadFromBase64(base64String string, providedMimeType string) (*types.Cached
|
||||
}
|
||||
|
||||
// GetImageConfig 获取图片配置
|
||||
func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
|
||||
func GetImageConfig(c *gin.Context, source types.FileSource) (image.Config, string, error) {
|
||||
cachedData, err := LoadFileSource(c, source, "get_image_config")
|
||||
if err != nil {
|
||||
return image.Config{}, "", err
|
||||
@@ -388,7 +415,7 @@ func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, str
|
||||
}
|
||||
|
||||
// GetBase64Data 获取 base64 编码的数据
|
||||
func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
|
||||
func GetBase64Data(c *gin.Context, source types.FileSource, reason ...string) (string, string, error) {
|
||||
cachedData, err := LoadFileSource(c, source, reason...)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
@@ -401,13 +428,13 @@ func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (
|
||||
}
|
||||
|
||||
// GetMimeType 获取文件的 MIME 类型
|
||||
func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
|
||||
func GetMimeType(c *gin.Context, source types.FileSource) (string, error) {
|
||||
if source.HasCache() {
|
||||
return source.GetCache().MimeType, nil
|
||||
}
|
||||
|
||||
if source.IsURL() {
|
||||
mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
|
||||
if urlSource, ok := source.(*types.URLSource); ok {
|
||||
mimeType, err := GetFileTypeFromUrl(c, urlSource.URL, "get_mime_type")
|
||||
if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
|
||||
return mimeType, nil
|
||||
}
|
||||
@@ -449,9 +476,118 @@ func decodeImageConfig(data []byte) (image.Config, string, error) {
|
||||
return config, "webp", nil
|
||||
}
|
||||
|
||||
// Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
|
||||
if heifMime := detectHEIF(data); heifMime != "" {
|
||||
formatName := "heif"
|
||||
if heifMime == "image/heic" {
|
||||
formatName = "heic"
|
||||
}
|
||||
if w, h, ok := parseHEIFDimensions(data); ok {
|
||||
return image.Config{Width: w, Height: h}, formatName, nil
|
||||
}
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
|
||||
}
|
||||
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
|
||||
}
|
||||
|
||||
// detectHEIF checks ISOBMFF magic bytes to detect HEIC/HEIF files.
|
||||
// Returns "image/heic", "image/heif", or "" if not recognized.
|
||||
func detectHEIF(data []byte) string {
|
||||
if len(data) < 12 {
|
||||
return ""
|
||||
}
|
||||
// ISOBMFF: bytes[4:8] must be "ftyp"
|
||||
if string(data[4:8]) != "ftyp" {
|
||||
return ""
|
||||
}
|
||||
brand := string(data[8:12])
|
||||
switch brand {
|
||||
case "heic", "heix", "hevc", "hevx", "heim", "heis":
|
||||
return "image/heic"
|
||||
case "mif1", "msf1":
|
||||
return "image/heif"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// parseHEIFDimensions parses ISOBMFF box tree to find the ispe box
|
||||
// and extract image width/height. Returns (width, height, ok).
|
||||
func parseHEIFDimensions(data []byte) (int, int, bool) {
|
||||
size := len(data)
|
||||
if size < 12 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// Walk top-level boxes to find "meta"
|
||||
offset := 0
|
||||
for offset+8 <= size {
|
||||
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
|
||||
boxType := string(data[offset+4 : offset+8])
|
||||
headerLen := 8
|
||||
|
||||
if boxSize == 1 {
|
||||
// 64-bit extended size
|
||||
if offset+16 > size {
|
||||
break
|
||||
}
|
||||
boxSize = int(binary.BigEndian.Uint64(data[offset+8 : offset+16]))
|
||||
headerLen = 16
|
||||
} else if boxSize == 0 {
|
||||
// box extends to end of data
|
||||
boxSize = size - offset
|
||||
}
|
||||
|
||||
if boxSize < headerLen || offset+boxSize > size {
|
||||
break
|
||||
}
|
||||
|
||||
if boxType == "meta" {
|
||||
// meta is a full box: 4 bytes version/flags after header
|
||||
metaData := data[offset+headerLen : offset+boxSize]
|
||||
if len(metaData) < 4 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return findISPE(metaData[4:])
|
||||
}
|
||||
offset += boxSize
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// findISPE recursively searches for the ispe box within container boxes.
|
||||
// Path: meta -> iprp -> ipco -> ispe
|
||||
func findISPE(data []byte) (int, int, bool) {
|
||||
offset := 0
|
||||
size := len(data)
|
||||
for offset+8 <= size {
|
||||
boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
|
||||
boxType := string(data[offset+4 : offset+8])
|
||||
if boxSize < 8 || offset+boxSize > size {
|
||||
break
|
||||
}
|
||||
content := data[offset+8 : offset+boxSize]
|
||||
switch boxType {
|
||||
case "iprp", "ipco":
|
||||
if w, h, ok := findISPE(content); ok {
|
||||
return w, h, true
|
||||
}
|
||||
case "ispe":
|
||||
// ispe is a full box: 4 bytes version/flags, then 4 bytes width, 4 bytes height
|
||||
if len(content) >= 12 {
|
||||
w := int(binary.BigEndian.Uint32(content[4:8]))
|
||||
h := int(binary.BigEndian.Uint32(content[8:12]))
|
||||
if w > 0 && h > 0 {
|
||||
return w, h, true
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += boxSize
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// guessMimeTypeFromURL 从 URL 猜测 MIME 类型
|
||||
func guessMimeTypeFromURL(url string) string {
|
||||
cleanedURL := url
|
||||
|
||||
+29
-13
@@ -159,20 +159,36 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
}
|
||||
|
||||
func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
||||
// Read all data so we can retry with different decoders
|
||||
data, readErr := io.ReadAll(reader)
|
||||
if readErr != nil {
|
||||
return image.Config{}, "", fmt.Errorf("failed to read image data: %w", readErr)
|
||||
}
|
||||
|
||||
// 读取图片的头部信息来获取图片尺寸
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
config, err = webp.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
config, format, err := image.DecodeConfig(bytes.NewReader(data))
|
||||
if err == nil {
|
||||
return config, format, nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
|
||||
config, err = webp.DecodeConfig(bytes.NewReader(data))
|
||||
if err == nil {
|
||||
return config, "webp", nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
|
||||
// Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
|
||||
if heifMime := detectHEIF(data); heifMime != "" {
|
||||
formatName := "heif"
|
||||
if heifMime == "image/heic" {
|
||||
formatName = "heic"
|
||||
}
|
||||
format = "webp"
|
||||
if w, h, ok := parseHEIFDimensions(data); ok {
|
||||
return image.Config{Width: w, Height: h}, formatName, nil
|
||||
}
|
||||
return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
|
||||
}
|
||||
if err != nil {
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
return config, format, nil
|
||||
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
@@ -73,8 +73,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
appendRequestConversionChain(relayInfo, other)
|
||||
appendFinalRequestFormat(relayInfo, other)
|
||||
appendBillingInfo(relayInfo, other)
|
||||
appendParamOverrideInfo(relayInfo, other)
|
||||
appendStreamStatus(relayInfo, other)
|
||||
return other
|
||||
}
|
||||
|
||||
@@ -85,6 +87,33 @@ func appendParamOverrideInfo(relayInfo *relaycommon.RelayInfo, other map[string]
|
||||
other["po"] = relayInfo.ParamOverrideAudit
|
||||
}
|
||||
|
||||
func appendStreamStatus(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil || !relayInfo.IsStream || relayInfo.StreamStatus == nil {
|
||||
return
|
||||
}
|
||||
ss := relayInfo.StreamStatus
|
||||
status := "ok"
|
||||
if !ss.IsNormalEnd() || ss.HasErrors() {
|
||||
status = "error"
|
||||
}
|
||||
streamInfo := map[string]interface{}{
|
||||
"status": status,
|
||||
"end_reason": string(ss.EndReason),
|
||||
}
|
||||
if ss.EndError != nil {
|
||||
streamInfo["end_error"] = ss.EndError.Error()
|
||||
}
|
||||
if ss.ErrorCount > 0 {
|
||||
streamInfo["error_count"] = ss.ErrorCount
|
||||
messages := make([]string, 0, len(ss.Errors))
|
||||
for _, e := range ss.Errors {
|
||||
messages = append(messages, e.Message)
|
||||
}
|
||||
streamInfo["errors"] = messages
|
||||
}
|
||||
other["stream_status"] = streamInfo
|
||||
}
|
||||
|
||||
func appendBillingInfo(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil {
|
||||
return
|
||||
@@ -167,6 +196,17 @@ func appendRequestConversionChain(relayInfo *relaycommon.RelayInfo, other map[st
|
||||
other["request_conversion"] = chain
|
||||
}
|
||||
|
||||
func appendFinalRequestFormat(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil {
|
||||
return
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
// claude indicates the final upstream request format is Claude Messages.
|
||||
// Frontend log rendering uses this to keep the original Claude input display.
|
||||
other["claude"] = true
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
|
||||
info["ws"] = true
|
||||
|
||||
@@ -235,108 +235,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
})
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
|
||||
if usage != nil {
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
cacheCreationRatio5m := relayInfo.PriceData.CacheCreation5mRatio
|
||||
cacheCreationRatio1h := relayInfo.PriceData.CacheCreation1hRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
cacheCreationTokens5m := usage.ClaudeCacheCreation5mTokens
|
||||
cacheCreationTokens1h := usage.ClaudeCacheCreation1hTokens
|
||||
|
||||
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
|
||||
if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
|
||||
if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
|
||||
cacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
promptTokens -= cacheCreationTokens
|
||||
}
|
||||
|
||||
calculateQuota := 0.0
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
calculateQuota = float64(promptTokens)
|
||||
calculateQuota += float64(cacheTokens) * cacheRatio
|
||||
calculateQuota += float64(cacheCreationTokens5m) * cacheCreationRatio5m
|
||||
calculateQuota += float64(cacheCreationTokens1h) * cacheCreationRatio1h
|
||||
remainingCacheCreationTokens := cacheCreationTokens - cacheCreationTokens5m - cacheCreationTokens1h
|
||||
if remainingCacheCreationTokens > 0 {
|
||||
calculateQuota += float64(remainingCacheCreationTokens) * cacheCreationRatio
|
||||
}
|
||||
calculateQuota += float64(completionTokens) * completionRatio
|
||||
calculateQuota = calculateQuota * groupRatio * modelRatio
|
||||
} else {
|
||||
calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
|
||||
}
|
||||
|
||||
if modelRatio != 0 && calculateQuota <= 0 {
|
||||
calculateQuota = 1
|
||||
}
|
||||
|
||||
quota := int(calculateQuota)
|
||||
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
var logContent string
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游出错)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio,
|
||||
cacheCreationTokens, cacheCreationRatio,
|
||||
cacheCreationTokens5m, cacheCreationRatio5m,
|
||||
cacheCreationTokens1h, cacheCreationRatio1h,
|
||||
modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
|
||||
if priceData.CacheCreationRatio == 1 {
|
||||
return 0
|
||||
|
||||
+20
-4
@@ -36,8 +36,12 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
other["is_task"] = true
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
other["model_price"] = info.PriceData.ModelPrice
|
||||
if info.PriceData.ModelRatio > 0 {
|
||||
other["model_ratio"] = info.PriceData.ModelRatio
|
||||
}
|
||||
other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
|
||||
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
|
||||
@@ -117,6 +121,9 @@ func taskBillingOther(task *model.Task) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
if bc := task.PrivateData.BillingContext; bc != nil {
|
||||
other["model_price"] = bc.ModelPrice
|
||||
if bc.ModelRatio > 0 {
|
||||
other["model_ratio"] = bc.ModelRatio
|
||||
}
|
||||
other["group_ratio"] = bc.GroupRatio
|
||||
if len(bc.OtherRatios) > 0 {
|
||||
for k, v := range bc.OtherRatios {
|
||||
@@ -222,7 +229,6 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int
|
||||
}
|
||||
other := taskBillingOther(task)
|
||||
other["task_id"] = task.TaskID
|
||||
//other["reason"] = reason
|
||||
other["pre_consumed_quota"] = preConsumedQuota
|
||||
other["actual_quota"] = actualQuota
|
||||
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
|
||||
@@ -277,9 +283,19 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
|
||||
finalGroupRatio = groupRatio
|
||||
}
|
||||
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
|
||||
// 计算 OtherRatios 乘积(视频折扣、时长等)
|
||||
otherMultiplier := 1.0
|
||||
if bc := task.PrivateData.BillingContext; bc != nil {
|
||||
for _, r := range bc.OtherRatios {
|
||||
if r != 1.0 && r > 0 {
|
||||
otherMultiplier *= r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio * otherMultiplier
|
||||
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio * otherMultiplier)
|
||||
|
||||
reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f, otherMultiplier=%.4f", totalTokens, modelRatio, finalGroupRatio, otherMultiplier)
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, reason)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,430 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
type textQuotaSummary struct {
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
CacheTokens int
|
||||
CacheCreationTokens int
|
||||
CacheCreationTokens5m int
|
||||
CacheCreationTokens1h int
|
||||
ImageTokens int
|
||||
AudioTokens int
|
||||
ModelName string
|
||||
TokenName string
|
||||
UseTimeSeconds int64
|
||||
CompletionRatio float64
|
||||
CacheRatio float64
|
||||
ImageRatio float64
|
||||
ModelRatio float64
|
||||
GroupRatio float64
|
||||
ModelPrice float64
|
||||
CacheCreationRatio float64
|
||||
CacheCreationRatio5m float64
|
||||
CacheCreationRatio1h float64
|
||||
Quota int
|
||||
IsClaudeUsageSemantic bool
|
||||
UsageSemantic string
|
||||
WebSearchPrice float64
|
||||
WebSearchCallCount int
|
||||
ClaudeWebSearchPrice float64
|
||||
ClaudeWebSearchCallCount int
|
||||
FileSearchPrice float64
|
||||
FileSearchCallCount int
|
||||
AudioInputPrice float64
|
||||
ImageGenerationCallPrice float64
|
||||
}
|
||||
|
||||
func cacheWriteTokensTotal(summary textQuotaSummary) int {
|
||||
if summary.CacheCreationTokens5m > 0 || summary.CacheCreationTokens1h > 0 {
|
||||
splitCacheWriteTokens := summary.CacheCreationTokens5m + summary.CacheCreationTokens1h
|
||||
if summary.CacheCreationTokens > splitCacheWriteTokens {
|
||||
return summary.CacheCreationTokens
|
||||
}
|
||||
return splitCacheWriteTokens
|
||||
}
|
||||
return summary.CacheCreationTokens
|
||||
}
|
||||
|
||||
func isLegacyClaudeDerivedOpenAIUsage(relayInfo *relaycommon.RelayInfo, usage *dto.Usage) bool {
|
||||
if relayInfo == nil || usage == nil {
|
||||
return false
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
return false
|
||||
}
|
||||
if usage.UsageSource != "" || usage.UsageSemantic != "" {
|
||||
return false
|
||||
}
|
||||
return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0
|
||||
}
|
||||
|
||||
func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary {
|
||||
summary := textQuotaSummary{
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
TokenName: ctx.GetString("token_name"),
|
||||
UseTimeSeconds: time.Now().Unix() - relayInfo.StartTime.Unix(),
|
||||
CompletionRatio: relayInfo.PriceData.CompletionRatio,
|
||||
CacheRatio: relayInfo.PriceData.CacheRatio,
|
||||
ImageRatio: relayInfo.PriceData.ImageRatio,
|
||||
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
|
||||
ModelPrice: relayInfo.PriceData.ModelPrice,
|
||||
CacheCreationRatio: relayInfo.PriceData.CacheCreationRatio,
|
||||
CacheCreationRatio5m: relayInfo.PriceData.CacheCreation5mRatio,
|
||||
CacheCreationRatio1h: relayInfo.PriceData.CacheCreation1hRatio,
|
||||
UsageSemantic: usageSemanticFromUsage(relayInfo, usage),
|
||||
}
|
||||
summary.IsClaudeUsageSemantic = summary.UsageSemantic == "anthropic"
|
||||
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
}
|
||||
|
||||
summary.PromptTokens = usage.PromptTokens
|
||||
summary.CompletionTokens = usage.CompletionTokens
|
||||
summary.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
summary.CacheTokens = usage.PromptTokensDetails.CachedTokens
|
||||
summary.CacheCreationTokens = usage.PromptTokensDetails.CachedCreationTokens
|
||||
summary.CacheCreationTokens5m = usage.ClaudeCacheCreation5mTokens
|
||||
summary.CacheCreationTokens1h = usage.ClaudeCacheCreation1hTokens
|
||||
summary.ImageTokens = usage.PromptTokensDetails.ImageTokens
|
||||
summary.AudioTokens = usage.PromptTokensDetails.AudioTokens
|
||||
legacyClaudeDerived := isLegacyClaudeDerivedOpenAIUsage(relayInfo, usage)
|
||||
isOpenRouterClaudeBilling := relayInfo.ChannelMeta != nil &&
|
||||
relayInfo.ChannelType == constant.ChannelTypeOpenRouter &&
|
||||
summary.IsClaudeUsageSemantic
|
||||
|
||||
if isOpenRouterClaudeBilling {
|
||||
summary.PromptTokens -= summary.CacheTokens
|
||||
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(summary.ModelName, relayInfo.PriceData.ModelRatio)
|
||||
if summary.CacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
|
||||
if maybeCacheCreationTokens >= 0 && summary.PromptTokens >= maybeCacheCreationTokens {
|
||||
summary.CacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
summary.PromptTokens -= summary.CacheCreationTokens
|
||||
}
|
||||
|
||||
dPromptTokens := decimal.NewFromInt(int64(summary.PromptTokens))
|
||||
dCacheTokens := decimal.NewFromInt(int64(summary.CacheTokens))
|
||||
dImageTokens := decimal.NewFromInt(int64(summary.ImageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(summary.AudioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(summary.CompletionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(summary.CacheCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(summary.CompletionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(summary.CacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(summary.ImageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(summary.ModelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(summary.GroupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(summary.ModelPrice)
|
||||
dCacheCreationRatio := decimal.NewFromFloat(summary.CacheCreationRatio)
|
||||
dCacheCreationRatio5m := decimal.NewFromFloat(summary.CacheCreationRatio5m)
|
||||
dCacheCreationRatio1h := decimal.NewFromFloat(summary.CacheCreationRatio1h)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
|
||||
var dWebSearchQuota decimal.Decimal
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||||
summary.WebSearchCallCount = webSearchTool.CallCount
|
||||
summary.WebSearchPrice = operation_setting.GetWebSearchPricePerThousand(summary.ModelName, webSearchTool.SearchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
} else if strings.HasSuffix(summary.ModelName, "search-preview") {
|
||||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
||||
if searchContextSize == "" {
|
||||
searchContextSize = "medium"
|
||||
}
|
||||
summary.WebSearchCallCount = 1
|
||||
summary.WebSearchPrice = operation_setting.GetWebSearchPricePerThousand(summary.ModelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests")
|
||||
if summary.ClaudeWebSearchCallCount > 0 {
|
||||
summary.ClaudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(summary.ClaudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).
|
||||
Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount)))
|
||||
}
|
||||
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||||
summary.FileSearchCallCount = fileSearchTool.CallCount
|
||||
summary.FileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
|
||||
dFileSearchQuota = decimal.NewFromFloat(summary.FileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
if !summary.IsClaudeUsageSemantic && !legacyClaudeDerived {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
|
||||
var cachedCreationTokensWithRatio decimal.Decimal
|
||||
hasSplitCacheCreationTokens := summary.CacheCreationTokens5m > 0 || summary.CacheCreationTokens1h > 0
|
||||
if !dCachedCreationTokens.IsZero() || hasSplitCacheCreationTokens {
|
||||
if !summary.IsClaudeUsageSemantic && !legacyClaudeDerived {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
cachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCacheCreationRatio)
|
||||
} else {
|
||||
remaining := summary.CacheCreationTokens - summary.CacheCreationTokens5m - summary.CacheCreationTokens1h
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
cachedCreationTokensWithRatio = decimal.NewFromInt(int64(remaining)).Mul(dCacheCreationRatio)
|
||||
cachedCreationTokensWithRatio = cachedCreationTokensWithRatio.Add(decimal.NewFromInt(int64(summary.CacheCreationTokens5m)).Mul(dCacheCreationRatio5m))
|
||||
cachedCreationTokensWithRatio = cachedCreationTokensWithRatio.Add(decimal.NewFromInt(int64(summary.CacheCreationTokens1h)).Mul(dCacheCreationRatio1h))
|
||||
}
|
||||
}
|
||||
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
if !dImageTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dImageTokens)
|
||||
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
|
||||
}
|
||||
|
||||
if !dAudioTokens.IsZero() {
|
||||
summary.AudioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(summary.ModelName)
|
||||
if summary.AudioInputPrice > 0 {
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(summary.AudioInputPrice).
|
||||
Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio)
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||||
}
|
||||
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
} else {
|
||||
quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
||||
}
|
||||
}
|
||||
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
}
|
||||
|
||||
if summary.TotalTokens == 0 {
|
||||
summary.Quota = 0
|
||||
} else if !ratio.IsZero() && summary.Quota == 0 {
|
||||
summary.Quota = 1
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
func usageSemanticFromUsage(relayInfo *relaycommon.RelayInfo, usage *dto.Usage) string {
|
||||
if usage != nil && usage.UsageSemantic != "" {
|
||||
return usage.UsageSemantic
|
||||
}
|
||||
if relayInfo != nil && relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
return "anthropic"
|
||||
}
|
||||
return "openai"
|
||||
}
|
||||
|
||||
func PostTextConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent []string) {
|
||||
originUsage := usage
|
||||
if usage == nil {
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
if originUsage != nil {
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
if summary.WebSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,调用花费 %s", summary.WebSearchCallCount, decimal.NewFromFloat(summary.WebSearchPrice).Mul(decimal.NewFromInt(int64(summary.WebSearchCallCount))).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.ClaudeWebSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", summary.ClaudeWebSearchCallCount, decimal.NewFromFloat(summary.ClaudeWebSearchPrice).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount))).String()))
|
||||
}
|
||||
if summary.FileSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s", summary.FileSearchCallCount, decimal.NewFromFloat(summary.FileSearchPrice).Mul(decimal.NewFromInt(int64(summary.FileSearchCallCount))).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.AudioInputPrice > 0 && summary.AudioTokens > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", decimal.NewFromFloat(summary.AudioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(decimal.NewFromInt(int64(summary.AudioTokens))).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.ImageGenerationCallPrice > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
|
||||
if summary.TotalTokens == 0 {
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, summary.ModelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, summary.Quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, summary.Quota)
|
||||
}
|
||||
|
||||
if err := SettleBilling(ctx, relayInfo, summary.Quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := summary.ModelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", summary.ModelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", summary.ModelName))
|
||||
}
|
||||
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
var other map[string]interface{}
|
||||
if summary.IsClaudeUsageSemantic {
|
||||
other = GenerateClaudeOtherInfo(ctx, relayInfo,
|
||||
summary.ModelRatio, summary.GroupRatio, summary.CompletionRatio,
|
||||
summary.CacheTokens, summary.CacheRatio,
|
||||
summary.CacheCreationTokens, summary.CacheCreationRatio,
|
||||
summary.CacheCreationTokens5m, summary.CacheCreationRatio5m,
|
||||
summary.CacheCreationTokens1h, summary.CacheCreationRatio1h,
|
||||
summary.ModelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
other["usage_semantic"] = "anthropic"
|
||||
} else {
|
||||
other = GenerateTextOtherInfo(ctx, relayInfo, summary.ModelRatio, summary.GroupRatio, summary.CompletionRatio, summary.CacheTokens, summary.CacheRatio, summary.ModelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
}
|
||||
if adminRejectReason != "" {
|
||||
other["reject_reason"] = adminRejectReason
|
||||
}
|
||||
if summary.ImageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = summary.ImageRatio
|
||||
other["image_output"] = summary.ImageTokens
|
||||
}
|
||||
if summary.WebSearchCallCount > 0 {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = summary.WebSearchCallCount
|
||||
other["web_search_price"] = summary.WebSearchPrice
|
||||
} else if summary.ClaudeWebSearchCallCount > 0 {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = summary.ClaudeWebSearchCallCount
|
||||
other["web_search_price"] = summary.ClaudeWebSearchPrice
|
||||
}
|
||||
if summary.FileSearchCallCount > 0 {
|
||||
other["file_search"] = true
|
||||
other["file_search_call_count"] = summary.FileSearchCallCount
|
||||
other["file_search_price"] = summary.FileSearchPrice
|
||||
}
|
||||
if summary.AudioInputPrice > 0 && summary.AudioTokens > 0 {
|
||||
other["audio_input_seperate_price"] = true
|
||||
other["audio_input_token_count"] = summary.AudioTokens
|
||||
other["audio_input_price"] = summary.AudioInputPrice
|
||||
}
|
||||
if summary.ImageGenerationCallPrice > 0 {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = summary.ImageGenerationCallPrice
|
||||
}
|
||||
if summary.CacheCreationTokens > 0 {
|
||||
other["cache_creation_tokens"] = summary.CacheCreationTokens
|
||||
other["cache_creation_ratio"] = summary.CacheCreationRatio
|
||||
}
|
||||
if summary.CacheCreationTokens5m > 0 {
|
||||
other["cache_creation_tokens_5m"] = summary.CacheCreationTokens5m
|
||||
other["cache_creation_ratio_5m"] = summary.CacheCreationRatio5m
|
||||
}
|
||||
if summary.CacheCreationTokens1h > 0 {
|
||||
other["cache_creation_tokens_1h"] = summary.CacheCreationTokens1h
|
||||
other["cache_creation_ratio_1h"] = summary.CacheCreationRatio1h
|
||||
}
|
||||
cacheWriteTokens := cacheWriteTokensTotal(summary)
|
||||
if cacheWriteTokens > 0 {
|
||||
// cache_write_tokens: normalized cache creation total for UI display.
|
||||
// If split 5m/1h values are present, this is their sum; otherwise it falls back
|
||||
// to cache_creation_tokens.
|
||||
other["cache_write_tokens"] = cacheWriteTokens
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() != types.RelayFormatClaude && usage != nil && usage.UsageSource != "" && usage.InputTokens > 0 {
|
||||
// input_tokens_total: explicit normalized total input used by the usage log UI.
|
||||
// Only write this field when upstream/current conversion has already provided a
|
||||
// reliable total input value and tagged the usage source. Do not infer it from
|
||||
// prompt/cache fields here, otherwise old upstream payloads may be double-counted.
|
||||
other["input_tokens_total"] = usage.InputTokens
|
||||
}
|
||||
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: summary.PromptTokens,
|
||||
CompletionTokens: summary.CompletionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: summary.TokenName,
|
||||
Quota: summary.Quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(summary.UseTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCalculateTextQuotaSummaryUnifiedForClaudeSemantic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 1000,
|
||||
CompletionTokens: 200,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 100,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
}
|
||||
|
||||
priceData := types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 2,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
}
|
||||
|
||||
chatRelayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: priceData,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
messageRelayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatClaude,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: priceData,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
chatSummary := calculateTextQuotaSummary(ctx, chatRelayInfo, usage)
|
||||
messageSummary := calculateTextQuotaSummary(ctx, messageRelayInfo, usage)
|
||||
|
||||
require.Equal(t, messageSummary.Quota, chatSummary.Quota)
|
||||
require.Equal(t, messageSummary.CacheCreationTokens5m, chatSummary.CacheCreationTokens5m)
|
||||
require.Equal(t, messageSummary.CacheCreationTokens1h, chatSummary.CacheCreationTokens1h)
|
||||
require.True(t, chatSummary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, 1488, chatSummary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryUsesSplitClaudeCacheCreationRatios(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0,
|
||||
CacheCreationRatio: 1,
|
||||
CacheCreation5mRatio: 2,
|
||||
CacheCreation1hRatio: 3,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 0,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedCreationTokens: 10,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 2,
|
||||
ClaudeCacheCreation1hTokens: 3,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// 100 + remaining(5)*1 + 2*2 + 3*3 = 118
|
||||
require.Equal(t, 118, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryUsesAnthropicUsageSemanticFromUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 2,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 1000,
|
||||
CompletionTokens: 200,
|
||||
UsageSemantic: "anthropic",
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 100,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
require.True(t, summary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, "anthropic", summary.UsageSemantic)
|
||||
require.Equal(t, 1488, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCacheWriteTokensTotal(t *testing.T) {
|
||||
t.Run("split cache creation", func(t *testing.T) {
|
||||
summary := textQuotaSummary{
|
||||
CacheCreationTokens: 50,
|
||||
CacheCreationTokens5m: 10,
|
||||
CacheCreationTokens1h: 20,
|
||||
}
|
||||
require.Equal(t, 50, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
|
||||
t.Run("legacy cache creation", func(t *testing.T) {
|
||||
summary := textQuotaSummary{CacheCreationTokens: 50}
|
||||
require.Equal(t, 50, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
|
||||
t.Run("split cache creation without aggregate remainder", func(t *testing.T) {
|
||||
summary := textQuotaSummary{
|
||||
CacheCreationTokens5m: 10,
|
||||
CacheCreationTokens1h: 20,
|
||||
}
|
||||
require.Equal(t, 30, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryHandlesLegacyClaudeDerivedOpenAIUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 5,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 62,
|
||||
CompletionTokens: 95,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 3544,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 586,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// 62 + 3544*0.1 + 586*1.25 + 95*5 = 1624.9 => 1624
|
||||
require.Equal(t, 1624, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummarySeparatesOpenRouterCacheReadFromPromptBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
OriginModelName: "openai/gpt-4.1",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 2432,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// OpenRouter OpenAI-format display keeps prompt_tokens as total input,
|
||||
// but billing still separates normal input from cache read tokens.
|
||||
// quota = (2604 - 2432) + 2432*0.1 + 383 = 798.2 => 798
|
||||
require.Equal(t, 2604, summary.PromptTokens)
|
||||
require.Equal(t, 798, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummarySeparatesOpenRouterCacheCreationFromPromptBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
OriginModelName: "openai/gpt-4.1",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedCreationTokens: 100,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// prompt_tokens is still logged as total input, but cache creation is billed separately.
|
||||
// quota = (2604 - 100) + 100*1.25 + 383 = 3012
|
||||
require.Equal(t, 2604, summary.PromptTokens)
|
||||
require.Equal(t, 3012, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryKeepsPrePRClaudeOpenRouterBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "anthropic/claude-3.7-sonnet",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 2432,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// Pre-PR PostClaudeConsumeQuota behavior for OpenRouter:
|
||||
// prompt = 2604 - 2432 = 172
|
||||
// quota = 172 + 2432*0.1 + 383 = 798.2 => 798
|
||||
require.True(t, summary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, 172, summary.PromptTokens)
|
||||
require.Equal(t, 798, summary.Quota)
|
||||
}
|
||||
@@ -100,8 +100,6 @@ func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, strea
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
fileMeta.MimeType = format
|
||||
|
||||
if config.Width == 0 || config.Height == 0 {
|
||||
// not an image, but might be a valid file
|
||||
if format != "" {
|
||||
@@ -268,7 +266,6 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
||||
}
|
||||
continue
|
||||
}
|
||||
file.MimeType = cachedData.MimeType
|
||||
file.FileType = DetectFileType(cachedData.MimeType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ var defaultQwenSettings = QwenSettings{
|
||||
"z-image",
|
||||
"qwen-image",
|
||||
"wan2.6",
|
||||
"wan2.7",
|
||||
"qwen-image-edit",
|
||||
"qwen-image-edit-max",
|
||||
"qwen-image-edit-max-2026-01-16",
|
||||
|
||||
@@ -88,7 +88,7 @@ var channelAffinitySetting = ChannelAffinitySetting{
|
||||
ValueRegex: "",
|
||||
TTLSeconds: 0,
|
||||
ParamOverrideTemplate: buildPassHeaderTemplate(codexCliPassThroughHeaders),
|
||||
SkipRetryOnFailure: false,
|
||||
SkipRetryOnFailure: true,
|
||||
IncludeUsingGroup: true,
|
||||
IncludeRuleName: true,
|
||||
UserAgentInclude: nil,
|
||||
@@ -103,7 +103,7 @@ var channelAffinitySetting = ChannelAffinitySetting{
|
||||
ValueRegex: "",
|
||||
TTLSeconds: 0,
|
||||
ParamOverrideTemplate: buildPassHeaderTemplate(claudeCliPassThroughHeaders),
|
||||
SkipRetryOnFailure: false,
|
||||
SkipRetryOnFailure: true,
|
||||
IncludeUsingGroup: true,
|
||||
IncludeRuleName: true,
|
||||
UserAgentInclude: nil,
|
||||
|
||||
@@ -36,7 +36,7 @@ var performanceSetting = PerformanceSetting{
|
||||
MonitorEnabled: true,
|
||||
MonitorCPUThreshold: 90,
|
||||
MonitorMemoryThreshold: 90,
|
||||
MonitorDiskThreshold: 90,
|
||||
MonitorDiskThreshold: 95,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -21,7 +21,7 @@ var defaultFetchSetting = FetchSetting{
|
||||
DomainList: []string{},
|
||||
IpList: []string{},
|
||||
AllowedPorts: []string{"80", "443", "8080", "8443"},
|
||||
ApplyIPFilterForDomain: false,
|
||||
ApplyIPFilterForDomain: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
+128
-127
@@ -4,39 +4,144 @@ import (
|
||||
"fmt"
|
||||
"image"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FileSourceType 文件来源类型
|
||||
type FileSourceType string
|
||||
|
||||
const (
|
||||
FileSourceTypeURL FileSourceType = "url" // URL 来源
|
||||
FileSourceTypeBase64 FileSourceType = "base64" // Base64 内联数据
|
||||
)
|
||||
|
||||
// FileSource 统一的文件来源抽象
|
||||
// FileSource 统一的文件来源抽象接口
|
||||
// 支持 URL 和 base64 两种来源,提供懒加载和缓存机制
|
||||
type FileSource struct {
|
||||
Type FileSourceType `json:"type"` // 来源类型
|
||||
URL string `json:"url,omitempty"` // URL(当 Type 为 url 时)
|
||||
Base64Data string `json:"base64_data,omitempty"` // Base64 数据(当 Type 为 base64 时)
|
||||
MimeType string `json:"mime_type,omitempty"` // MIME 类型(可选,会自动检测)
|
||||
type FileSource interface {
|
||||
IsURL() bool
|
||||
GetIdentifier() string
|
||||
GetRawData() string
|
||||
ClearRawData()
|
||||
|
||||
// 内部缓存(不导出,不序列化)
|
||||
SetCache(data *CachedFileData)
|
||||
GetCache() *CachedFileData
|
||||
HasCache() bool
|
||||
ClearCache()
|
||||
|
||||
IsRegistered() bool
|
||||
SetRegistered(registered bool)
|
||||
Mu() *sync.Mutex
|
||||
}
|
||||
|
||||
// baseFileSource 共享的缓存/锁/清理注册状态
|
||||
type baseFileSource struct {
|
||||
cachedData *CachedFileData
|
||||
cacheLoaded bool
|
||||
registered bool // 是否已注册到清理列表
|
||||
mu sync.Mutex // 保护加载过程
|
||||
registered bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Mu 获取内部锁
|
||||
func (f *FileSource) Mu() *sync.Mutex {
|
||||
return &f.mu
|
||||
func (b *baseFileSource) SetCache(data *CachedFileData) {
|
||||
b.cachedData = data
|
||||
b.cacheLoaded = true
|
||||
}
|
||||
|
||||
// CachedFileData 缓存的文件数据
|
||||
// 支持内存缓存和磁盘缓存两种模式
|
||||
func (b *baseFileSource) GetCache() *CachedFileData {
|
||||
return b.cachedData
|
||||
}
|
||||
|
||||
func (b *baseFileSource) HasCache() bool {
|
||||
return b.cacheLoaded && b.cachedData != nil
|
||||
}
|
||||
|
||||
func (b *baseFileSource) ClearCache() {
|
||||
if b.cachedData != nil {
|
||||
b.cachedData.Close()
|
||||
}
|
||||
b.cachedData = nil
|
||||
b.cacheLoaded = false
|
||||
}
|
||||
|
||||
func (b *baseFileSource) IsRegistered() bool {
|
||||
return b.registered
|
||||
}
|
||||
|
||||
func (b *baseFileSource) SetRegistered(registered bool) {
|
||||
b.registered = registered
|
||||
}
|
||||
|
||||
func (b *baseFileSource) Mu() *sync.Mutex {
|
||||
return &b.mu
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URLSource — URL 来源的 FileSource 实现
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type URLSource struct {
|
||||
baseFileSource
|
||||
URL string
|
||||
}
|
||||
|
||||
func (u *URLSource) IsURL() bool { return true }
|
||||
|
||||
func (u *URLSource) GetIdentifier() string {
|
||||
if len(u.URL) > 100 {
|
||||
return u.URL[:100] + "..."
|
||||
}
|
||||
return u.URL
|
||||
}
|
||||
|
||||
func (u *URLSource) GetRawData() string { return u.URL }
|
||||
|
||||
func (u *URLSource) ClearRawData() {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Base64Source — Base64 内联数据来源的 FileSource 实现
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type Base64Source struct {
|
||||
baseFileSource
|
||||
Base64Data string
|
||||
MimeType string
|
||||
}
|
||||
|
||||
func (b *Base64Source) IsURL() bool { return false }
|
||||
|
||||
func (b *Base64Source) GetIdentifier() string {
|
||||
if len(b.Base64Data) > 50 {
|
||||
return "base64:" + b.Base64Data[:50] + "..."
|
||||
}
|
||||
return "base64:" + b.Base64Data
|
||||
}
|
||||
|
||||
func (b *Base64Source) GetRawData() string { return b.Base64Data }
|
||||
|
||||
func (b *Base64Source) ClearRawData() {
|
||||
if len(b.Base64Data) > 1024 {
|
||||
b.Base64Data = ""
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constructors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func NewURLFileSource(url string) *URLSource {
|
||||
return &URLSource{URL: url}
|
||||
}
|
||||
|
||||
func NewBase64FileSource(base64Data string, mimeType string) *Base64Source {
|
||||
return &Base64Source{
|
||||
Base64Data: base64Data,
|
||||
MimeType: mimeType,
|
||||
}
|
||||
}
|
||||
|
||||
func NewFileSourceFromData(data string, mimeType string) FileSource {
|
||||
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
|
||||
return NewURLFileSource(data)
|
||||
}
|
||||
return NewBase64FileSource(data, mimeType)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CachedFileData — 缓存的文件数据(支持内存和磁盘两种模式)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type CachedFileData struct {
|
||||
base64Data string // 内存中的 base64 数据(小文件)
|
||||
MimeType string // MIME 类型
|
||||
@@ -45,18 +150,15 @@ type CachedFileData struct {
|
||||
ImageConfig *image.Config // 图片配置(如果是图片)
|
||||
ImageFormat string // 图片格式(如果是图片)
|
||||
|
||||
// 磁盘缓存相关
|
||||
diskPath string // 磁盘缓存文件路径(大文件)
|
||||
isDisk bool // 是否使用磁盘缓存
|
||||
diskMu sync.Mutex // 磁盘操作锁(保护磁盘文件的读取和删除)
|
||||
diskClosed bool // 是否已关闭/清理
|
||||
statDecremented bool // 是否已扣减统计
|
||||
|
||||
// 统计回调,避免循环依赖
|
||||
OnClose func(size int64)
|
||||
}
|
||||
|
||||
// NewMemoryCachedData 创建内存缓存的数据
|
||||
func NewMemoryCachedData(base64Data string, mimeType string, size int64) *CachedFileData {
|
||||
return &CachedFileData{
|
||||
base64Data: base64Data,
|
||||
@@ -66,7 +168,6 @@ func NewMemoryCachedData(base64Data string, mimeType string, size int64) *Cached
|
||||
}
|
||||
}
|
||||
|
||||
// NewDiskCachedData 创建磁盘缓存的数据
|
||||
func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFileData {
|
||||
return &CachedFileData{
|
||||
diskPath: diskPath,
|
||||
@@ -76,7 +177,6 @@ func NewDiskCachedData(diskPath string, mimeType string, size int64) *CachedFile
|
||||
}
|
||||
}
|
||||
|
||||
// GetBase64Data 获取 base64 数据(自动处理内存/磁盘)
|
||||
func (c *CachedFileData) GetBase64Data() (string, error) {
|
||||
if !c.isDisk {
|
||||
return c.base64Data, nil
|
||||
@@ -89,7 +189,6 @@ func (c *CachedFileData) GetBase64Data() (string, error) {
|
||||
return "", fmt.Errorf("disk cache already closed")
|
||||
}
|
||||
|
||||
// 从磁盘读取
|
||||
data, err := os.ReadFile(c.diskPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read from disk cache: %w", err)
|
||||
@@ -97,22 +196,19 @@ func (c *CachedFileData) GetBase64Data() (string, error) {
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// SetBase64Data 设置 base64 数据(仅用于内存模式)
|
||||
func (c *CachedFileData) SetBase64Data(data string) {
|
||||
if !c.isDisk {
|
||||
c.base64Data = data
|
||||
}
|
||||
}
|
||||
|
||||
// IsDisk 是否使用磁盘缓存
|
||||
func (c *CachedFileData) IsDisk() bool {
|
||||
return c.isDisk
|
||||
}
|
||||
|
||||
// Close 关闭并清理资源
|
||||
func (c *CachedFileData) Close() error {
|
||||
if !c.isDisk {
|
||||
c.base64Data = "" // 释放内存
|
||||
c.base64Data = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -126,7 +222,6 @@ func (c *CachedFileData) Close() error {
|
||||
c.diskClosed = true
|
||||
if c.diskPath != "" {
|
||||
err := os.Remove(c.diskPath)
|
||||
// 只有在删除成功且未扣减过统计时,才执行回调
|
||||
if err == nil && !c.statDecremented && c.OnClose != nil {
|
||||
c.OnClose(c.DiskSize)
|
||||
c.statDecremented = true
|
||||
@@ -135,97 +230,3 @@ func (c *CachedFileData) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewURLFileSource 创建 URL 来源的 FileSource
|
||||
func NewURLFileSource(url string) *FileSource {
|
||||
return &FileSource{
|
||||
Type: FileSourceTypeURL,
|
||||
URL: url,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBase64FileSource 创建 base64 来源的 FileSource
|
||||
func NewBase64FileSource(base64Data string, mimeType string) *FileSource {
|
||||
return &FileSource{
|
||||
Type: FileSourceTypeBase64,
|
||||
Base64Data: base64Data,
|
||||
MimeType: mimeType,
|
||||
}
|
||||
}
|
||||
|
||||
// IsURL 判断是否是 URL 来源
|
||||
func (f *FileSource) IsURL() bool {
|
||||
return f.Type == FileSourceTypeURL
|
||||
}
|
||||
|
||||
// IsBase64 判断是否是 base64 来源
|
||||
func (f *FileSource) IsBase64() bool {
|
||||
return f.Type == FileSourceTypeBase64
|
||||
}
|
||||
|
||||
// GetIdentifier 获取文件标识符(用于日志和错误追踪)
|
||||
func (f *FileSource) GetIdentifier() string {
|
||||
if f.IsURL() {
|
||||
if len(f.URL) > 100 {
|
||||
return f.URL[:100] + "..."
|
||||
}
|
||||
return f.URL
|
||||
}
|
||||
if len(f.Base64Data) > 50 {
|
||||
return "base64:" + f.Base64Data[:50] + "..."
|
||||
}
|
||||
return "base64:" + f.Base64Data
|
||||
}
|
||||
|
||||
// GetRawData 获取原始数据(URL 或完整的 base64 字符串)
|
||||
func (f *FileSource) GetRawData() string {
|
||||
if f.IsURL() {
|
||||
return f.URL
|
||||
}
|
||||
return f.Base64Data
|
||||
}
|
||||
|
||||
// SetCache 设置缓存数据
|
||||
func (f *FileSource) SetCache(data *CachedFileData) {
|
||||
f.cachedData = data
|
||||
f.cacheLoaded = true
|
||||
}
|
||||
|
||||
// IsRegistered 是否已注册到清理列表
|
||||
func (f *FileSource) IsRegistered() bool {
|
||||
return f.registered
|
||||
}
|
||||
|
||||
// SetRegistered 设置注册状态
|
||||
func (f *FileSource) SetRegistered(registered bool) {
|
||||
f.registered = registered
|
||||
}
|
||||
|
||||
// GetCache 获取缓存数据
|
||||
func (f *FileSource) GetCache() *CachedFileData {
|
||||
return f.cachedData
|
||||
}
|
||||
|
||||
// HasCache 是否有缓存
|
||||
func (f *FileSource) HasCache() bool {
|
||||
return f.cacheLoaded && f.cachedData != nil
|
||||
}
|
||||
|
||||
// ClearCache 清除缓存,释放内存和磁盘文件
|
||||
func (f *FileSource) ClearCache() {
|
||||
// 如果有缓存数据,先关闭它(会清理磁盘文件)
|
||||
if f.cachedData != nil {
|
||||
f.cachedData.Close()
|
||||
}
|
||||
f.cachedData = nil
|
||||
f.cacheLoaded = false
|
||||
}
|
||||
|
||||
// ClearRawData 清除原始数据,只保留必要的元信息
|
||||
// 用于在处理完成后释放大文件的内存
|
||||
func (f *FileSource) ClearRawData() {
|
||||
// 保留 URL(通常很短),只清除大的 base64 数据
|
||||
if f.IsBase64() && len(f.Base64Data) > 1024 {
|
||||
f.Base64Data = ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,13 +32,12 @@ type TokenCountMeta struct {
|
||||
|
||||
type FileMeta struct {
|
||||
FileType
|
||||
MimeType string
|
||||
Source *FileSource // 统一的文件来源(URL 或 base64)
|
||||
Detail string // 图片细节级别(low/high/auto)
|
||||
Source FileSource // 统一的文件来源(URL 或 base64)
|
||||
Detail string // 图片细节级别(low/high/auto)
|
||||
}
|
||||
|
||||
// NewFileMeta 创建新的 FileMeta
|
||||
func NewFileMeta(fileType FileType, source *FileSource) *FileMeta {
|
||||
func NewFileMeta(fileType FileType, source FileSource) *FileMeta {
|
||||
return &FileMeta{
|
||||
FileType: fileType,
|
||||
Source: source,
|
||||
@@ -46,7 +45,7 @@ func NewFileMeta(fileType FileType, source *FileSource) *FileMeta {
|
||||
}
|
||||
|
||||
// NewImageFileMeta 创建图片类型的 FileMeta
|
||||
func NewImageFileMeta(source *FileSource, detail string) *FileMeta {
|
||||
func NewImageFileMeta(source FileSource, detail string) *FileMeta {
|
||||
return &FileMeta{
|
||||
FileType: FileTypeImage,
|
||||
Source: source,
|
||||
|
||||
@@ -56,7 +56,7 @@ const OAuth2Callback = (props) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (message === 'bind') {
|
||||
if (data?.action === 'bind') {
|
||||
showSuccess(t('绑定成功!'));
|
||||
navigate('/console/personal');
|
||||
} else {
|
||||
|
||||
@@ -221,23 +221,27 @@ const FooterBar = () => {
|
||||
return (
|
||||
<div className='w-full'>
|
||||
{footer ? (
|
||||
<div className='relative'>
|
||||
<div
|
||||
className='custom-footer'
|
||||
dangerouslySetInnerHTML={{ __html: footer }}
|
||||
></div>
|
||||
<div className='absolute bottom-2 right-4 text-xs !text-semi-color-text-2 opacity-70'>
|
||||
<span>{t('设计与开发由')} </span>
|
||||
<a
|
||||
href='https://github.com/QuantumNous/new-api'
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
className='!text-semi-color-primary font-medium'
|
||||
>
|
||||
New API
|
||||
</a>
|
||||
<footer className='relative h-auto py-4 px-6 md:px-24 w-full flex items-center justify-center overflow-hidden'>
|
||||
<div className='flex flex-col md:flex-row items-center justify-between w-full max-w-[1110px] gap-4'>
|
||||
<div
|
||||
className='custom-footer na-cb6feafeb3990c78 text-sm !text-semi-color-text-1'
|
||||
dangerouslySetInnerHTML={{ __html: footer }}
|
||||
></div>
|
||||
<div className='text-sm flex-shrink-0'>
|
||||
<span className='!text-semi-color-text-1'>
|
||||
{t('设计与开发由')}{' '}
|
||||
</span>
|
||||
<a
|
||||
href='https://github.com/QuantumNous/new-api'
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
className='!text-semi-color-primary font-medium'
|
||||
>
|
||||
New API
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
) : (
|
||||
customFooter
|
||||
)}
|
||||
|
||||
@@ -18,7 +18,14 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Input, Slider, Typography, Button, Tag } from '@douyinfe/semi-ui';
|
||||
import {
|
||||
Input,
|
||||
InputNumber,
|
||||
Slider,
|
||||
Typography,
|
||||
Button,
|
||||
Tag,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
Hash,
|
||||
@@ -241,15 +248,14 @@ const ParameterControl = ({
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
<Input
|
||||
<InputNumber
|
||||
placeholder='MaxTokens'
|
||||
name='max_tokens'
|
||||
required
|
||||
autoComplete='new-password'
|
||||
defaultValue={0}
|
||||
value={inputs.max_tokens}
|
||||
onChange={(value) => onInputChange('max_tokens', value)}
|
||||
className='!rounded-lg'
|
||||
onNumberChange={(value) => onInputChange('max_tokens', value)}
|
||||
min={0}
|
||||
precision={0}
|
||||
style={{ width: '100%' }}
|
||||
disabled={!parameterEnabled.max_tokens || disabled}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -65,11 +65,15 @@ export const loadConfig = () => {
|
||||
const savedConfig = localStorage.getItem(STORAGE_KEYS.CONFIG);
|
||||
if (savedConfig) {
|
||||
const parsedConfig = JSON.parse(savedConfig);
|
||||
const parsedMaxTokens = parseInt(parsedConfig?.inputs?.max_tokens, 10);
|
||||
|
||||
const mergedConfig = {
|
||||
inputs: {
|
||||
...DEFAULT_CONFIG.inputs,
|
||||
...parsedConfig.inputs,
|
||||
max_tokens: Number.isNaN(parsedMaxTokens)
|
||||
? parsedConfig?.inputs?.max_tokens
|
||||
: parsedMaxTokens,
|
||||
},
|
||||
parameterEnabled: {
|
||||
...DEFAULT_CONFIG.parameterEnabled,
|
||||
|
||||
@@ -306,9 +306,9 @@ const PersonalSetting = () => {
|
||||
|
||||
const bindWeChat = async () => {
|
||||
if (inputs.wechat_verification_code === '') return;
|
||||
const res = await API.get(
|
||||
`/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}`,
|
||||
);
|
||||
const res = await API.post('/api/oauth/wechat/bind', {
|
||||
code: inputs.wechat_verification_code,
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('微信账户绑定成功!'));
|
||||
@@ -378,9 +378,10 @@ const PersonalSetting = () => {
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
const res = await API.get(
|
||||
`/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}`,
|
||||
);
|
||||
const res = await API.post('/api/oauth/email/bind', {
|
||||
email: inputs.email,
|
||||
code: inputs.email_verification_code,
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('邮箱账户绑定成功!'));
|
||||
|
||||
@@ -21,9 +21,8 @@ import React, { useEffect, useState } from 'react';
|
||||
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import ModelPricingCombined from '../../pages/Setting/Ratio/ModelPricingCombined';
|
||||
import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings';
|
||||
import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings';
|
||||
import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor';
|
||||
import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor';
|
||||
import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync';
|
||||
|
||||
@@ -95,18 +94,14 @@ const RatioSetting = () => {
|
||||
|
||||
return (
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* 模型倍率设置以及价格编辑器 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<Tabs type='card' defaultActiveKey='visual'>
|
||||
<Tabs.TabPane tab={t('模型倍率设置')} itemKey='model'>
|
||||
<ModelRatioSettings options={inputs} refresh={onRefresh} />
|
||||
<Tabs type='card' defaultActiveKey='pricing'>
|
||||
<Tabs.TabPane tab={t('模型定价设置')} itemKey='pricing'>
|
||||
<ModelPricingCombined options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('分组相关设置')} itemKey='group'>
|
||||
<GroupRatioSettings options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('价格设置')} itemKey='visual'>
|
||||
<ModelSettingsVisualEditor options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('未设置价格模型')} itemKey='unset_models'>
|
||||
<ModelRatioNotSetEditor options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
|
||||
@@ -108,7 +108,7 @@ const SystemSetting = () => {
|
||||
'fetch_setting.domain_list': [],
|
||||
'fetch_setting.ip_list': [],
|
||||
'fetch_setting.allowed_ports': [],
|
||||
'fetch_setting.apply_ip_filter_for_domain': false,
|
||||
'fetch_setting.apply_ip_filter_for_domain': true,
|
||||
});
|
||||
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
@@ -847,7 +847,7 @@ const SystemSetting = () => {
|
||||
}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
{t('对域名启用 IP 过滤(实验性)')}
|
||||
{t('对域名启用 IP 过滤(推荐开启)')}
|
||||
</Form.Checkbox>
|
||||
<Text strong>
|
||||
{t(domainFilterMode ? '域名白名单' : '域名黑名单')}
|
||||
|
||||
@@ -538,19 +538,24 @@ export const getChannelsColumns = ({
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
content={
|
||||
t('剩余额度') +
|
||||
': ' +
|
||||
renderQuotaWithAmount(record.balance) +
|
||||
t(',点击更新')
|
||||
record.type === 57
|
||||
? t('查看 Codex 帐号信息与用量')
|
||||
: t('剩余额度') +
|
||||
': ' +
|
||||
renderQuotaWithAmount(record.balance) +
|
||||
t(',点击更新')
|
||||
}
|
||||
>
|
||||
<Tag
|
||||
color='white'
|
||||
type='ghost'
|
||||
color={record.type === 57 ? 'light-blue' : 'white'}
|
||||
type={record.type === 57 ? 'light' : 'ghost'}
|
||||
shape='circle'
|
||||
className={record.type === 57 ? 'cursor-pointer' : ''}
|
||||
onClick={() => updateChannelBalance(record)}
|
||||
>
|
||||
{renderQuotaWithAmount(record.balance)}
|
||||
{record.type === 57
|
||||
? t('帐号信息')
|
||||
: renderQuotaWithAmount(record.balance)}
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
</Space>
|
||||
|
||||
@@ -22,9 +22,11 @@ import {
|
||||
Modal,
|
||||
Button,
|
||||
Progress,
|
||||
Tag,
|
||||
Typography,
|
||||
Spin,
|
||||
Tag,
|
||||
Descriptions,
|
||||
Collapse,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../../../helpers';
|
||||
|
||||
@@ -128,6 +130,87 @@ const formatUnixSeconds = (unixSeconds) => {
|
||||
}
|
||||
};
|
||||
|
||||
const getDisplayText = (value) => {
|
||||
if (value == null) return '';
|
||||
return String(value).trim();
|
||||
};
|
||||
|
||||
const formatAccountTypeLabel = (value, t) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const normalized = normalizePlanType(value);
|
||||
switch (normalized) {
|
||||
case 'free':
|
||||
return 'Free';
|
||||
case 'plus':
|
||||
return 'Plus';
|
||||
case 'pro':
|
||||
return 'Pro';
|
||||
case 'team':
|
||||
return 'Team';
|
||||
case 'enterprise':
|
||||
return 'Enterprise';
|
||||
default:
|
||||
return getDisplayText(value) || tt('未识别');
|
||||
}
|
||||
};
|
||||
|
||||
const getAccountTypeTagColor = (value) => {
|
||||
const normalized = normalizePlanType(value);
|
||||
switch (normalized) {
|
||||
case 'enterprise':
|
||||
return 'green';
|
||||
case 'team':
|
||||
return 'cyan';
|
||||
case 'pro':
|
||||
return 'blue';
|
||||
case 'plus':
|
||||
return 'violet';
|
||||
case 'free':
|
||||
return 'amber';
|
||||
default:
|
||||
return 'grey';
|
||||
}
|
||||
};
|
||||
|
||||
const resolveUsageStatusTag = (t, rateLimit) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
if (!rateLimit || Object.keys(rateLimit).length === 0) {
|
||||
return <Tag color='grey'>{tt('待确认')}</Tag>;
|
||||
}
|
||||
if (rateLimit?.allowed && !rateLimit?.limit_reached) {
|
||||
return <Tag color='green'>{tt('可用')}</Tag>;
|
||||
}
|
||||
return <Tag color='red'>{tt('受限')}</Tag>;
|
||||
};
|
||||
|
||||
const AccountInfoValue = ({ t, value, onCopy, monospace = false }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const text = getDisplayText(value);
|
||||
const hasValue = text !== '';
|
||||
|
||||
return (
|
||||
<div className='flex min-w-0 items-start justify-between gap-2'>
|
||||
<div
|
||||
className={`min-w-0 flex-1 break-all text-xs leading-5 text-semi-color-text-1 ${
|
||||
monospace ? 'font-mono' : ''
|
||||
}`}
|
||||
>
|
||||
{hasValue ? text : '-'}
|
||||
</div>
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
theme='borderless'
|
||||
className='shrink-0 px-1 text-xs'
|
||||
disabled={!hasValue}
|
||||
onClick={() => onCopy?.(text)}
|
||||
>
|
||||
{tt('复制')}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const RateLimitWindowCard = ({ t, title, windowData }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const hasWindowData =
|
||||
@@ -181,50 +264,100 @@ const RateLimitWindowCard = ({ t, title, windowData }) => {
|
||||
|
||||
const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
const [showRawJson, setShowRawJson] = useState(false);
|
||||
const data = payload?.data ?? null;
|
||||
const rateLimit = data?.rate_limit ?? {};
|
||||
const { fiveHourWindow, weeklyWindow } = resolveRateLimitWindows(data);
|
||||
|
||||
const allowed = !!rateLimit?.allowed;
|
||||
const limitReached = !!rateLimit?.limit_reached;
|
||||
const upstreamStatus = payload?.upstream_status;
|
||||
|
||||
const statusTag =
|
||||
allowed && !limitReached ? (
|
||||
<Tag color='green'>{tt('可用')}</Tag>
|
||||
) : (
|
||||
<Tag color='red'>{tt('受限')}</Tag>
|
||||
);
|
||||
const accountType = data?.plan_type ?? rateLimit?.plan_type;
|
||||
const accountTypeLabel = formatAccountTypeLabel(accountType, tt);
|
||||
const accountTypeTagColor = getAccountTypeTagColor(accountType);
|
||||
const statusTag = resolveUsageStatusTag(tt, rateLimit);
|
||||
const userId = data?.user_id;
|
||||
const email = data?.email;
|
||||
const accountId = data?.account_id;
|
||||
const errorMessage =
|
||||
payload?.success === false ? getDisplayText(payload?.message) || tt('获取用量失败') : '';
|
||||
|
||||
const rawText =
|
||||
typeof data === 'string' ? data : JSON.stringify(data ?? payload, null, 2);
|
||||
|
||||
return (
|
||||
<div className='flex flex-col gap-3'>
|
||||
<div className='flex flex-wrap items-center justify-between gap-2'>
|
||||
<Text type='tertiary' size='small'>
|
||||
{tt('渠道:')}
|
||||
{record?.name || '-'} ({tt('编号:')}
|
||||
{record?.id || '-'})
|
||||
</Text>
|
||||
<div className='flex items-center gap-2'>
|
||||
{statusTag}
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
theme='borderless'
|
||||
onClick={onRefresh}
|
||||
>
|
||||
<div className='flex flex-col gap-4'>
|
||||
{errorMessage && (
|
||||
<div className='rounded-xl border border-red-200 bg-red-50 px-4 py-3 text-sm text-red-700'>
|
||||
{errorMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className='rounded-xl border border-semi-color-border bg-semi-color-bg-0 p-3'>
|
||||
<div className='flex flex-wrap items-start justify-between gap-2'>
|
||||
<div className='min-w-0'>
|
||||
<div className='text-xs font-medium text-semi-color-text-2'>
|
||||
{tt('Codex 帐号')}
|
||||
</div>
|
||||
<div className='mt-2 flex flex-wrap items-center gap-2'>
|
||||
<Tag
|
||||
color={accountTypeTagColor}
|
||||
type='light'
|
||||
shape='circle'
|
||||
size='large'
|
||||
className='font-semibold'
|
||||
>
|
||||
{accountTypeLabel}
|
||||
</Tag>
|
||||
{statusTag}
|
||||
<Tag color='grey' type='light' shape='circle'>
|
||||
{tt('上游状态码:')}
|
||||
{upstreamStatus ?? '-'}
|
||||
</Tag>
|
||||
</div>
|
||||
</div>
|
||||
<Button size='small' type='tertiary' theme='outline' onClick={onRefresh}>
|
||||
{tt('刷新')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className='mt-2 rounded-lg bg-semi-color-fill-0 px-3 py-2'>
|
||||
<Descriptions>
|
||||
<Descriptions.Item itemKey='User ID'>
|
||||
<AccountInfoValue
|
||||
t={tt}
|
||||
value={userId}
|
||||
onCopy={onCopy}
|
||||
monospace={true}
|
||||
/>
|
||||
</Descriptions.Item>
|
||||
<Descriptions.Item itemKey={tt('邮箱')}>
|
||||
<AccountInfoValue t={tt} value={email} onCopy={onCopy} />
|
||||
</Descriptions.Item>
|
||||
<Descriptions.Item itemKey='Account ID'>
|
||||
<AccountInfoValue
|
||||
t={tt}
|
||||
value={accountId}
|
||||
onCopy={onCopy}
|
||||
monospace={true}
|
||||
/>
|
||||
</Descriptions.Item>
|
||||
</Descriptions>
|
||||
</div>
|
||||
|
||||
<div className='mt-2 text-xs text-semi-color-text-2'>
|
||||
{tt('渠道:')}
|
||||
{record?.name || '-'} ({tt('编号:')}
|
||||
{record?.id || '-'})
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex flex-wrap items-center justify-between gap-2'>
|
||||
<Text type='tertiary' size='small'>
|
||||
{tt('上游状态码:')}
|
||||
{upstreamStatus ?? '-'}
|
||||
</Text>
|
||||
<div>
|
||||
<div className='mb-2'>
|
||||
<div className='text-sm font-semibold text-semi-color-text-0'>
|
||||
{tt('额度窗口')}
|
||||
</div>
|
||||
<Text type='tertiary' size='small'>
|
||||
{tt('用于观察当前帐号在 Codex 上游的限额使用情况')}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='grid grid-cols-1 gap-3 md:grid-cols-2'>
|
||||
@@ -240,23 +373,30 @@ const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => {
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className='mb-1 flex items-center justify-between gap-2'>
|
||||
<div className='text-sm font-medium'>{tt('原始 JSON')}</div>
|
||||
<Button
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
onClick={() => onCopy?.(rawText)}
|
||||
disabled={!rawText}
|
||||
>
|
||||
{tt('复制')}
|
||||
</Button>
|
||||
</div>
|
||||
<pre className='max-h-[50vh] overflow-auto rounded-lg bg-semi-color-fill-0 p-3 text-xs text-semi-color-text-0'>
|
||||
{rawText}
|
||||
</pre>
|
||||
</div>
|
||||
<Collapse
|
||||
activeKey={showRawJson ? ['raw-json'] : []}
|
||||
onChange={(activeKey) => {
|
||||
const keys = Array.isArray(activeKey) ? activeKey : [activeKey];
|
||||
setShowRawJson(keys.includes('raw-json'));
|
||||
}}
|
||||
>
|
||||
<Collapse.Panel header={tt('原始 JSON')} itemKey='raw-json'>
|
||||
<div className='mb-2 flex justify-end'>
|
||||
<Button
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
onClick={() => onCopy?.(rawText)}
|
||||
disabled={!rawText}
|
||||
>
|
||||
{tt('复制')}
|
||||
</Button>
|
||||
</div>
|
||||
<pre className='max-h-[50vh] overflow-y-auto rounded-lg bg-semi-color-fill-0 p-3 text-xs text-semi-color-text-0'>
|
||||
{rawText}
|
||||
</pre>
|
||||
</Collapse.Panel>
|
||||
</Collapse>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -351,7 +491,7 @@ export const openCodexUsageModal = ({ t, record, payload, onCopy }) => {
|
||||
const tt = typeof t === 'function' ? t : (v) => v;
|
||||
|
||||
Modal.info({
|
||||
title: tt('Codex 用量'),
|
||||
title: tt('Codex 帐号与用量'),
|
||||
centered: true,
|
||||
width: 900,
|
||||
style: { maxWidth: '95vw' },
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user