Compare commits
118 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5238f279db | |||
| 5402bf417d | |||
| c766913baf | |||
| 40dc43f44e | |||
| fbf235d222 | |||
| 62b9aaa520 | |||
| 22b6b16702 | |||
| 6154b8e3cd | |||
| ff66288e3a | |||
| 926e1781dd | |||
| d4a470a638 | |||
| 9f61407bf0 | |||
| dbf900a531 | |||
| 7399e4721b | |||
| a5e20269dd | |||
| 9ae9040b3c | |||
| 0191a68d4e | |||
| 16221f8279 | |||
| 763c3ff709 | |||
| c667e4706a | |||
| 216b94dac0 | |||
| 49eb533aaf | |||
| 7693edae53 | |||
| ded4a124e2 | |||
| d6982c8182 | |||
| 9ecad90652 | |||
| 929b5060ea | |||
| 755ece2f01 | |||
| f40eb4e5d2 | |||
| 45f65c297b | |||
| 6c074ef897 | |||
| deff59a5be | |||
| 3c516084f8 | |||
| 4d675b4d1f | |||
| 87b426f306 | |||
| 49db5147c3 | |||
| 13122aa0fa | |||
| dcd0911612 | |||
| e904579a5b | |||
| e80d867f38 | |||
| cf86fe5fea | |||
| 42846c692e | |||
| 1911520eba | |||
| 2c3ae32c8e | |||
| 64f41efc47 | |||
| 498199b37d | |||
| ff29900f30 | |||
| eff51857d0 | |||
| e9f8f62796 | |||
| 5fe8e98eeb | |||
| e520977efc | |||
| ed6ff0f267 | |||
| d955a0c080 | |||
| d096a2e5b7 | |||
| d2fb485d34 | |||
| 04f5dd0206 | |||
| ede0ad117b | |||
| 5bb8fe6af5 | |||
| a1a92c1918 | |||
| a4d1ed6da5 | |||
| 669e596ff7 | |||
| 1daeac42ef | |||
| e70bfa2d57 | |||
| b09337e6ed | |||
| bd09b47ef4 | |||
| d595ef4990 | |||
| 2270f63c00 | |||
| 8ed2ea6ec1 | |||
| 202a433f86 | |||
| 620e066b39 | |||
| 0246b20bf1 | |||
| 69551ab2de | |||
| 8aa8b81e03 | |||
| bc80477b1a | |||
| 5db25f47f1 | |||
| a4fd2246ba | |||
| 4e5e7b5828 | |||
| 95738594b4 | |||
| efab41c476 | |||
| c77c82421e | |||
| e4144d60f8 | |||
| 63f4595ef8 | |||
| 5e856f0263 | |||
| b9f1d01e00 | |||
| 5d620b9640 | |||
| 264bc963e0 | |||
| 9fbb782230 | |||
| da8a52f50a | |||
| 9fdb0bc248 | |||
| 24ec27f844 | |||
| 5e9cc681f5 | |||
| 7e68e1b36a | |||
| 45a59d32fb | |||
| c1c07d063d | |||
| 7fc39363d7 | |||
| 7b62694f60 | |||
| 3b5d1daf39 | |||
| d087cc5025 | |||
| d67f446b66 | |||
| ac72f90fc5 | |||
| 3f662e4bc0 | |||
| 287af7ebee | |||
| aa89ea2db5 | |||
| 8d7d880db5 | |||
| 50ec2bac6b | |||
| c0a0285f74 | |||
| fb76abb329 | |||
| 9905599d27 | |||
| 329416d67b | |||
| ffb06d084b | |||
| 2e20ede2a0 | |||
| 9cfaa68e5a | |||
| 57d525869a | |||
| 3defef3588 | |||
| 172f92aa72 | |||
| 12aacf27b6 | |||
| 728607b8f5 | |||
| 88b7322483 |
@@ -1,12 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
@@ -7,14 +7,23 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**例行检查**
|
||||
## 提交前必读(请勿删除本节)
|
||||
|
||||
- 文档:https://docs.newapi.ai/
|
||||
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
|
||||
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
|
||||
|
||||
**您当前的 newapi 版本**
|
||||
|
||||
请填写,例如:`v1.0.0`
|
||||
|
||||
**提交确认**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已确认我已升级到最新版本
|
||||
+ [ ] 我已完整查看过项目 README,尤其是常见问题部分
|
||||
+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
|
||||
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
|
||||
+ [ ] 我已完整查看过文档 https://docs.newapi.ai/ 和项目 README,尤其是常见问题部分
|
||||
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
|
||||
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
|
||||
|
||||
**问题描述**
|
||||
|
||||
@@ -23,4 +32,3 @@ assignees: ''
|
||||
**预期结果**
|
||||
|
||||
**相关截图**
|
||||
如果没有的话,请删除此节。
|
||||
@@ -7,14 +7,23 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Routine Checks**
|
||||
## Read This First (Do Not Remove This Section)
|
||||
|
||||
- Docs: https://docs.newapi.ai/
|
||||
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
|
||||
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
|
||||
|
||||
**Your current newapi version**
|
||||
|
||||
Please fill this in, for example: `v1.0.0`
|
||||
|
||||
**Submission Checks**
|
||||
|
||||
[//]: # (Remove the space in the box and fill with an x)
|
||||
+ [ ] I have confirmed there are no similar issues currently
|
||||
+ [ ] I have confirmed I have upgraded to the latest version
|
||||
+ [ ] I have thoroughly read the project README, especially the FAQ section
|
||||
+ [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback
|
||||
+ [ ] I understand and acknowledge the above, and understand that project maintainers have limited time and energy, **issues that do not follow the rules may be ignored or closed directly**
|
||||
+ [ ] I have confirmed there are no similar issues
|
||||
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, especially the FAQ section
|
||||
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
|
||||
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
|
||||
|
||||
**Issue Description**
|
||||
|
||||
@@ -23,4 +32,3 @@ assignees: ''
|
||||
**Expected Result**
|
||||
|
||||
**Related Screenshots**
|
||||
If none, please delete this section.
|
||||
@@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 项目群聊
|
||||
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
|
||||
about: QQ 群:629454374
|
||||
- name: 使用文档 / Documentation
|
||||
url: https://docs.newapi.ai/
|
||||
about: 提交 issue 前请先查阅文档,确认现有说明无法解决你的问题。
|
||||
- name: 使用问题 / Usage Questions
|
||||
url: https://deepwiki.com/QuantumNous/new-api
|
||||
about: 使用、配置、接入等问题请优先在 DeepWiki 查询或提问。
|
||||
|
||||
@@ -7,14 +7,23 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**例行检查**
|
||||
## 提交前必读(请勿删除本节)
|
||||
|
||||
- 文档:https://docs.newapi.ai/
|
||||
- 使用问题先看或先问:https://deepwiki.com/QuantumNous/new-api
|
||||
- 警告:删除本模板、删除小节标题或随意清空内容的 issue,可能会被直接关闭;重复恶意提交者可能会被 block。
|
||||
|
||||
**您当前的 newapi 版本**
|
||||
|
||||
请填写,例如:`v1.0.0`
|
||||
|
||||
**提交确认**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已确认我已升级到最新版本
|
||||
+ [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求
|
||||
+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
|
||||
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
|
||||
+ [ ] 我已完整查看过文档 https://docs.newapi.ai/ 和项目 README,已确定现有版本无法满足需求
|
||||
+ [ ] 我未删除此模板中的任何引导内容或小节标题,并会按要求完整填写
|
||||
+ [ ] 我理解项目维护者精力有限,不遵循模板要求的 issue 可能会被无视或直接关闭
|
||||
|
||||
**功能描述**
|
||||
|
||||
|
||||
@@ -7,16 +7,24 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Routine Checks**
|
||||
## Read This First (Do Not Remove This Section)
|
||||
|
||||
- Docs: https://docs.newapi.ai/
|
||||
- Usage questions first: https://deepwiki.com/QuantumNous/new-api
|
||||
- Warning: issues with this template removed, section headings deleted, or content cleared may be closed directly. Repeated abusive submissions may result in a block.
|
||||
|
||||
**Your current newapi version**
|
||||
|
||||
Please fill this in, for example: `v1.0.0`
|
||||
|
||||
**Submission Checks**
|
||||
|
||||
[//]: # (Remove the space in the box and fill with an x)
|
||||
+ [ ] I have confirmed there are no similar issues currently
|
||||
+ [ ] I have confirmed I have upgraded to the latest version
|
||||
+ [ ] I have thoroughly read the project README and confirmed the current version cannot meet my needs
|
||||
+ [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback
|
||||
+ [ ] I understand and acknowledge the above, and understand that project maintainers have limited time and energy, **issues that do not follow the rules may be ignored or closed directly**
|
||||
+ [ ] I have confirmed there are no similar issues
|
||||
+ [ ] I have thoroughly read the docs at https://docs.newapi.ai/ and the project README, and confirmed the current version cannot meet my needs
|
||||
+ [ ] I have not removed any guidance or section headings from this template and will complete it as requested
|
||||
+ [ ] I understand that maintainers have limited time and issues that do not follow this template may be ignored or closed directly
|
||||
|
||||
**Feature Description**
|
||||
|
||||
**Use Case**
|
||||
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
### PR 类型
|
||||
# ⚠️ 提交警告 / PR Warning
|
||||
> **请注意:** 请提供**人工撰写**的简洁摘要。包含大量 AI 灌水内容、逻辑混乱或无视模版的 PR **可能会被无视或直接关闭**。
|
||||
|
||||
- [ ] Bug 修复
|
||||
- [ ] 新功能
|
||||
- [ ] 文档更新
|
||||
- [ ] 其他
|
||||
---
|
||||
|
||||
### PR 是否包含破坏性更新?
|
||||
## 💡 沟通提示 / Pre-submission
|
||||
> **重大功能变更?** 请先提交 Issue 交流,避免无效劳动。
|
||||
|
||||
- [ ] 是
|
||||
- [ ] 否
|
||||
## 📝 变更描述 / Description
|
||||
(简述:做了什么?为什么这样改能生效?你必须理解代码逻辑,禁止粘贴 AI 废话)
|
||||
|
||||
### PR 描述
|
||||
## 🚀 变更类型 / Type of change
|
||||
- [ ] 🐛 Bug 修复 (Bug fix)
|
||||
- [ ] ✨ 新功能 (New feature) - *重大特性建议先 Issue 沟通*
|
||||
- [ ] ⚡ 性能优化 / 重构 (Refactor)
|
||||
- [ ] 📝 文档更新 (Documentation)
|
||||
|
||||
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||
## 🔗 关联任务 / Related Issue
|
||||
- Closes # (如有)
|
||||
|
||||
## ✅ 提交前检查项 / Checklist
|
||||
- [ ] **人工确认:** 我已亲自撰写此描述,去除了 AI 原始输出的冗余。
|
||||
- [ ] **深度理解:** 我已**完全理解**这些更改的工作原理及潜在影响。
|
||||
- [ ] **范围聚焦:** 本 PR 未包含任何与当前任务无关的代码改动。
|
||||
- [ ] **本地验证:** 已在本地运行并通过了测试或手动验证。
|
||||
- [ ] **安全合规:** 代码中无敏感凭据,且符合项目代码规范。
|
||||
|
||||
## 📸 运行证明 / Proof of Work
|
||||
(请在此粘贴截图、关键日志或测试报告,以证明变更生效)
|
||||
@@ -27,9 +27,10 @@ jobs:
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Check out (shallow)
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -46,16 +47,16 @@ jobs:
|
||||
run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -63,14 +64,15 @@ jobs:
|
||||
|
||||
- name: Extract metadata (labels)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ env.GHCR_REPOSITORY }}
|
||||
|
||||
- name: Build & push single-arch (to both registries)
|
||||
uses: docker/build-push-action@v6
|
||||
id: build
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -83,8 +85,25 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@398d4b0eeef1380460a10c8013a76f728fb906ac # v3
|
||||
|
||||
- name: Sign image with cosign
|
||||
run: |
|
||||
cosign sign --yes calciumion/new-api@${{ steps.build.outputs.digest }}
|
||||
cosign sign --yes ghcr.io/${{ env.GHCR_REPOSITORY }}@${{ steps.build.outputs.digest }}
|
||||
|
||||
- name: Output digest
|
||||
run: |
|
||||
echo "### Docker Image Digest (${{ matrix.arch }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:alpha-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "ghcr.io/${{ env.GHCR_REPOSITORY }}:alpha-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.build.outputs.digest }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
create_manifests:
|
||||
name: Create multi-arch manifests (Docker Hub + GHCR)
|
||||
@@ -95,7 +114,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out (shallow)
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -110,7 +129,7 @@ jobs:
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -130,7 +149,7 @@ jobs:
|
||||
calciumion/new-api:${VERSION}-arm64
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -149,3 +168,12 @@ jobs:
|
||||
-t ghcr.io/${GHCR_REPOSITORY}:${VERSION} \
|
||||
ghcr.io/${GHCR_REPOSITORY}:${VERSION}-amd64 \
|
||||
ghcr.io/${GHCR_REPOSITORY}:${VERSION}-arm64
|
||||
|
||||
- name: Output manifest digest
|
||||
run: |
|
||||
echo "### Multi-arch Manifest Digests" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect calciumion/new-api:alpha >> $GITHUB_STEP_SUMMARY
|
||||
echo "---" >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect ghcr.io/${GHCR_REPOSITORY}:alpha >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -4,6 +4,7 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!nightly*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
@@ -29,10 +30,11 @@ jobs:
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: ${{ github.event_name == 'workflow_dispatch' && 0 || 1 }}
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
@@ -58,16 +60,16 @@ jobs:
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@v3
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
@@ -75,14 +77,15 @@ jobs:
|
||||
|
||||
- name: Extract metadata (labels)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
# ghcr.io/${{ env.GHCR_REPOSITORY }}
|
||||
|
||||
- name: Build & push single-arch (to both registries)
|
||||
uses: docker/build-push-action@v6
|
||||
id: build
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -95,8 +98,22 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@398d4b0eeef1380460a10c8013a76f728fb906ac # v3
|
||||
|
||||
- name: Sign image with cosign
|
||||
run: cosign sign --yes calciumion/new-api@${{ steps.build.outputs.digest }}
|
||||
|
||||
- name: Output digest
|
||||
run: |
|
||||
echo "### Docker Image Digest (${{ matrix.arch }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:${{ env.TAG }}-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.build.outputs.digest }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
create_manifests:
|
||||
name: Create multi-arch manifests (Docker Hub)
|
||||
@@ -116,7 +133,7 @@ jobs:
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -135,9 +152,16 @@ jobs:
|
||||
calciumion/new-api:latest-amd64 \
|
||||
calciumion/new-api:latest-arm64
|
||||
|
||||
- name: Output manifest digest
|
||||
run: |
|
||||
echo "### Multi-arch Manifest" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect calciumion/new-api:${TAG} >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# ---- GHCR ----
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@v3
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
|
||||
@@ -19,14 +19,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend (amd64)
|
||||
@@ -50,12 +50,16 @@ jobs:
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
|
||||
- name: Generate checksums
|
||||
run: sha256sum new-api-* > checksums-linux.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
new-api-*
|
||||
checksums-linux.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -64,14 +68,14 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -84,18 +88,23 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-X 'new-api/common.Version=$VERSION'" -o new-api-macos-$VERSION
|
||||
- name: Generate checksums
|
||||
run: shasum -a 256 new-api-macos-* > checksums-macos.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-macos-*
|
||||
files: |
|
||||
new-api-macos-*
|
||||
checksums-macos.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -107,14 +116,14 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
@@ -126,17 +135,22 @@ jobs:
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '>=1.25.1'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION'" -o new-api-$VERSION.exe
|
||||
- name: Generate checksums
|
||||
run: sha256sum new-api-*.exe > checksums-windows.txt
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-*.exe
|
||||
files: |
|
||||
new-api-*.exe
|
||||
checksums-windows.txt
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
.idea
|
||||
.vscode
|
||||
.zed
|
||||
.history
|
||||
upload
|
||||
*.exe
|
||||
*.db
|
||||
@@ -20,6 +21,7 @@ tiktoken_cache
|
||||
.cache
|
||||
web/bun.lock
|
||||
plans
|
||||
.claude
|
||||
|
||||
electron/node_modules
|
||||
electron/dist
|
||||
|
||||
+3
-3
@@ -1,4 +1,4 @@
|
||||
FROM oven/bun:latest AS builder
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
@@ -8,7 +8,7 @@ COPY ./web .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
|
||||
FROM golang:alpine AS builder2
|
||||
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
|
||||
ARG TARGETOS
|
||||
@@ -25,7 +25,7 @@ COPY . .
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates tzdata libasan8 wget \
|
||||
|
||||
+1
-1
@@ -383,7 +383,7 @@ docker run --name new-api -d --restart always \
|
||||
2. 在应用商店搜索 **New-API**
|
||||
3. 一键安装
|
||||
|
||||
📖 [图文教程](./docs/BT.md)
|
||||
📖 [图文教程](./docs/installation/BT.md)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -177,6 +177,7 @@ var (
|
||||
DownloadRateLimitDuration int64 = 60
|
||||
|
||||
// Per-user search rate limit (applies after authentication, keyed by user ID)
|
||||
SearchRateLimitEnable = true
|
||||
SearchRateLimitNum = 10
|
||||
SearchRateLimitDuration int64 = 60
|
||||
)
|
||||
@@ -211,5 +212,6 @@ const (
|
||||
const (
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusFailed = "failed"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
|
||||
@@ -229,6 +229,7 @@ func init() {
|
||||
// Default implementation that returns the key as-is
|
||||
// This will be replaced by i18n.T during i18n initialization
|
||||
TranslateMessage = func(c *gin.Context, key string, args ...map[string]any) string {
|
||||
c.Header("X-Translate-id", "d5e7afdfc7f03414b941f9c1e7096be9966510e7")
|
||||
return key
|
||||
}
|
||||
}
|
||||
|
||||
+5
-1
@@ -120,6 +120,10 @@ func InitEnv() {
|
||||
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
|
||||
CriticalRateLimitNum = GetEnvOrDefault("CRITICAL_RATE_LIMIT", 20)
|
||||
CriticalRateLimitDuration = int64(GetEnvOrDefault("CRITICAL_RATE_LIMIT_DURATION", 20*60))
|
||||
|
||||
SearchRateLimitEnable = GetEnvOrDefaultBool("SEARCH_RATE_LIMIT_ENABLE", true)
|
||||
SearchRateLimitNum = GetEnvOrDefault("SEARCH_RATE_LIMIT", 10)
|
||||
SearchRateLimitDuration = int64(GetEnvOrDefault("SEARCH_RATE_LIMIT_DURATION", 60))
|
||||
initConstantEnv()
|
||||
}
|
||||
|
||||
@@ -127,7 +131,7 @@ func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 64)
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 128)
|
||||
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
|
||||
+15
-8
@@ -3,53 +3,60 @@ package common
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LogWriterMu protects concurrent access to gin.DefaultWriter/gin.DefaultErrorWriter
|
||||
// during log file rotation. Acquire RLock when reading/writing through the writers,
|
||||
// acquire Lock when swapping writers and closing old files.
|
||||
var LogWriterMu sync.RWMutex
|
||||
|
||||
func SysLog(s string) {
|
||||
t := time.Now()
|
||||
LogWriterMu.RLock()
|
||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
LogWriterMu.RUnlock()
|
||||
}
|
||||
|
||||
func SysError(s string) {
|
||||
t := time.Now()
|
||||
LogWriterMu.RLock()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
LogWriterMu.RUnlock()
|
||||
}
|
||||
|
||||
func FatalLog(v ...any) {
|
||||
t := time.Now()
|
||||
LogWriterMu.RLock()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||
LogWriterMu.RUnlock()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func LogStartupSuccess(startTime time.Time, port string) {
|
||||
|
||||
duration := time.Since(startTime)
|
||||
durationMs := duration.Milliseconds()
|
||||
|
||||
// Get network IPs
|
||||
networkIps := GetNetworkIps()
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
LogWriterMu.RLock()
|
||||
defer LogWriterMu.RUnlock()
|
||||
|
||||
// Print the main success message
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
|
||||
// Skip fancy startup message in container environments
|
||||
if !IsRunningInContainer() {
|
||||
// Print local URL
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
|
||||
}
|
||||
|
||||
// Print network URLs
|
||||
for _, ip := range networkIps {
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
|
||||
}
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
// WaffoPayMethod defines the display and API parameter mapping for Waffo payment methods.
|
||||
type WaffoPayMethod struct {
|
||||
Name string `json:"name"` // Frontend display name
|
||||
Icon string `json:"icon"` // Frontend icon identifier: credit-card, apple, google
|
||||
PayMethodType string `json:"payMethodType"` // Waffo API PayMethodType, can be comma-separated
|
||||
PayMethodName string `json:"payMethodName"` // Waffo API PayMethodName, empty means auto-select by Waffo checkout
|
||||
}
|
||||
|
||||
// DefaultWaffoPayMethods is the default list of supported payment methods.
|
||||
var DefaultWaffoPayMethods = []WaffoPayMethod{
|
||||
{Name: "Card", Icon: "/pay-card.png", PayMethodType: "CREDITCARD,DEBITCARD", PayMethodName: ""},
|
||||
{Name: "Apple Pay", Icon: "/pay-apple.png", PayMethodType: "APPLEPAY", PayMethodName: "APPLEPAY"},
|
||||
{Name: "Google Pay", Icon: "/pay-google.png", PayMethodType: "GOOGLEPAY", PayMethodName: "GOOGLEPAY"},
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -169,10 +170,7 @@ func collectPendingUpstreamModelChangesFromModels(
|
||||
upstreamSet[modelName] = struct{}{}
|
||||
}
|
||||
|
||||
ignoredSet := make(map[string]struct{})
|
||||
for _, modelName := range normalizeModelNames(ignoredModels) {
|
||||
ignoredSet[modelName] = struct{}{}
|
||||
}
|
||||
normalizedIgnoredModels := normalizeModelNames(ignoredModels)
|
||||
|
||||
redirectSourceSet := make(map[string]struct{}, len(modelMapping))
|
||||
redirectTargetSet := make(map[string]struct{}, len(modelMapping))
|
||||
@@ -193,7 +191,13 @@ func collectPendingUpstreamModelChangesFromModels(
|
||||
if _, ok := coveredUpstreamSet[modelName]; ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := ignoredSet[modelName]; ok {
|
||||
if lo.ContainsBy(normalizedIgnoredModels, func(ignoredModel string) bool {
|
||||
if regexBody, ok := strings.CutPrefix(ignoredModel, "regex:"); ok {
|
||||
matched, err := regexp.MatchString(strings.TrimSpace(regexBody), modelName)
|
||||
return err == nil && matched
|
||||
}
|
||||
return ignoredModel == modelName
|
||||
}) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -111,6 +111,18 @@ func TestCollectPendingUpstreamModelChangesFromModels_WithModelMapping(t *testin
|
||||
require.Equal(t, []string{"stale-model"}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestCollectPendingUpstreamModelChangesFromModels_WithIgnoredRegexPatterns(t *testing.T) {
|
||||
pendingAddModels, pendingRemoveModels := collectPendingUpstreamModelChangesFromModels(
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "claude-3-5-sonnet", "sora-video", "gpt-4.1"},
|
||||
[]string{"regex:^sora-.*$", "gpt-4.1"},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, []string{"claude-3-5-sonnet"}, pendingAddModels)
|
||||
require.Equal(t, []string{}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestBuildUpstreamModelUpdateTaskNotificationContent_OmitOverflowDetails(t *testing.T) {
|
||||
channelSummaries := make([]upstreamModelUpdateChannelSummary, 0, 12)
|
||||
for i := 0; i < 12; i++ {
|
||||
|
||||
+3
-1
@@ -190,7 +190,9 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
}
|
||||
}
|
||||
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, gin.H{
|
||||
"action": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
// findOrCreateOAuthUser finds existing user or creates new user
|
||||
|
||||
+58
-3
@@ -1,7 +1,6 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -17,10 +16,56 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var completionRatioMetaOptionKeys = []string{
|
||||
"ModelPrice",
|
||||
"ModelRatio",
|
||||
"CompletionRatio",
|
||||
"CacheRatio",
|
||||
"CreateCacheRatio",
|
||||
"ImageRatio",
|
||||
"AudioRatio",
|
||||
"AudioCompletionRatio",
|
||||
}
|
||||
|
||||
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := common.UnmarshalJsonStr(raw, &parsed); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for modelName := range parsed {
|
||||
modelNames[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func buildCompletionRatioMetaValue(optionValues map[string]string) string {
|
||||
modelNames := make(map[string]struct{})
|
||||
for _, key := range completionRatioMetaOptionKeys {
|
||||
collectModelNamesFromOptionValue(optionValues[key], modelNames)
|
||||
}
|
||||
|
||||
meta := make(map[string]ratio_setting.CompletionRatioInfo, len(modelNames))
|
||||
for modelName := range modelNames {
|
||||
meta[modelName] = ratio_setting.GetCompletionRatioInfo(modelName)
|
||||
}
|
||||
|
||||
jsonBytes, err := common.Marshal(meta)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func GetOptions(c *gin.Context) {
|
||||
var options []*model.Option
|
||||
optionValues := make(map[string]string)
|
||||
common.OptionMapRWMutex.Lock()
|
||||
for k, v := range common.OptionMap {
|
||||
value := common.Interface2String(v)
|
||||
if strings.HasSuffix(k, "Token") ||
|
||||
strings.HasSuffix(k, "Secret") ||
|
||||
strings.HasSuffix(k, "Key") ||
|
||||
@@ -30,10 +75,20 @@ func GetOptions(c *gin.Context) {
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
Key: k,
|
||||
Value: common.Interface2String(v),
|
||||
Value: value,
|
||||
})
|
||||
for _, optionKey := range completionRatioMetaOptionKeys {
|
||||
if optionKey == k {
|
||||
optionValues[k] = value
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
options = append(options, &model.Option{
|
||||
Key: "CompletionRatioMeta",
|
||||
Value: buildCompletionRatioMetaValue(optionValues),
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -49,7 +104,7 @@ type OptionUpdateRequest struct {
|
||||
|
||||
func UpdateOption(c *gin.Context) {
|
||||
var option OptionUpdateRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
||||
err := common.DecodeJson(c.Request.Body, &option)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -470,6 +470,15 @@ func PasskeyVerifyFinish(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
// Mark passkey as ready; /api/verify will convert this into the final secure verification session.
|
||||
session.Set(PasskeyReadySessionKey, time.Now().Unix())
|
||||
session.Delete(SecureVerificationSessionKey)
|
||||
if err := session.Save(); err != nil {
|
||||
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Passkey 验证成功",
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -169,6 +175,183 @@ func ForceGC(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// LogFileInfo 日志文件信息
|
||||
type LogFileInfo struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
ModTime time.Time `json:"mod_time"`
|
||||
}
|
||||
|
||||
// LogFilesResponse 日志文件列表响应
|
||||
type LogFilesResponse struct {
|
||||
LogDir string `json:"log_dir"`
|
||||
Enabled bool `json:"enabled"`
|
||||
FileCount int `json:"file_count"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
OldestTime *time.Time `json:"oldest_time,omitempty"`
|
||||
NewestTime *time.Time `json:"newest_time,omitempty"`
|
||||
Files []LogFileInfo `json:"files"`
|
||||
}
|
||||
|
||||
// getLogFiles 读取日志目录中的日志文件列表
|
||||
func getLogFiles() ([]LogFileInfo, error) {
|
||||
if *common.LogDir == "" {
|
||||
return nil, nil
|
||||
}
|
||||
entries, err := os.ReadDir(*common.LogDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var files []LogFileInfo
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "oneapi-") || !strings.HasSuffix(name, ".log") {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
files = append(files, LogFileInfo{
|
||||
Name: name,
|
||||
Size: info.Size(),
|
||||
ModTime: info.ModTime(),
|
||||
})
|
||||
}
|
||||
// 按文件名降序排列(最新在前)
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Name > files[j].Name
|
||||
})
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// GetLogFiles 获取日志文件列表
|
||||
func GetLogFiles(c *gin.Context) {
|
||||
if *common.LogDir == "" {
|
||||
common.ApiSuccess(c, LogFilesResponse{Enabled: false})
|
||||
return
|
||||
}
|
||||
files, err := getLogFiles()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var totalSize int64
|
||||
var oldest, newest time.Time
|
||||
for i, f := range files {
|
||||
totalSize += f.Size
|
||||
if i == 0 || f.ModTime.Before(oldest) {
|
||||
oldest = f.ModTime
|
||||
}
|
||||
if i == 0 || f.ModTime.After(newest) {
|
||||
newest = f.ModTime
|
||||
}
|
||||
}
|
||||
resp := LogFilesResponse{
|
||||
LogDir: *common.LogDir,
|
||||
Enabled: true,
|
||||
FileCount: len(files),
|
||||
TotalSize: totalSize,
|
||||
Files: files,
|
||||
}
|
||||
if len(files) > 0 {
|
||||
resp.OldestTime = &oldest
|
||||
resp.NewestTime = &newest
|
||||
}
|
||||
common.ApiSuccess(c, resp)
|
||||
}
|
||||
|
||||
// CleanupLogFiles 清理过期日志文件
|
||||
func CleanupLogFiles(c *gin.Context) {
|
||||
mode := c.Query("mode")
|
||||
valueStr := c.Query("value")
|
||||
if mode != "by_count" && mode != "by_days" {
|
||||
common.ApiErrorMsg(c, "invalid mode, must be by_count or by_days")
|
||||
return
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil || value < 1 {
|
||||
common.ApiErrorMsg(c, "invalid value, must be a positive integer")
|
||||
return
|
||||
}
|
||||
if *common.LogDir == "" {
|
||||
common.ApiErrorMsg(c, "log directory not configured")
|
||||
return
|
||||
}
|
||||
|
||||
files, err := getLogFiles()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
activeLogPath := logger.GetCurrentLogPath()
|
||||
var toDelete []LogFileInfo
|
||||
|
||||
switch mode {
|
||||
case "by_count":
|
||||
// files 已按名称降序(最新在前),保留前 value 个
|
||||
for i, f := range files {
|
||||
if i < value {
|
||||
continue
|
||||
}
|
||||
fullPath := filepath.Join(*common.LogDir, f.Name)
|
||||
if fullPath == activeLogPath {
|
||||
continue
|
||||
}
|
||||
toDelete = append(toDelete, f)
|
||||
}
|
||||
case "by_days":
|
||||
cutoff := time.Now().AddDate(0, 0, -value)
|
||||
for _, f := range files {
|
||||
if f.ModTime.Before(cutoff) {
|
||||
fullPath := filepath.Join(*common.LogDir, f.Name)
|
||||
if fullPath == activeLogPath {
|
||||
continue
|
||||
}
|
||||
toDelete = append(toDelete, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var deletedCount int
|
||||
var freedBytes int64
|
||||
var failedFiles []string
|
||||
for _, f := range toDelete {
|
||||
fullPath := filepath.Join(*common.LogDir, f.Name)
|
||||
if err := os.Remove(fullPath); err != nil {
|
||||
failedFiles = append(failedFiles, f.Name)
|
||||
continue
|
||||
}
|
||||
deletedCount++
|
||||
freedBytes += f.Size
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"deleted_count": deletedCount,
|
||||
"freed_bytes": freedBytes,
|
||||
"failed_files": failedFiles,
|
||||
}
|
||||
|
||||
if len(failedFiles) > 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("部分文件删除失败(%d/%d)", len(failedFiles), len(toDelete)),
|
||||
"data": result,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
// getDiskCacheInfo 获取磁盘缓存目录信息
|
||||
func getDiskCacheInfo() DiskCacheInfo {
|
||||
// 使用统一的缓存目录
|
||||
|
||||
@@ -341,6 +341,9 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
if code < 100 || code > 599 {
|
||||
return true
|
||||
}
|
||||
if operation_setting.IsAlwaysSkipRetryCode(openaiErr.GetErrorCode()) {
|
||||
return false
|
||||
}
|
||||
return operation_setting.ShouldRetryByStatusCode(code)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,18 +7,19 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
passkeysvc "github.com/QuantumNous/new-api/service/passkey"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// SecureVerificationSessionKey 安全验证的 session key
|
||||
// SecureVerificationSessionKey means the user has fully passed secure verification.
|
||||
SecureVerificationSessionKey = "secure_verified_at"
|
||||
// PasskeyReadySessionKey means WebAuthn finished and /api/verify can finalize step-up verification.
|
||||
PasskeyReadySessionKey = "secure_passkey_ready_at"
|
||||
// SecureVerificationTimeout 验证有效期(秒)
|
||||
SecureVerificationTimeout = 300 // 5分钟
|
||||
// PasskeyReadyTimeout passkey ready 标记有效期(秒)
|
||||
PasskeyReadyTimeout = 60
|
||||
)
|
||||
|
||||
type UniversalVerifyRequest struct {
|
||||
@@ -76,6 +77,7 @@ func UniversalVerify(c *gin.Context) {
|
||||
// 根据验证方式进行验证
|
||||
var verified bool
|
||||
var verifyMethod string
|
||||
var err error
|
||||
|
||||
switch req.Method {
|
||||
case "2fa":
|
||||
@@ -95,10 +97,16 @@ func UniversalVerify(c *gin.Context) {
|
||||
common.ApiError(c, fmt.Errorf("用户未启用Passkey"))
|
||||
return
|
||||
}
|
||||
// Passkey 验证需要先调用 PasskeyVerifyBegin 和 PasskeyVerifyFinish
|
||||
// 这里只是验证 Passkey 验证流程是否已经完成
|
||||
// 实际上,前端应该先调用这两个接口,然后再调用本接口
|
||||
verified = true // Passkey 验证逻辑已在 PasskeyVerifyFinish 中完成
|
||||
// Passkey branch only trusts the short-lived marker written by PasskeyVerifyFinish.
|
||||
verified, err = consumePasskeyReady(c)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("Passkey 验证状态异常: %v", err))
|
||||
return
|
||||
}
|
||||
if !verified {
|
||||
common.ApiError(c, fmt.Errorf("请先完成 Passkey 验证"))
|
||||
return
|
||||
}
|
||||
verifyMethod = "Passkey"
|
||||
|
||||
default:
|
||||
@@ -112,10 +120,8 @@ func UniversalVerify(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证成功,在 session 中记录时间戳
|
||||
session := sessions.Default(c)
|
||||
now := time.Now().Unix()
|
||||
session.Set(SecureVerificationSessionKey, now)
|
||||
if err := session.Save(); err != nil {
|
||||
now, err := setSecureVerificationSession(c)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
|
||||
return
|
||||
}
|
||||
@@ -133,94 +139,37 @@ func UniversalVerify(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// PasskeyVerifyAndSetSession Passkey 验证完成后设置 session
|
||||
// 这是一个辅助函数,供 PasskeyVerifyFinish 调用
|
||||
func PasskeyVerifyAndSetSession(c *gin.Context) {
|
||||
func setSecureVerificationSession(c *gin.Context) (int64, error) {
|
||||
session := sessions.Default(c)
|
||||
session.Delete(PasskeyReadySessionKey)
|
||||
now := time.Now().Unix()
|
||||
session.Set(SecureVerificationSessionKey, now)
|
||||
_ = session.Save()
|
||||
if err := session.Save(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return now, nil
|
||||
}
|
||||
|
||||
// PasskeyVerifyForSecure 用于安全验证的 Passkey 验证流程
|
||||
// 整合了 begin 和 finish 流程
|
||||
func PasskeyVerifyForSecure(c *gin.Context) {
|
||||
if !system_setting.GetPasskeySettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未启用 Passkey 登录",
|
||||
})
|
||||
return
|
||||
func consumePasskeyReady(c *gin.Context) (bool, error) {
|
||||
session := sessions.Default(c)
|
||||
readyAtRaw := session.Get(PasskeyReadySessionKey)
|
||||
if readyAtRaw == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "未登录",
|
||||
})
|
||||
return
|
||||
readyAt, ok := readyAtRaw.(int64)
|
||||
if !ok {
|
||||
session.Delete(PasskeyReadySessionKey)
|
||||
_ = session.Save()
|
||||
return false, fmt.Errorf("无效的 Passkey 验证状态")
|
||||
}
|
||||
|
||||
user := &model.User{Id: userId}
|
||||
if err := user.FillUserById(); err != nil {
|
||||
common.ApiError(c, fmt.Errorf("获取用户信息失败: %v", err))
|
||||
return
|
||||
session.Delete(PasskeyReadySessionKey)
|
||||
if err := session.Save(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
common.ApiError(c, fmt.Errorf("该用户已被禁用"))
|
||||
return
|
||||
// Expired ready markers cannot be reused.
|
||||
if time.Now().Unix()-readyAt >= PasskeyReadyTimeout {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
credential, err := model.GetPasskeyByUserID(userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户尚未绑定 Passkey",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
wa, err := passkeysvc.BuildWebAuthn(c.Request)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
waUser := passkeysvc.NewWebAuthnUser(user, credential)
|
||||
sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.VerifySessionKey)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = wa.FinishLogin(waUser, *sessionData, c.Request)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新凭证的最后使用时间
|
||||
now := time.Now()
|
||||
credential.LastUsedAt = &now
|
||||
if err := model.UpsertPasskeyCredential(credential); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证成功,设置 session
|
||||
PasskeyVerifyAndSetSession(c)
|
||||
|
||||
// 记录日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "Passkey 安全验证成功")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Passkey 验证成功",
|
||||
"data": gin.H{
|
||||
"verified": true,
|
||||
"expires_at": time.Now().Unix() + SecureVerificationTimeout,
|
||||
},
|
||||
})
|
||||
return true, nil
|
||||
}
|
||||
|
||||
+37
-12
@@ -14,6 +14,23 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func buildMaskedTokenResponse(token *model.Token) *model.Token {
|
||||
if token == nil {
|
||||
return nil
|
||||
}
|
||||
maskedToken := *token
|
||||
maskedToken.Key = token.GetMaskedKey()
|
||||
return &maskedToken
|
||||
}
|
||||
|
||||
func buildMaskedTokenResponses(tokens []*model.Token) []*model.Token {
|
||||
maskedTokens := make([]*model.Token, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
maskedTokens = append(maskedTokens, buildMaskedTokenResponse(token))
|
||||
}
|
||||
return maskedTokens
|
||||
}
|
||||
|
||||
func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
@@ -24,9 +41,8 @@ func GetAllTokens(c *gin.Context) {
|
||||
}
|
||||
total, _ := model.CountUserTokens(userId)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
pageInfo.SetItems(buildMaskedTokenResponses(tokens))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchTokens(c *gin.Context) {
|
||||
@@ -42,9 +58,8 @@ func SearchTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
pageInfo.SetItems(buildMaskedTokenResponses(tokens))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetToken(c *gin.Context) {
|
||||
@@ -59,12 +74,24 @@ func GetToken(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": token,
|
||||
common.ApiSuccess(c, buildMaskedTokenResponse(token))
|
||||
}
|
||||
|
||||
func GetTokenKey(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
token, err := model.GetTokenByIds(id, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"key": token.GetFullKey(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetTokenStatus(c *gin.Context) {
|
||||
@@ -204,7 +231,6 @@ func AddToken(c *gin.Context) {
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteToken(c *gin.Context) {
|
||||
@@ -219,7 +245,6 @@ func DeleteToken(c *gin.Context) {
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateToken(c *gin.Context) {
|
||||
@@ -283,7 +308,7 @@ func UpdateToken(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": cleanToken,
|
||||
"data": buildMaskedTokenResponse(cleanToken),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type tokenAPIResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type tokenPageResponse struct {
|
||||
Items []tokenResponseItem `json:"items"`
|
||||
}
|
||||
|
||||
type tokenResponseItem struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Key string `json:"key"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
type tokenKeyResponse struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
common.UsingSQLite = true
|
||||
common.UsingMySQL = false
|
||||
common.UsingPostgreSQL = false
|
||||
common.RedisEnabled = false
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite db: %v", err)
|
||||
}
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
if err := db.AutoMigrate(&model.Token{}); err != nil {
|
||||
t.Fatalf("failed to migrate token table: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlDB, err := db.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
|
||||
t.Helper()
|
||||
|
||||
token := &model.Token{
|
||||
UserId: userID,
|
||||
Name: name,
|
||||
Key: rawKey,
|
||||
Status: common.TokenStatusEnabled,
|
||||
CreatedTime: 1,
|
||||
AccessedTime: 1,
|
||||
ExpiredTime: -1,
|
||||
RemainQuota: 100,
|
||||
UnlimitedQuota: true,
|
||||
Group: "default",
|
||||
}
|
||||
if err := db.Create(token).Error; err != nil {
|
||||
t.Fatalf("failed to create token: %v", err)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func newAuthenticatedContext(t *testing.T, method string, target string, body any, userID int) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
|
||||
var requestBody *bytes.Reader
|
||||
if body != nil {
|
||||
payload, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request body: %v", err)
|
||||
}
|
||||
requestBody = bytes.NewReader(payload)
|
||||
} else {
|
||||
requestBody = bytes.NewReader(nil)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(method, target, requestBody)
|
||||
if body != nil {
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
ctx.Set("id", userID)
|
||||
return ctx, recorder
|
||||
}
|
||||
|
||||
func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenAPIResponse {
|
||||
t.Helper()
|
||||
|
||||
var response tokenAPIResponse
|
||||
if err := common.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to decode api response: %v", err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
|
||||
db := setupTokenControllerTestDB(t)
|
||||
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
|
||||
seedToken(t, db, 2, "other-user-token", "zzzz1234yyyy5678")
|
||||
|
||||
ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/?p=1&size=10", nil, 1)
|
||||
GetAllTokens(ctx)
|
||||
|
||||
response := decodeAPIResponse(t, recorder)
|
||||
if !response.Success {
|
||||
t.Fatalf("expected success response, got message: %s", response.Message)
|
||||
}
|
||||
|
||||
var page tokenPageResponse
|
||||
if err := common.Unmarshal(response.Data, &page); err != nil {
|
||||
t.Fatalf("failed to decode token page response: %v", err)
|
||||
}
|
||||
if len(page.Items) != 1 {
|
||||
t.Fatalf("expected exactly one token, got %d", len(page.Items))
|
||||
}
|
||||
if page.Items[0].Key != token.GetMaskedKey() {
|
||||
t.Fatalf("expected masked key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
|
||||
}
|
||||
if strings.Contains(recorder.Body.String(), token.Key) {
|
||||
t.Fatalf("list response leaked raw token key: %s", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchTokensMasksKeyInResponse(t *testing.T) {
|
||||
db := setupTokenControllerTestDB(t)
|
||||
token := seedToken(t, db, 1, "searchable-token", "ijkl1234mnop5678")
|
||||
|
||||
ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/search?keyword=searchable-token&p=1&size=10", nil, 1)
|
||||
SearchTokens(ctx)
|
||||
|
||||
response := decodeAPIResponse(t, recorder)
|
||||
if !response.Success {
|
||||
t.Fatalf("expected success response, got message: %s", response.Message)
|
||||
}
|
||||
|
||||
var page tokenPageResponse
|
||||
if err := common.Unmarshal(response.Data, &page); err != nil {
|
||||
t.Fatalf("failed to decode search response: %v", err)
|
||||
}
|
||||
if len(page.Items) != 1 {
|
||||
t.Fatalf("expected exactly one search result, got %d", len(page.Items))
|
||||
}
|
||||
if page.Items[0].Key != token.GetMaskedKey() {
|
||||
t.Fatalf("expected masked search key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
|
||||
}
|
||||
if strings.Contains(recorder.Body.String(), token.Key) {
|
||||
t.Fatalf("search response leaked raw token key: %s", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenMasksKeyInResponse(t *testing.T) {
|
||||
db := setupTokenControllerTestDB(t)
|
||||
token := seedToken(t, db, 1, "detail-token", "qrst1234uvwx5678")
|
||||
|
||||
ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/"+strconv.Itoa(token.Id), nil, 1)
|
||||
ctx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
|
||||
GetToken(ctx)
|
||||
|
||||
response := decodeAPIResponse(t, recorder)
|
||||
if !response.Success {
|
||||
t.Fatalf("expected success response, got message: %s", response.Message)
|
||||
}
|
||||
|
||||
var detail tokenResponseItem
|
||||
if err := common.Unmarshal(response.Data, &detail); err != nil {
|
||||
t.Fatalf("failed to decode token detail response: %v", err)
|
||||
}
|
||||
if detail.Key != token.GetMaskedKey() {
|
||||
t.Fatalf("expected masked detail key %q, got %q", token.GetMaskedKey(), detail.Key)
|
||||
}
|
||||
if strings.Contains(recorder.Body.String(), token.Key) {
|
||||
t.Fatalf("detail response leaked raw token key: %s", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateTokenMasksKeyInResponse(t *testing.T) {
|
||||
db := setupTokenControllerTestDB(t)
|
||||
token := seedToken(t, db, 1, "editable-token", "yzab1234cdef5678")
|
||||
|
||||
body := map[string]any{
|
||||
"id": token.Id,
|
||||
"name": "updated-token",
|
||||
"expired_time": -1,
|
||||
"remain_quota": 100,
|
||||
"unlimited_quota": true,
|
||||
"model_limits_enabled": false,
|
||||
"model_limits": "",
|
||||
"group": "default",
|
||||
"cross_group_retry": false,
|
||||
}
|
||||
|
||||
ctx, recorder := newAuthenticatedContext(t, http.MethodPut, "/api/token/", body, 1)
|
||||
UpdateToken(ctx)
|
||||
|
||||
response := decodeAPIResponse(t, recorder)
|
||||
if !response.Success {
|
||||
t.Fatalf("expected success response, got message: %s", response.Message)
|
||||
}
|
||||
|
||||
var detail tokenResponseItem
|
||||
if err := common.Unmarshal(response.Data, &detail); err != nil {
|
||||
t.Fatalf("failed to decode token update response: %v", err)
|
||||
}
|
||||
if detail.Key != token.GetMaskedKey() {
|
||||
t.Fatalf("expected masked update key %q, got %q", token.GetMaskedKey(), detail.Key)
|
||||
}
|
||||
if strings.Contains(recorder.Body.String(), token.Key) {
|
||||
t.Fatalf("update response leaked raw token key: %s", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenKeyRequiresOwnershipAndReturnsFullKey(t *testing.T) {
|
||||
db := setupTokenControllerTestDB(t)
|
||||
token := seedToken(t, db, 1, "owned-token", "owner1234token5678")
|
||||
|
||||
authorizedCtx, authorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 1)
|
||||
authorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
|
||||
GetTokenKey(authorizedCtx)
|
||||
|
||||
authorizedResponse := decodeAPIResponse(t, authorizedRecorder)
|
||||
if !authorizedResponse.Success {
|
||||
t.Fatalf("expected authorized key fetch to succeed, got message: %s", authorizedResponse.Message)
|
||||
}
|
||||
|
||||
var keyData tokenKeyResponse
|
||||
if err := common.Unmarshal(authorizedResponse.Data, &keyData); err != nil {
|
||||
t.Fatalf("failed to decode token key response: %v", err)
|
||||
}
|
||||
if keyData.Key != token.GetFullKey() {
|
||||
t.Fatalf("expected full key %q, got %q", token.GetFullKey(), keyData.Key)
|
||||
}
|
||||
|
||||
unauthorizedCtx, unauthorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 2)
|
||||
unauthorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
|
||||
GetTokenKey(unauthorizedCtx)
|
||||
|
||||
unauthorizedResponse := decodeAPIResponse(t, unauthorizedRecorder)
|
||||
if unauthorizedResponse.Success {
|
||||
t.Fatalf("expected unauthorized key fetch to fail")
|
||||
}
|
||||
if strings.Contains(unauthorizedRecorder.Body.String(), token.Key) {
|
||||
t.Fatalf("unauthorized key response leaked raw token key: %s", unauthorizedRecorder.Body.String())
|
||||
}
|
||||
}
|
||||
+68
-14
@@ -48,14 +48,52 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果启用了 Waffo 支付,添加到支付方法列表
|
||||
enableWaffo := setting.WaffoEnabled &&
|
||||
((!setting.WaffoSandbox &&
|
||||
setting.WaffoApiKey != "" &&
|
||||
setting.WaffoPrivateKey != "" &&
|
||||
setting.WaffoPublicCert != "") ||
|
||||
(setting.WaffoSandbox &&
|
||||
setting.WaffoSandboxApiKey != "" &&
|
||||
setting.WaffoSandboxPrivateKey != "" &&
|
||||
setting.WaffoSandboxPublicCert != ""))
|
||||
if enableWaffo {
|
||||
hasWaffo := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == "waffo" {
|
||||
hasWaffo = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasWaffo {
|
||||
waffoMethod := map[string]string{
|
||||
"name": "Waffo (Global Payment)",
|
||||
"type": "waffo",
|
||||
"color": "rgba(var(--semi-blue-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.WaffoMinTopUp),
|
||||
}
|
||||
payMethods = append(payMethods, waffoMethod)
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"enable_creem_topup": setting.CreemApiKey != "" && setting.CreemProducts != "[]",
|
||||
"creem_products": setting.CreemProducts,
|
||||
"enable_waffo_topup": enableWaffo,
|
||||
"waffo_pay_methods": func() interface{} {
|
||||
if enableWaffo {
|
||||
return setting.GetWaffoPayMethods()
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
"creem_products": setting.CreemProducts,
|
||||
"pay_methods": payMethods,
|
||||
"min_topup": operation_setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"waffo_min_topup": setting.WaffoMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
}
|
||||
@@ -204,27 +242,42 @@ func RequestEpay(c *gin.Context) {
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// refCountedMutex 带引用计数的互斥锁,确保最后一个使用者才从 map 中删除
|
||||
type refCountedMutex struct {
|
||||
mu sync.Mutex
|
||||
refCount int
|
||||
}
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
createLock.Lock()
|
||||
var rcm *refCountedMutex
|
||||
if v, ok := orderLocks.Load(tradeNo); ok {
|
||||
rcm = v.(*refCountedMutex)
|
||||
} else {
|
||||
rcm = &refCountedMutex{}
|
||||
orderLocks.Store(tradeNo, rcm)
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
rcm.refCount++
|
||||
createLock.Unlock()
|
||||
rcm.mu.Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
v, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
rcm := v.(*refCountedMutex)
|
||||
rcm.mu.Unlock()
|
||||
|
||||
createLock.Lock()
|
||||
rcm.refCount--
|
||||
if rcm.refCount == 0 {
|
||||
orderLocks.Delete(tradeNo)
|
||||
}
|
||||
createLock.Unlock()
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
@@ -410,3 +463,4 @@ func AdminCompleteTopUp(c *gin.Context) {
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/thanhpk/randstr"
|
||||
waffo "github.com/waffo-com/waffo-go"
|
||||
"github.com/waffo-com/waffo-go/config"
|
||||
"github.com/waffo-com/waffo-go/core"
|
||||
"github.com/waffo-com/waffo-go/types/order"
|
||||
)
|
||||
|
||||
func getWaffoSDK() (*waffo.Waffo, error) {
|
||||
env := config.Sandbox
|
||||
apiKey := setting.WaffoSandboxApiKey
|
||||
privateKey := setting.WaffoSandboxPrivateKey
|
||||
publicKey := setting.WaffoSandboxPublicCert
|
||||
if !setting.WaffoSandbox {
|
||||
env = config.Production
|
||||
apiKey = setting.WaffoApiKey
|
||||
privateKey = setting.WaffoPrivateKey
|
||||
publicKey = setting.WaffoPublicCert
|
||||
}
|
||||
builder := config.NewConfigBuilder().
|
||||
APIKey(apiKey).
|
||||
PrivateKey(privateKey).
|
||||
WaffoPublicKey(publicKey).
|
||||
Environment(env)
|
||||
if setting.WaffoMerchantId != "" {
|
||||
builder = builder.MerchantID(setting.WaffoMerchantId)
|
||||
}
|
||||
cfg, err := builder.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return waffo.New(cfg), nil
|
||||
}
|
||||
|
||||
func getWaffoUserEmail(user *model.User) string {
|
||||
return fmt.Sprintf("%d@examples.com", user.Id)
|
||||
}
|
||||
|
||||
func getWaffoCurrency() string {
|
||||
if setting.WaffoCurrency != "" {
|
||||
return setting.WaffoCurrency
|
||||
}
|
||||
return "USD"
|
||||
}
|
||||
|
||||
// zeroDecimalCurrencies 零小数位币种,金额不能带小数点
|
||||
var zeroDecimalCurrencies = map[string]bool{
|
||||
"IDR": true, "JPY": true, "KRW": true, "VND": true,
|
||||
}
|
||||
|
||||
func formatWaffoAmount(amount float64, currency string) string {
|
||||
if zeroDecimalCurrencies[currency] {
|
||||
return fmt.Sprintf("%.0f", amount)
|
||||
}
|
||||
return fmt.Sprintf("%.2f", amount)
|
||||
}
|
||||
|
||||
// getWaffoPayMoney converts the user-facing amount to USD for Waffo payment.
|
||||
// Waffo only accepts USD, so this function handles the conversion from different
|
||||
// display types (USD/CNY/TOKENS) to the actual USD amount to charge.
|
||||
func getWaffoPayMoney(amount float64, group string) float64 {
|
||||
originalAmount := amount
|
||||
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
return amount * setting.WaffoUnitPrice * topupGroupRatio * discount
|
||||
}
|
||||
|
||||
type WaffoPayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PayMethodIndex *int `json:"pay_method_index"` // 服务端支付方式列表的索引,nil 表示由 Waffo 自动选择
|
||||
PayMethodType string `json:"pay_method_type"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
|
||||
PayMethodName string `json:"pay_method_name"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
|
||||
}
|
||||
|
||||
// RequestWaffoPay 创建 Waffo 支付订单
|
||||
func RequestWaffoPay(c *gin.Context) {
|
||||
if !setting.WaffoEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "Waffo 支付未启用"})
|
||||
return
|
||||
}
|
||||
|
||||
var req WaffoPayRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
waffoMinTopup := int64(setting.WaffoMinTopUp)
|
||||
if req.Amount < waffoMinTopup {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil || user == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 从服务端配置查找支付方式,客户端只传索引或旧字段
|
||||
var resolvedPayMethodType, resolvedPayMethodName string
|
||||
methods := setting.GetWaffoPayMethods()
|
||||
if req.PayMethodIndex != nil {
|
||||
// 新协议:按索引查找
|
||||
idx := *req.PayMethodIndex
|
||||
if idx < 0 || idx >= len(methods) {
|
||||
log.Printf("Waffo 无效的支付方式索引: %d, UserId=%d, 可用范围: [0, %d)", idx, id, len(methods))
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
return
|
||||
}
|
||||
resolvedPayMethodType = methods[idx].PayMethodType
|
||||
resolvedPayMethodName = methods[idx].PayMethodName
|
||||
} else if req.PayMethodType != "" {
|
||||
// 兼容旧前端:验证客户端传的值在服务端列表中
|
||||
valid := false
|
||||
for _, m := range methods {
|
||||
if m.PayMethodType == req.PayMethodType && m.PayMethodName == req.PayMethodName {
|
||||
valid = true
|
||||
resolvedPayMethodType = m.PayMethodType
|
||||
resolvedPayMethodName = m.PayMethodName
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
log.Printf("Waffo 无效的支付方式: PayMethodType=%s, PayMethodName=%s, UserId=%d", req.PayMethodType, req.PayMethodName, id)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
|
||||
return
|
||||
}
|
||||
}
|
||||
// resolvedPayMethodType/Name 为空时,Waffo 自动选择支付方式
|
||||
|
||||
group, _ := model.GetUserGroup(id, true)
|
||||
payMoney := getWaffoPayMoney(float64(req.Amount), group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成唯一订单号,paymentRequestId 与 merchantOrderId 保持一致,简化追踪
|
||||
merchantOrderId := fmt.Sprintf("WAFFO-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
|
||||
paymentRequestId := merchantOrderId
|
||||
|
||||
// Token 模式下归一化 Amount(存等价美元/CNY 数量,避免 RechargeWaffo 双重放大)
|
||||
amount := req.Amount
|
||||
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
|
||||
amount = int64(float64(req.Amount) / common.QuotaPerUnit)
|
||||
if amount < 1 {
|
||||
amount = 1
|
||||
}
|
||||
}
|
||||
|
||||
// 创建本地订单
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: "waffo",
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
log.Printf("Waffo 创建本地订单失败: %v", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
|
||||
sdk, err := getWaffoSDK()
|
||||
if err != nil {
|
||||
log.Printf("Waffo SDK 初始化失败: %v", err)
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付配置错误"})
|
||||
return
|
||||
}
|
||||
|
||||
callbackAddr := service.GetCallbackAddress()
|
||||
notifyUrl := callbackAddr + "/api/waffo/webhook"
|
||||
if setting.WaffoNotifyUrl != "" {
|
||||
notifyUrl = setting.WaffoNotifyUrl
|
||||
}
|
||||
returnUrl := system_setting.ServerAddress + "/console/topup?show_history=true"
|
||||
if setting.WaffoReturnUrl != "" {
|
||||
returnUrl = setting.WaffoReturnUrl
|
||||
}
|
||||
|
||||
currency := getWaffoCurrency()
|
||||
createParams := &order.CreateOrderParams{
|
||||
PaymentRequestID: paymentRequestId,
|
||||
MerchantOrderID: merchantOrderId,
|
||||
OrderAmount: formatWaffoAmount(payMoney, currency),
|
||||
OrderCurrency: currency,
|
||||
OrderDescription: fmt.Sprintf("Recharge %d credits", req.Amount),
|
||||
OrderRequestedAt: time.Now().UTC().Format("2006-01-02T15:04:05.000Z"),
|
||||
NotifyURL: notifyUrl,
|
||||
MerchantInfo: &order.MerchantInfo{
|
||||
MerchantID: setting.WaffoMerchantId,
|
||||
},
|
||||
UserInfo: &order.UserInfo{
|
||||
UserID: strconv.Itoa(user.Id),
|
||||
UserEmail: getWaffoUserEmail(user),
|
||||
UserTerminal: "WEB",
|
||||
},
|
||||
PaymentInfo: &order.PaymentInfo{
|
||||
ProductName: "ONE_TIME_PAYMENT",
|
||||
PayMethodType: resolvedPayMethodType,
|
||||
PayMethodName: resolvedPayMethodName,
|
||||
},
|
||||
SuccessRedirectURL: returnUrl,
|
||||
FailedRedirectURL: returnUrl,
|
||||
}
|
||||
resp, err := sdk.Order().Create(c.Request.Context(), createParams, nil)
|
||||
if err != nil {
|
||||
log.Printf("Waffo 创建订单失败: %v", err)
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
if !resp.IsSuccess() {
|
||||
log.Printf("Waffo 创建订单业务失败: [%s] %s, 完整响应: %+v", resp.Code, resp.Message, resp)
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
orderData := resp.GetData()
|
||||
log.Printf("Waffo 订单创建成功 - 用户: %d, 订单: %s, 金额: %.2f", id, merchantOrderId, payMoney)
|
||||
|
||||
paymentUrl := orderData.FetchRedirectURL()
|
||||
if paymentUrl == "" {
|
||||
paymentUrl = orderData.OrderAction
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payment_url": paymentUrl,
|
||||
"order_id": merchantOrderId,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// webhookPayloadWithSubInfo 扩展 PAYMENT_NOTIFICATION,包含 SDK 未定义的 subscriptionInfo 字段
|
||||
type webhookPayloadWithSubInfo struct {
|
||||
EventType string `json:"eventType"`
|
||||
Result struct {
|
||||
core.PaymentNotificationResult
|
||||
SubscriptionInfo *webhookSubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
} `json:"result"`
|
||||
}
|
||||
|
||||
type webhookSubscriptionInfo struct {
|
||||
Period string `json:"period,omitempty"`
|
||||
MerchantRequest string `json:"merchantRequest,omitempty"`
|
||||
SubscriptionID string `json:"subscriptionId,omitempty"`
|
||||
SubscriptionRequest string `json:"subscriptionRequest,omitempty"`
|
||||
}
|
||||
|
||||
// WaffoWebhook 处理 Waffo 回调通知(支付/退款/订阅)
|
||||
func WaffoWebhook(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("Waffo Webhook 读取 body 失败: %v", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sdk, err := getWaffoSDK()
|
||||
if err != nil {
|
||||
log.Printf("Waffo Webhook SDK 初始化失败: %v", err)
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
wh := sdk.Webhook()
|
||||
bodyStr := string(bodyBytes)
|
||||
signature := c.GetHeader("X-SIGNATURE")
|
||||
|
||||
// 验证请求签名
|
||||
if !wh.VerifySignature(bodyStr, signature) {
|
||||
log.Printf("Waffo webhook 签名验证失败")
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var event core.WebhookEvent
|
||||
if err := common.Unmarshal(bodyBytes, &event); err != nil {
|
||||
log.Printf("Waffo Webhook 解析失败: %v", err)
|
||||
sendWaffoWebhookResponse(c, wh, false, "invalid payload")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.EventType {
|
||||
case core.EventPayment:
|
||||
// 解析为扩展类型,区分普通支付和订阅支付
|
||||
var payload webhookPayloadWithSubInfo
|
||||
if err := common.Unmarshal(bodyBytes, &payload); err != nil {
|
||||
sendWaffoWebhookResponse(c, wh, false, "invalid payment payload")
|
||||
return
|
||||
}
|
||||
log.Printf("Waffo Webhook - EventType: %s, MerchantOrderId: %s, OrderStatus: %s",
|
||||
event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus)
|
||||
handleWaffoPayment(c, wh, &payload.Result.PaymentNotificationResult)
|
||||
default:
|
||||
log.Printf("Waffo Webhook 未知事件: %s", event.EventType)
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaffoPayment 处理支付完成通知
|
||||
func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.PaymentNotificationResult) {
|
||||
if result.OrderStatus != "PAY_SUCCESS" {
|
||||
log.Printf("Waffo 订单状态非成功: %s, 订单: %s", result.OrderStatus, result.MerchantOrderID)
|
||||
// 终态失败订单标记为 failed,避免永远停在 pending
|
||||
if result.MerchantOrderID != "" {
|
||||
if topUp := model.GetTopUpByTradeNo(result.MerchantOrderID); topUp != nil &&
|
||||
topUp.Status == common.TopUpStatusPending {
|
||||
topUp.Status = common.TopUpStatusFailed
|
||||
_ = topUp.Update()
|
||||
}
|
||||
}
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
return
|
||||
}
|
||||
|
||||
merchantOrderId := result.MerchantOrderID
|
||||
|
||||
LockOrder(merchantOrderId)
|
||||
defer UnlockOrder(merchantOrderId)
|
||||
|
||||
if err := model.RechargeWaffo(merchantOrderId); err != nil {
|
||||
log.Printf("Waffo 充值处理失败: %v, 订单: %s", err, merchantOrderId)
|
||||
sendWaffoWebhookResponse(c, wh, false, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Waffo 充值成功 - 订单: %s", merchantOrderId)
|
||||
sendWaffoWebhookResponse(c, wh, true, "")
|
||||
}
|
||||
|
||||
// sendWaffoWebhookResponse 发送签名响应
|
||||
func sendWaffoWebhookResponse(c *gin.Context, wh *core.WebhookHandler, success bool, msg string) {
|
||||
var body, sig string
|
||||
if success {
|
||||
body, sig = wh.BuildSuccessResponse()
|
||||
} else {
|
||||
body, sig = wh.BuildFailedResponse(msg)
|
||||
}
|
||||
c.Header("X-SIGNATURE", sig)
|
||||
c.Data(http.StatusOK, "application/json", []byte(body))
|
||||
}
|
||||
@@ -35,7 +35,8 @@ func VideoProxy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
task, exists, err := model.GetByOnlyTaskId(taskID)
|
||||
userID := c.GetInt("id")
|
||||
task, exists, err := model.GetByTaskId(userID, taskID)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
|
||||
|
||||
@@ -43,6 +43,8 @@ services:
|
||||
- redis
|
||||
- postgres
|
||||
# - mysql # Uncomment if using MySQL
|
||||
networks:
|
||||
- new-api-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' || exit 1"]
|
||||
interval: 30s
|
||||
@@ -53,6 +55,8 @@ services:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
restart: always
|
||||
networks:
|
||||
- new-api-network
|
||||
|
||||
postgres:
|
||||
image: postgres:15
|
||||
@@ -64,6 +68,8 @@ services:
|
||||
POSTGRES_DB: new-api
|
||||
volumes:
|
||||
- pg_data:/var/lib/postgresql/data
|
||||
networks:
|
||||
- new-api-network
|
||||
# ports:
|
||||
# - "5432:5432" # Uncomment if you need to access PostgreSQL from outside Docker
|
||||
|
||||
@@ -76,9 +82,15 @@ services:
|
||||
# MYSQL_DATABASE: new-api
|
||||
# volumes:
|
||||
# - mysql_data:/var/lib/mysql
|
||||
# networks:
|
||||
# - new-api-network
|
||||
# ports:
|
||||
# - "3306:3306" # Uncomment if you need to access MySQL from outside Docker
|
||||
|
||||
volumes:
|
||||
pg_data:
|
||||
# mysql_data:
|
||||
|
||||
networks:
|
||||
new-api-network:
|
||||
driver: bridge
|
||||
|
||||
+150
-2
@@ -1,3 +1,151 @@
|
||||
密钥为环境变量SESSION_SECRET
|
||||
# 宝塔面板部署教程
|
||||
|
||||
本文档提供使用宝塔面板 Docker 功能部署 New API 的图文教程。
|
||||
|
||||
> 📖 官方文档:[宝塔面板部署](https://docs.newapi.pro/zh/docs/installation/deployment-methods/bt-docker-installation)
|
||||
|
||||
***
|
||||
|
||||
## 前置要求
|
||||
|
||||
| 项目 | 要求 |
|
||||
| ----- | ---------------------------------- |
|
||||
| 宝塔面板 | ≥ 9.2.0 版本 |
|
||||
| 推荐系统 | CentOS 7+、Ubuntu 18.04+、Debian 10+ |
|
||||
| 服务器配置 | 至少 1 核 2G 内存 |
|
||||
|
||||
***
|
||||
|
||||
## 步骤一:安装宝塔面板
|
||||
|
||||
1. 前往 [宝塔面板官网](https://www.bt.cn/new/download.html) 下载适合您系统的安装脚本
|
||||
2. 运行安装脚本安装宝塔面板
|
||||
3. 安装完成后,使用提供的地址、用户名和密码登录宝塔面板
|
||||
|
||||
***
|
||||
|
||||
## 步骤二:安装 Docker
|
||||
|
||||
1. 登录宝塔面板后,在左侧菜单栏找到并点击 **Docker**
|
||||
2. 首次进入会提示安装 Docker 服务,点击 **立即安装**
|
||||
3. 按照提示完成 Docker 服务的安装
|
||||
|
||||
***
|
||||
|
||||
## 步骤三:安装 New API
|
||||
|
||||
### 方法一:使用宝塔应用商店(推荐)
|
||||
|
||||
1. 在宝塔面板 Docker 功能中,点击 **应用商店**
|
||||
2. 搜索并找到 **New-API**
|
||||
3. 点击 **安装**
|
||||
4. 配置以下基本选项:
|
||||
- **容器名称**:可自定义,默认为 `new-api`
|
||||
- **端口映射**:默认为 `3000:3000`
|
||||
- **环境变量**:
|
||||
- `SESSION_SECRET`:会话密钥(**必填**,多机部署时必须一致)
|
||||
- `CRYPTO_SECRET`:加密密钥(使用 Redis 时必填)
|
||||
5. 点击 **确认** 开始安装
|
||||
6. 等待安装完成后,访问 `http://您的服务器IP:3000` 即可使用
|
||||
|
||||
### 方法二:使用 Docker Compose
|
||||
|
||||
1. 在宝塔面板中创建网站目录,如 `/www/wwwroot/new-api`
|
||||
2. 创建 `docker-compose.yml` 文件:
|
||||
|
||||
```yaml
|
||||
version: '3'
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
environment:
|
||||
- SESSION_SECRET=your_session_secret_here # 请修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
```
|
||||
|
||||
1. 在终端中进入目录并启动:
|
||||
|
||||
```bash
|
||||
cd /www/wwwroot/new-api
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 必要环境变量
|
||||
|
||||
| 变量名 | 说明 | 是否必填 |
|
||||
| ------------------- | ------------------ | ------ |
|
||||
| `SESSION_SECRET` | 会话密钥,多机部署必须一致 | **必填** |
|
||||
| `CRYPTO_SECRET` | 加密密钥,使用 Redis 时必填 | 条件必填 |
|
||||
| `SQL_DSN` | 数据库连接字符串(使用外部数据库时) | 可选 |
|
||||
| `REDIS_CONN_STRING` | Redis 连接字符串 | 可选 |
|
||||
|
||||
### 生成随机密钥
|
||||
|
||||
```bash
|
||||
# 生成 SESSION_SECRET
|
||||
openssl rand -hex 16
|
||||
|
||||
# 或使用 Linux 命令
|
||||
head -c 16 /dev/urandom | xxd -p
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q1:无法访问 3000 端口?
|
||||
|
||||
1. 检查服务器防火墙是否开放 3000 端口
|
||||
2. 在宝塔面板 **安全** 中放行 3000 端口
|
||||
3. 检查云服务器安全组是否开放端口
|
||||
|
||||
### Q2:登录后提示会话失效?
|
||||
|
||||
确保设置了 `SESSION_SECRET` 环境变量,且值不为空。
|
||||
|
||||
### Q3:数据如何持久化?
|
||||
|
||||
使用 Docker 卷映射数据目录:
|
||||
|
||||
```yaml
|
||||
volumes:
|
||||
- ./data:/data
|
||||
```
|
||||
|
||||
### Q4:如何更新版本?
|
||||
|
||||
```bash
|
||||
# 拉取最新镜像
|
||||
docker pull calciumion/new-api:latest
|
||||
|
||||
# 重启容器
|
||||
docker-compose down && docker-compose up -d
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 相关链接
|
||||
|
||||
- [官方文档](https://docs.newapi.pro/zh/docs/installation)
|
||||
- [环境变量配置](https://docs.newapi.pro/zh/docs/installation/config-maintenance/environment-variables)
|
||||
- [常见问题](https://docs.newapi.pro/zh/docs/support/faq)
|
||||
- [GitHub 仓库](https://github.com/QuantumNous/new-api)
|
||||
|
||||
***
|
||||
|
||||
## 截图示例
|
||||
|
||||

|
||||
|
||||
> ⚠️ 注意:密钥为环境变量 `SESSION_SECRET`,请务必设置!
|
||||
|
||||

|
||||
|
||||
@@ -56,10 +56,10 @@ type GeneralOpenAIRequest struct {
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
User json.RawMessage `json:"user,omitempty"`
|
||||
// ServiceTier specifies upstream service level and may affect billing.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_service_tier.
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
ServiceTier json.RawMessage `json:"service_tier,omitempty"`
|
||||
LogProbs *bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs *int `json:"top_logprobs,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
@@ -67,7 +67,7 @@ type GeneralOpenAIRequest struct {
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
// 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户
|
||||
// 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤,可通过 allow_safety_identifier 开启
|
||||
SafetyIdentifier string `json:"safety_identifier,omitempty"`
|
||||
SafetyIdentifier json.RawMessage `json:"safety_identifier,omitempty"`
|
||||
// Whether or not to store the output of this chat completion request for use in our model distillation or evals products.
|
||||
// 是否存储此次请求数据供 OpenAI 用于评估和优化产品
|
||||
// 注意:默认允许透传,可通过 disable_store 禁用;禁用后可能导致 Codex 无法正常使用
|
||||
@@ -100,10 +100,10 @@ type GeneralOpenAIRequest struct {
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"`
|
||||
// pplx Params
|
||||
SearchDomainFilter json.RawMessage `json:"search_domain_filter,omitempty"`
|
||||
SearchRecencyFilter string `json:"search_recency_filter,omitempty"`
|
||||
SearchRecencyFilter json.RawMessage `json:"search_recency_filter,omitempty"`
|
||||
ReturnImages *bool `json:"return_images,omitempty"`
|
||||
ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"`
|
||||
SearchMode string `json:"search_mode,omitempty"`
|
||||
SearchMode json.RawMessage `json:"search_mode,omitempty"`
|
||||
// Minimax
|
||||
ReasoningSplit json.RawMessage `json:"reasoning_split,omitempty"`
|
||||
}
|
||||
@@ -393,7 +393,7 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
|
||||
|
||||
type MessageImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
MimeType string
|
||||
}
|
||||
|
||||
@@ -836,7 +836,7 @@ type OpenAIResponsesRequest struct {
|
||||
PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
|
||||
// SafetyIdentifier carries client identity for policy abuse detection.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_safety_identifier.
|
||||
SafetyIdentifier string `json:"safety_identifier,omitempty"`
|
||||
SafetyIdentifier json.RawMessage `json:"safety_identifier,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
@@ -844,8 +844,8 @@ type OpenAIResponsesRequest struct {
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Truncation json.RawMessage `json:"truncation,omitempty"`
|
||||
User json.RawMessage `json:"user,omitempty"`
|
||||
MaxToolCalls *uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
// qwen
|
||||
|
||||
+11
-9
@@ -220,10 +220,12 @@ type CompletionsStreamResponse struct {
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||
UsageSemantic string `json:"usage_semantic,omitempty"`
|
||||
UsageSource string `json:"usage_source,omitempty"`
|
||||
|
||||
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
||||
@@ -251,7 +253,7 @@ type OpenAIVideoResponse struct {
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
CachedCreationTokens int `json:"-"`
|
||||
CachedCreationTokens int `json:"cached_creation_tokens,omitempty"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
@@ -267,7 +269,7 @@ type OpenAIResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Status json.RawMessage `json:"status"`
|
||||
Error any `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
@@ -275,14 +277,14 @@ type OpenAIResponsesResponse struct {
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
PreviousResponseID string `json:"previous_response_id"`
|
||||
PreviousResponseID json.RawMessage `json:"previous_response_id"`
|
||||
Reasoning *Reasoning `json:"reasoning"`
|
||||
Store bool `json:"store"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice"`
|
||||
Tools []map[string]any `json:"tools"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Truncation string `json:"truncation"`
|
||||
Truncation json.RawMessage `json:"truncation"`
|
||||
Usage *Usage `json:"usage"`
|
||||
User json.RawMessage `json:"user"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
|
||||
@@ -46,13 +46,14 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
github.com/waffo-com/waffo-go v1.3.1
|
||||
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/image v0.38.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
golang.org/x/text v0.35.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
@@ -120,7 +121,6 @@ require (
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/samber/go-singleflightx v0.3.2 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/DmitriyVTitov/size v1.5.0 h1:/PzqxYrOyOUX1BXj6J9OuVRVGe+66VL4D9FlUaW515g=
|
||||
@@ -10,34 +12,18 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0 h1:TDKR8ACRw7G+GFaQlhoy6biu+8q6ZtSddQCy9avMdMI=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0/go.mod h1:XlhOh5Ax/lesqN4aZCUgj9vVJed5VoXYHHFYGAlJEwU=
|
||||
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
|
||||
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
@@ -58,7 +44,6 @@ github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gE
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -132,12 +117,13 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU=
|
||||
github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -186,8 +172,6 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
|
||||
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
@@ -245,7 +229,6 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -262,8 +245,9 @@ github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoG
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/samber/go-singleflightx v0.3.2 h1:jXbUU0fvis8Fdv4HGONboX5WdEZcYLoBEcKiE+ITCyQ=
|
||||
github.com/samber/go-singleflightx v0.3.2/go.mod h1:X2BR+oheHIYc73PvxRMlcASg6KYYTQyUYpdVU7t/ux4=
|
||||
github.com/samber/hot v0.11.0 h1:JhV9hk8SmZIqB0To8OyCzPubvszkuoSXWx/7FCEGO+Q=
|
||||
@@ -320,6 +304,8 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/waffo-com/waffo-go v1.3.1 h1:NCYD3oQ59DTJj1bwS5T/659LI4h8PuAIW4Qj/w7fKPw=
|
||||
github.com/waffo-com/waffo-go v1.3.1/go.mod h1:IaXVYq6mmYtrLFFsLxPslNwuIZx0mIadWWjhe+eWb0g=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
@@ -330,6 +316,8 @@ github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFi
|
||||
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.21.0 h1:iTC9o7+wP6cPWpDWkivCvQFGAHDQ59SrSxsLPcnkArw=
|
||||
golang.org/x/arch v0.21.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
@@ -337,18 +325,16 @@ golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/image v0.38.0 h1:5l+q+Y9JDC7mBOMjo4/aPhMDcxEptsX+Tt3GgRQRPuE=
|
||||
golang.org/x/image v0.38.0/go.mod h1:/3f6vaXC+6CEanU4KJxbcUZyEePbyKbaLoDOe4ehFYY=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -367,19 +353,14 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
+26
-4
@@ -29,6 +29,15 @@ const maxLogCount = 1000000
|
||||
var logCount int
|
||||
var setupLogLock sync.Mutex
|
||||
var setupLogWorking bool
|
||||
var currentLogPath string
|
||||
var currentLogPathMu sync.RWMutex
|
||||
var currentLogFile *os.File
|
||||
|
||||
func GetCurrentLogPath() string {
|
||||
currentLogPathMu.RLock()
|
||||
defer currentLogPathMu.RUnlock()
|
||||
return currentLogPath
|
||||
}
|
||||
|
||||
func SetupLogger() {
|
||||
defer func() {
|
||||
@@ -48,8 +57,19 @@ func SetupLogger() {
|
||||
if err != nil {
|
||||
log.Fatal("failed to open log file")
|
||||
}
|
||||
currentLogPathMu.Lock()
|
||||
oldFile := currentLogFile
|
||||
currentLogPath = logPath
|
||||
currentLogFile = fd
|
||||
currentLogPathMu.Unlock()
|
||||
|
||||
common.LogWriterMu.Lock()
|
||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||
if oldFile != nil {
|
||||
_ = oldFile.Close()
|
||||
}
|
||||
common.LogWriterMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,16 +95,18 @@ func LogDebug(ctx context.Context, msg string, args ...any) {
|
||||
}
|
||||
|
||||
func logHelper(ctx context.Context, level string, msg string) {
|
||||
writer := gin.DefaultErrorWriter
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(common.RequestIdKey)
|
||||
if id == nil {
|
||||
id = "SYSTEM"
|
||||
}
|
||||
now := time.Now()
|
||||
common.LogWriterMu.RLock()
|
||||
writer := gin.DefaultErrorWriter
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||
common.LogWriterMu.RUnlock()
|
||||
logCount++ // we don't need accurate count, so no lock here
|
||||
if logCount > maxLogCount && !setupLogWorking {
|
||||
logCount = 0
|
||||
|
||||
@@ -101,8 +101,13 @@ func Distribute() func(c *gin.Context) {
|
||||
|
||||
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
|
||||
preferred, err := model.CacheGetChannel(preferredChannelID)
|
||||
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
|
||||
if usingGroup == "auto" {
|
||||
if err == nil && preferred != nil {
|
||||
if preferred.Status != common.ChannelStatusEnabled {
|
||||
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
|
||||
return
|
||||
}
|
||||
} else if usingGroup == "auto" {
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
autoGroups := service.GetUserAutoGroup(userGroup)
|
||||
for _, g := range autoGroups {
|
||||
|
||||
@@ -196,7 +196,10 @@ func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key
|
||||
}
|
||||
|
||||
// SearchRateLimit returns a per-user rate limiter for search endpoints.
|
||||
// 10 requests per 60 seconds per user (by user ID, not IP).
|
||||
// Configurable via SEARCH_RATE_LIMIT_ENABLE / SEARCH_RATE_LIMIT / SEARCH_RATE_LIMIT_DURATION.
|
||||
func SearchRateLimit() func(c *gin.Context) {
|
||||
if !common.SearchRateLimitEnable {
|
||||
return defNext
|
||||
}
|
||||
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
|
||||
}
|
||||
|
||||
+2
-1
@@ -58,7 +58,8 @@ func formatUserLogs(logs []*Log, startIdx int) {
|
||||
if otherMap != nil {
|
||||
// Remove admin-only debug fields.
|
||||
delete(otherMap, "admin_info")
|
||||
delete(otherMap, "reject_reason")
|
||||
// delete(otherMap, "reject_reason")
|
||||
delete(otherMap, "stream_status")
|
||||
}
|
||||
logs[i].Other = common.MapToJsonStr(otherMap)
|
||||
logs[i].Id = startIdx + i + 1
|
||||
|
||||
@@ -89,6 +89,22 @@ func InitOptionMap() {
|
||||
common.OptionMap["CreemProducts"] = setting.CreemProducts
|
||||
common.OptionMap["CreemTestMode"] = strconv.FormatBool(setting.CreemTestMode)
|
||||
common.OptionMap["CreemWebhookSecret"] = setting.CreemWebhookSecret
|
||||
common.OptionMap["WaffoEnabled"] = strconv.FormatBool(setting.WaffoEnabled)
|
||||
common.OptionMap["WaffoApiKey"] = setting.WaffoApiKey
|
||||
common.OptionMap["WaffoPrivateKey"] = setting.WaffoPrivateKey
|
||||
common.OptionMap["WaffoPublicCert"] = setting.WaffoPublicCert
|
||||
common.OptionMap["WaffoSandboxPublicCert"] = setting.WaffoSandboxPublicCert
|
||||
common.OptionMap["WaffoSandboxApiKey"] = setting.WaffoSandboxApiKey
|
||||
common.OptionMap["WaffoSandboxPrivateKey"] = setting.WaffoSandboxPrivateKey
|
||||
common.OptionMap["WaffoSandbox"] = strconv.FormatBool(setting.WaffoSandbox)
|
||||
common.OptionMap["WaffoMerchantId"] = setting.WaffoMerchantId
|
||||
common.OptionMap["WaffoNotifyUrl"] = setting.WaffoNotifyUrl
|
||||
common.OptionMap["WaffoReturnUrl"] = setting.WaffoReturnUrl
|
||||
common.OptionMap["WaffoSubscriptionReturnUrl"] = setting.WaffoSubscriptionReturnUrl
|
||||
common.OptionMap["WaffoCurrency"] = setting.WaffoCurrency
|
||||
common.OptionMap["WaffoUnitPrice"] = strconv.FormatFloat(setting.WaffoUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["WaffoMinTopUp"] = strconv.Itoa(setting.WaffoMinTopUp)
|
||||
common.OptionMap["WaffoPayMethods"] = setting.WaffoPayMethods2JsonString()
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
@@ -358,6 +374,36 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.CreemTestMode = value == "true"
|
||||
case "CreemWebhookSecret":
|
||||
setting.CreemWebhookSecret = value
|
||||
case "WaffoEnabled":
|
||||
setting.WaffoEnabled = value == "true"
|
||||
case "WaffoApiKey":
|
||||
setting.WaffoApiKey = value
|
||||
case "WaffoPrivateKey":
|
||||
setting.WaffoPrivateKey = value
|
||||
case "WaffoPublicCert":
|
||||
setting.WaffoPublicCert = value
|
||||
case "WaffoSandboxPublicCert":
|
||||
setting.WaffoSandboxPublicCert = value
|
||||
case "WaffoSandboxApiKey":
|
||||
setting.WaffoSandboxApiKey = value
|
||||
case "WaffoSandboxPrivateKey":
|
||||
setting.WaffoSandboxPrivateKey = value
|
||||
case "WaffoSandbox":
|
||||
setting.WaffoSandbox = value == "true"
|
||||
case "WaffoMerchantId":
|
||||
setting.WaffoMerchantId = value
|
||||
case "WaffoNotifyUrl":
|
||||
setting.WaffoNotifyUrl = value
|
||||
case "WaffoReturnUrl":
|
||||
setting.WaffoReturnUrl = value
|
||||
case "WaffoSubscriptionReturnUrl":
|
||||
setting.WaffoSubscriptionReturnUrl = value
|
||||
case "WaffoCurrency":
|
||||
setting.WaffoCurrency = value
|
||||
case "WaffoUnitPrice":
|
||||
setting.WaffoUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "WaffoMinTopUp":
|
||||
setting.WaffoMinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
@@ -458,6 +504,10 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = operation_setting.UpdatePayMethodsByJsonString(value)
|
||||
case "WaffoPayMethods":
|
||||
// WaffoPayMethods is read directly from OptionMap via setting.GetWaffoPayMethods().
|
||||
// The value is already stored in OptionMap at the top of this function (line: common.OptionMap[key] = value).
|
||||
// No additional in-memory variable to update.
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
+23
-1
@@ -25,6 +25,11 @@ type Pricing struct {
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
CacheRatio *float64 `json:"cache_ratio,omitempty"`
|
||||
CreateCacheRatio *float64 `json:"create_cache_ratio,omitempty"`
|
||||
ImageRatio *float64 `json:"image_ratio,omitempty"`
|
||||
AudioRatio *float64 `json:"audio_ratio,omitempty"`
|
||||
AudioCompletionRatio *float64 `json:"audio_completion_ratio,omitempty"`
|
||||
EnableGroup []string `json:"enable_groups"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
PricingVersion string `json:"pricing_version,omitempty"`
|
||||
@@ -297,12 +302,29 @@ func updatePricing() {
|
||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
if cacheRatio, ok := ratio_setting.GetCacheRatio(model); ok {
|
||||
pricing.CacheRatio = &cacheRatio
|
||||
}
|
||||
if createCacheRatio, ok := ratio_setting.GetCreateCacheRatio(model); ok {
|
||||
pricing.CreateCacheRatio = &createCacheRatio
|
||||
}
|
||||
if imageRatio, ok := ratio_setting.GetImageRatio(model); ok {
|
||||
pricing.ImageRatio = &imageRatio
|
||||
}
|
||||
if ratio_setting.ContainsAudioRatio(model) {
|
||||
audioRatio := ratio_setting.GetAudioRatio(model)
|
||||
pricing.AudioRatio = &audioRatio
|
||||
}
|
||||
if ratio_setting.ContainsAudioCompletionRatio(model) {
|
||||
audioCompletionRatio := ratio_setting.GetAudioCompletionRatio(model)
|
||||
pricing.AudioCompletionRatio = &audioCompletionRatio
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
}
|
||||
|
||||
// 防止大更新后数据不通用
|
||||
if len(pricingMap) > 0 {
|
||||
pricingMap[0].PricingVersion = "82c4a357505fff6fee8462c3f7ec8a645bb95532669cb73b2cabee6a416ec24f"
|
||||
pricingMap[0].PricingVersion = "5a90f2b86c08bd983a9a2e6d66c255f4eaef9c4bc934386d2b6ae84ef0ff1f1f"
|
||||
}
|
||||
|
||||
// 刷新缓存映射,供高并发快速查询
|
||||
|
||||
+22
-1
@@ -35,6 +35,27 @@ func (token *Token) Clean() {
|
||||
token.Key = ""
|
||||
}
|
||||
|
||||
func MaskTokenKey(key string) string {
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
if len(key) <= 4 {
|
||||
return strings.Repeat("*", len(key))
|
||||
}
|
||||
if len(key) <= 8 {
|
||||
return key[:2] + "****" + key[len(key)-2:]
|
||||
}
|
||||
return key[:4] + "**********" + key[len(key)-4:]
|
||||
}
|
||||
|
||||
func (token *Token) GetFullKey() string {
|
||||
return token.Key
|
||||
}
|
||||
|
||||
func (token *Token) GetMaskedKey() string {
|
||||
return MaskTokenKey(token.Key)
|
||||
}
|
||||
|
||||
func (token *Token) GetIpLimits() []string {
|
||||
// delete empty spaces
|
||||
//split with \n
|
||||
@@ -201,7 +222,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
}
|
||||
keyPrefix := key[:3]
|
||||
keySuffix := key[len(key)-3:]
|
||||
return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
|
||||
return token, fmt.Errorf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
+68
-9
@@ -12,15 +12,15 @@ import (
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (topUp *TopUp) Insert() error {
|
||||
@@ -376,3 +376,62 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RechargeWaffo(tradeNo string) (err error) {
|
||||
if tradeNo == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
|
||||
var quotaToAdd int
|
||||
topUp := &TopUp{}
|
||||
|
||||
refCol := "`trade_no`"
|
||||
if common.UsingPostgreSQL {
|
||||
refCol = `"trade_no"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error
|
||||
if err != nil {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.Status == common.TopUpStatusSuccess {
|
||||
return nil // 幂等:已成功直接返回
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
return errors.New("充值订单状态错误")
|
||||
}
|
||||
|
||||
dAmount := decimal.NewFromInt(topUp.Amount)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd = int(dAmount.Mul(dQuotaPerUnit).IntPart())
|
||||
if quotaToAdd <= 0 {
|
||||
return errors.New("无效的充值额度")
|
||||
}
|
||||
|
||||
topUp.CompleteTime = common.GetTimestamp()
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
if err := tx.Save(topUp).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&User{}).Where("id = ?", topUp.UserId).Update("quota", gorm.Expr("quota + ?", quotaToAdd)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
common.SysError("waffo topup failed: " + err.Error())
|
||||
return errors.New("充值失败,请稍后重试")
|
||||
}
|
||||
|
||||
if quotaToAdd > 0 {
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("Waffo充值成功,充值额度: %v,支付金额: %.2f", logger.FormatQuota(quotaToAdd), topUp.Money))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+9
-4
@@ -208,10 +208,7 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke
|
||||
}
|
||||
|
||||
// Set authorization header
|
||||
tokenType := token.TokenType
|
||||
if tokenType == "" {
|
||||
tokenType = "Bearer"
|
||||
}
|
||||
tokenType := normalizeAuthorizationTokenType(token.TokenType)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
@@ -320,6 +317,14 @@ func (p *GenericOAuthProvider) GetProviderId() int {
|
||||
return p.config.Id
|
||||
}
|
||||
|
||||
func normalizeAuthorizationTokenType(tokenType string) string {
|
||||
tokenType = strings.TrimSpace(tokenType)
|
||||
if tokenType == "" || strings.EqualFold(tokenType, "Bearer") {
|
||||
return "Bearer"
|
||||
}
|
||||
return tokenType
|
||||
}
|
||||
|
||||
// IsGenericProvider returns true for generic providers
|
||||
func (p *GenericOAuthProvider) IsGenericProvider() bool {
|
||||
return true
|
||||
|
||||
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
+11
-10
@@ -1,6 +1,7 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -12,16 +13,16 @@ type BaiduMessage struct {
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
UserId json.RawMessage `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
|
||||
@@ -116,12 +116,12 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
if err := common.Unmarshal([]byte(data), &baiduResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
@@ -129,11 +129,10 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, response); err != nil {
|
||||
common.SysLog("error sending stream response: " + err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -41,11 +42,32 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
if info.IsClaudeBetaQuery {
|
||||
baseURL = baseURL + "?beta=true"
|
||||
requestURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
if !shouldAppendClaudeBetaQuery(info) {
|
||||
return requestURL, nil
|
||||
}
|
||||
return baseURL, nil
|
||||
|
||||
parsedURL, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := parsedURL.Query()
|
||||
query.Set("beta", "true")
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
func shouldAppendClaudeBetaQuery(info *relaycommon.RelayInfo) bool {
|
||||
if info == nil {
|
||||
return false
|
||||
}
|
||||
if info.IsClaudeBetaQuery {
|
||||
return true
|
||||
}
|
||||
if info.ChannelOtherSettings.ClaudeBetaQuery {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func CommonClaudeHeadersOperation(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) {
|
||||
|
||||
@@ -25,6 +25,7 @@ var ModelList = []string{
|
||||
"claude-opus-4-6-high",
|
||||
"claude-opus-4-6-medium",
|
||||
"claude-opus-4-6-low",
|
||||
"claude-sonnet-4-6",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -555,6 +555,35 @@ type ClaudeResponseInfo struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
func cacheCreationTokensForOpenAIUsage(usage *dto.Usage) int {
|
||||
if usage == nil {
|
||||
return 0
|
||||
}
|
||||
splitCacheCreationTokens := usage.ClaudeCacheCreation5mTokens + usage.ClaudeCacheCreation1hTokens
|
||||
if splitCacheCreationTokens == 0 {
|
||||
return usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
if usage.PromptTokensDetails.CachedCreationTokens > splitCacheCreationTokens {
|
||||
return usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
return splitCacheCreationTokens
|
||||
}
|
||||
|
||||
func buildOpenAIStyleUsageFromClaudeUsage(usage *dto.Usage) dto.Usage {
|
||||
if usage == nil {
|
||||
return dto.Usage{}
|
||||
}
|
||||
clone := *usage
|
||||
cacheCreationTokens := cacheCreationTokensForOpenAIUsage(usage)
|
||||
totalInputTokens := usage.PromptTokens + usage.PromptTokensDetails.CachedTokens + cacheCreationTokens
|
||||
clone.PromptTokens = totalInputTokens
|
||||
clone.InputTokens = totalInputTokens
|
||||
clone.TotalTokens = totalInputTokens + usage.CompletionTokens
|
||||
clone.UsageSemantic = "openai"
|
||||
clone.UsageSource = "anthropic"
|
||||
return clone
|
||||
}
|
||||
|
||||
func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
|
||||
usage := &dto.ClaudeUsage{}
|
||||
if claudeResponse != nil && claudeResponse.Usage != nil {
|
||||
@@ -643,6 +672,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
|
||||
// message_start, 获取usage
|
||||
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
|
||||
@@ -661,6 +691,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage != nil {
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
@@ -754,12 +785,16 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
if claudeInfo.Usage != nil {
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
}
|
||||
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
//
|
||||
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, openAIUsage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
@@ -778,12 +813,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
var err *types.NewAPIError
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data)
|
||||
if err != nil {
|
||||
return false
|
||||
sr.Stop(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -810,6 +844,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.UsageSemantic = "anthropic"
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
|
||||
@@ -819,7 +854,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
openaiResponse.Usage = buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
|
||||
@@ -173,3 +173,85 @@ func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) {
|
||||
t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIStyleUsageFromClaudeUsage(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 20,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
UsageSemantic: "anthropic",
|
||||
}
|
||||
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
|
||||
|
||||
if openAIUsage.PromptTokens != 180 {
|
||||
t.Fatalf("PromptTokens = %d, want 180", openAIUsage.PromptTokens)
|
||||
}
|
||||
if openAIUsage.InputTokens != 180 {
|
||||
t.Fatalf("InputTokens = %d, want 180", openAIUsage.InputTokens)
|
||||
}
|
||||
if openAIUsage.TotalTokens != 200 {
|
||||
t.Fatalf("TotalTokens = %d, want 200", openAIUsage.TotalTokens)
|
||||
}
|
||||
if openAIUsage.UsageSemantic != "openai" {
|
||||
t.Fatalf("UsageSemantic = %s, want openai", openAIUsage.UsageSemantic)
|
||||
}
|
||||
if openAIUsage.UsageSource != "anthropic" {
|
||||
t.Fatalf("UsageSource = %s, want anthropic", openAIUsage.UsageSource)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIStyleUsageFromClaudeUsagePreservesCacheCreationRemainder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cachedCreationTokens int
|
||||
cacheCreationTokens5m int
|
||||
cacheCreationTokens1h int
|
||||
expectedTotalInputToken int
|
||||
}{
|
||||
{
|
||||
name: "prefers aggregate when it includes remainder",
|
||||
cachedCreationTokens: 50,
|
||||
cacheCreationTokens5m: 10,
|
||||
cacheCreationTokens1h: 20,
|
||||
expectedTotalInputToken: 180,
|
||||
},
|
||||
{
|
||||
name: "falls back to split tokens when aggregate missing",
|
||||
cachedCreationTokens: 0,
|
||||
cacheCreationTokens5m: 10,
|
||||
cacheCreationTokens1h: 20,
|
||||
expectedTotalInputToken: 160,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 20,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: tt.cachedCreationTokens,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: tt.cacheCreationTokens5m,
|
||||
ClaudeCacheCreation1hTokens: tt.cacheCreationTokens1h,
|
||||
UsageSemantic: "anthropic",
|
||||
}
|
||||
|
||||
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
|
||||
|
||||
if openAIUsage.PromptTokens != tt.expectedTotalInputToken {
|
||||
t.Fatalf("PromptTokens = %d, want %d", openAIUsage.PromptTokens, tt.expectedTotalInputToken)
|
||||
}
|
||||
if openAIUsage.InputTokens != tt.expectedTotalInputToken {
|
||||
t.Fatalf("InputTokens = %d, want %d", openAIUsage.InputTokens, tt.expectedTotalInputToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
var baseModelList = []string{
|
||||
"gpt-5", "gpt-5-codex", "gpt-5-codex-mini",
|
||||
"gpt-5.1", "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini",
|
||||
"gpt-5.2", "gpt-5.2-codex", "gpt-5.3-codex",
|
||||
"gpt-5.2", "gpt-5.2-codex", "gpt-5.3-codex", "gpt-5.3-codex-spark",
|
||||
"gpt-5.4",
|
||||
}
|
||||
|
||||
var ModelList = withCompactModelSuffix(baseModelList)
|
||||
|
||||
@@ -17,7 +17,7 @@ type CozeEnterMessage struct {
|
||||
|
||||
type CozeChatRequest struct {
|
||||
BotId string `json:"bot_id"`
|
||||
UserId string `json:"user_id"`
|
||||
UserId json.RawMessage `json:"user_id"`
|
||||
AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
CustomVariables json.RawMessage `json:"custom_variables,omitempty"`
|
||||
|
||||
@@ -34,8 +34,8 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
|
||||
}
|
||||
}
|
||||
user := request.User
|
||||
if user == "" {
|
||||
user = helper.GetResponseID(c)
|
||||
if len(user) == 0 {
|
||||
user = json.RawMessage(helper.GetResponseID(c))
|
||||
}
|
||||
cozeRequest := &CozeChatRequest{
|
||||
BotId: c.GetString("bot_id"),
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package dify
|
||||
|
||||
import "github.com/QuantumNous/new-api/dto"
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
)
|
||||
|
||||
type DifyChatRequest struct {
|
||||
Inputs map[string]interface{} `json:"inputs"`
|
||||
|
||||
@@ -131,10 +131,16 @@ func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto
|
||||
}
|
||||
|
||||
user := request.User
|
||||
if user == "" {
|
||||
user = helper.GetResponseID(c)
|
||||
if len(user) == 0 {
|
||||
user = json.RawMessage(helper.GetResponseID(c))
|
||||
}
|
||||
difyReq.User = user
|
||||
var stringUser string
|
||||
err := json.Unmarshal(user, &stringUser)
|
||||
if err != nil {
|
||||
common.SysLog("failed to unmarshal user: " + err.Error())
|
||||
stringUser = helper.GetResponseID(c)
|
||||
}
|
||||
difyReq.User = stringUser
|
||||
|
||||
files := make([]DifyFile, 0)
|
||||
var content strings.Builder
|
||||
@@ -217,33 +223,32 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
usage := &dto.Usage{}
|
||||
var nodeToken int
|
||||
helper.SetEventStreamHeaders(c)
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal([]byte(data), &difyResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||
if difyResponse.Event == "message_end" {
|
||||
usage = &difyResponse.MetaData.Usage
|
||||
return false
|
||||
sr.Done()
|
||||
return
|
||||
} else if difyResponse.Event == "error" {
|
||||
return false
|
||||
} else {
|
||||
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
|
||||
nodeToken += 1
|
||||
}
|
||||
sr.Stop(fmt.Errorf("dify error event"))
|
||||
return
|
||||
}
|
||||
openaiResponse := *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
|
||||
nodeToken += 1
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, openaiResponse); err != nil {
|
||||
common.SysLog(err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
|
||||
@@ -2,29 +2,34 @@ package gemini
|
||||
|
||||
var ModelList = []string{
|
||||
// stable version
|
||||
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash", "gemini-2.5-pro", "gemini-2.0-flash",
|
||||
"gemini-2.0-flash-001", "gemini-2.0-flash-lite-001", "gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash-lite",
|
||||
// latest version
|
||||
"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
|
||||
"gemini-flash-latest", "gemini-flash-lite-latest", "gemini-pro-latest",
|
||||
"gemini-2.5-flash-native-audio-latest",
|
||||
// preview version
|
||||
"gemini-2.0-flash-lite-preview",
|
||||
"gemini-3-pro-preview",
|
||||
// gemini exp
|
||||
"gemini-exp-1206",
|
||||
// flash exp
|
||||
"gemini-2.0-flash-exp",
|
||||
// pro exp
|
||||
"gemini-2.0-pro-exp",
|
||||
// thinking exp
|
||||
"gemini-2.0-flash-thinking-exp",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
// imagen models
|
||||
"imagen-3.0-generate-002",
|
||||
"gemini-2.5-flash-preview-tts", "gemini-2.5-pro-preview-tts",
|
||||
"gemini-2.5-flash-image", "gemini-2.5-flash-lite-preview-09-2025",
|
||||
"gemini-3-pro-preview", "gemini-3-flash-preview", "gemini-3.1-pro-preview",
|
||||
"gemini-3.1-pro-preview-customtools", "gemini-3.1-flash-lite-preview",
|
||||
"gemini-3-pro-image-preview", "nano-banana-pro-preview",
|
||||
"gemini-3.1-flash-image-preview", "gemini-robotics-er-1.5-preview",
|
||||
"gemini-2.5-computer-use-preview-10-2025", "deep-research-pro-preview-12-2025",
|
||||
"gemini-2.5-flash-native-audio-preview-09-2025", "gemini-2.5-flash-native-audio-preview-12-2025",
|
||||
// gemma models
|
||||
"gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it",
|
||||
"gemma-3-27b-it", "gemma-3n-e4b-it", "gemma-3n-e2b-it",
|
||||
// embedding models
|
||||
"gemini-embedding-exp-03-07",
|
||||
"text-embedding-004",
|
||||
"embedding-001",
|
||||
"gemini-embedding-001", "gemini-embedding-2-preview",
|
||||
// imagen models
|
||||
"imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001",
|
||||
"imagen-4.0-fast-generate-001",
|
||||
// veo models
|
||||
"veo-2.0-generate-001", "veo-3.0-generate-001", "veo-3.0-fast-generate-001",
|
||||
"veo-3.1-generate-preview", "veo-3.1-fast-generate-preview",
|
||||
// other models
|
||||
"aqa",
|
||||
}
|
||||
|
||||
var SafetySettingList = []string{
|
||||
|
||||
@@ -1297,12 +1297,11 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
var imageCount int
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
if err := common.UnmarshalJsonStr(data, &geminiResponse); err != nil {
|
||||
sr.Stop(fmt.Errorf("unmarshal: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
@@ -1327,7 +1326,9 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
*usage = mappedUsage
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
if !callback(data, &geminiResponse) {
|
||||
sr.Stop(fmt.Errorf("gemini callback stopped"))
|
||||
}
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
|
||||
@@ -15,8 +15,10 @@ var ModelList = []string{
|
||||
"speech-01-hd",
|
||||
"speech-01-turbo",
|
||||
"MiniMax-M2.1",
|
||||
"MiniMax-M2.1-lightning",
|
||||
"MiniMax-M2.1-highspeed",
|
||||
"MiniMax-M2",
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
}
|
||||
|
||||
var ChannelName = "minimax"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package moonshot
|
||||
|
||||
var ModelList = []string{
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-0905-preview",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
}
|
||||
|
||||
var ChannelName = "moonshot"
|
||||
|
||||
@@ -225,8 +225,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
||||
}
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
header.Set("HTTP-Referer", "https://www.newapi.ai")
|
||||
header.Set("X-Title", "New API")
|
||||
if header.Get("HTTP-Referer") == "" {
|
||||
header.Set("HTTP-Referer", "https://www.newapi.ai")
|
||||
}
|
||||
if header.Get("X-OpenRouter-Title") == "" {
|
||||
header.Set("X-OpenRouter-Title", "New API")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -35,21 +35,21 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
if info.IsStream {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if service.SundaySearch(data, "usage") {
|
||||
var simpleResponse dto.SimpleResponse
|
||||
err := common.Unmarshal([]byte(data), &simpleResponse)
|
||||
if err != nil {
|
||||
if err := common.Unmarshal([]byte(data), &simpleResponse); err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
if simpleResponse.Usage.TotalTokens != 0 {
|
||||
sr.Error(err)
|
||||
} else if simpleResponse.Usage.TotalTokens != 0 {
|
||||
usage.PromptTokens = simpleResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = simpleResponse.OutputTokens
|
||||
usage.TotalTokens = simpleResponse.TotalTokens
|
||||
}
|
||||
}
|
||||
_ = helper.StringData(c, data)
|
||||
return true
|
||||
if err := helper.StringData(c, data); err != nil {
|
||||
sr.Error(err)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
|
||||
@@ -296,15 +296,17 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
return true
|
||||
}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if streamErr != nil {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
var streamResp dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
switch streamResp.Type {
|
||||
@@ -320,14 +322,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
|
||||
//case "response.reasoning_text.delta":
|
||||
//if !sendReasoningDelta(streamResp.Delta) {
|
||||
// return false
|
||||
// sr.Stop(streamErr)
|
||||
// return
|
||||
//}
|
||||
|
||||
//case "response.reasoning_text.done":
|
||||
|
||||
case "response.reasoning_summary_text.delta":
|
||||
if !sendReasoningSummaryDelta(streamResp.Delta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.reasoning_summary_text.done":
|
||||
@@ -349,12 +353,14 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
// delta := stringDeltaFromPrefix(prev, next)
|
||||
// reasoningSummaryTextByKey[key] = next
|
||||
// if !sendReasoningSummaryDelta(delta) {
|
||||
// return false
|
||||
// sr.Stop(streamErr)
|
||||
// return
|
||||
// }
|
||||
|
||||
case "response.output_text.delta":
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
if streamResp.Delta != "" {
|
||||
@@ -376,7 +382,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
},
|
||||
}
|
||||
if !sendChatChunk(chunk) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,7 +421,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
if !sendToolCallDelta(callID, name, argsDelta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.delta":
|
||||
@@ -428,7 +436,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
toolCallArgsByID[callID] += streamResp.Delta
|
||||
if !sendToolCallDelta(callID, "", streamResp.Delta) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.done":
|
||||
@@ -467,7 +476,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
if !sentStop {
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
|
||||
@@ -479,7 +489,8 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
||||
if !sendChatChunk(stop) {
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
sentStop = true
|
||||
}
|
||||
@@ -488,16 +499,16 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
if streamResp.Response != nil {
|
||||
if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
|
||||
streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
sr.Stop(streamErr)
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if streamErr != nil {
|
||||
|
||||
@@ -3,14 +3,19 @@ package openai
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct-0914",
|
||||
"gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-vision-preview",
|
||||
"chatgpt-4o-latest",
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
|
||||
"gpt-4o-transcribe", "gpt-4o-transcribe-diarize",
|
||||
"gpt-4o-search-preview", "gpt-4o-search-preview-2025-03-11",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"gpt-4o-mini-transcribe", "gpt-4o-mini-transcribe-2025-03-20", "gpt-4o-mini-transcribe-2025-12-15",
|
||||
"gpt-4o-mini-tts", "gpt-4o-mini-tts-2025-03-20", "gpt-4o-mini-tts-2025-12-15",
|
||||
"gpt-4o-mini-search-preview", "gpt-4o-mini-search-preview-2025-03-11",
|
||||
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
|
||||
"gpt-4.1", "gpt-4.1-2025-04-14",
|
||||
"gpt-4.1-mini", "gpt-4.1-mini-2025-04-14",
|
||||
@@ -31,17 +36,41 @@ var ModelList = []string{
|
||||
"gpt-5", "gpt-5-2025-08-07", "gpt-5-chat-latest",
|
||||
"gpt-5-mini", "gpt-5-mini-2025-08-07",
|
||||
"gpt-5-nano", "gpt-5-nano-2025-08-07",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
|
||||
"gpt-5-codex",
|
||||
"gpt-5-pro", "gpt-5-pro-2025-10-06",
|
||||
"gpt-5-search-api", "gpt-5-search-api-2025-10-14",
|
||||
"gpt-5.1", "gpt-5.1-2025-11-13", "gpt-5.1-chat-latest",
|
||||
"gpt-5.1-codex", "gpt-5.1-codex-mini", "gpt-5.1-codex-max",
|
||||
"gpt-5.2", "gpt-5.2-2025-12-11", "gpt-5.2-chat-latest",
|
||||
"gpt-5.2-pro", "gpt-5.2-pro-2025-12-11",
|
||||
"gpt-5.2-codex",
|
||||
"gpt-5.3-chat-latest",
|
||||
"gpt-5.3-codex",
|
||||
"gpt-5.4", "gpt-5.4-2026-03-05",
|
||||
"gpt-5.4-pro", "gpt-5.4-pro-2026-03-05",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2025-06-03",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17", "gpt-4o-realtime-preview-2025-06-03",
|
||||
"gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
|
||||
"gpt-4o-mini-audio-preview", "gpt-4o-mini-audio-preview-2024-12-17",
|
||||
"gpt-audio", "gpt-audio-2025-08-28",
|
||||
"gpt-audio-mini", "gpt-audio-mini-2025-10-06", "gpt-audio-mini-2025-12-15",
|
||||
"gpt-audio-1.5",
|
||||
"gpt-realtime", "gpt-realtime-2025-08-28",
|
||||
"gpt-realtime-mini", "gpt-realtime-mini-2025-10-06", "gpt-realtime-mini-2025-12-15",
|
||||
"gpt-realtime-1.5",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"omni-moderation-latest", "omni-moderation-2024-09-26",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-3", "gpt-image-1",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"gpt-image-1", "gpt-image-1-mini", "gpt-image-1.5",
|
||||
"chatgpt-image-latest",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
"computer-use-preview", "computer-use-preview-2025-03-11",
|
||||
"sora-2", "sora-2-pro",
|
||||
}
|
||||
|
||||
var ChannelName = "openai"
|
||||
|
||||
@@ -126,11 +126,11 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
// 检查是否为音频模型
|
||||
isAudioModel := strings.Contains(strings.ToLower(model), "audio")
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
if lastStreamData != "" {
|
||||
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
if err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent); err != nil {
|
||||
common.SysLog("error handling stream format: " + err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
}
|
||||
if len(data) > 0 {
|
||||
@@ -142,7 +142,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// 对音频模型,从倒数第二个stream data中提取usage信息
|
||||
@@ -627,6 +626,12 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -689,3 +694,25 @@ func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
||||
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Timings struct {
|
||||
CachedTokens *int `json:"cache_n"`
|
||||
} `json:"timings"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Timings.CachedTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *payload.Timings.CachedTokens, true
|
||||
}
|
||||
|
||||
@@ -79,55 +79,55 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
var usage = &dto.Usage{}
|
||||
var responseTextBuilder strings.Builder
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
|
||||
// 检查当前数据是否包含 completed 状态和 usage 信息
|
||||
var streamResponse dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
case dto.ResponsesOutputTypeItemDone:
|
||||
// 函数调用处理
|
||||
if streamResponse.Item != nil {
|
||||
switch streamResponse.Item.Type {
|
||||
case dto.BuildInCallWebSearchCall:
|
||||
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
|
||||
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
|
||||
webSearchTool.CallCount++
|
||||
}
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
case dto.ResponsesOutputTypeItemDone:
|
||||
// 函数调用处理
|
||||
if streamResponse.Item != nil {
|
||||
switch streamResponse.Item.Type {
|
||||
case dto.BuildInCallWebSearchCall:
|
||||
if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
|
||||
if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
|
||||
webSearchTool.CallCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -80,15 +82,28 @@ type responsePayload struct {
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
TaskStatusMsg string `json:"task_status_msg"`
|
||||
TaskResult struct {
|
||||
TaskInfo struct {
|
||||
ExternalTaskId string `json:"external_task_id"`
|
||||
} `json:"task_info"`
|
||||
WatermarkInfo struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
} `json:"watermark_info"`
|
||||
TaskResult struct {
|
||||
Videos []struct {
|
||||
Id string `json:"id"`
|
||||
Url string `json:"url"`
|
||||
Duration string `json:"duration"`
|
||||
Id string `json:"id"`
|
||||
Url string `json:"url"`
|
||||
WatermarkUrl string `json:"watermark_url"`
|
||||
Duration string `json:"duration"`
|
||||
} `json:"videos"`
|
||||
Images []struct {
|
||||
Index int `json:"index"`
|
||||
Url string `json:"url"`
|
||||
WatermarkUrl string `json:"watermark_url"`
|
||||
} `json:"images"`
|
||||
} `json:"task_result"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
FinalUnitDeduction string `json:"final_unit_deduction"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -338,15 +353,22 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
taskInfo.Status = model.TaskStatusInProgress
|
||||
case "succeed":
|
||||
taskInfo.Status = model.TaskStatusSuccess
|
||||
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
|
||||
video := videos[0]
|
||||
taskInfo.Url = video.Url
|
||||
}
|
||||
if tokens, err := strconv.ParseFloat(resPayload.Data.FinalUnitDeduction, 64); err == nil {
|
||||
rounded := int(math.Ceil(tokens))
|
||||
if rounded > 0 {
|
||||
taskInfo.CompletionTokens = rounded
|
||||
taskInfo.TotalTokens = rounded
|
||||
}
|
||||
}
|
||||
case "failed":
|
||||
taskInfo.Status = model.TaskStatusFailure
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown task status: %s", status)
|
||||
}
|
||||
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
|
||||
video := videos[0]
|
||||
taskInfo.Url = video.Url
|
||||
}
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
@@ -383,5 +405,12 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
Code: fmt.Sprintf("%d", klingResp.Code),
|
||||
}
|
||||
}
|
||||
|
||||
// https://app.klingai.com/cn/dev/document-api/apiReference/model/textToVideo
|
||||
if data := klingResp.Data; data.TaskStatus == "failed" {
|
||||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||||
Message: data.TaskStatusMsg,
|
||||
}
|
||||
}
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ func UnmarshalMetadata(metadata map[string]any, target any) error {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
// Prevent metadata from overriding model fields to avoid billing bypass.
|
||||
delete(metadata, "model")
|
||||
metaBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
|
||||
@@ -43,12 +43,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||
err := common.UnmarshalJsonStr(data, &xAIResp)
|
||||
if err != nil {
|
||||
if err := common.UnmarshalJsonStr(data, &xAIResp); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
sr.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 把 xAI 的usage转换为 OpenAI 的usage
|
||||
@@ -61,11 +61,10 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
|
||||
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
||||
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
if err := helper.ObjectData(c, openaiResponse); err != nil {
|
||||
common.SysLog(err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package zhipu_4v
|
||||
|
||||
var ModelList = []string{
|
||||
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus", "glm-4.6",
|
||||
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus", "glm-4.6", "glm-4.6v", "glm-4.7", "glm-4.7-flash", "glm-5",
|
||||
}
|
||||
|
||||
var ChannelName = "zhipu_4v"
|
||||
|
||||
@@ -59,7 +59,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
Type: "adaptive",
|
||||
}
|
||||
request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
|
||||
request.TopP = common.GetPointer[float64](0)
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
info.UpstreamModelName = request.Model
|
||||
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
@@ -77,7 +76,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
request.TopP = common.GetPointer[float64](0)
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
if !model_setting.ShouldPreserveThinkingSuffix(info.OriginModelName) {
|
||||
@@ -124,7 +122,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage)
|
||||
service.PostTextConsumeQuota(c, info, usage, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -192,6 +190,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
+266
-11
@@ -21,10 +21,23 @@ var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
|
||||
const (
|
||||
paramOverrideContextRequestHeaders = "request_headers"
|
||||
paramOverrideContextHeaderOverride = "header_override"
|
||||
paramOverrideContextAuditRecorder = "__param_override_audit_recorder"
|
||||
)
|
||||
|
||||
var errSourceHeaderNotFound = errors.New("source header does not exist")
|
||||
|
||||
var paramOverrideKeyAuditPaths = map[string]struct{}{
|
||||
"model": {},
|
||||
"original_model": {},
|
||||
"upstream_model": {},
|
||||
"service_tier": {},
|
||||
"inference_geo": {},
|
||||
}
|
||||
|
||||
type paramOverrideAuditRecorder struct {
|
||||
lines []string
|
||||
}
|
||||
|
||||
type ConditionOperation struct {
|
||||
Path string `json:"path"` // JSON路径
|
||||
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
|
||||
@@ -118,6 +131,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
}
|
||||
auditRecorder := getParamOverrideAuditRecorder(conditionContext)
|
||||
|
||||
// 尝试断言为操作格式
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
@@ -125,7 +139,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
|
||||
workingJSON := jsonData
|
||||
var err error
|
||||
if len(legacyOverride) > 0 {
|
||||
workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride)
|
||||
workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride, auditRecorder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -137,7 +151,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
|
||||
}
|
||||
|
||||
// 直接使用旧方法
|
||||
return applyOperationsLegacy(jsonData, paramOverride)
|
||||
return applyOperationsLegacy(jsonData, paramOverride, auditRecorder)
|
||||
}
|
||||
|
||||
func buildLegacyParamOverride(paramOverride map[string]interface{}) map[string]interface{} {
|
||||
@@ -161,14 +175,200 @@ func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte,
|
||||
}
|
||||
|
||||
overrideCtx := BuildParamOverrideContext(info)
|
||||
var recorder *paramOverrideAuditRecorder
|
||||
if shouldEnableParamOverrideAudit(paramOverride) {
|
||||
recorder = ¶mOverrideAuditRecorder{}
|
||||
overrideCtx[paramOverrideContextAuditRecorder] = recorder
|
||||
}
|
||||
result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
syncRuntimeHeaderOverrideFromContext(info, overrideCtx)
|
||||
if info != nil {
|
||||
if recorder != nil {
|
||||
info.ParamOverrideAudit = recorder.lines
|
||||
} else {
|
||||
info.ParamOverrideAudit = nil
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func shouldEnableParamOverrideAudit(paramOverride map[string]interface{}) bool {
|
||||
if common.DebugEnabled {
|
||||
return true
|
||||
}
|
||||
if len(paramOverride) == 0 {
|
||||
return false
|
||||
}
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
for _, operation := range operations {
|
||||
if shouldAuditParamPath(strings.TrimSpace(operation.Path)) ||
|
||||
shouldAuditParamPath(strings.TrimSpace(operation.To)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for key := range buildLegacyParamOverride(paramOverride) {
|
||||
if shouldAuditParamPath(strings.TrimSpace(key)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
for key := range paramOverride {
|
||||
if shouldAuditParamPath(strings.TrimSpace(key)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getParamOverrideAuditRecorder(context map[string]interface{}) *paramOverrideAuditRecorder {
|
||||
if context == nil {
|
||||
return nil
|
||||
}
|
||||
recorder, _ := context[paramOverrideContextAuditRecorder].(*paramOverrideAuditRecorder)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func (r *paramOverrideAuditRecorder) recordOperation(mode, path, from, to string, value interface{}) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
line := buildParamOverrideAuditLine(mode, path, from, to, value)
|
||||
if line == "" {
|
||||
return
|
||||
}
|
||||
if lo.Contains(r.lines, line) {
|
||||
return
|
||||
}
|
||||
r.lines = append(r.lines, line)
|
||||
}
|
||||
|
||||
func shouldAuditParamPath(path string) bool {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return false
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
return true
|
||||
}
|
||||
_, ok := paramOverrideKeyAuditPaths[path]
|
||||
return ok
|
||||
}
|
||||
|
||||
func shouldAuditOperation(mode, path, from, to string) bool {
|
||||
if common.DebugEnabled {
|
||||
return true
|
||||
}
|
||||
for _, candidate := range []string{path, to} {
|
||||
if shouldAuditParamPath(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func formatParamOverrideAuditValue(value interface{}) string {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return "<empty>"
|
||||
case string:
|
||||
return typed
|
||||
default:
|
||||
return common.GetJsonString(typed)
|
||||
}
|
||||
}
|
||||
|
||||
func buildParamOverrideAuditLine(mode, path, from, to string, value interface{}) string {
|
||||
mode = strings.TrimSpace(mode)
|
||||
path = strings.TrimSpace(path)
|
||||
from = strings.TrimSpace(from)
|
||||
to = strings.TrimSpace(to)
|
||||
|
||||
if !shouldAuditOperation(mode, path, from, to) {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case "set":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("set %s = %s", path, formatParamOverrideAuditValue(value))
|
||||
case "delete":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("delete %s", path)
|
||||
case "copy":
|
||||
if from == "" || to == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("copy %s -> %s", from, to)
|
||||
case "move":
|
||||
if from == "" || to == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("move %s -> %s", from, to)
|
||||
case "prepend":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("prepend %s with %s", path, formatParamOverrideAuditValue(value))
|
||||
case "append":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("append %s with %s", path, formatParamOverrideAuditValue(value))
|
||||
case "trim_prefix", "trim_suffix", "ensure_prefix", "ensure_suffix":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s %s with %s", mode, path, formatParamOverrideAuditValue(value))
|
||||
case "trim_space", "to_lower", "to_upper":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s %s", mode, path)
|
||||
case "replace", "regex_replace":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s %s from %s to %s", mode, path, from, to)
|
||||
case "set_header":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("set_header %s = %s", path, formatParamOverrideAuditValue(value))
|
||||
case "delete_header":
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("delete_header %s", path)
|
||||
case "copy_header", "move_header":
|
||||
if from == "" || to == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s %s -> %s", mode, from, to)
|
||||
case "pass_headers":
|
||||
return fmt.Sprintf("pass_headers %s", formatParamOverrideAuditValue(value))
|
||||
case "sync_fields":
|
||||
if from == "" || to == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sync_fields %s -> %s", from, to)
|
||||
case "return_error":
|
||||
return fmt.Sprintf("return_error %s", formatParamOverrideAuditValue(value))
|
||||
default:
|
||||
if path == "" {
|
||||
return mode
|
||||
}
|
||||
return fmt.Sprintf("%s %s", mode, path)
|
||||
}
|
||||
}
|
||||
|
||||
func getParamOverrideMap(info *RelayInfo) map[string]interface{} {
|
||||
if info == nil || info.ChannelMeta == nil {
|
||||
return nil
|
||||
@@ -455,7 +655,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
|
||||
}
|
||||
|
||||
// applyOperationsLegacy 原参数覆盖方法
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) {
|
||||
reqMap := make(map[string]interface{})
|
||||
err := common.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
@@ -464,6 +664,7 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
|
||||
|
||||
for key, value := range paramOverride {
|
||||
reqMap[key] = value
|
||||
auditRecorder.recordOperation("set", key, "", "", value)
|
||||
}
|
||||
|
||||
return common.Marshal(reqMap)
|
||||
@@ -471,6 +672,7 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
|
||||
|
||||
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
|
||||
context := ensureContextMap(conditionContext)
|
||||
auditRecorder := getParamOverrideAuditRecorder(context)
|
||||
contextJSON, err := marshalContextJSON(context)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal condition context: %v", err)
|
||||
@@ -506,6 +708,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("delete", path, "", "", nil)
|
||||
}
|
||||
case "set":
|
||||
for _, path := range opPaths {
|
||||
@@ -516,11 +719,15 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("set", path, "", "", op.Value)
|
||||
}
|
||||
case "move":
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
result, err = moveValue(result, opFrom, opTo)
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("move", "", opFrom, opTo, nil)
|
||||
}
|
||||
case "copy":
|
||||
if op.From == "" || op.To == "" {
|
||||
return "", fmt.Errorf("copy from/to is required")
|
||||
@@ -528,12 +735,16 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
result, err = copyValue(result, opFrom, opTo)
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("copy", "", opFrom, opTo, nil)
|
||||
}
|
||||
case "prepend":
|
||||
for _, path := range opPaths {
|
||||
result, err = modifyValue(result, path, op.Value, op.KeepOrigin, true)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("prepend", path, "", "", op.Value)
|
||||
}
|
||||
case "append":
|
||||
for _, path := range opPaths {
|
||||
@@ -541,6 +752,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("append", path, "", "", op.Value)
|
||||
}
|
||||
case "trim_prefix":
|
||||
for _, path := range opPaths {
|
||||
@@ -548,6 +760,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("trim_prefix", path, "", "", op.Value)
|
||||
}
|
||||
case "trim_suffix":
|
||||
for _, path := range opPaths {
|
||||
@@ -555,6 +768,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("trim_suffix", path, "", "", op.Value)
|
||||
}
|
||||
case "ensure_prefix":
|
||||
for _, path := range opPaths {
|
||||
@@ -562,6 +776,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("ensure_prefix", path, "", "", op.Value)
|
||||
}
|
||||
case "ensure_suffix":
|
||||
for _, path := range opPaths {
|
||||
@@ -569,6 +784,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("ensure_suffix", path, "", "", op.Value)
|
||||
}
|
||||
case "trim_space":
|
||||
for _, path := range opPaths {
|
||||
@@ -576,6 +792,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("trim_space", path, "", "", nil)
|
||||
}
|
||||
case "to_lower":
|
||||
for _, path := range opPaths {
|
||||
@@ -583,6 +800,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("to_lower", path, "", "", nil)
|
||||
}
|
||||
case "to_upper":
|
||||
for _, path := range opPaths {
|
||||
@@ -590,6 +808,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("to_upper", path, "", "", nil)
|
||||
}
|
||||
case "replace":
|
||||
for _, path := range opPaths {
|
||||
@@ -597,6 +816,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("replace", path, op.From, op.To, nil)
|
||||
}
|
||||
case "regex_replace":
|
||||
for _, path := range opPaths {
|
||||
@@ -604,8 +824,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
auditRecorder.recordOperation("regex_replace", path, op.From, op.To, nil)
|
||||
}
|
||||
case "return_error":
|
||||
auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value)
|
||||
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
|
||||
if parseErr != nil {
|
||||
return "", parseErr
|
||||
@@ -621,11 +843,13 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
case "set_header":
|
||||
err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin)
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("set_header", op.Path, "", "", op.Value)
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
case "delete_header":
|
||||
err = deleteHeaderOverrideInContext(context, op.Path)
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("delete_header", op.Path, "", "", nil)
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
case "copy_header":
|
||||
@@ -642,6 +866,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
err = nil
|
||||
}
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("copy_header", "", sourceHeader, targetHeader, nil)
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
case "move_header":
|
||||
@@ -658,6 +883,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
err = nil
|
||||
}
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("move_header", "", sourceHeader, targetHeader, nil)
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
case "pass_headers":
|
||||
@@ -675,11 +901,13 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("pass_headers", "", "", "", headerNames)
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
case "sync_fields":
|
||||
result, err = syncFieldsBetweenTargets(result, context, op.From, op.To)
|
||||
if err == nil {
|
||||
auditRecorder.recordOperation("sync_fields", "", op.From, op.To, nil)
|
||||
contextJSON, err = marshalContextJSON(context)
|
||||
}
|
||||
default:
|
||||
@@ -847,24 +1075,30 @@ func resolveHeaderOverrideValueByMapping(context map[string]interface{}, headerN
|
||||
return "", false, fmt.Errorf("header value mapping cannot be empty")
|
||||
}
|
||||
|
||||
sourceValue, exists := getHeaderValueFromContext(context, headerName)
|
||||
if !exists {
|
||||
return "", false, nil
|
||||
appendTokens, err := parseHeaderAppendTokens(mapping)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
sourceTokens := splitHeaderListValue(sourceValue)
|
||||
if len(sourceTokens) == 0 {
|
||||
return "", false, nil
|
||||
keepOnlyDeclared := parseHeaderKeepOnlyDeclared(mapping)
|
||||
|
||||
sourceValue, exists := getHeaderValueFromContext(context, headerName)
|
||||
sourceTokens := make([]string, 0)
|
||||
if exists {
|
||||
sourceTokens = splitHeaderListValue(sourceValue)
|
||||
}
|
||||
|
||||
wildcardValue, hasWildcard := mapping["*"]
|
||||
resultTokens := make([]string, 0, len(sourceTokens))
|
||||
resultTokens := make([]string, 0, len(sourceTokens)+len(appendTokens))
|
||||
for _, token := range sourceTokens {
|
||||
replacementRaw, hasReplacement := mapping[token]
|
||||
if !hasReplacement && hasWildcard {
|
||||
if !hasReplacement && hasWildcard && !keepOnlyDeclared {
|
||||
replacementRaw = wildcardValue
|
||||
hasReplacement = true
|
||||
}
|
||||
if !hasReplacement {
|
||||
if keepOnlyDeclared {
|
||||
continue
|
||||
}
|
||||
resultTokens = append(resultTokens, token)
|
||||
continue
|
||||
}
|
||||
@@ -875,6 +1109,7 @@ func resolveHeaderOverrideValueByMapping(context map[string]interface{}, headerN
|
||||
resultTokens = append(resultTokens, replacementTokens...)
|
||||
}
|
||||
|
||||
resultTokens = append(resultTokens, appendTokens...)
|
||||
resultTokens = lo.Uniq(resultTokens)
|
||||
if len(resultTokens) == 0 {
|
||||
return "", false, nil
|
||||
@@ -882,6 +1117,26 @@ func resolveHeaderOverrideValueByMapping(context map[string]interface{}, headerN
|
||||
return strings.Join(resultTokens, ","), true, nil
|
||||
}
|
||||
|
||||
func parseHeaderAppendTokens(mapping map[string]interface{}) ([]string, error) {
|
||||
appendRaw, ok := mapping["$append"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return parseHeaderReplacementTokens(appendRaw)
|
||||
}
|
||||
|
||||
func parseHeaderKeepOnlyDeclared(mapping map[string]interface{}) bool {
|
||||
keepOnlyDeclaredRaw, ok := mapping["$keep_only_declared"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
keepOnlyDeclared, ok := keepOnlyDeclaredRaw.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return keepOnlyDeclared
|
||||
}
|
||||
|
||||
func parseHeaderReplacementTokens(value interface{}) ([]string, error) {
|
||||
switch raw := value.(type) {
|
||||
case nil:
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
common2 "github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -1653,6 +1654,141 @@ func TestApplyParamOverrideSetHeaderMapDeleteWholeHeaderWhenAllTokensCleared(t *
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideSetHeaderMapAppendsTokens(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "set_header",
|
||||
"path": "anthropic-beta",
|
||||
"value": map[string]interface{}{
|
||||
"$append": []interface{}{"context-1m-2025-08-07", "computer-use-2025-01-24"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := map[string]interface{}{
|
||||
"header_override": map[string]interface{}{
|
||||
"anthropic-beta": "computer-use-2025-01-24",
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||
|
||||
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected header_override context map")
|
||||
}
|
||||
if headers["anthropic-beta"] != "computer-use-2025-01-24,context-1m-2025-08-07" {
|
||||
t.Fatalf("expected anthropic-beta to append new token without duplicates, got: %v", headers["anthropic-beta"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideSetHeaderMapAppendsTokensWhenHeaderMissing(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "set_header",
|
||||
"path": "anthropic-beta",
|
||||
"value": map[string]interface{}{
|
||||
"$append": []interface{}{"context-1m-2025-08-07", "computer-use-2025-01-24"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := map[string]interface{}{}
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||
|
||||
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected header_override context map")
|
||||
}
|
||||
if headers["anthropic-beta"] != "context-1m-2025-08-07,computer-use-2025-01-24" {
|
||||
t.Fatalf("expected anthropic-beta to be created from appended tokens, got: %v", headers["anthropic-beta"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideSetHeaderMapKeepOnlyDeclaredDropsUndeclaredTokens(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "set_header",
|
||||
"path": "anthropic-beta",
|
||||
"value": map[string]interface{}{
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"$append": []interface{}{"context-1m-2025-08-07"},
|
||||
"$keep_only_declared": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := map[string]interface{}{
|
||||
"header_override": map[string]interface{}{
|
||||
"anthropic-beta": "advanced-tool-use-2025-11-20,computer-use-2025-01-24",
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||
|
||||
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected header_override context map")
|
||||
}
|
||||
if headers["anthropic-beta"] != "computer-use-2025-01-24,context-1m-2025-08-07" {
|
||||
t.Fatalf("expected anthropic-beta to keep only declared tokens, got: %v", headers["anthropic-beta"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideSetHeaderMapKeepOnlyDeclaredDeletesHeaderWhenNothingDeclaredMatches(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "set_header",
|
||||
"path": "anthropic-beta",
|
||||
"value": map[string]interface{}{
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"$keep_only_declared": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := map[string]interface{}{
|
||||
"header_override": map[string]interface{}{
|
||||
"anthropic-beta": "advanced-tool-use-2025-11-20",
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.7}`, string(out))
|
||||
|
||||
headers, ok := ctx["header_override"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected header_override context map")
|
||||
}
|
||||
if _, exists := headers["anthropic-beta"]; exists {
|
||||
t.Fatalf("expected anthropic-beta to be deleted when no declared tokens remain, got: %v", headers["anthropic-beta"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideConditionsObjectShorthand(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
@@ -1931,6 +2067,105 @@ func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
||||
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideWithRelayInfoRecordsOperationAuditInDebugMode(t *testing.T) {
|
||||
originalDebugEnabled := common2.DebugEnabled
|
||||
common2.DebugEnabled = true
|
||||
t.Cleanup(func() {
|
||||
common2.DebugEnabled = originalDebugEnabled
|
||||
})
|
||||
|
||||
info := &RelayInfo{
|
||||
ChannelMeta: &ChannelMeta{
|
||||
ParamOverride: map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "copy",
|
||||
"from": "metadata.target_model",
|
||||
"to": "model",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "set",
|
||||
"path": "service_tier",
|
||||
"value": "flex",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "set",
|
||||
"path": "temperature",
|
||||
"value": 0.1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverrideWithRelayInfo([]byte(`{
|
||||
"model":"gpt-4.1",
|
||||
"temperature":0.7,
|
||||
"metadata":{"target_model":"gpt-4.1-mini"}
|
||||
}`), info)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{
|
||||
"model":"gpt-4.1-mini",
|
||||
"temperature":0.1,
|
||||
"service_tier":"flex",
|
||||
"metadata":{"target_model":"gpt-4.1-mini"}
|
||||
}`, string(out))
|
||||
|
||||
expected := []string{
|
||||
"copy metadata.target_model -> model",
|
||||
"set service_tier = flex",
|
||||
"set temperature = 0.1",
|
||||
}
|
||||
if !reflect.DeepEqual(info.ParamOverrideAudit, expected) {
|
||||
t.Fatalf("unexpected param override audit, got %#v", info.ParamOverrideAudit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideWithRelayInfoRecordsOnlyKeyOperationsWhenDebugDisabled(t *testing.T) {
|
||||
originalDebugEnabled := common2.DebugEnabled
|
||||
common2.DebugEnabled = false
|
||||
t.Cleanup(func() {
|
||||
common2.DebugEnabled = originalDebugEnabled
|
||||
})
|
||||
|
||||
info := &RelayInfo{
|
||||
ChannelMeta: &ChannelMeta{
|
||||
ParamOverride: map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "copy",
|
||||
"from": "metadata.target_model",
|
||||
"to": "model",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"mode": "set",
|
||||
"path": "temperature",
|
||||
"value": 0.1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverrideWithRelayInfo([]byte(`{
|
||||
"model":"gpt-4.1",
|
||||
"temperature":0.7,
|
||||
"metadata":{"target_model":"gpt-4.1-mini"}
|
||||
}`), info)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
"copy metadata.target_model -> model",
|
||||
}
|
||||
if !reflect.DeepEqual(info.ParamOverrideAudit, expected) {
|
||||
t.Fatalf("unexpected param override audit, got %#v", info.ParamOverrideAudit)
|
||||
}
|
||||
}
|
||||
|
||||
func assertJSONEqual(t *testing.T, want, got string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -149,6 +149,7 @@ type RelayInfo struct {
|
||||
LastError *types.NewAPIError
|
||||
RuntimeHeadersOverride map[string]interface{}
|
||||
UseRuntimeHeadersOverride bool
|
||||
ParamOverrideAudit []string
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
@@ -161,6 +162,8 @@ type RelayInfo struct {
|
||||
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
|
||||
FinalRequestRelayFormat types.RelayFormat
|
||||
|
||||
StreamStatus *StreamStatus
|
||||
|
||||
ThinkingContentInfo
|
||||
TokenCountMeta
|
||||
*ClaudeConvertInfo
|
||||
@@ -337,15 +340,10 @@ func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeNone,
|
||||
}
|
||||
info.IsClaudeBetaQuery = c.Query("beta") == "true" || isClaudeBetaForced(c)
|
||||
info.IsClaudeBetaQuery = c.Query("beta") == "true"
|
||||
return info
|
||||
}
|
||||
|
||||
func isClaudeBetaForced(c *gin.Context) bool {
|
||||
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
||||
return ok && channelOtherSettings.ClaudeBetaQuery
|
||||
}
|
||||
|
||||
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StreamEndReason string
|
||||
|
||||
const (
|
||||
StreamEndReasonNone StreamEndReason = ""
|
||||
StreamEndReasonDone StreamEndReason = "done"
|
||||
StreamEndReasonTimeout StreamEndReason = "timeout"
|
||||
StreamEndReasonClientGone StreamEndReason = "client_gone"
|
||||
StreamEndReasonScannerErr StreamEndReason = "scanner_error"
|
||||
StreamEndReasonHandlerStop StreamEndReason = "handler_stop"
|
||||
StreamEndReasonEOF StreamEndReason = "eof"
|
||||
StreamEndReasonPanic StreamEndReason = "panic"
|
||||
StreamEndReasonPingFail StreamEndReason = "ping_fail"
|
||||
)
|
||||
|
||||
const maxStreamErrorEntries = 20
|
||||
|
||||
type StreamErrorEntry struct {
|
||||
Message string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
type StreamStatus struct {
|
||||
EndReason StreamEndReason
|
||||
EndError error
|
||||
endOnce sync.Once
|
||||
|
||||
mu sync.Mutex
|
||||
Errors []StreamErrorEntry
|
||||
ErrorCount int
|
||||
}
|
||||
|
||||
func NewStreamStatus() *StreamStatus {
|
||||
return &StreamStatus{}
|
||||
}
|
||||
|
||||
func (s *StreamStatus) SetEndReason(reason StreamEndReason, err error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.endOnce.Do(func() {
|
||||
s.EndReason = reason
|
||||
s.EndError = err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StreamStatus) RecordError(msg string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ErrorCount++
|
||||
if len(s.Errors) < maxStreamErrorEntries {
|
||||
s.Errors = append(s.Errors, StreamErrorEntry{
|
||||
Message: msg,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StreamStatus) HasErrors() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ErrorCount > 0
|
||||
}
|
||||
|
||||
func (s *StreamStatus) TotalErrorCount() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ErrorCount
|
||||
}
|
||||
|
||||
func (s *StreamStatus) IsNormalEnd() bool {
|
||||
if s == nil {
|
||||
return true
|
||||
}
|
||||
return s.EndReason == StreamEndReasonDone ||
|
||||
s.EndReason == StreamEndReasonEOF ||
|
||||
s.EndReason == StreamEndReasonHandlerStop
|
||||
}
|
||||
|
||||
func (s *StreamStatus) Summary() string {
|
||||
if s == nil {
|
||||
return "StreamStatus<nil>"
|
||||
}
|
||||
b := &strings.Builder{}
|
||||
fmt.Fprintf(b, "reason=%s", s.EndReason)
|
||||
if s.EndError != nil {
|
||||
fmt.Fprintf(b, " end_error=%q", s.EndError.Error())
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.ErrorCount > 0 {
|
||||
fmt.Fprintf(b, " soft_errors=%d", s.ErrorCount)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStreamStatus_SetEndReason_FirstWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
s.SetEndReason(StreamEndReasonTimeout, nil)
|
||||
s.SetEndReason(StreamEndReasonClientGone, fmt.Errorf("context canceled"))
|
||||
|
||||
assert.Equal(t, StreamEndReasonDone, s.EndReason)
|
||||
assert.Nil(t, s.EndError)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_WithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
expectedErr := fmt.Errorf("read: connection reset")
|
||||
s.SetEndReason(StreamEndReasonScannerErr, expectedErr)
|
||||
|
||||
assert.Equal(t, StreamEndReasonScannerErr, s.EndReason)
|
||||
assert.Equal(t, expectedErr, s.EndError)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
}
|
||||
|
||||
func TestStreamStatus_SetEndReason_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
reasons := []StreamEndReason{
|
||||
StreamEndReasonDone,
|
||||
StreamEndReasonTimeout,
|
||||
StreamEndReasonClientGone,
|
||||
StreamEndReasonScannerErr,
|
||||
StreamEndReasonHandlerStop,
|
||||
StreamEndReasonEOF,
|
||||
StreamEndReasonPanic,
|
||||
StreamEndReasonPingFail,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, r := range reasons {
|
||||
wg.Add(1)
|
||||
go func(reason StreamEndReason) {
|
||||
defer wg.Done()
|
||||
s.SetEndReason(reason, nil)
|
||||
}(r)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.NotEqual(t, StreamEndReasonNone, s.EndReason)
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
s.RecordError("bad json")
|
||||
s.RecordError("another bad json")
|
||||
s.RecordError("client gone")
|
||||
|
||||
assert.True(t, s.HasErrors())
|
||||
assert.Equal(t, 3, s.TotalErrorCount())
|
||||
assert.Len(t, s.Errors, 3)
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_CapAtMax(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
s.RecordError(fmt.Sprintf("error_%d", i))
|
||||
}
|
||||
|
||||
assert.Equal(t, maxStreamErrorEntries, len(s.Errors))
|
||||
assert.Equal(t, 30, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
s.RecordError("should not panic")
|
||||
}
|
||||
|
||||
func TestStreamStatus_RecordError_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
s.RecordError(fmt.Sprintf("error_%d", idx))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 100, s.TotalErrorCount())
|
||||
assert.LessOrEqual(t, len(s.Errors), maxStreamErrorEntries)
|
||||
}
|
||||
|
||||
func TestStreamStatus_HasErrors_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStreamStatus()
|
||||
assert.False(t, s.HasErrors())
|
||||
assert.Equal(t, 0, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_HasErrors_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.False(t, s.HasErrors())
|
||||
assert.Equal(t, 0, s.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamStatus_IsNormalEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
reason StreamEndReason
|
||||
normal bool
|
||||
}{
|
||||
{StreamEndReasonDone, true},
|
||||
{StreamEndReasonEOF, true},
|
||||
{StreamEndReasonHandlerStop, true},
|
||||
{StreamEndReasonTimeout, false},
|
||||
{StreamEndReasonClientGone, false},
|
||||
{StreamEndReasonScannerErr, false},
|
||||
{StreamEndReasonPanic, false},
|
||||
{StreamEndReasonPingFail, false},
|
||||
{StreamEndReasonNone, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
s := NewStreamStatus()
|
||||
s.SetEndReason(tt.reason, nil)
|
||||
assert.Equal(t, tt.normal, s.IsNormalEnd(), "reason=%s", tt.reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamStatus_IsNormalEnd_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.True(t, s.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamStatus_Summary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := NewStreamStatus()
|
||||
s.SetEndReason(StreamEndReasonDone, nil)
|
||||
summary := s.Summary()
|
||||
assert.Contains(t, summary, "reason=done")
|
||||
assert.NotContains(t, summary, "soft_errors")
|
||||
|
||||
s2 := NewStreamStatus()
|
||||
s2.SetEndReason(StreamEndReasonTimeout, nil)
|
||||
s2.RecordError("bad json")
|
||||
s2.RecordError("write failed")
|
||||
summary2 := s2.Summary()
|
||||
assert.Contains(t, summary2, "reason=timeout")
|
||||
assert.Contains(t, summary2, "soft_errors=2")
|
||||
}
|
||||
|
||||
func TestStreamStatus_Summary_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamStatus
|
||||
assert.Equal(t, "StreamStatus<nil>", s.Summary())
|
||||
}
|
||||
+2
-293
@@ -6,25 +6,20 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -93,7 +88,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage, "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage)
|
||||
service.PostTextConsumeQuota(c, info, usage, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -216,293 +211,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
|
||||
originUsage := usage
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
|
||||
if originUsage != nil {
|
||||
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
imageRatio := relayInfo.PriceData.ImageRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
|
||||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
|
||||
// openai web search 工具计费
|
||||
var dWebSearchQuota decimal.Decimal
|
||||
var webSearchPrice float64
|
||||
// response api 格式工具计费
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||||
// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
// search-preview 模型不支持 response api
|
||||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
||||
if searchContextSize == "" {
|
||||
searchContextSize = "medium"
|
||||
}
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||
searchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
// claude web search tool 计费
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
var claudeWebSearchPrice float64
|
||||
claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
|
||||
if claudeWebSearchCallCount > 0 {
|
||||
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
|
||||
}
|
||||
// file search tool 计费
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
var fileSearchPrice float64
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||||
fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
|
||||
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String()))
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
var imageGenerationCallPrice float64
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
var audioInputPrice float64
|
||||
isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
// Anthropic API 的 input_tokens 已经不包含缓存 tokens,不需要减去
|
||||
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens,需要减去
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
if !isClaudeUsageSemantic {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
if !isClaudeUsageSemantic {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
}
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
// 减去 image tokens
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
if !dImageTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dImageTokens)
|
||||
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
|
||||
}
|
||||
|
||||
// 减去 Gemini audio tokens
|
||||
if !dAudioTokens.IsZero() {
|
||||
audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
|
||||
if audioInputPrice > 0 {
|
||||
// 重新计算 base tokens
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
Add(imageTokensWithRatio).
|
||||
Add(dCachedCreationTokensWithRatio)
|
||||
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
|
||||
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
|
||||
|
||||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||||
}
|
||||
} else {
|
||||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
}
|
||||
// 添加 responses tools call 调用的配额
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
dOtherRatio := decimal.NewFromFloat(otherRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
|
||||
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
//var logContent string
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
if !ratio.IsZero() && quota == 0 {
|
||||
quota = 1
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
if err := service.SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if adminRejectReason != "" {
|
||||
other["reject_reason"] = adminRejectReason
|
||||
}
|
||||
// For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
|
||||
if isClaudeUsageSemantic {
|
||||
other["claude"] = true
|
||||
other["usage_semantic"] = "anthropic"
|
||||
}
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
other["image_output"] = imageTokens
|
||||
}
|
||||
if cachedCreationTokens != 0 {
|
||||
other["cache_creation_tokens"] = cachedCreationTokens
|
||||
other["cache_creation_ratio"] = cachedCreationRatio
|
||||
}
|
||||
if !dWebSearchQuota.IsZero() {
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = webSearchTool.CallCount
|
||||
other["web_search_price"] = webSearchPrice
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = 1
|
||||
other["web_search_price"] = webSearchPrice
|
||||
}
|
||||
} else if !dClaudeWebSearchQuota.IsZero() {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = claudeWebSearchCallCount
|
||||
other["web_search_price"] = claudeWebSearchPrice
|
||||
}
|
||||
if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
|
||||
other["file_search"] = true
|
||||
other["file_search_call_count"] = fileSearchTool.CallCount
|
||||
other["file_search_price"] = fileSearchPrice
|
||||
}
|
||||
}
|
||||
if !audioInputQuota.IsZero() {
|
||||
other["audio_input_seperate_price"] = true
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dImageGenerationCallQuota.IsZero() {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = imageGenerationCallPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -288,6 +288,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
)
|
||||
|
||||
// StreamResult is passed to each dataHandler invocation, providing methods
|
||||
// to record soft errors, signal fatal stops, or mark normal completion.
|
||||
// StreamScannerHandler checks IsStopped() after each callback invocation.
|
||||
type StreamResult struct {
|
||||
status *relaycommon.StreamStatus
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func newStreamResult(status *relaycommon.StreamStatus) *StreamResult {
|
||||
return &StreamResult{status: status}
|
||||
}
|
||||
|
||||
// Error records a soft error. The stream continues processing.
|
||||
// Can be called multiple times per chunk.
|
||||
func (r *StreamResult) Error(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
r.status.RecordError(err.Error())
|
||||
}
|
||||
|
||||
// Stop records a fatal error and marks the stream to stop after this chunk.
|
||||
func (r *StreamResult) Stop(err error) {
|
||||
if err != nil {
|
||||
r.status.RecordError(err.Error())
|
||||
}
|
||||
r.status.SetEndReason(relaycommon.StreamEndReasonHandlerStop, err)
|
||||
r.stopped = true
|
||||
}
|
||||
|
||||
// Done signals that the handler has finished processing normally
|
||||
// (e.g., Dify "message_end"). The stream stops after this chunk.
|
||||
func (r *StreamResult) Done() {
|
||||
r.status.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
r.stopped = true
|
||||
}
|
||||
|
||||
// IsStopped returns whether Stop() or Done() was called during this chunk.
|
||||
func (r *StreamResult) IsStopped() bool {
|
||||
return r.stopped
|
||||
}
|
||||
|
||||
// reset clears the per-chunk stopped flag so the object can be reused.
|
||||
func (r *StreamResult) reset() {
|
||||
r.stopped = false
|
||||
}
|
||||
@@ -34,12 +34,15 @@ func getScannerBufferSize() int {
|
||||
return DefaultMaxScannerBufferSize
|
||||
}
|
||||
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
|
||||
|
||||
if resp == nil || dataHandler == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 无条件新建 StreamStatus
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
|
||||
// 确保响应体总是被关闭
|
||||
defer func() {
|
||||
if resp.Body != nil {
|
||||
@@ -121,6 +124,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("ping panic: %v", r))
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
@@ -148,6 +152,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.LogError(c, "ping data error: "+err.Error())
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, err)
|
||||
return
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
@@ -155,6 +160,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
logger.LogError(c, "ping data send timeout")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, fmt.Errorf("ping send timeout"))
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -184,14 +190,17 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("handler panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
sr := newStreamResult(info.StreamStatus)
|
||||
for data := range dataChan {
|
||||
sr.reset()
|
||||
writeMutex.Lock()
|
||||
success := dataHandler(data)
|
||||
dataHandler(data, sr)
|
||||
writeMutex.Unlock()
|
||||
if !success {
|
||||
if sr.IsStopped() {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -205,6 +214,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("scanner panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
if common.DebugEnabled {
|
||||
@@ -220,6 +230,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-c.Request.Context().Done():
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
||||
return
|
||||
default:
|
||||
}
|
||||
@@ -253,7 +264,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
if common.DebugEnabled {
|
||||
println("received [DONE], stopping scanner")
|
||||
}
|
||||
@@ -264,20 +275,25 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
if err := scanner.Err(); err != nil {
|
||||
if err != io.EOF {
|
||||
logger.LogError(c, "scanner error: "+err.Error())
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
|
||||
}
|
||||
}
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
||||
})
|
||||
|
||||
// 主循环等待完成或超时
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
logger.LogError(c, "streaming timeout")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonTimeout, nil)
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
logger.LogInfo(c, "streaming finished")
|
||||
// EndReason already set by the goroutine that triggered stopChan
|
||||
case <-c.Request.Context().Done():
|
||||
// 客户端断开连接
|
||||
logger.LogInfo(c, "client disconnected")
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
||||
}
|
||||
|
||||
if info.StreamStatus.IsNormalEnd() && !info.StreamStatus.HasErrors() {
|
||||
logger.LogInfo(c, fmt.Sprintf("stream ended: %s", info.StreamStatus.Summary()))
|
||||
} else {
|
||||
logger.LogError(c, fmt.Sprintf("stream ended: %s, received=%d", info.StreamStatus.Summary(), info.ReceivedResponseCount))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,8 +56,6 @@ func buildSSEBody(n int) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// slowReader wraps a reader and injects a delay before each Read call,
|
||||
// simulating a slow upstream that trickles data.
|
||||
type slowReader struct {
|
||||
r io.Reader
|
||||
delay time.Duration
|
||||
@@ -79,7 +77,7 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
||||
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
StreamScannerHandler(c, nil, info, func(data string) bool { return true })
|
||||
StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {})
|
||||
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
||||
}
|
||||
|
||||
@@ -89,9 +87,8 @@ func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(""))
|
||||
|
||||
var called atomic.Bool
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
called.Store(true)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.False(t, called.Load(), "handler should not be called for empty body")
|
||||
@@ -105,9 +102,8 @@ func TestStreamScannerHandler_1000Chunks(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
@@ -124,9 +120,8 @@ func TestStreamScannerHandler_10000Chunks(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
elapsed := time.Since(start)
|
||||
@@ -145,11 +140,10 @@ func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
received := make([]string, 0, numChunks)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
mu.Lock()
|
||||
received = append(received, data)
|
||||
mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
require.Equal(t, numChunks, len(received))
|
||||
@@ -166,31 +160,32 @@ func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
|
||||
func TestStreamScannerHandler_StopStopsStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 200
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
const failAt = 50
|
||||
const stopAt int64 = 50
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
return n < failAt
|
||||
if n >= stopAt {
|
||||
sr.Stop(fmt.Errorf("fatal at %d", n))
|
||||
}
|
||||
})
|
||||
|
||||
// The worker stops at failAt; the scanner may have read ahead,
|
||||
// but the handler should not be called beyond failAt.
|
||||
assert.Equal(t, int64(failAt), count.Load())
|
||||
assert.Equal(t, stopAt, count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
@@ -210,9 +205,8 @@ func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(100), count.Load())
|
||||
@@ -225,25 +219,18 @@ func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var got string
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
got = data
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, "{\"trimmed\":true}", got)
|
||||
}
|
||||
|
||||
// ---------- Decoupling: scanner not blocked by slow handler ----------
|
||||
// ---------- Decoupling ----------
|
||||
|
||||
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
|
||||
// If the scanner were synchronously coupled to the handler, total time would be
|
||||
// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
|
||||
// With decoupling, total time should be closer to
|
||||
// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
|
||||
// because the scanner reads ahead into the buffer while the handler processes.
|
||||
const numChunks = 50
|
||||
const upstreamDelay = 10 * time.Millisecond
|
||||
const handlerDelay = 20 * time.Millisecond
|
||||
@@ -273,10 +260,9 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
time.Sleep(handlerDelay)
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -293,7 +279,6 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
|
||||
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
|
||||
|
||||
// If decoupled, elapsed should be well under the coupled estimate.
|
||||
assert.Less(t, elapsed, coupledTime*85/100,
|
||||
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
|
||||
}
|
||||
@@ -311,9 +296,8 @@ func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -344,8 +328,6 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
|
||||
// The ping interval is 1s, so we should see at least 2 pings.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
@@ -372,9 +354,8 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -436,9 +417,8 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
@@ -456,6 +436,199 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
|
||||
}
|
||||
|
||||
// ---------- StreamStatus integration ----------
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(10)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Nil(t, info.StreamStatus.EndError)
|
||||
assert.True(t, info.StreamStatus.IsNormalEnd())
|
||||
assert.False(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < 5; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
||||
}
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(100)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
if n >= 10 {
|
||||
sr.Stop(fmt.Errorf("stop at 10"))
|
||||
}
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(20)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
n := count.Add(1)
|
||||
if n >= 5 {
|
||||
sr.Done()
|
||||
}
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(5), count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.False(t, info.StreamStatus.HasErrors())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) {
|
||||
// Not parallel: modifies global constant.StreamingTimeout
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 2
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
fmt.Fprint(pw, "data: {\"id\":1}\n")
|
||||
time.Sleep(10 * time.Second)
|
||||
pw.Close()
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for stream timeout")
|
||||
}
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason)
|
||||
assert.False(t, info.StreamStatus.IsNormalEnd())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(10)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
sr.Error(fmt.Errorf("soft error for chunk"))
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.True(t, info.StreamStatus.HasErrors())
|
||||
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(5)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
sr.Error(fmt.Errorf("error A"))
|
||||
sr.Error(fmt.Errorf("error B"))
|
||||
})
|
||||
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a large body without [DONE] to avoid race between scanner's [DONE]
|
||||
// and handler's Stop on the sync.Once EndReason.
|
||||
var b strings.Builder
|
||||
for i := 0; i < 100; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
||||
}
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
sr.Error(fmt.Errorf("soft error"))
|
||||
sr.Stop(fmt.Errorf("fatal"))
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(1), count.Load())
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 2, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(1)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
assert.Nil(t, info.StreamStatus)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
assert.NotNil(t, info.StreamStatus)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(5)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
info.StreamStatus.RecordError("pre-existing error")
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -469,9 +642,6 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Slow upstream + slow handler. Total stream takes ~5 seconds.
|
||||
// The ping goroutine stays alive as long as the scanner is reading,
|
||||
// so pings should fire between data writes.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
@@ -498,9 +668,8 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -141,6 +141,6 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,6 +96,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
info.PriceData = originPriceData
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
postConsumeQuota(c, info, usageDto)
|
||||
service.PostTextConsumeQuota(c, info, usageDto, nil)
|
||||
|
||||
info.OriginModelName = originModelName
|
||||
info.PriceData = originPriceData
|
||||
@@ -155,7 +155,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, info, usageDto, "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usageDto)
|
||||
service.PostTextConsumeQuota(c, info, usageDto, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
apiRouter.POST("/creem/webhook", controller.CreemWebhook)
|
||||
apiRouter.POST("/waffo/webhook", controller.WaffoWebhook)
|
||||
|
||||
// Universal secure verification routes
|
||||
apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
|
||||
@@ -89,6 +90,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay)
|
||||
selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
|
||||
selfRoute.POST("/creem/pay", middleware.CriticalRateLimit(), controller.RequestCreemPay)
|
||||
selfRoute.POST("/waffo/pay", middleware.CriticalRateLimit(), controller.RequestWaffoPay)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
selfRoute.PUT("/setting", controller.UpdateUserSetting)
|
||||
|
||||
@@ -192,6 +194,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
performanceRoute.DELETE("/disk_cache", controller.ClearDiskCache)
|
||||
performanceRoute.POST("/reset_stats", controller.ResetPerformanceStats)
|
||||
performanceRoute.POST("/gc", controller.ForceGC)
|
||||
performanceRoute.GET("/logs", controller.GetLogFiles)
|
||||
performanceRoute.DELETE("/logs", controller.CleanupLogFiles)
|
||||
}
|
||||
ratioSyncRoute := apiRouter.Group("/ratio_sync")
|
||||
ratioSyncRoute.Use(middleware.RootAuth())
|
||||
@@ -248,6 +252,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
tokenRoute.GET("/", controller.GetAllTokens)
|
||||
tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
|
||||
tokenRoute.GET("/:id", controller.GetToken)
|
||||
tokenRoute.POST("/:id/key", middleware.CriticalRateLimit(), middleware.DisableCache(), controller.GetTokenKey)
|
||||
tokenRoute.POST("/", controller.AddToken)
|
||||
tokenRoute.PUT("/", controller.UpdateToken)
|
||||
tokenRoute.DELETE("/:id", controller.DeleteToken)
|
||||
|
||||
@@ -214,7 +214,7 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
||||
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/edits", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/video", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/notify", controller.RelayMidjourney)
|
||||
//relayMjRouter.POST("/notify", controller.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||
|
||||
@@ -610,14 +610,17 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
v, ok := c.Get(ginKeyChannelAffinitySkipRetry)
|
||||
if ok {
|
||||
b, ok := v.(bool)
|
||||
if ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
meta, ok := getChannelAffinityMeta(c)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return b
|
||||
return meta.SkipRetry
|
||||
}
|
||||
|
||||
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
|
||||
|
||||
@@ -116,6 +116,66 @@ func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) {
|
||||
require.Equal(t, "trim_prefix", secondOp["mode"])
|
||||
}
|
||||
|
||||
func TestShouldSkipRetryAfterChannelAffinityFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx func() *gin.Context
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil context",
|
||||
ctx: func() *gin.Context {
|
||||
return nil
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "explicit skip retry flag in context",
|
||||
ctx: func() *gin.Context {
|
||||
ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-explicit-flag",
|
||||
SkipRetry: false,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
ctx.Set(ginKeyChannelAffinitySkipRetry, true)
|
||||
return ctx
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "fallback to matched rule meta",
|
||||
ctx: func() *gin.Context {
|
||||
return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-skip-retry",
|
||||
SkipRetry: true,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no flag and no skip retry meta",
|
||||
ctx: func() *gin.Context {
|
||||
return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
|
||||
RuleName: "rule-no-skip-retry",
|
||||
SkipRetry: false,
|
||||
UsingGroup: "default",
|
||||
ModelName: "gpt-5",
|
||||
})
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, ShouldSkipRetryAfterChannelAffinityFailure(tt.ctx()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
+26
-25
@@ -223,6 +223,25 @@ func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func buildClaudeUsageFromOpenAIUsage(oaiUsage *dto.Usage) *dto.ClaudeUsage {
|
||||
if oaiUsage == nil {
|
||||
return nil
|
||||
}
|
||||
usage := &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
}
|
||||
if oaiUsage.ClaudeCacheCreation5mTokens > 0 || oaiUsage.ClaudeCacheCreation1hTokens > 0 {
|
||||
usage.CacheCreation = &dto.ClaudeCacheCreationUsage{
|
||||
Ephemeral5mInputTokens: oaiUsage.ClaudeCacheCreation5mTokens,
|
||||
Ephemeral1hInputTokens: oaiUsage.ClaudeCacheCreation1hTokens,
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
return nil
|
||||
@@ -391,13 +410,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -419,13 +433,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -555,13 +564,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Type: "message_delta",
|
||||
Usage: buildClaudeUsageFromOpenAIUsage(oaiUsage),
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
@@ -612,10 +616,7 @@ func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relayco
|
||||
}
|
||||
claudeResponse.Content = contents
|
||||
claudeResponse.StopReason = stopReason
|
||||
claudeResponse.Usage = &dto.ClaudeUsage{
|
||||
InputTokens: openAIResponse.PromptTokens,
|
||||
OutputTokens: openAIResponse.CompletionTokens,
|
||||
}
|
||||
claudeResponse.Usage = buildClaudeUsageFromOpenAIUsage(&openAIResponse.Usage)
|
||||
|
||||
return claudeResponse
|
||||
}
|
||||
|
||||
@@ -73,10 +73,47 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
appendRequestConversionChain(relayInfo, other)
|
||||
appendFinalRequestFormat(relayInfo, other)
|
||||
appendBillingInfo(relayInfo, other)
|
||||
appendParamOverrideInfo(relayInfo, other)
|
||||
appendStreamStatus(relayInfo, other)
|
||||
return other
|
||||
}
|
||||
|
||||
func appendParamOverrideInfo(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil || len(relayInfo.ParamOverrideAudit) == 0 {
|
||||
return
|
||||
}
|
||||
other["po"] = relayInfo.ParamOverrideAudit
|
||||
}
|
||||
|
||||
func appendStreamStatus(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil || !relayInfo.IsStream || relayInfo.StreamStatus == nil {
|
||||
return
|
||||
}
|
||||
ss := relayInfo.StreamStatus
|
||||
status := "ok"
|
||||
if !ss.IsNormalEnd() || ss.HasErrors() {
|
||||
status = "error"
|
||||
}
|
||||
streamInfo := map[string]interface{}{
|
||||
"status": status,
|
||||
"end_reason": string(ss.EndReason),
|
||||
}
|
||||
if ss.EndError != nil {
|
||||
streamInfo["end_error"] = ss.EndError.Error()
|
||||
}
|
||||
if ss.ErrorCount > 0 {
|
||||
streamInfo["error_count"] = ss.ErrorCount
|
||||
messages := make([]string, 0, len(ss.Errors))
|
||||
for _, e := range ss.Errors {
|
||||
messages = append(messages, e.Message)
|
||||
}
|
||||
streamInfo["errors"] = messages
|
||||
}
|
||||
other["stream_status"] = streamInfo
|
||||
}
|
||||
|
||||
func appendBillingInfo(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil {
|
||||
return
|
||||
@@ -159,6 +196,17 @@ func appendRequestConversionChain(relayInfo *relaycommon.RelayInfo, other map[st
|
||||
other["request_conversion"] = chain
|
||||
}
|
||||
|
||||
func appendFinalRequestFormat(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if relayInfo == nil || other == nil {
|
||||
return
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
// claude indicates the final upstream request format is Claude Messages.
|
||||
// Frontend log rendering uses this to keep the original Claude input display.
|
||||
other["claude"] = true
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
|
||||
info["ws"] = true
|
||||
|
||||
@@ -235,108 +235,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
})
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
|
||||
if usage != nil {
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
cacheCreationRatio5m := relayInfo.PriceData.CacheCreation5mRatio
|
||||
cacheCreationRatio1h := relayInfo.PriceData.CacheCreation1hRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
cacheCreationTokens5m := usage.ClaudeCacheCreation5mTokens
|
||||
cacheCreationTokens1h := usage.ClaudeCacheCreation1hTokens
|
||||
|
||||
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
|
||||
if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
|
||||
if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
|
||||
cacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
promptTokens -= cacheCreationTokens
|
||||
}
|
||||
|
||||
calculateQuota := 0.0
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
calculateQuota = float64(promptTokens)
|
||||
calculateQuota += float64(cacheTokens) * cacheRatio
|
||||
calculateQuota += float64(cacheCreationTokens5m) * cacheCreationRatio5m
|
||||
calculateQuota += float64(cacheCreationTokens1h) * cacheCreationRatio1h
|
||||
remainingCacheCreationTokens := cacheCreationTokens - cacheCreationTokens5m - cacheCreationTokens1h
|
||||
if remainingCacheCreationTokens > 0 {
|
||||
calculateQuota += float64(remainingCacheCreationTokens) * cacheCreationRatio
|
||||
}
|
||||
calculateQuota += float64(completionTokens) * completionRatio
|
||||
calculateQuota = calculateQuota * groupRatio * modelRatio
|
||||
} else {
|
||||
calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
|
||||
}
|
||||
|
||||
if modelRatio != 0 && calculateQuota <= 0 {
|
||||
calculateQuota = 1
|
||||
}
|
||||
|
||||
quota := int(calculateQuota)
|
||||
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
var logContent string
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游出错)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio,
|
||||
cacheCreationTokens, cacheCreationRatio,
|
||||
cacheCreationTokens5m, cacheCreationRatio5m,
|
||||
cacheCreationTokens1h, cacheCreationRatio1h,
|
||||
modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
|
||||
if priceData.CacheCreationRatio == 1 {
|
||||
return 0
|
||||
|
||||
@@ -0,0 +1,430 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
type textQuotaSummary struct {
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
CacheTokens int
|
||||
CacheCreationTokens int
|
||||
CacheCreationTokens5m int
|
||||
CacheCreationTokens1h int
|
||||
ImageTokens int
|
||||
AudioTokens int
|
||||
ModelName string
|
||||
TokenName string
|
||||
UseTimeSeconds int64
|
||||
CompletionRatio float64
|
||||
CacheRatio float64
|
||||
ImageRatio float64
|
||||
ModelRatio float64
|
||||
GroupRatio float64
|
||||
ModelPrice float64
|
||||
CacheCreationRatio float64
|
||||
CacheCreationRatio5m float64
|
||||
CacheCreationRatio1h float64
|
||||
Quota int
|
||||
IsClaudeUsageSemantic bool
|
||||
UsageSemantic string
|
||||
WebSearchPrice float64
|
||||
WebSearchCallCount int
|
||||
ClaudeWebSearchPrice float64
|
||||
ClaudeWebSearchCallCount int
|
||||
FileSearchPrice float64
|
||||
FileSearchCallCount int
|
||||
AudioInputPrice float64
|
||||
ImageGenerationCallPrice float64
|
||||
}
|
||||
|
||||
func cacheWriteTokensTotal(summary textQuotaSummary) int {
|
||||
if summary.CacheCreationTokens5m > 0 || summary.CacheCreationTokens1h > 0 {
|
||||
splitCacheWriteTokens := summary.CacheCreationTokens5m + summary.CacheCreationTokens1h
|
||||
if summary.CacheCreationTokens > splitCacheWriteTokens {
|
||||
return summary.CacheCreationTokens
|
||||
}
|
||||
return splitCacheWriteTokens
|
||||
}
|
||||
return summary.CacheCreationTokens
|
||||
}
|
||||
|
||||
func isLegacyClaudeDerivedOpenAIUsage(relayInfo *relaycommon.RelayInfo, usage *dto.Usage) bool {
|
||||
if relayInfo == nil || usage == nil {
|
||||
return false
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
return false
|
||||
}
|
||||
if usage.UsageSource != "" || usage.UsageSemantic != "" {
|
||||
return false
|
||||
}
|
||||
return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0
|
||||
}
|
||||
|
||||
func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary {
|
||||
summary := textQuotaSummary{
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
TokenName: ctx.GetString("token_name"),
|
||||
UseTimeSeconds: time.Now().Unix() - relayInfo.StartTime.Unix(),
|
||||
CompletionRatio: relayInfo.PriceData.CompletionRatio,
|
||||
CacheRatio: relayInfo.PriceData.CacheRatio,
|
||||
ImageRatio: relayInfo.PriceData.ImageRatio,
|
||||
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
|
||||
ModelPrice: relayInfo.PriceData.ModelPrice,
|
||||
CacheCreationRatio: relayInfo.PriceData.CacheCreationRatio,
|
||||
CacheCreationRatio5m: relayInfo.PriceData.CacheCreation5mRatio,
|
||||
CacheCreationRatio1h: relayInfo.PriceData.CacheCreation1hRatio,
|
||||
UsageSemantic: usageSemanticFromUsage(relayInfo, usage),
|
||||
}
|
||||
summary.IsClaudeUsageSemantic = summary.UsageSemantic == "anthropic"
|
||||
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
}
|
||||
|
||||
summary.PromptTokens = usage.PromptTokens
|
||||
summary.CompletionTokens = usage.CompletionTokens
|
||||
summary.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
summary.CacheTokens = usage.PromptTokensDetails.CachedTokens
|
||||
summary.CacheCreationTokens = usage.PromptTokensDetails.CachedCreationTokens
|
||||
summary.CacheCreationTokens5m = usage.ClaudeCacheCreation5mTokens
|
||||
summary.CacheCreationTokens1h = usage.ClaudeCacheCreation1hTokens
|
||||
summary.ImageTokens = usage.PromptTokensDetails.ImageTokens
|
||||
summary.AudioTokens = usage.PromptTokensDetails.AudioTokens
|
||||
legacyClaudeDerived := isLegacyClaudeDerivedOpenAIUsage(relayInfo, usage)
|
||||
isOpenRouterClaudeBilling := relayInfo.ChannelMeta != nil &&
|
||||
relayInfo.ChannelType == constant.ChannelTypeOpenRouter &&
|
||||
summary.IsClaudeUsageSemantic
|
||||
|
||||
if isOpenRouterClaudeBilling {
|
||||
summary.PromptTokens -= summary.CacheTokens
|
||||
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(summary.ModelName, relayInfo.PriceData.ModelRatio)
|
||||
if summary.CacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
|
||||
if maybeCacheCreationTokens >= 0 && summary.PromptTokens >= maybeCacheCreationTokens {
|
||||
summary.CacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
summary.PromptTokens -= summary.CacheCreationTokens
|
||||
}
|
||||
|
||||
dPromptTokens := decimal.NewFromInt(int64(summary.PromptTokens))
|
||||
dCacheTokens := decimal.NewFromInt(int64(summary.CacheTokens))
|
||||
dImageTokens := decimal.NewFromInt(int64(summary.ImageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(summary.AudioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(summary.CompletionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(summary.CacheCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(summary.CompletionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(summary.CacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(summary.ImageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(summary.ModelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(summary.GroupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(summary.ModelPrice)
|
||||
dCacheCreationRatio := decimal.NewFromFloat(summary.CacheCreationRatio)
|
||||
dCacheCreationRatio5m := decimal.NewFromFloat(summary.CacheCreationRatio5m)
|
||||
dCacheCreationRatio1h := decimal.NewFromFloat(summary.CacheCreationRatio1h)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
|
||||
var dWebSearchQuota decimal.Decimal
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||||
summary.WebSearchCallCount = webSearchTool.CallCount
|
||||
summary.WebSearchPrice = operation_setting.GetWebSearchPricePerThousand(summary.ModelName, webSearchTool.SearchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
} else if strings.HasSuffix(summary.ModelName, "search-preview") {
|
||||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
||||
if searchContextSize == "" {
|
||||
searchContextSize = "medium"
|
||||
}
|
||||
summary.WebSearchCallCount = 1
|
||||
summary.WebSearchPrice = operation_setting.GetWebSearchPricePerThousand(summary.ModelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests")
|
||||
if summary.ClaudeWebSearchCallCount > 0 {
|
||||
summary.ClaudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(summary.ClaudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).
|
||||
Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount)))
|
||||
}
|
||||
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||||
summary.FileSearchCallCount = fileSearchTool.CallCount
|
||||
summary.FileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
|
||||
dFileSearchQuota = decimal.NewFromFloat(summary.FileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
if !summary.IsClaudeUsageSemantic && !legacyClaudeDerived {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
|
||||
var cachedCreationTokensWithRatio decimal.Decimal
|
||||
hasSplitCacheCreationTokens := summary.CacheCreationTokens5m > 0 || summary.CacheCreationTokens1h > 0
|
||||
if !dCachedCreationTokens.IsZero() || hasSplitCacheCreationTokens {
|
||||
if !summary.IsClaudeUsageSemantic && !legacyClaudeDerived {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
cachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCacheCreationRatio)
|
||||
} else {
|
||||
remaining := summary.CacheCreationTokens - summary.CacheCreationTokens5m - summary.CacheCreationTokens1h
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
cachedCreationTokensWithRatio = decimal.NewFromInt(int64(remaining)).Mul(dCacheCreationRatio)
|
||||
cachedCreationTokensWithRatio = cachedCreationTokensWithRatio.Add(decimal.NewFromInt(int64(summary.CacheCreationTokens5m)).Mul(dCacheCreationRatio5m))
|
||||
cachedCreationTokensWithRatio = cachedCreationTokensWithRatio.Add(decimal.NewFromInt(int64(summary.CacheCreationTokens1h)).Mul(dCacheCreationRatio1h))
|
||||
}
|
||||
}
|
||||
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
if !dImageTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dImageTokens)
|
||||
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
|
||||
}
|
||||
|
||||
if !dAudioTokens.IsZero() {
|
||||
summary.AudioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(summary.ModelName)
|
||||
if summary.AudioInputPrice > 0 {
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(summary.AudioInputPrice).
|
||||
Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio)
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||||
}
|
||||
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
} else {
|
||||
quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for _, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
|
||||
}
|
||||
}
|
||||
summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
}
|
||||
|
||||
if summary.TotalTokens == 0 {
|
||||
summary.Quota = 0
|
||||
} else if !ratio.IsZero() && summary.Quota == 0 {
|
||||
summary.Quota = 1
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
func usageSemanticFromUsage(relayInfo *relaycommon.RelayInfo, usage *dto.Usage) string {
|
||||
if usage != nil && usage.UsageSemantic != "" {
|
||||
return usage.UsageSemantic
|
||||
}
|
||||
if relayInfo != nil && relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude {
|
||||
return "anthropic"
|
||||
}
|
||||
return "openai"
|
||||
}
|
||||
|
||||
func PostTextConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent []string) {
|
||||
originUsage := usage
|
||||
if usage == nil {
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
if originUsage != nil {
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
if summary.WebSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,调用花费 %s", summary.WebSearchCallCount, decimal.NewFromFloat(summary.WebSearchPrice).Mul(decimal.NewFromInt(int64(summary.WebSearchCallCount))).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.ClaudeWebSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", summary.ClaudeWebSearchCallCount, decimal.NewFromFloat(summary.ClaudeWebSearchPrice).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount))).String()))
|
||||
}
|
||||
if summary.FileSearchCallCount > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s", summary.FileSearchCallCount, decimal.NewFromFloat(summary.FileSearchPrice).Mul(decimal.NewFromInt(int64(summary.FileSearchCallCount))).Div(decimal.NewFromInt(1000)).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.AudioInputPrice > 0 && summary.AudioTokens > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", decimal.NewFromFloat(summary.AudioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(decimal.NewFromInt(int64(summary.AudioTokens))).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
if summary.ImageGenerationCallPrice > 0 {
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(decimal.NewFromFloat(summary.GroupRatio)).Mul(decimal.NewFromFloat(common.QuotaPerUnit)).String()))
|
||||
}
|
||||
|
||||
if summary.TotalTokens == 0 {
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, summary.ModelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, summary.Quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, summary.Quota)
|
||||
}
|
||||
|
||||
if err := SettleBilling(ctx, relayInfo, summary.Quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := summary.ModelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", summary.ModelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", summary.ModelName))
|
||||
}
|
||||
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
var other map[string]interface{}
|
||||
if summary.IsClaudeUsageSemantic {
|
||||
other = GenerateClaudeOtherInfo(ctx, relayInfo,
|
||||
summary.ModelRatio, summary.GroupRatio, summary.CompletionRatio,
|
||||
summary.CacheTokens, summary.CacheRatio,
|
||||
summary.CacheCreationTokens, summary.CacheCreationRatio,
|
||||
summary.CacheCreationTokens5m, summary.CacheCreationRatio5m,
|
||||
summary.CacheCreationTokens1h, summary.CacheCreationRatio1h,
|
||||
summary.ModelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
other["usage_semantic"] = "anthropic"
|
||||
} else {
|
||||
other = GenerateTextOtherInfo(ctx, relayInfo, summary.ModelRatio, summary.GroupRatio, summary.CompletionRatio, summary.CacheTokens, summary.CacheRatio, summary.ModelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
}
|
||||
if adminRejectReason != "" {
|
||||
other["reject_reason"] = adminRejectReason
|
||||
}
|
||||
if summary.ImageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = summary.ImageRatio
|
||||
other["image_output"] = summary.ImageTokens
|
||||
}
|
||||
if summary.WebSearchCallCount > 0 {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = summary.WebSearchCallCount
|
||||
other["web_search_price"] = summary.WebSearchPrice
|
||||
} else if summary.ClaudeWebSearchCallCount > 0 {
|
||||
other["web_search"] = true
|
||||
other["web_search_call_count"] = summary.ClaudeWebSearchCallCount
|
||||
other["web_search_price"] = summary.ClaudeWebSearchPrice
|
||||
}
|
||||
if summary.FileSearchCallCount > 0 {
|
||||
other["file_search"] = true
|
||||
other["file_search_call_count"] = summary.FileSearchCallCount
|
||||
other["file_search_price"] = summary.FileSearchPrice
|
||||
}
|
||||
if summary.AudioInputPrice > 0 && summary.AudioTokens > 0 {
|
||||
other["audio_input_seperate_price"] = true
|
||||
other["audio_input_token_count"] = summary.AudioTokens
|
||||
other["audio_input_price"] = summary.AudioInputPrice
|
||||
}
|
||||
if summary.ImageGenerationCallPrice > 0 {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = summary.ImageGenerationCallPrice
|
||||
}
|
||||
if summary.CacheCreationTokens > 0 {
|
||||
other["cache_creation_tokens"] = summary.CacheCreationTokens
|
||||
other["cache_creation_ratio"] = summary.CacheCreationRatio
|
||||
}
|
||||
if summary.CacheCreationTokens5m > 0 {
|
||||
other["cache_creation_tokens_5m"] = summary.CacheCreationTokens5m
|
||||
other["cache_creation_ratio_5m"] = summary.CacheCreationRatio5m
|
||||
}
|
||||
if summary.CacheCreationTokens1h > 0 {
|
||||
other["cache_creation_tokens_1h"] = summary.CacheCreationTokens1h
|
||||
other["cache_creation_ratio_1h"] = summary.CacheCreationRatio1h
|
||||
}
|
||||
cacheWriteTokens := cacheWriteTokensTotal(summary)
|
||||
if cacheWriteTokens > 0 {
|
||||
// cache_write_tokens: normalized cache creation total for UI display.
|
||||
// If split 5m/1h values are present, this is their sum; otherwise it falls back
|
||||
// to cache_creation_tokens.
|
||||
other["cache_write_tokens"] = cacheWriteTokens
|
||||
}
|
||||
if relayInfo.GetFinalRequestRelayFormat() != types.RelayFormatClaude && usage != nil && usage.UsageSource != "" && usage.InputTokens > 0 {
|
||||
// input_tokens_total: explicit normalized total input used by the usage log UI.
|
||||
// Only write this field when upstream/current conversion has already provided a
|
||||
// reliable total input value and tagged the usage source. Do not infer it from
|
||||
// prompt/cache fields here, otherwise old upstream payloads may be double-counted.
|
||||
other["input_tokens_total"] = usage.InputTokens
|
||||
}
|
||||
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: summary.PromptTokens,
|
||||
CompletionTokens: summary.CompletionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: summary.TokenName,
|
||||
Quota: summary.Quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(summary.UseTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCalculateTextQuotaSummaryUnifiedForClaudeSemantic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 1000,
|
||||
CompletionTokens: 200,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 100,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
}
|
||||
|
||||
priceData := types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 2,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
}
|
||||
|
||||
chatRelayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: priceData,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
messageRelayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatClaude,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: priceData,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
chatSummary := calculateTextQuotaSummary(ctx, chatRelayInfo, usage)
|
||||
messageSummary := calculateTextQuotaSummary(ctx, messageRelayInfo, usage)
|
||||
|
||||
require.Equal(t, messageSummary.Quota, chatSummary.Quota)
|
||||
require.Equal(t, messageSummary.CacheCreationTokens5m, chatSummary.CacheCreationTokens5m)
|
||||
require.Equal(t, messageSummary.CacheCreationTokens1h, chatSummary.CacheCreationTokens1h)
|
||||
require.True(t, chatSummary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, 1488, chatSummary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryUsesSplitClaudeCacheCreationRatios(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0,
|
||||
CacheCreationRatio: 1,
|
||||
CacheCreation5mRatio: 2,
|
||||
CacheCreation1hRatio: 3,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 0,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedCreationTokens: 10,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 2,
|
||||
ClaudeCacheCreation1hTokens: 3,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// 100 + remaining(5)*1 + 2*2 + 3*3 = 118
|
||||
require.Equal(t, 118, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryUsesAnthropicUsageSemanticFromUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 2,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{
|
||||
GroupRatio: 1,
|
||||
},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 1000,
|
||||
CompletionTokens: 200,
|
||||
UsageSemantic: "anthropic",
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 100,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
require.True(t, summary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, "anthropic", summary.UsageSemantic)
|
||||
require.Equal(t, 1488, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCacheWriteTokensTotal(t *testing.T) {
|
||||
t.Run("split cache creation", func(t *testing.T) {
|
||||
summary := textQuotaSummary{
|
||||
CacheCreationTokens: 50,
|
||||
CacheCreationTokens5m: 10,
|
||||
CacheCreationTokens1h: 20,
|
||||
}
|
||||
require.Equal(t, 50, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
|
||||
t.Run("legacy cache creation", func(t *testing.T) {
|
||||
summary := textQuotaSummary{CacheCreationTokens: 50}
|
||||
require.Equal(t, 50, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
|
||||
t.Run("split cache creation without aggregate remainder", func(t *testing.T) {
|
||||
summary := textQuotaSummary{
|
||||
CacheCreationTokens5m: 10,
|
||||
CacheCreationTokens1h: 20,
|
||||
}
|
||||
require.Equal(t, 30, cacheWriteTokensTotal(summary))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryHandlesLegacyClaudeDerivedOpenAIUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
OriginModelName: "claude-3-7-sonnet",
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 5,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
CacheCreation5mRatio: 1.25,
|
||||
CacheCreation1hRatio: 2,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 62,
|
||||
CompletionTokens: 95,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 3544,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 586,
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// 62 + 3544*0.1 + 586*1.25 + 95*5 = 1624.9 => 1624
|
||||
require.Equal(t, 1624, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummarySeparatesOpenRouterCacheReadFromPromptBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
OriginModelName: "openai/gpt-4.1",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 2432,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// OpenRouter OpenAI-format display keeps prompt_tokens as total input,
|
||||
// but billing still separates normal input from cache read tokens.
|
||||
// quota = (2604 - 2432) + 2432*0.1 + 383 = 798.2 => 798
|
||||
require.Equal(t, 2604, summary.PromptTokens)
|
||||
require.Equal(t, 798, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummarySeparatesOpenRouterCacheCreationFromPromptBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
OriginModelName: "openai/gpt-4.1",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedCreationTokens: 100,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// prompt_tokens is still logged as total input, but cache creation is billed separately.
|
||||
// quota = (2604 - 100) + 100*1.25 + 383 = 3012
|
||||
require.Equal(t, 2604, summary.PromptTokens)
|
||||
require.Equal(t, 3012, summary.Quota)
|
||||
}
|
||||
|
||||
func TestCalculateTextQuotaSummaryKeepsPrePRClaudeOpenRouterBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
|
||||
relayInfo := &relaycommon.RelayInfo{
|
||||
FinalRequestRelayFormat: types.RelayFormatClaude,
|
||||
OriginModelName: "anthropic/claude-3.7-sonnet",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: constant.ChannelTypeOpenRouter,
|
||||
},
|
||||
PriceData: types.PriceData{
|
||||
ModelRatio: 1,
|
||||
CompletionRatio: 1,
|
||||
CacheRatio: 0.1,
|
||||
CacheCreationRatio: 1.25,
|
||||
GroupRatioInfo: types.GroupRatioInfo{GroupRatio: 1},
|
||||
},
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 2604,
|
||||
CompletionTokens: 383,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 2432,
|
||||
},
|
||||
}
|
||||
|
||||
summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
|
||||
|
||||
// Pre-PR PostClaudeConsumeQuota behavior for OpenRouter:
|
||||
// prompt = 2604 - 2432 = 172
|
||||
// quota = 172 + 2432*0.1 + 383 = 798.2 => 798
|
||||
require.True(t, summary.IsClaudeUsageSemantic)
|
||||
require.Equal(t, 172, summary.PromptTokens)
|
||||
require.Equal(t, 798, summary.Quota)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package model_setting
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/setting/config"
|
||||
)
|
||||
@@ -50,23 +51,36 @@ func GetClaudeSettings() *ClaudeSettings {
|
||||
func (c *ClaudeSettings) WriteHeaders(originModel string, httpHeader *http.Header) {
|
||||
if headers, ok := c.HeadersSettings[originModel]; ok {
|
||||
for headerKey, headerValues := range headers {
|
||||
// get existing values for this header key
|
||||
existingValues := httpHeader.Values(headerKey)
|
||||
existingValuesMap := make(map[string]bool)
|
||||
for _, v := range existingValues {
|
||||
existingValuesMap[v] = true
|
||||
}
|
||||
|
||||
// add only values that don't already exist
|
||||
for _, headerValue := range headerValues {
|
||||
if !existingValuesMap[headerValue] {
|
||||
httpHeader.Add(headerKey, headerValue)
|
||||
}
|
||||
mergedValues := normalizeHeaderListValues(
|
||||
append(append([]string(nil), httpHeader.Values(headerKey)...), headerValues...),
|
||||
)
|
||||
if len(mergedValues) == 0 {
|
||||
continue
|
||||
}
|
||||
httpHeader.Set(headerKey, strings.Join(mergedValues, ","))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeHeaderListValues(values []string) []string {
|
||||
normalizedValues := make([]string, 0, len(values))
|
||||
seenValues := make(map[string]struct{}, len(values))
|
||||
for _, value := range values {
|
||||
for _, item := range strings.Split(value, ",") {
|
||||
normalizedItem := strings.TrimSpace(item)
|
||||
if normalizedItem == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seenValues[normalizedItem]; exists {
|
||||
continue
|
||||
}
|
||||
seenValues[normalizedItem] = struct{}{}
|
||||
normalizedValues = append(normalizedValues, normalizedItem)
|
||||
}
|
||||
}
|
||||
return normalizedValues
|
||||
}
|
||||
|
||||
func (c *ClaudeSettings) GetDefaultMaxTokens(model string) int {
|
||||
if maxTokens, ok := c.DefaultMaxTokens[model]; ok {
|
||||
return maxTokens
|
||||
|
||||
@@ -88,7 +88,7 @@ var channelAffinitySetting = ChannelAffinitySetting{
|
||||
ValueRegex: "",
|
||||
TTLSeconds: 0,
|
||||
ParamOverrideTemplate: buildPassHeaderTemplate(codexCliPassThroughHeaders),
|
||||
SkipRetryOnFailure: false,
|
||||
SkipRetryOnFailure: true,
|
||||
IncludeUsingGroup: true,
|
||||
IncludeRuleName: true,
|
||||
UserAgentInclude: nil,
|
||||
@@ -103,7 +103,7 @@ var channelAffinitySetting = ChannelAffinitySetting{
|
||||
ValueRegex: "",
|
||||
TTLSeconds: 0,
|
||||
ParamOverrideTemplate: buildPassHeaderTemplate(claudeCliPassThroughHeaders),
|
||||
SkipRetryOnFailure: false,
|
||||
SkipRetryOnFailure: true,
|
||||
IncludeUsingGroup: true,
|
||||
IncludeRuleName: true,
|
||||
UserAgentInclude: nil,
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
type StatusCodeRange struct {
|
||||
@@ -31,6 +33,10 @@ var alwaysSkipRetryStatusCodes = map[int]struct{}{
|
||||
524: {},
|
||||
}
|
||||
|
||||
var alwaysSkipRetryCodes = map[types.ErrorCode]struct{}{
|
||||
types.ErrorCodeBadResponseBody: {},
|
||||
}
|
||||
|
||||
func AutomaticDisableStatusCodesToString() string {
|
||||
return statusCodeRangesToString(AutomaticDisableStatusCodeRanges)
|
||||
}
|
||||
@@ -66,6 +72,11 @@ func IsAlwaysSkipRetryStatusCode(code int) bool {
|
||||
return exists
|
||||
}
|
||||
|
||||
func IsAlwaysSkipRetryCode(errorCode types.ErrorCode) bool {
|
||||
_, exists := alwaysSkipRetryCodes[errorCode]
|
||||
return exists
|
||||
}
|
||||
|
||||
func ShouldRetryByStatusCode(code int) bool {
|
||||
if IsAlwaysSkipRetryStatusCode(code) {
|
||||
return false
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package setting
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
)
|
||||
|
||||
var (
|
||||
WaffoEnabled bool
|
||||
WaffoApiKey string
|
||||
WaffoPrivateKey string
|
||||
WaffoPublicCert string
|
||||
WaffoSandboxPublicCert string
|
||||
WaffoSandboxApiKey string
|
||||
WaffoSandboxPrivateKey string
|
||||
WaffoSandbox bool
|
||||
WaffoMerchantId string
|
||||
WaffoNotifyUrl string
|
||||
WaffoReturnUrl string
|
||||
WaffoSubscriptionReturnUrl string
|
||||
WaffoCurrency string
|
||||
WaffoUnitPrice float64 = 1.0
|
||||
WaffoMinTopUp int = 1
|
||||
)
|
||||
|
||||
// GetWaffoPayMethods 从 options 读取 Waffo 支付方式配置
|
||||
func GetWaffoPayMethods() []constant.WaffoPayMethod {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
jsonStr := common.OptionMap["WaffoPayMethods"]
|
||||
common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
if jsonStr == "" {
|
||||
return copyDefaultWaffoPayMethods()
|
||||
}
|
||||
var methods []constant.WaffoPayMethod
|
||||
if err := common.UnmarshalJsonStr(jsonStr, &methods); err != nil {
|
||||
return copyDefaultWaffoPayMethods()
|
||||
}
|
||||
return methods
|
||||
}
|
||||
|
||||
// SetWaffoPayMethods 序列化 Waffo 支付方式配置并更新 OptionMap
|
||||
func SetWaffoPayMethods(methods []constant.WaffoPayMethod) error {
|
||||
jsonBytes, err := common.Marshal(methods)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.OptionMapRWMutex.Lock()
|
||||
common.OptionMap["WaffoPayMethods"] = string(jsonBytes)
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyDefaultWaffoPayMethods() []constant.WaffoPayMethod {
|
||||
cp := make([]constant.WaffoPayMethod, len(constant.DefaultWaffoPayMethods))
|
||||
copy(cp, constant.DefaultWaffoPayMethods)
|
||||
return cp
|
||||
}
|
||||
|
||||
// WaffoPayMethods2JsonString 将默认 WaffoPayMethods 序列化为 JSON 字符串(供 InitOptionMap 使用)
|
||||
func WaffoPayMethods2JsonString() string {
|
||||
jsonBytes, err := common.Marshal(constant.DefaultWaffoPayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user