Compare commits
239 Commits
v0.9.0-alpha.2
...
v0.9.0
| Author | SHA1 | Date | |
|---|---|---|---|
| e8afc25c71 | |||
| e2c2d182fa | |||
| bbfbce9c67 | |||
| 1b1953e21a | |||
| b3e67d5ef7 | |||
| 8319198122 | |||
| de73bfff78 | |||
| 80cfa0d0df | |||
| 8fcc49377c | |||
| 23a82b9646 | |||
| baf134cd50 | |||
| ab5351c270 | |||
| dffbd39cde | |||
| 1de5216148 | |||
| e53cbd96ad | |||
| 6d81312e7e | |||
| 4f5c343791 | |||
| f0183785c9 | |||
| 1bbabda081 | |||
| 22b724ca44 | |||
| 25dbd39d1e | |||
| 31d5eb87ba | |||
| 14af08750f | |||
| 4083126788 | |||
| d1fc9bd712 | |||
| ada7b96823 | |||
| 785d0c4284 | |||
| 42e5794d00 | |||
| ff42b7fa88 | |||
| ffe78e99ee | |||
| 69e76dae34 | |||
| f10dab864f | |||
| 94129c48ea | |||
| c28e89abb8 | |||
| 5dc2d775e9 | |||
| d8e36a7057 | |||
| d15b31ab5c | |||
| 54f118d9ba | |||
| 55c8271311 | |||
| 7d9728519c | |||
| 99afd05d7d | |||
| 60b7db29fc | |||
| 06785c019f | |||
| 4f44bbed31 | |||
| 5bb732394f | |||
| 3d86279240 | |||
| 11fa5042df | |||
| a00d8e25eb | |||
| 8ed175a5f8 | |||
| d2dcd8beb3 | |||
| ed71c9fcf3 | |||
| 6d7c00634c | |||
| d0d6168e2f | |||
| 0d57b1acd4 | |||
| 41cf516ec5 | |||
| aee5b8caa7 | |||
| fc7234a1c9 | |||
| d081dcff46 | |||
| 95f5cb3980 | |||
| 63bf35a111 | |||
| 5b90f60519 | |||
| 7826099221 | |||
| 4bf9ef5579 | |||
| fe71af943c | |||
| 64058614cb | |||
| 8ceb0316ce | |||
| 1bc0010f5c | |||
| 6db0e77931 | |||
| 742871db57 | |||
| 9e4c4d3bd1 | |||
| c7dc4ad1ef | |||
| e6f78733e1 | |||
| f7c4eda0f3 | |||
| c21219fcff | |||
| 7e698f658a | |||
| 906f797be5 | |||
| cebba71d7c | |||
| 3963863eb0 | |||
| 13d14dc8a8 | |||
| a45513a7a6 | |||
| d1f3f2c395 | |||
| d92c9db61c | |||
| d47779b0a3 | |||
| 3401abe4d2 | |||
| 61f2aa541b | |||
| 17cf9c2d60 | |||
| 943e36ebaf | |||
| 80a7d7eb51 | |||
| d6fcc7a51d | |||
| 1555e6832e | |||
| a102bed25a | |||
| 8004b4ff5f | |||
| 59a5cb8d32 | |||
| 6e62b76105 | |||
| 67719dc087 | |||
| 1b8bcfb000 | |||
| b0f86bd82e | |||
| 0e34de8fe2 | |||
| 3e5bc637de | |||
| f249cf9acc | |||
| d904f9b486 | |||
| 5d384fa4f6 | |||
| 6579d71fe1 | |||
| 6033d4318e | |||
| 781a708173 | |||
| bb897a893e | |||
| 79c7d8f477 | |||
| 5c7395adc6 | |||
| 7e513ad06d | |||
| 05ad8c1223 | |||
| a955b43f3e | |||
| 3baca03895 | |||
| a69f166e9e | |||
| 1a5ba75068 | |||
| c94a74c398 | |||
| 87d0d12d39 | |||
| 3d7a05bdac | |||
| 15a36a65b3 | |||
| 488f6a2e71 | |||
| 103cb4e084 | |||
| eddaebf745 | |||
| 3b50fcb72b | |||
| 99be6d557b | |||
| 157df7ae84 | |||
| a8062c4e9f | |||
| 82c032e8c3 | |||
| 2d1512096a | |||
| 868f1b036a | |||
| 1f887c2e78 | |||
| 90a30a860b | |||
| 313bbc1c8b | |||
| b10f28e5c6 | |||
| fc4d593301 | |||
| bee947000f | |||
| 48a6123c08 | |||
| 94e7f10367 | |||
| 1de4daa7d5 | |||
| fe53424ed8 | |||
| 12b9f0f560 | |||
| eece073235 | |||
| 13301d8544 | |||
| 060ce89286 | |||
| ff57ad2072 | |||
| 1ee3d1cc50 | |||
| 54e03c4214 | |||
| 82e87b8498 | |||
| 1dff565ba1 | |||
| 30fc4d3082 | |||
| 9674065465 | |||
| d9024452dc | |||
| b43ecc2504 | |||
| 2ab2edec7b | |||
| ada604db4e | |||
| 858585d974 | |||
| d09b99c1f0 | |||
| 184cdb5c28 | |||
| d95d260555 | |||
| cc514c7d18 | |||
| 2eae3a6116 | |||
| b6542c6840 | |||
| b383d1d7df | |||
| 436eb77ae2 | |||
| d2e85a7d15 | |||
| 6f322fdbab | |||
| db28d0de4e | |||
| 45130574fc | |||
| b6a3f8ee81 | |||
| 18209c7d47 | |||
| 6c4777bc82 | |||
| 69686cdc43 | |||
| 47ddf76af9 | |||
| 44b7c605e3 | |||
| 8a1f6534fc | |||
| 57b1ddcc5f | |||
| bbc362c301 | |||
| ba6ed31a1a | |||
| 4f22ab6477 | |||
| cfa7399612 | |||
| 24fbd0bfb0 | |||
| 891100dc70 | |||
| 4da4371003 | |||
| 61601d38c4 | |||
| f181ca254b | |||
| 13fcb8fd5d | |||
| 519c26d5f6 | |||
| 041dd9b4cc | |||
| ed3b11a304 | |||
| 3e61e6eb52 | |||
| aad6314c51 | |||
| bc6c3042b8 | |||
| 7411c24954 | |||
| bbe381f656 | |||
| b505121790 | |||
| 45368deac3 | |||
| 66289f7991 | |||
| 2f176cff7f | |||
| 6ce92b9e19 | |||
| 3004ede0ce | |||
| 4e3f008ae9 | |||
| a6c5b6d09f | |||
| d094df09a2 | |||
| 14ea270406 | |||
| c06d9485c3 | |||
| 9e649148bb | |||
| 7323dcf906 | |||
| 5683f4b95f | |||
| 3c30b7c4cb | |||
| 22125a6dd8 | |||
| 3a90d5054c | |||
| 11ce48be47 | |||
| 1d878e5e88 | |||
| 28d381982d | |||
| a5420c90c0 | |||
| 5fbbb24201 | |||
| 02401b9a38 | |||
| 4b6031b59c | |||
| cf9faeb901 | |||
| f8f0ee1e3e | |||
| f6ee0b9791 | |||
| 0e95efdf34 | |||
| beb61343cd | |||
| 77b100ba2b | |||
| 28dd5f5f0c | |||
| 269bb0c896 | |||
| 215546a805 | |||
| 63e7fb697f | |||
| 30e4679384 | |||
| 893104a173 | |||
| 8284cff9b7 | |||
| c57f33bff6 | |||
| c2e49f2b2a | |||
| ebb624148b | |||
| 4a49a80d10 | |||
| 5520bf4dbe | |||
| f8c2c78ee2 | |||
| dbced40e01 | |||
| b30cf70c91 | |||
| a7e02f8961 | |||
| 9dc153eda1 |
@@ -56,8 +56,6 @@
|
||||
# SESSION_SECRET=random_string
|
||||
|
||||
# 其他配置
|
||||
# 渠道测试频率(单位:秒)
|
||||
# CHANNEL_TEST_FREQUENCY=10
|
||||
# 生成默认token
|
||||
# GENERATE_DEFAULT_TOKEN=false
|
||||
# Cohere 安全设置
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
name: Check PR Branching Strategy
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
|
||||
jobs:
|
||||
check-branching-strategy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Enforce branching strategy
|
||||
run: |
|
||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "Branching strategy check passed."
|
||||
@@ -40,6 +40,28 @@
|
||||
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
|
||||
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
|
||||
|
||||
<h2>🤝 Trusted Partners</h2>
|
||||
<p id="premium-sponsors"> </p>
|
||||
<p align="center"><strong>No particular order</strong></p>
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target=_blank><img
|
||||
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
|
||||
/></a>
|
||||
<a href="https://bda.pku.edu.cn/" target=_blank><img
|
||||
src="./docs/images/pku.png" alt="Peking University" height="120"
|
||||
/></a>
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
|
||||
src="./docs/images/ucloud.png" alt="UCloud" height="120"
|
||||
/></a>
|
||||
<a href="https://www.aliyun.com/" target=_blank><img
|
||||
src="./docs/images/aliyun.png" alt="Alibaba Cloud" height="120"
|
||||
/></a>
|
||||
<a href="https://io.net/" target=_blank><img
|
||||
src="./docs/images/io-net.png" alt="IO.NET" height="120"
|
||||
/></a>
|
||||
</p>
|
||||
<p> </p>
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||
@@ -189,24 +211,6 @@ If you have any questions, please refer to [Help and Support](https://docs.newap
|
||||
- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [FAQ](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🤝 Trusted Partners
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank"><img
|
||||
src="./docs/images/cherry-studio.svg" alt="Cherry Studio" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank"><img
|
||||
src="./docs/images/pku.png" alt="Peking University" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank"><img
|
||||
src="./docs/images/ucloud.svg" alt="UCloud" height="58"
|
||||
/></a>
|
||||
</p>
|
||||
|
||||
<p align="center"><em>No particular order</em></p>
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
@@ -40,6 +40,28 @@
|
||||
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
<h2>🤝 我们信任的合作伙伴</h2>
|
||||
<p id="premium-sponsors"> </p>
|
||||
<p align="center"><strong>排名不分先后</strong></p>
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target=_blank><img
|
||||
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
|
||||
/></a>
|
||||
<a href="https://bda.pku.edu.cn/" target=_blank><img
|
||||
src="./docs/images/pku.png" alt="北京大学" height="120"
|
||||
/></a>
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
|
||||
src="./docs/images/ucloud.png" alt="UCloud 优刻得" height="120"
|
||||
/></a>
|
||||
<a href="https://www.aliyun.com/" target=_blank><img
|
||||
src="./docs/images/aliyun.png" alt="阿里云" height="120"
|
||||
/></a>
|
||||
<a href="https://io.net/" target=_blank><img
|
||||
src="./docs/images/io-net.png" alt="IO.NET" height="120"
|
||||
/></a>
|
||||
</p>
|
||||
<p> </p>
|
||||
|
||||
## 📚 文档
|
||||
|
||||
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||
@@ -74,7 +96,11 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
||||
16. 🔄 思考转内容功能
|
||||
17. 🔄 针对用户的模型限流功能
|
||||
18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||
18. 🔄 请求格式转换功能,支持以下三种格式转换:
|
||||
1. OpenAI Chat Completions => Claude Messages
|
||||
2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型)
|
||||
3. OpenAI Chat Completions => Gemini Chat
|
||||
19. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
||||
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
||||
3. 支持的渠道:
|
||||
@@ -188,24 +214,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [常见问题](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🤝 我们信任的合作伙伴
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank"><img
|
||||
src="./docs/images/cherry-studio.svg" alt="Cherry Studio" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank"><img
|
||||
src="./docs/images/pku.png" alt="北京大学" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank"><img
|
||||
src="./docs/images/ucloud.svg" alt="UCloud 优刻得" height="58"
|
||||
/></a>
|
||||
</p>
|
||||
|
||||
<p align="center"><em>排名不分先后</em></p>
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jinzhu/copier"
|
||||
)
|
||||
|
||||
func DeepCopy[T any](src *T) (*T, error) {
|
||||
if src == nil {
|
||||
return nil, fmt.Errorf("copy source cannot be nil")
|
||||
}
|
||||
var dst T
|
||||
err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dst, nil
|
||||
}
|
||||
@@ -2,12 +2,13 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
|
||||
@@ -20,3 +20,25 @@ func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
func GetJsonType(data json.RawMessage) string {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 {
|
||||
return "unknown"
|
||||
}
|
||||
firstChar := bytes.TrimSpace(data)[0]
|
||||
switch firstChar {
|
||||
case '{':
|
||||
return "object"
|
||||
case '[':
|
||||
return "array"
|
||||
case '"':
|
||||
return "string"
|
||||
case 't', 'f':
|
||||
return "boolean"
|
||||
case 'n':
|
||||
return "null"
|
||||
default:
|
||||
return "number"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,8 +123,16 @@ func Interface2String(inter interface{}) string {
|
||||
return fmt.Sprintf("%d", inter.(int))
|
||||
case float64:
|
||||
return fmt.Sprintf("%f", inter.(float64))
|
||||
case bool:
|
||||
if inter.(bool) {
|
||||
return "true"
|
||||
} else {
|
||||
return "false"
|
||||
}
|
||||
case nil:
|
||||
return ""
|
||||
}
|
||||
return "Not Implemented"
|
||||
return fmt.Sprintf("%v", inter)
|
||||
}
|
||||
|
||||
func UnescapeHTML(x string) interface{} {
|
||||
@@ -257,32 +265,32 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||
}
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ package constant
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
@@ -26,6 +27,7 @@ const (
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
|
||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelHeaderOverride ContextKey = "header_override"
|
||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||
|
||||
@@ -135,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
||||
for k := range headers {
|
||||
req.Header.Add(k, headers.Get(k))
|
||||
}
|
||||
res, err := service.GetHttpClient().Do(req)
|
||||
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -445,7 +446,7 @@ func testAllChannels(notify bool) error {
|
||||
|
||||
// disable channel
|
||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
}
|
||||
|
||||
// enable channel
|
||||
@@ -477,15 +478,26 @@ func TestAllChannels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
if frequency <= 0 {
|
||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("testing all channels")
|
||||
_ = testAllChannels(false)
|
||||
common.SysLog("channel test finished")
|
||||
}
|
||||
var autoTestChannelsOnce sync.Once
|
||||
|
||||
func AutomaticallyTestChannels() {
|
||||
autoTestChannelsOnce.Do(func() {
|
||||
for {
|
||||
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||
time.Sleep(10 * time.Minute)
|
||||
continue
|
||||
}
|
||||
frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
|
||||
common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency))
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("automatically testing all channels")
|
||||
_ = testAllChannels(false)
|
||||
common.SysLog("automatically channel test finished")
|
||||
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -380,6 +380,85 @@ func GetChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// GetChannelKey 验证2FA后获取渠道密钥
|
||||
func GetChannelKey(c *gin.Context) {
|
||||
type GetChannelKeyRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
var req GetChannelKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, fmt.Errorf("参数错误: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取2FA记录并验证
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("获取2FA信息失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
common.ApiError(c, fmt.Errorf("用户未启用2FA,无法查看密钥"))
|
||||
return
|
||||
}
|
||||
|
||||
// 统一的2FA验证逻辑
|
||||
if !validateTwoFactorAuth(twoFA, req.Code) {
|
||||
common.ApiError(c, fmt.Errorf("验证码或备用码错误,请重试"))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取渠道信息(包含密钥)
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if channel == nil {
|
||||
common.ApiError(c, fmt.Errorf("渠道不存在"))
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
|
||||
|
||||
// 统一的成功响应格式
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "验证成功",
|
||||
"data": map[string]interface{}{
|
||||
"key": channel.Key,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// validateTwoFactorAuth 统一的2FA验证函数
|
||||
func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
|
||||
// 尝试验证TOTP
|
||||
if cleanCode, err := common.ValidateNumericCode(code); err == nil {
|
||||
if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试验证备用码
|
||||
if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateChannel 通用的渠道校验函数
|
||||
func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||
// 校验 channel settings
|
||||
|
||||
@@ -39,6 +39,8 @@ func TestStatus(c *gin.Context) {
|
||||
func GetStatus(c *gin.Context) {
|
||||
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
@@ -89,6 +91,10 @@ func GetStatus(c *gin.Context) {
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
// 模块管理配置
|
||||
"HeaderNavModules": common.OptionMap["HeaderNavModules"],
|
||||
"SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"],
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
|
||||
@@ -207,6 +207,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
"object": "list",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,604 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 上游地址
|
||||
const (
|
||||
upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
|
||||
upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
|
||||
)
|
||||
|
||||
func normalizeLocale(locale string) (string, bool) {
|
||||
l := strings.ToLower(strings.TrimSpace(locale))
|
||||
switch l {
|
||||
case "en", "zh", "ja":
|
||||
return l, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func getUpstreamBase() string {
|
||||
return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata")
|
||||
}
|
||||
|
||||
func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) {
|
||||
base := strings.TrimRight(getUpstreamBase(), "/")
|
||||
if l, ok := normalizeLocale(locale); ok && l != "" {
|
||||
return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l),
|
||||
fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l)
|
||||
}
|
||||
return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base)
|
||||
}
|
||||
|
||||
type upstreamEnvelope[T any] struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data []T `json:"data"`
|
||||
}
|
||||
|
||||
type upstreamModel struct {
|
||||
Description string `json:"description"`
|
||||
Endpoints json.RawMessage `json:"endpoints"`
|
||||
Icon string `json:"icon"`
|
||||
ModelName string `json:"model_name"`
|
||||
NameRule int `json:"name_rule"`
|
||||
Status int `json:"status"`
|
||||
Tags string `json:"tags"`
|
||||
VendorName string `json:"vendor_name"`
|
||||
}
|
||||
|
||||
type upstreamVendor struct {
|
||||
Description string `json:"description"`
|
||||
Icon string `json:"icon"`
|
||||
Name string `json:"name"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
var (
|
||||
etagCache = make(map[string]string)
|
||||
bodyCache = make(map[string][]byte)
|
||||
cacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
type overwriteField struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Fields []string `json:"fields"`
|
||||
}
|
||||
|
||||
type syncRequest struct {
|
||||
Overwrite []overwriteField `json:"overwrite"`
|
||||
Locale string `json:"locale"`
|
||||
}
|
||||
|
||||
func newHTTPClient() *http.Client {
|
||||
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10)
|
||||
dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second}
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second,
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
if strings.HasSuffix(host, "github.io") {
|
||||
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp6", addr)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
|
||||
var httpClient = newHTTPClient()
|
||||
|
||||
func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
|
||||
var lastErr error
|
||||
attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3)
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
baseDelay := 200 * time.Millisecond
|
||||
maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10)
|
||||
maxBytes := int64(maxMB) << 20
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// ETag conditional request
|
||||
cacheMutex.RLock()
|
||||
if et := etagCache[url]; et != "" {
|
||||
req.Header.Set("If-None-Match", et)
|
||||
}
|
||||
cacheMutex.RUnlock()
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
// backoff with jitter
|
||||
sleep := baseDelay * time.Duration(1<<attempt)
|
||||
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||
time.Sleep(sleep + jitter)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer resp.Body.Close()
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
// read body into buffer for caching and flexible decode
|
||||
limited := io.LimitReader(resp.Body, maxBytes)
|
||||
buf, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
return
|
||||
}
|
||||
// cache body and ETag
|
||||
cacheMutex.Lock()
|
||||
if et := resp.Header.Get("ETag"); et != "" {
|
||||
etagCache[url] = et
|
||||
}
|
||||
bodyCache[url] = buf
|
||||
cacheMutex.Unlock()
|
||||
|
||||
// Try decode as envelope first
|
||||
if err := json.Unmarshal(buf, out); err != nil {
|
||||
// Try decode as pure array
|
||||
var arr []T
|
||||
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||
lastErr = err
|
||||
return
|
||||
}
|
||||
out.Success = true
|
||||
out.Data = arr
|
||||
out.Message = ""
|
||||
} else {
|
||||
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||
out.Success = true
|
||||
}
|
||||
}
|
||||
lastErr = nil
|
||||
case http.StatusNotModified:
|
||||
// use cache
|
||||
cacheMutex.RLock()
|
||||
buf := bodyCache[url]
|
||||
cacheMutex.RUnlock()
|
||||
if len(buf) == 0 {
|
||||
lastErr = errors.New("cache miss for 304 response")
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(buf, out); err != nil {
|
||||
var arr []T
|
||||
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||
lastErr = err
|
||||
return
|
||||
}
|
||||
out.Success = true
|
||||
out.Data = arr
|
||||
out.Message = ""
|
||||
} else {
|
||||
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||
out.Success = true
|
||||
}
|
||||
}
|
||||
lastErr = nil
|
||||
default:
|
||||
lastErr = errors.New(resp.Status)
|
||||
}
|
||||
}()
|
||||
if lastErr == nil {
|
||||
return nil
|
||||
}
|
||||
sleep := baseDelay * time.Duration(1<<attempt)
|
||||
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||
time.Sleep(sleep + jitter)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, vendorIDCache map[string]int, createdVendors *int) int {
|
||||
if vendorName == "" {
|
||||
return 0
|
||||
}
|
||||
if id, ok := vendorIDCache[vendorName]; ok {
|
||||
return id
|
||||
}
|
||||
var existing model.Vendor
|
||||
if err := model.DB.Where("name = ?", vendorName).First(&existing).Error; err == nil {
|
||||
vendorIDCache[vendorName] = existing.Id
|
||||
return existing.Id
|
||||
}
|
||||
uv := vendorByName[vendorName]
|
||||
v := &model.Vendor{
|
||||
Name: vendorName,
|
||||
Description: uv.Description,
|
||||
Icon: coalesce(uv.Icon, ""),
|
||||
Status: chooseStatus(uv.Status, 1),
|
||||
}
|
||||
if err := v.Insert(); err == nil {
|
||||
*createdVendors++
|
||||
vendorIDCache[vendorName] = v.Id
|
||||
return v.Id
|
||||
}
|
||||
vendorIDCache[vendorName] = 0
|
||||
return 0
|
||||
}
|
||||
|
||||
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
|
||||
func SyncUpstreamModels(c *gin.Context) {
|
||||
var req syncRequest
|
||||
// 允许空体
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
// 1) 获取未配置模型列表
|
||||
missing, err := model.GetMissingModels()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(missing) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
|
||||
"created_models": 0,
|
||||
"created_vendors": 0,
|
||||
"skipped_models": []string{},
|
||||
}})
|
||||
return
|
||||
}
|
||||
|
||||
// 2) 拉取上游 vendors 与 models
|
||||
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
modelsURL, vendorsURL := getUpstreamURLs(req.Locale)
|
||||
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||
var fetchErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// vendor 失败不拦截
|
||||
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||
fetchErr = err
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
if fetchErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": req.Locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||
return
|
||||
}
|
||||
|
||||
// 建立映射
|
||||
vendorByName := make(map[string]upstreamVendor)
|
||||
for _, v := range vendorsEnv.Data {
|
||||
if v.Name != "" {
|
||||
vendorByName[v.Name] = v
|
||||
}
|
||||
}
|
||||
modelByName := make(map[string]upstreamModel)
|
||||
for _, m := range modelsEnv.Data {
|
||||
if m.ModelName != "" {
|
||||
modelByName[m.ModelName] = m
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 执行同步:仅创建缺失模型;若上游缺失该模型则跳过
|
||||
createdModels := 0
|
||||
createdVendors := 0
|
||||
updatedModels := 0
|
||||
var skipped []string
|
||||
var createdList []string
|
||||
var updatedList []string
|
||||
|
||||
// 本地缓存:vendorName -> id
|
||||
vendorIDCache := make(map[string]int)
|
||||
|
||||
for _, name := range missing {
|
||||
up, ok := modelByName[name]
|
||||
if !ok {
|
||||
skipped = append(skipped, name)
|
||||
continue
|
||||
}
|
||||
|
||||
// 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
|
||||
var existing model.Model
|
||||
if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
|
||||
if existing.SyncOfficial == 0 {
|
||||
skipped = append(skipped, name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 vendor 存在
|
||||
vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||
|
||||
// 创建模型
|
||||
mi := &model.Model{
|
||||
ModelName: name,
|
||||
Description: up.Description,
|
||||
Icon: up.Icon,
|
||||
Tags: up.Tags,
|
||||
VendorID: vendorID,
|
||||
Status: chooseStatus(up.Status, 1),
|
||||
NameRule: up.NameRule,
|
||||
}
|
||||
if err := mi.Insert(); err == nil {
|
||||
createdModels++
|
||||
createdList = append(createdList, name)
|
||||
} else {
|
||||
skipped = append(skipped, name)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 处理可选覆盖(更新本地已有模型的差异字段)
|
||||
if len(req.Overwrite) > 0 {
|
||||
// vendorIDCache 已用于创建阶段,可复用
|
||||
for _, ow := range req.Overwrite {
|
||||
up, ok := modelByName[ow.ModelName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var local model.Model
|
||||
if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过被禁用官方同步的模型
|
||||
if local.SyncOfficial == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 映射 vendor
|
||||
newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||
|
||||
// 应用字段覆盖(事务)
|
||||
_ = model.DB.Transaction(func(tx *gorm.DB) error {
|
||||
needUpdate := false
|
||||
if containsField(ow.Fields, "description") {
|
||||
local.Description = up.Description
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "icon") {
|
||||
local.Icon = up.Icon
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "tags") {
|
||||
local.Tags = up.Tags
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "vendor") {
|
||||
local.VendorID = newVendorID
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "name_rule") {
|
||||
local.NameRule = up.NameRule
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "status") {
|
||||
local.Status = chooseStatus(up.Status, local.Status)
|
||||
needUpdate = true
|
||||
}
|
||||
if !needUpdate {
|
||||
return nil
|
||||
}
|
||||
if err := tx.Save(&local).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
updatedModels++
|
||||
updatedList = append(updatedList, ow.ModelName)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"created_models": createdModels,
|
||||
"created_vendors": createdVendors,
|
||||
"updated_models": updatedModels,
|
||||
"skipped_models": skipped,
|
||||
"created_list": createdList,
|
||||
"updated_list": updatedList,
|
||||
"source": gin.H{
|
||||
"locale": req.Locale,
|
||||
"models_url": modelsURL,
|
||||
"vendors_url": vendorsURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func containsField(fields []string, key string) bool {
|
||||
key = strings.ToLower(strings.TrimSpace(key))
|
||||
for _, f := range fields {
|
||||
if strings.ToLower(strings.TrimSpace(f)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func coalesce(a, b string) string {
|
||||
if strings.TrimSpace(a) != "" {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func chooseStatus(primary, fallback int) int {
|
||||
if primary == 0 && fallback != 0 {
|
||||
return fallback
|
||||
}
|
||||
if primary != 0 {
|
||||
return primary
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
|
||||
func SyncUpstreamPreview(c *gin.Context) {
|
||||
// 1) 拉取上游数据
|
||||
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
locale := c.Query("locale")
|
||||
modelsURL, vendorsURL := getUpstreamURLs(locale)
|
||||
|
||||
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||
var fetchErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||
fetchErr = err
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
if fetchErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||
return
|
||||
}
|
||||
|
||||
vendorByName := make(map[string]upstreamVendor)
|
||||
for _, v := range vendorsEnv.Data {
|
||||
if v.Name != "" {
|
||||
vendorByName[v.Name] = v
|
||||
}
|
||||
}
|
||||
modelByName := make(map[string]upstreamModel)
|
||||
upstreamNames := make([]string, 0, len(modelsEnv.Data))
|
||||
for _, m := range modelsEnv.Data {
|
||||
if m.ModelName != "" {
|
||||
modelByName[m.ModelName] = m
|
||||
upstreamNames = append(upstreamNames, m.ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 本地已有模型
|
||||
var locals []model.Model
|
||||
if len(upstreamNames) > 0 {
|
||||
_ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
|
||||
}
|
||||
|
||||
// 本地 vendor 名称映射
|
||||
vendorIdSet := make(map[int]struct{})
|
||||
for _, m := range locals {
|
||||
if m.VendorID != 0 {
|
||||
vendorIdSet[m.VendorID] = struct{}{}
|
||||
}
|
||||
}
|
||||
vendorIDs := make([]int, 0, len(vendorIdSet))
|
||||
for id := range vendorIdSet {
|
||||
vendorIDs = append(vendorIDs, id)
|
||||
}
|
||||
idToVendorName := make(map[int]string)
|
||||
if len(vendorIDs) > 0 {
|
||||
var dbVendors []model.Vendor
|
||||
_ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
|
||||
for _, v := range dbVendors {
|
||||
idToVendorName[v.Id] = v.Name
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 缺失且上游存在的模型
|
||||
missingList, _ := model.GetMissingModels()
|
||||
var missing []string
|
||||
for _, name := range missingList {
|
||||
if _, ok := modelByName[name]; ok {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 计算冲突字段
|
||||
type conflictField struct {
|
||||
Field string `json:"field"`
|
||||
Local interface{} `json:"local"`
|
||||
Upstream interface{} `json:"upstream"`
|
||||
}
|
||||
type conflictItem struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Fields []conflictField `json:"fields"`
|
||||
}
|
||||
|
||||
var conflicts []conflictItem
|
||||
for _, local := range locals {
|
||||
up, ok := modelByName[local.ModelName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fields := make([]conflictField, 0, 6)
|
||||
if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
|
||||
fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
|
||||
}
|
||||
if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
|
||||
fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
|
||||
}
|
||||
if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
|
||||
fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
|
||||
}
|
||||
// vendor 对比使用名称
|
||||
localVendor := idToVendorName[local.VendorID]
|
||||
if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
|
||||
fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
|
||||
}
|
||||
if local.NameRule != up.NameRule {
|
||||
fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
|
||||
}
|
||||
if local.Status != chooseStatus(up.Status, local.Status) {
|
||||
fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
|
||||
}
|
||||
if len(fields) > 0 {
|
||||
conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"missing": missing,
|
||||
"conflicts": conflicts,
|
||||
"source": gin.H{
|
||||
"locale": locale,
|
||||
"models_url": modelsURL,
|
||||
"vendors_url": vendorsURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -35,8 +36,13 @@ func GetOptions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type OptionUpdateRequest struct {
|
||||
Key string `json:"key"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
func UpdateOption(c *gin.Context) {
|
||||
var option model.Option
|
||||
var option OptionUpdateRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
@@ -45,6 +51,16 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
switch option.Value.(type) {
|
||||
case bool:
|
||||
option.Value = common.Interface2String(option.Value.(bool))
|
||||
case float64:
|
||||
option.Value = common.Interface2String(option.Value.(float64))
|
||||
case int:
|
||||
option.Value = common.Interface2String(option.Value.(int))
|
||||
default:
|
||||
option.Value = fmt.Sprintf("%v", option.Value)
|
||||
}
|
||||
switch option.Key {
|
||||
case "GitHubOAuthEnabled":
|
||||
if option.Value == "true" && common.GitHubClientId == "" {
|
||||
@@ -104,7 +120,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -113,7 +129,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "ModelRequestRateLimitGroup":
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value)
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -122,7 +138,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.api_info":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -131,7 +147,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.announcements":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "Announcements")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -140,7 +156,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.faq":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "FAQ")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -149,7 +165,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.uptime_kuma_groups":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "UptimeKumaGroups")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -158,7 +174,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
err = model.UpdateOption(option.Key, option.Value.(string))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"one-api/logger"
|
||||
"strings"
|
||||
@@ -21,8 +23,26 @@ const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||
floatEpsilon = 1e-9
|
||||
)
|
||||
|
||||
func nearlyEqual(a, b float64) bool {
|
||||
if a > b {
|
||||
return a-b < floatEpsilon
|
||||
}
|
||||
return b-a < floatEpsilon
|
||||
}
|
||||
|
||||
func valuesEqual(a, b interface{}) bool {
|
||||
af, aok := a.(float64)
|
||||
bf, bok := b.(float64)
|
||||
if aok && bok {
|
||||
return nearlyEqual(af, bf)
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
@@ -87,7 +107,23 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
// 对 github.io 优先尝试 IPv4,失败则回退 IPv6
|
||||
if strings.HasSuffix(host, "github.io") {
|
||||
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp6", addr)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
@@ -98,12 +134,17 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
var fullURL string
|
||||
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||
fullURL = endpoint
|
||||
} else {
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL = chItem.BaseURL + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
@@ -120,10 +161,19 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
// 简单重试:最多 3 次,指数退避
|
||||
var resp *http.Response
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
resp, lastErr = client.Do(httpReq)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||
}
|
||||
if lastErr != nil {
|
||||
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -132,6 +182,12 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
|
||||
// Content-Type 和响应体大小校验
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
|
||||
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
|
||||
}
|
||||
limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
@@ -141,7 +197,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
if err := json.NewDecoder(limited).Decode(&body); err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
@@ -152,6 +208,8 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
@@ -357,9 +415,9 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
if localValue != nil && !valuesEqual(localValue, val) {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
} else if valuesEqual(localValue, val) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
@@ -466,6 +524,13 @@ func GetSyncableChannels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: -100,
|
||||
Name: "官方倍率预设",
|
||||
BaseURL: "https://basellm.github.io",
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
@@ -61,8 +63,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
|
||||
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||
|
||||
var (
|
||||
newAPIError *types.NewAPIError
|
||||
@@ -127,12 +129,16 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
return
|
||||
}
|
||||
|
||||
relayInfo.SetPromptTokens(tokens)
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return
|
||||
}
|
||||
|
||||
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||
|
||||
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
@@ -170,35 +176,9 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
if newAPIError == nil {
|
||||
return
|
||||
} else {
|
||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
modelName := c.GetString("original_model")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = newAPIError.GetErrorType()
|
||||
other["error_code"] = newAPIError.GetErrorCode()
|
||||
other["status_code"] = newAPIError.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||
if isMultiKey {
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
}
|
||||
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
@@ -296,12 +276,42 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
|
||||
gopool.Go(func() {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
modelName := c.GetString("original_model")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||
if isMultiKey {
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func RelayMidjourney(c *gin.Context) {
|
||||
@@ -374,11 +384,14 @@ func RelayNotFound(c *gin.Context) {
|
||||
func RelayTask(c *gin.Context) {
|
||||
retryTimes := common.RetryTimes
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayMode := c.GetInt("relay_mode")
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||
taskErr := taskRelayHandler(c, relayMode)
|
||||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
taskErr := taskRelayHandler(c, relayInfo)
|
||||
if taskErr == nil {
|
||||
retryTimes = 0
|
||||
}
|
||||
@@ -398,7 +411,7 @@ func RelayTask(c *gin.Context) {
|
||||
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
taskErr = taskRelayHandler(c, relayMode)
|
||||
taskErr = taskRelayHandler(c, relayInfo)
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
@@ -413,13 +426,13 @@ func RelayTask(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
err = relay.RelayTaskSubmit(c, relayInfo)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -82,6 +83,57 @@ func GetTokenStatus(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func GetTokenUsage(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "No Authorization header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid Bearer token",
|
||||
})
|
||||
return
|
||||
}
|
||||
tokenKey := parts[1]
|
||||
|
||||
token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
expiredAt := token.ExpiredTime
|
||||
if expiredAt == -1 {
|
||||
expiredAt = 0
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": true,
|
||||
"message": "ok",
|
||||
"data": gin.H{
|
||||
"object": "token_usage",
|
||||
"name": token.Name,
|
||||
"total_granted": token.RemainQuota + token.UsedQuota,
|
||||
"total_used": token.UsedQuota,
|
||||
"total_available": token.RemainQuota,
|
||||
"unlimited_quota": token.UnlimitedQuota,
|
||||
"model_limits": token.GetModelLimitsMap(),
|
||||
"model_limits_enabled": token.ModelLimitsEnabled,
|
||||
"expires_at": expiredAt,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
|
||||
@@ -31,7 +31,7 @@ type Monitor struct {
|
||||
|
||||
type UptimeGroupResult struct {
|
||||
CategoryName string `json:"categoryName"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
}
|
||||
|
||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||
@@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
||||
url, _ := groupConfig["url"].(string)
|
||||
slug, _ := groupConfig["slug"].(string)
|
||||
categoryName, _ := groupConfig["categoryName"].(string)
|
||||
|
||||
|
||||
result := UptimeGroupResult{
|
||||
CategoryName: categoryName,
|
||||
Monitors: []Monitor{},
|
||||
Monitors: []Monitor{},
|
||||
}
|
||||
|
||||
|
||||
if url == "" || slug == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(url, "/")
|
||||
|
||||
|
||||
var statusData struct {
|
||||
PublicGroupList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MonitorList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"monitorList"`
|
||||
} `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
|
||||
var heartbeatData struct {
|
||||
HeartbeatList map[string][]struct {
|
||||
Status int `json:"status"`
|
||||
@@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
||||
}
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
})
|
||||
|
||||
if g.Wait() != nil {
|
||||
@@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
||||
|
||||
client := &http.Client{Timeout: httpTimeout}
|
||||
results := make([]UptimeGroupResult, len(groups))
|
||||
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
for i, group := range groups {
|
||||
i, group := i, group
|
||||
@@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
g.Wait()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +210,7 @@ func Register(c *gin.Context) {
|
||||
Password: user.Password,
|
||||
DisplayName: user.Username,
|
||||
InviterId: inviterId,
|
||||
Role: common.RoleCommonUser, // 明确设置角色为普通用户
|
||||
}
|
||||
if common.EmailVerificationEnabled {
|
||||
cleanUser.Email = user.Email
|
||||
@@ -426,6 +427,7 @@ func GetAffCode(c *gin.Context) {
|
||||
|
||||
func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
userRole := c.GetInt("role")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -434,14 +436,134 @@ func GetSelf(c *gin.Context) {
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
user.Remark = ""
|
||||
|
||||
// 计算用户权限信息
|
||||
permissions := calculateUserPermissions(userRole)
|
||||
|
||||
// 获取用户设置并提取sidebar_modules
|
||||
userSetting := user.GetSetting()
|
||||
|
||||
// 构建响应数据,包含用户信息和权限
|
||||
responseData := map[string]interface{}{
|
||||
"id": user.Id,
|
||||
"username": user.Username,
|
||||
"display_name": user.DisplayName,
|
||||
"role": user.Role,
|
||||
"status": user.Status,
|
||||
"email": user.Email,
|
||||
"group": user.Group,
|
||||
"quota": user.Quota,
|
||||
"used_quota": user.UsedQuota,
|
||||
"request_count": user.RequestCount,
|
||||
"aff_code": user.AffCode,
|
||||
"aff_count": user.AffCount,
|
||||
"aff_quota": user.AffQuota,
|
||||
"aff_history_quota": user.AffHistoryQuota,
|
||||
"inviter_id": user.InviterId,
|
||||
"linux_do_id": user.LinuxDOId,
|
||||
"setting": user.Setting,
|
||||
"stripe_customer": user.StripeCustomer,
|
||||
"sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段
|
||||
"permissions": permissions, // 新增权限字段
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user,
|
||||
"data": responseData,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算用户权限的辅助函数
|
||||
func calculateUserPermissions(userRole int) map[string]interface{} {
|
||||
permissions := map[string]interface{}{}
|
||||
|
||||
// 根据用户角色计算权限
|
||||
if userRole == common.RoleRootUser {
|
||||
// 超级管理员不需要边栏设置功能
|
||||
permissions["sidebar_settings"] = false
|
||||
permissions["sidebar_modules"] = map[string]interface{}{}
|
||||
} else if userRole == common.RoleAdminUser {
|
||||
// 管理员可以设置边栏,但不包含系统设置功能
|
||||
permissions["sidebar_settings"] = true
|
||||
permissions["sidebar_modules"] = map[string]interface{}{
|
||||
"admin": map[string]interface{}{
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// 普通用户只能设置个人功能,不包含管理员区域
|
||||
permissions["sidebar_settings"] = true
|
||||
permissions["sidebar_modules"] = map[string]interface{}{
|
||||
"admin": false, // 普通用户不能访问管理员区域
|
||||
}
|
||||
}
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
// 根据用户角色生成默认的边栏配置
|
||||
func generateDefaultSidebarConfig(userRole int) string {
|
||||
defaultConfig := map[string]interface{}{}
|
||||
|
||||
// 聊天区域 - 所有用户都可以访问
|
||||
defaultConfig["chat"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"playground": true,
|
||||
"chat": true,
|
||||
}
|
||||
|
||||
// 控制台区域 - 所有用户都可以访问
|
||||
defaultConfig["console"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"detail": true,
|
||||
"token": true,
|
||||
"log": true,
|
||||
"midjourney": true,
|
||||
"task": true,
|
||||
}
|
||||
|
||||
// 个人中心区域 - 所有用户都可以访问
|
||||
defaultConfig["personal"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"topup": true,
|
||||
"personal": true,
|
||||
}
|
||||
|
||||
// 管理员区域 - 根据角色决定
|
||||
if userRole == common.RoleAdminUser {
|
||||
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
}
|
||||
} else if userRole == common.RoleRootUser {
|
||||
// 超级管理员可以访问所有功能
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": true,
|
||||
}
|
||||
}
|
||||
// 普通用户不包含admin区域
|
||||
|
||||
// 转换为JSON字符串
|
||||
configBytes, err := json.Marshal(defaultConfig)
|
||||
if err != nil {
|
||||
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(configBytes)
|
||||
}
|
||||
|
||||
func GetUserModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -528,8 +650,8 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
func UpdateSelf(c *gin.Context) {
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
var requestData map[string]interface{}
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -537,6 +659,60 @@ func UpdateSelf(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否是sidebar_modules更新请求
|
||||
if sidebarModules, exists := requestData["sidebar_modules"]; exists {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户设置
|
||||
currentSetting := user.GetSetting()
|
||||
|
||||
// 更新sidebar_modules字段
|
||||
if sidebarModulesStr, ok := sidebarModules.(string); ok {
|
||||
currentSetting.SidebarModules = sidebarModulesStr
|
||||
}
|
||||
|
||||
// 保存更新后的设置
|
||||
user.SetSetting(currentSetting)
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "更新设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "设置更新成功",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 原有的用户信息更新逻辑
|
||||
var user model.User
|
||||
requestDataBytes, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(requestDataBytes, &user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if user.Password == "" {
|
||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||
}
|
||||
@@ -679,6 +855,7 @@ func CreateUser(c *gin.Context) {
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role, // 保持管理员设置的角色
|
||||
}
|
||||
if err := cleanUser.Insert(0); err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -844,18 +1021,64 @@ type topUpRequest struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
var topUpLock = sync.Mutex{}
|
||||
var topUpLocks sync.Map
|
||||
var topUpCreateLock sync.Mutex
|
||||
|
||||
type topUpTryLock struct {
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
func newTopUpTryLock() *topUpTryLock {
|
||||
return &topUpTryLock{ch: make(chan struct{}, 1)}
|
||||
}
|
||||
|
||||
func (l *topUpTryLock) TryLock() bool {
|
||||
select {
|
||||
case l.ch <- struct{}{}:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (l *topUpTryLock) Unlock() {
|
||||
select {
|
||||
case <-l.ch:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func getTopUpLock(userID int) *topUpTryLock {
|
||||
if v, ok := topUpLocks.Load(userID); ok {
|
||||
return v.(*topUpTryLock)
|
||||
}
|
||||
topUpCreateLock.Lock()
|
||||
defer topUpCreateLock.Unlock()
|
||||
if v, ok := topUpLocks.Load(userID); ok {
|
||||
return v.(*topUpTryLock)
|
||||
}
|
||||
l := newTopUpTryLock()
|
||||
topUpLocks.Store(userID, l)
|
||||
return l
|
||||
}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
topUpLock.Lock()
|
||||
defer topUpLock.Unlock()
|
||||
id := c.GetInt("id")
|
||||
lock := getTopUpLock(id)
|
||||
if !lock.TryLock() {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "充值处理中,请稍后重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer lock.Unlock()
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
quota, err := model.Redeem(req.Key, id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -866,7 +1089,6 @@ func TopUp(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": quota,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateUserSettingRequest struct {
|
||||
@@ -875,6 +1097,7 @@ type UpdateUserSettingRequest struct {
|
||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
BarkUrl string `json:"bark_url,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
RecordIpLog bool `json:"record_ip_log"`
|
||||
}
|
||||
@@ -890,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证预警类型
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的预警类型",
|
||||
@@ -938,6 +1161,33 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是Bark类型,验证Bark URL
|
||||
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||
if req.BarkUrl == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Bark推送URL不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 验证URL格式
|
||||
if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的Bark推送URL",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 检查是否是HTTP或HTTPS
|
||||
if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Bark推送URL必须以http://或https://开头",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
@@ -966,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
settings.NotificationEmail = req.NotificationEmail
|
||||
}
|
||||
|
||||
// 如果是Bark类型,添加Bark URL到设置中
|
||||
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||
settings.BarkUrl = req.BarkUrl
|
||||
}
|
||||
|
||||
// 更新用户设置
|
||||
user.SetSetting(settings)
|
||||
if err := user.Update(false); err != nil {
|
||||
|
||||
|
After Width: | Height: | Size: 5.0 KiB |
|
After Width: | Height: | Size: 11 KiB |
@@ -1,55 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg id="_图层_2" data-name="图层_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 198.45 66.73">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #ea5e5d;
|
||||
}
|
||||
|
||||
.cls-2 {
|
||||
fill: #23af69;
|
||||
}
|
||||
|
||||
.cls-3 {
|
||||
fill: #ea5756;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<g id="_图层_1-2" data-name="图层_1">
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="cls-1" d="M16.72,51.21c-4.45,0-8.64-1.78-11.81-5.01-3.17-3.23-4.91-7.51-4.91-12.04s1.74-8.81,4.91-12.04,7.36-5.01,11.81-5.01,8.71,1.82,11.82,4.99c2.32,2.36,2.32,6.2,0,8.56-2.32,2.36-6.08,2.36-8.4,0-.9-.92-2.15-1.45-3.43-1.45-2.63,0-4.85,2.26-4.85,4.94s2.22,4.94,4.85,4.94c1.28,0,2.52-.53,3.43-1.45,2.32-2.36,6.08-2.36,8.4,0,2.32,2.36,2.32,6.2,0,8.56-3.11,3.17-7.42,4.99-11.82,4.99Z"/>
|
||||
<path class="cls-1" d="M32.05,66.73c-4.45,0-8.64-1.78-11.81-5.01s-4.91-7.51-4.91-12.04,1.79-8.88,4.9-12.06c2.32-2.36,6.08-2.36,8.4,0,2.32,2.36,2.32,6.2,0,8.56-.9.92-1.42,2.19-1.42,3.49,0,2.68,2.22,4.94,4.85,4.94s4.85-2.26,4.85-4.94c0-.95-.23-2.31-1.32-3.43-3.13-3.19-4.92-7.6-4.92-12.09s1.74-8.81,4.91-12.04,7.36-5.01,11.81-5.01,8.64,1.78,11.81,5.01,4.91,7.51,4.91,12.04-1.79,8.88-4.9,12.06c-2.32,2.36-6.08,2.36-8.4,0-2.32-2.36-2.32-6.2,0-8.56.9-.92,1.42-2.19,1.42-3.49,0-2.68-2.22-4.94-4.85-4.94s-4.85,2.26-4.85,4.94c0,1.31.53,2.6,1.45,3.53,3.1,3.16,4.8,7.42,4.8,11.99s-1.74,8.81-4.91,12.04c-3.17,3.23-7.36,5.01-11.81,5.01Z"/>
|
||||
</g>
|
||||
<path class="cls-2" d="M32.05,19.09l-9.72-9.12c-1.5-1.4-1.57-3.75-.17-5.25,1.4-1.49,3.75-1.57,5.25-.17l3.89,3.65,5.53-6.83c1.29-1.59,3.63-1.84,5.22-.55,1.59,1.29,1.84,3.63.55,5.22l-10.56,13.05Z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M93.93,24.6l.55-.39c.69-.4,1.17-.61,1.46-.61.63,0,1.3.57,2.03,1.7.44.71.67,1.27.67,1.7s-.14.78-.41,1.06c-.27.28-.59.54-.96.76-.36.22-.71.43-1.05.64-.33.2-1.02.47-2.05.79-1.03.32-2.03.49-2.99.49s-1.93-.13-2.91-.38c-.98-.25-1.99-.68-3.03-1.27-1.04-.6-1.98-1.32-2.81-2.18-.83-.86-1.51-1.96-2.05-3.31-.54-1.35-.8-2.81-.8-4.38s.26-3.01.79-4.29c.53-1.28,1.2-2.35,2.02-3.19.82-.84,1.75-1.54,2.81-2.11,1.98-1.09,3.97-1.64,5.98-1.64.95,0,1.92.15,2.9.44.98.29,1.72.59,2.23.9l.73.42c.36.22.65.4.85.55.53.42.79.91.79,1.44s-.21,1.1-.64,1.68c-.79,1.09-1.5,1.64-2.12,1.64-.36,0-.88-.22-1.55-.67-.85-.69-1.98-1.03-3.4-1.03-1.31,0-2.61.46-3.88,1.36-.61.44-1.11,1.07-1.52,1.88-.4.81-.61,1.72-.61,2.75s.2,1.94.61,2.75c.4.81.92,1.45,1.55,1.91,1.23.89,2.52,1.34,3.85,1.34.63,0,1.22-.08,1.77-.24.56-.16.96-.32,1.2-.49Z"/>
|
||||
<path class="cls-3" d="M114.38,9.07c.16-.3.43-.52.82-.64.38-.12.87-.18,1.46-.18s1.05.05,1.4.15c.34.1.61.22.79.36.18.14.32.34.42.61.1.34.15.87.15,1.58v16.84c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58v-6.16h-8.04v6.19c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58V10.92c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82,1.42,0,2.25.37,2.52,1.12.1.34.15.87.15,1.58v6.19h8.04v-6.22c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8Z"/>
|
||||
<path class="cls-3" d="M127.21,25.1h9.34c.47,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.37,2.25-1.12,2.49-.34.12-.87.18-1.58.18h-12.01c-1.42,0-2.25-.38-2.49-1.15-.12-.32-.18-.84-.18-1.55V10.9c0-1.03.19-1.73.58-2.11.38-.37,1.11-.56,2.18-.56h11.95c.47,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.37,2.25-1.12,2.49-.34.12-.87.18-1.58.18h-9.31v3.06h6.01c.46,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.38,2.25-1.15,2.49-.34.12-.87.18-1.58.18h-5.95v3.06Z"/>
|
||||
<path class="cls-3" d="M196.96,8.79c.99.69,1.49,1.35,1.49,2,0,.38-.23.92-.7,1.61l-6.55,9.8v5.79c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.16.3-.43.52-.82.64-.38.12-.9.18-1.55.18s-1.16-.06-1.55-.18c-.38-.12-.66-.34-.82-.65-.16-.31-.26-.59-.29-.82-.03-.23-.05-.59-.05-1.08v-5.73l-6.55-9.8c-.47-.69-.7-1.22-.7-1.61,0-.65.44-1.27,1.33-1.87.89-.6,1.53-.9,1.91-.9s.69.08.91.24c.34.22.71.64,1.09,1.24l4.7,7.52,4.7-7.52c.38-.61.72-1.01,1-1.2s.61-.29.99-.29.97.25,1.77.76Z"/>
|
||||
<g>
|
||||
<path class="cls-3" d="M81.93,56.63c-.53-.65-.79-1.23-.79-1.74s.43-1.2,1.3-2.05c.51-.49,1.04-.73,1.61-.73s1.36.51,2.37,1.52c.28.34.69.67,1.21.99.53.31,1.01.47,1.46.47,1.88,0,2.82-.77,2.82-2.31,0-.46-.26-.85-.77-1.17-.52-.31-1.16-.54-1.93-.68-.77-.14-1.6-.37-2.49-.68-.89-.31-1.72-.68-2.49-1.11-.77-.42-1.41-1.1-1.93-2.02-.52-.92-.77-2.03-.77-3.32,0-1.78.66-3.33,1.99-4.66s3.13-1.99,5.42-1.99c1.21,0,2.32.16,3.32.47,1,.31,1.69.63,2.08.96l.76.58c.63.59.94,1.08.94,1.49s-.24.96-.73,1.67c-.69,1.01-1.4,1.52-2.12,1.52-.42,0-.95-.2-1.58-.61-.06-.04-.18-.14-.35-.3-.17-.16-.33-.29-.47-.39-.42-.26-.97-.39-1.62-.39s-1.2.16-1.64.47c-.43.31-.65.75-.65,1.3s.26,1.01.77,1.35c.52.34,1.16.58,1.93.7.77.12,1.61.31,2.52.56.91.25,1.75.56,2.52.93.77.36,1.41,1,1.93,1.9.52.9.77,2.01.77,3.32s-.26,2.47-.79,3.47c-.53,1-1.21,1.77-2.06,2.32-1.64,1.07-3.39,1.61-5.25,1.61-.95,0-1.85-.12-2.7-.35-.85-.23-1.54-.52-2.06-.86-1.07-.65-1.82-1.27-2.24-1.88l-.27-.33Z"/>
|
||||
<path class="cls-3" d="M100.74,37.49h16.87c.65,0,1.12.08,1.43.23.3.15.51.39.61.71.1.32.15.75.15,1.27s-.05.95-.15,1.26c-.1.31-.27.53-.52.65-.36.18-.88.27-1.55.27h-5.79v15.26c0,.47-.02.81-.05,1.03s-.12.48-.27.77c-.15.29-.42.5-.8.62-.38.12-.89.18-1.52.18s-1.13-.06-1.5-.18c-.37-.12-.64-.33-.79-.62-.15-.29-.24-.56-.27-.79-.03-.23-.05-.58-.05-1.05v-15.23h-5.82c-.65,0-1.12-.08-1.43-.23-.3-.15-.51-.39-.61-.71-.1-.32-.15-.75-.15-1.27s.05-.95.15-1.26c.1-.31.27-.53.52-.65.36-.18.88-.27,1.55-.27Z"/>
|
||||
<path class="cls-3" d="M135.99,38.34c.2-.32.5-.55.88-.67.38-.12.86-.18,1.44-.18s1.04.05,1.38.15c.34.1.61.22.79.36.18.14.31.35.39.64.12.34.18.87.18,1.58v9.16c0,2.67-.83,5.1-2.49,7.28-.81,1.03-1.85,1.87-3.12,2.5s-2.68.96-4.23.96-2.95-.32-4.22-.97c-1.26-.65-2.29-1.5-3.08-2.55-1.64-2.14-2.46-4.57-2.46-7.28v-9.13c0-.49.02-.84.05-1.08.03-.23.13-.5.29-.8.16-.3.43-.52.82-.64.38-.12.9-.18,1.55-.18s1.16.06,1.55.18c.38.12.65.33.79.64.24.47.36,1.1.36,1.91v9.1c0,1.23.3,2.41.91,3.52.3.57.76,1.02,1.37,1.36.61.34,1.32.52,2.15.52,1.48,0,2.58-.55,3.31-1.64.73-1.09,1.09-2.36,1.09-3.79v-9.28c0-.79.1-1.34.3-1.67Z"/>
|
||||
<path class="cls-3" d="M146.18,37.49l5.61.03c2.93,0,5.51,1.06,7.74,3.17,2.22,2.11,3.34,4.71,3.34,7.8s-1.09,5.73-3.26,7.93c-2.17,2.2-4.81,3.31-7.9,3.31h-5.55c-1.23,0-2-.25-2.31-.76-.24-.42-.36-1.07-.36-1.94v-16.87c0-.49.02-.84.05-1.06s.13-.49.29-.79c.28-.55,1.07-.82,2.37-.82ZM151.79,54.35c1.46,0,2.77-.54,3.94-1.62,1.17-1.08,1.76-2.44,1.76-4.08s-.57-3.01-1.71-4.11c-1.14-1.1-2.48-1.65-4.02-1.65h-2.91v11.47h2.94Z"/>
|
||||
<path class="cls-3" d="M164.84,40.19c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82,1.42,0,2.25.37,2.52,1.12.1.34.15.87.15,1.58v16.87c0,.49-.02.84-.05,1.06s-.13.49-.29.79c-.28.55-1.07.82-2.37.82-1.42,0-2.25-.38-2.49-1.15-.12-.32-.18-.84-.18-1.55v-16.87Z"/>
|
||||
<path class="cls-3" d="M183.07,37.24c2.99,0,5.59,1.08,7.8,3.25,2.2,2.16,3.31,4.85,3.31,8.05s-1.05,5.94-3.16,8.19c-2.1,2.26-4.69,3.38-7.77,3.38s-5.69-1.11-7.84-3.34c-2.15-2.22-3.23-4.87-3.23-7.95,0-1.68.3-3.25.91-4.72.61-1.47,1.42-2.7,2.43-3.69,1.01-.99,2.17-1.77,3.49-2.34,1.31-.57,2.67-.85,4.07-.85ZM177.55,48.68c0,1.8.58,3.26,1.74,4.38,1.16,1.12,2.46,1.68,3.9,1.68s2.73-.55,3.88-1.64c1.15-1.09,1.73-2.56,1.73-4.4s-.58-3.32-1.74-4.43c-1.16-1.11-2.46-1.67-3.9-1.67s-2.73.56-3.88,1.68c-1.15,1.12-1.73,2.58-1.73,4.38Z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M176.92,11.06c-.03-.23-.13-.5-.29-.8-.28-.55-1.07-.82-2.37-.82h-6.55c-1.78,0-3.51.65-5.19,1.94-.81.63-1.48,1.48-2,2.55-.53,1.07-.79,2.27-.79,3.58,0,2.29.76,4.17,2.28,5.64-.44,1.07-1.13,2.66-2.06,4.76-.3.73-.45,1.25-.45,1.58,0,.77.63,1.42,1.88,1.94.65.28,1.17.43,1.56.43s.72-.1.97-.29c.25-.19.44-.39.56-.59.2-.38.99-2.21,2.37-5.49l.94.06h3.82v3.43c0,.47.02.81.05,1.05.03.23.13.5.29.8.28.55,1.07.82,2.37.82,1.42,0,2.25-.37,2.49-1.12.12-.34.18-.87.18-1.58V12.11c0-.46-.02-.81-.05-1.05ZM172.81,19.44c-.09.14-.48.77-1.24.91-.2.04-.37.03-.48.02-.02.14-.04.26-.06.38-.16.83-.38,1.05-.57,1.07-.29.05-.51-.35-.93-.9-.23.01-.46.02-.69.02-.51,0-1.01-.03-1.49-.09-.25-.03-.5-.07-.74-.11-1.18-.32-2.03-1.27-2.03-2.4v-1.37c0-1.13.86-2.08,2.03-2.4.24-.04.49-.08.74-.11.48-.06.98-.09,1.49-.09s1.01.03,1.49.09c.25.03.5.07.74.11.6.16,1.12.49,1.49.93.34.41.55.92.55,1.47v1.37c0,.23-.01.66-.29,1.1Z"/>
|
||||
<circle class="cls-2" cx="167.24" cy="17.67" r=".49"/>
|
||||
<circle class="cls-2" cx="168.88" cy="17.71" r=".49"/>
|
||||
<circle class="cls-2" cx="170.59" cy="17.71" r=".49"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M141.01,8.24c.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82h6.55c1.78,0,3.51.65,5.19,1.94.81.63,1.48,1.48,2,2.55.53,1.07.79,2.27.79,3.58,0,2.29-.76,4.17-2.28,5.64.44,1.07,1.13,2.66,2.06,4.76.3.73.45,1.25.45,1.58,0,.77-.63,1.42-1.88,1.94-.65.28-1.17.43-1.56.43s-.72-.1-.97-.29c-.25-.19-.44-.39-.56-.59-.2-.38-.99-2.21-2.37-5.49l-.94.06h-3.82v3.43c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58V9.28c0-.46.02-.81.05-1.05ZM145.12,16.62c.09.14.48.77,1.24.91.2.04.37.03.48.02.02.14.04.26.06.38.16.83.38,1.05.57,1.07.29.05.51-.35.93-.9.23.01.46.02.69.02.51,0,1.01-.03,1.49-.09.25-.03.5-.07.74-.11,1.18-.32,2.03-1.27,2.03-2.4v-1.37c0-1.13-.86-2.08-2.03-2.4-.24-.04-.49-.08-.74-.11-.48-.06-.98-.09-1.49-.09s-1.01.03-1.49.09c-.25.03-.5.07-.74.11-.6.16-1.12.49-1.49.93-.34.41-.55.92-.55,1.47v1.37c0,.23.01.66.29,1.1Z"/>
|
||||
<circle class="cls-2" cx="150.69" cy="14.84" r=".49"/>
|
||||
<circle class="cls-2" cx="149.05" cy="14.89" r=".49"/>
|
||||
<circle class="cls-2" cx="147.35" cy="14.89" r=".49"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 9.5 KiB |
|
After Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 12 KiB |
|
After Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 7.9 KiB |
@@ -26,6 +26,12 @@ func (r *AudioRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *AudioRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
type AudioResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
data = common.Interface2String(media.Source.Data)
|
||||
}
|
||||
if data != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
|
||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,7 +263,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
data = common.Interface2String(media.Source.Data)
|
||||
}
|
||||
if data != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
|
||||
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||
}
|
||||
}
|
||||
case "tool_use":
|
||||
@@ -321,8 +321,14 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
return &tokenCountMeta
|
||||
}
|
||||
|
||||
func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool {
|
||||
return claudeRequest.Stream
|
||||
func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
|
||||
return c.Stream
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
c.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
|
||||
@@ -482,14 +488,14 @@ func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.ClaudeError{
|
||||
Type: "error",
|
||||
Type: "upstream_error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.ClaudeError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
Type: "unknown_upstream_error",
|
||||
Message: fmt.Sprintf("unknown_error: %v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,12 @@ func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *EmbeddingRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (r *EmbeddingRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return make([]string, 0)
|
||||
|
||||
@@ -2,11 +2,12 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
@@ -35,23 +36,23 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||
files = append(files, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Data: part.InlineData.Data,
|
||||
FileType: types.FileTypeImage,
|
||||
OriginData: part.InlineData.Data,
|
||||
})
|
||||
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
|
||||
files = append(files, &types.FileMeta{
|
||||
FileType: types.FileTypeAudio,
|
||||
Data: part.InlineData.Data,
|
||||
FileType: types.FileTypeAudio,
|
||||
OriginData: part.InlineData.Data,
|
||||
})
|
||||
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
|
||||
files = append(files, &types.FileMeta{
|
||||
FileType: types.FileTypeVideo,
|
||||
Data: part.InlineData.Data,
|
||||
FileType: types.FileTypeVideo,
|
||||
OriginData: part.InlineData.Data,
|
||||
})
|
||||
} else {
|
||||
files = append(files, &types.FileMeta{
|
||||
FileType: types.FileTypeFile,
|
||||
Data: part.InlineData.Data,
|
||||
FileType: types.FileTypeFile,
|
||||
OriginData: part.InlineData.Data,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -73,6 +74,10 @@ func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) SetModelName(modelName string) {
|
||||
// GeminiChatRequest does not have a model field, so this method does nothing.
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||
var tools []GeminiChatTool
|
||||
if strings.HasSuffix(string(r.Tools), "[") {
|
||||
@@ -264,14 +269,15 @@ type GeminiChatResponse struct {
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"`
|
||||
CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"`
|
||||
}
|
||||
|
||||
type GeminiPromptTokensDetails struct {
|
||||
type GeminiModalityTokenCount struct {
|
||||
Modality string `json:"modality"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
@@ -312,10 +318,61 @@ type GeminiEmbeddingRequest struct {
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||
// Gemini embedding requests are not streamed
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
var inputTexts []string
|
||||
for _, part := range r.Content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
inputText := strings.Join(inputTexts, "\n")
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: inputText,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
type GeminiBatchEmbeddingRequest struct {
|
||||
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
||||
}
|
||||
|
||||
func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||
// Gemini batch embedding requests are not streamed
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
var inputTexts []string
|
||||
for _, request := range r.Requests {
|
||||
meta := request.GetTokenCountMeta()
|
||||
if meta != nil && meta.CombineText != "" {
|
||||
inputTexts = append(inputTexts, meta.CombineText)
|
||||
}
|
||||
}
|
||||
inputText := strings.Join(inputTexts, "\n")
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: inputText,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
for _, req := range r.Requests {
|
||||
req.SetModelName(modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type GeminiEmbeddingResponse struct {
|
||||
Embedding ContentEmbedding `json:"embedding"`
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -12,10 +14,10 @@ type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N uint `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style json.RawMessage `json:"style,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style json.RawMessage `json:"style,omitempty"`
|
||||
User json.RawMessage `json:"user,omitempty"`
|
||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||
Background json.RawMessage `json:"background,omitempty"`
|
||||
@@ -25,6 +27,70 @@ type ImageRequest struct {
|
||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
// 用匿名参数接收额外参数
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||
// 先解析成 map[string]interface{}
|
||||
var rawMap map[string]json.RawMessage
|
||||
if err := common.Unmarshal(data, &rawMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 用 struct tag 获取所有已定义字段名
|
||||
knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
|
||||
|
||||
// 再正常解析已定义字段
|
||||
type Alias ImageRequest
|
||||
var known Alias
|
||||
if err := common.Unmarshal(data, &known); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = ImageRequest(known)
|
||||
|
||||
// 提取多余字段
|
||||
i.Extra = make(map[string]json.RawMessage)
|
||||
for k, v := range rawMap {
|
||||
if _, ok := knownFields[k]; !ok {
|
||||
i.Extra[k] = v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||
fields := make(map[string]struct{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
// 跳过匿名字段(例如 ExtraFields)
|
||||
if field.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "-" || tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 取逗号前字段名(排除 omitempty 等)
|
||||
name := tag
|
||||
if commaIdx := indexComma(tag); commaIdx != -1 {
|
||||
name = tag[:commaIdx]
|
||||
}
|
||||
fields[name] = struct{}{}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func indexComma(s string) int {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == ',' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
@@ -63,9 +129,16 @@ func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *ImageRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
i.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Data []ImageData `json:"data"`
|
||||
Created int64 `json:"created"`
|
||||
Extra any `json:"extra,omitempty"`
|
||||
}
|
||||
type ImageData struct {
|
||||
Url string `json:"url"`
|
||||
|
||||
@@ -57,18 +57,24 @@ type GeneralOpenAIRequest struct {
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao,zhipu_v4
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
SearchParameters any `json:"search_parameters,omitempty"` //xai
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// gemini
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
//xai
|
||||
SearchParameters json.RawMessage `json:"search_parameters,omitempty"`
|
||||
// claude
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
// 用匿名参数接收额外参数,例如ollama的think参数在此接收
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"`
|
||||
// ollama Params
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
// baidu v2
|
||||
WebSearch json.RawMessage `json:"web_search,omitempty"`
|
||||
// doubao,zhipu_v4
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
@@ -115,12 +121,14 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
if m.Type == ContentTypeImageURL {
|
||||
imageUrl := m.GetImageMedia()
|
||||
if imageUrl != nil {
|
||||
meta := &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
if imageUrl.Url != "" {
|
||||
meta := &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
}
|
||||
meta.OriginData = imageUrl.Url
|
||||
meta.Detail = imageUrl.Detail
|
||||
fileMeta = append(fileMeta, meta)
|
||||
}
|
||||
meta.Data = imageUrl.Url
|
||||
meta.Detail = imageUrl.Detail
|
||||
fileMeta = append(fileMeta, meta)
|
||||
}
|
||||
} else if m.Type == ContentTypeInputAudio {
|
||||
inputAudio := m.GetInputAudio()
|
||||
@@ -128,7 +136,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
meta := &types.FileMeta{
|
||||
FileType: types.FileTypeAudio,
|
||||
}
|
||||
meta.Data = inputAudio.Data
|
||||
meta.OriginData = inputAudio.Data
|
||||
fileMeta = append(fileMeta, meta)
|
||||
}
|
||||
} else if m.Type == ContentTypeFile {
|
||||
@@ -137,16 +145,16 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
meta := &types.FileMeta{
|
||||
FileType: types.FileTypeFile,
|
||||
}
|
||||
meta.Data = file.FileData
|
||||
meta.OriginData = file.FileData
|
||||
fileMeta = append(fileMeta, meta)
|
||||
}
|
||||
} else if m.Type == ContentTypeVideoUrl {
|
||||
videoUrl := m.GetVideoUrl()
|
||||
if videoUrl != nil {
|
||||
if videoUrl != nil && videoUrl.Url != "" {
|
||||
meta := &types.FileMeta{
|
||||
FileType: types.FileTypeVideo,
|
||||
}
|
||||
meta.Data = videoUrl.Url
|
||||
meta.OriginData = videoUrl.Url
|
||||
fileMeta = append(fileMeta, meta)
|
||||
}
|
||||
} else {
|
||||
@@ -181,6 +189,12 @@ func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
|
||||
return r.Stream
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.Marshal(r)
|
||||
@@ -752,27 +766,27 @@ type WebSearchOptions struct {
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/responses/create
|
||||
type OpenAIResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
@@ -783,16 +797,20 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
inputs := r.ParseInput()
|
||||
for _, input := range inputs {
|
||||
if input.Type == "input_image" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
Data: input.ImageUrl,
|
||||
Detail: input.Detail,
|
||||
})
|
||||
if input.ImageUrl != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeImage,
|
||||
OriginData: input.ImageUrl,
|
||||
Detail: input.Detail,
|
||||
})
|
||||
}
|
||||
} else if input.Type == "input_file" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeFile,
|
||||
Data: input.FileUrl,
|
||||
})
|
||||
if input.FileUrl != "" {
|
||||
fileMeta = append(fileMeta, &types.FileMeta{
|
||||
FileType: types.FileTypeFile,
|
||||
OriginData: input.FileUrl,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
texts = append(texts, input.Text)
|
||||
}
|
||||
@@ -820,8 +838,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
|
||||
if len(r.Tools) > 0 {
|
||||
toolStr, _ := common.Marshal(r.Tools)
|
||||
texts = append(texts, string(toolStr))
|
||||
texts = append(texts, string(r.Tools))
|
||||
}
|
||||
|
||||
return &types.TokenCountMeta{
|
||||
@@ -835,6 +852,20 @@ func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
|
||||
return r.Stream
|
||||
}
|
||||
|
||||
func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any {
|
||||
var toolsMap []map[string]any
|
||||
if len(r.Tools) > 0 {
|
||||
_ = common.Unmarshal(r.Tools, &toolsMap)
|
||||
}
|
||||
return toolsMap
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
Effort string `json:"effort,omitempty"`
|
||||
Summary string `json:"summary,omitempty"`
|
||||
@@ -861,13 +892,21 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||
var inputs []MediaInput
|
||||
|
||||
// Try string first
|
||||
if str, ok := r.Input.(string); ok {
|
||||
// if str, ok := common.GetJsonType(r.Input); ok {
|
||||
// inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||
// return inputs
|
||||
// }
|
||||
if common.GetJsonType(r.Input) == "string" {
|
||||
var str string
|
||||
_ = common.Unmarshal(r.Input, &str)
|
||||
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||
return inputs
|
||||
}
|
||||
|
||||
// Try array of parts
|
||||
if array, ok := r.Input.([]any); ok {
|
||||
if common.GetJsonType(r.Input) == "array" {
|
||||
var array []any
|
||||
_ = common.Unmarshal(r.Input, &array)
|
||||
for _, itemAny := range array {
|
||||
// Already parsed MediaInput
|
||||
if media, ok := itemAny.(MediaInput); ok {
|
||||
|
||||
@@ -110,7 +110,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
||||
c.ReasoningContent = &s
|
||||
c.Reasoning = &s
|
||||
//c.Reasoning = &s
|
||||
}
|
||||
|
||||
type ToolCallResponse struct {
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
package dto
|
||||
|
||||
type UpstreamDTO struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type UpstreamRequest struct {
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// TestResult 上游测试连通性结果
|
||||
type TestResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// DifferenceItem 差异项
|
||||
@@ -25,14 +25,14 @@ type TestResult struct {
|
||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||
|
||||
type DifferenceItem struct {
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
}
|
||||
|
||||
type SyncableChannel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
type Request interface {
|
||||
GetTokenCountMeta() *types.TokenCountMeta
|
||||
IsStream(c *gin.Context) bool
|
||||
SetModelName(modelName string)
|
||||
}
|
||||
|
||||
type BaseRequest struct {
|
||||
@@ -18,7 +19,7 @@ func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
TokenType: types.TokenTypeTokenizer,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BaseRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
func (b *BaseRequest) SetModelName(modelName string) {}
|
||||
|
||||
@@ -37,6 +37,12 @@ func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RerankRequest) SetModelName(modelName string) {
|
||||
if modelName != "" {
|
||||
r.Model = modelName
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RerankRequest) GetReturnDocuments() bool {
|
||||
if r.ReturnDocuments == nil {
|
||||
return false
|
||||
|
||||
@@ -6,11 +6,14 @@ type UserSetting struct {
|
||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||
BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL
|
||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||
SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
NotifyTypeBark = "bark" // Bark 推送
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
@@ -31,6 +32,8 @@ require (
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/stripe/stripe-go/v81 v81.4.0
|
||||
github.com/thanhpk/randstr v1.0.6
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
@@ -82,6 +85,8 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
|
||||
@@ -120,6 +120,8 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
|
||||
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
@@ -204,6 +206,15 @@ github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJ
|
||||
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
||||
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
|
||||
@@ -94,13 +94,9 @@ func main() {
|
||||
}
|
||||
go controller.AutomaticallyUpdateChannels(frequency)
|
||||
}
|
||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||
if err != nil {
|
||||
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
||||
}
|
||||
go controller.AutomaticallyTestChannels(frequency)
|
||||
}
|
||||
|
||||
go controller.AutomaticallyTestChannels()
|
||||
|
||||
if common.IsMasterNode && constant.UpdateTask {
|
||||
gopool.Go(func() {
|
||||
controller.UpdateMidjourneyTaskBulk()
|
||||
|
||||
@@ -192,12 +192,9 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
anthropicKey := c.Request.Header.Get("x-api-key")
|
||||
// 检查path包含/v1/messages
|
||||
// 或者是否 x-api-key 不为空且存在anthropic-version
|
||||
// 谁知道有多少不符合规范没写anthropic-version的
|
||||
// 所以就这样随它去吧(
|
||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
||||
anthropicKey := c.Request.Header.Get("x-api-key")
|
||||
if anthropicKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func DisableCache() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private, max-age=0")
|
||||
c.Header("Pragma", "no-cache")
|
||||
c.Header("Expires", "0")
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
modelRequest.Model = modelName
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -208,7 +208,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||
modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
modelRequest.Model = c.PostForm("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
relayMode := relayconstant.RelayModeAudioSpeech
|
||||
@@ -248,6 +251,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
|
||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||
}
|
||||
|
||||
@@ -18,12 +18,12 @@ func StatsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 增加活跃连接数
|
||||
atomic.AddInt64(&globalStats.activeConnections, 1)
|
||||
|
||||
|
||||
// 确保在请求结束时减少连接数
|
||||
defer func() {
|
||||
atomic.AddInt64(&globalStats.activeConnections, -1)
|
||||
}()
|
||||
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -38,4 +38,4 @@ func GetStats() StatsInfo {
|
||||
return StatsInfo{
|
||||
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,8 @@ type Channel struct {
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
HeaderOverride *string `json:"header_override" gorm:"type:text"`
|
||||
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
|
||||
@@ -111,6 +113,10 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||
}
|
||||
|
||||
lock := GetChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||
// helper to get key status, default to enabled when missing
|
||||
getStatus := func(idx int) int {
|
||||
@@ -142,9 +148,6 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
return keys[selectedIdx], selectedIdx, nil
|
||||
case constant.MultiKeyModePolling:
|
||||
// Use channel-specific lock to ensure thread-safe polling
|
||||
lock := GetChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||
if err != nil {
|
||||
@@ -246,6 +249,10 @@ func (channel *Channel) Save() error {
|
||||
return DB.Save(channel).Error
|
||||
}
|
||||
|
||||
func (channel *Channel) SaveWithoutKey() error {
|
||||
return DB.Omit("key").Save(channel).Error
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
var err error
|
||||
@@ -600,8 +607,12 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
return false
|
||||
}
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
|
||||
pollingLock := GetChannelPollingLock(channelId)
|
||||
pollingLock.Lock()
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||
pollingLock.Unlock()
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
@@ -632,7 +643,11 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
beforeStatus := channel.Status
|
||||
// Protect map writes with the same per-channel lock used by readers
|
||||
pollingLock := GetChannelPollingLock(channelId)
|
||||
pollingLock.Lock()
|
||||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||||
pollingLock.Unlock()
|
||||
if beforeStatus != channel.Status {
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
@@ -644,7 +659,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
channel.Status = status
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
err = channel.Save()
|
||||
err = channel.SaveWithoutKey()
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
|
||||
return false
|
||||
@@ -875,6 +890,17 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
return paramOverride
|
||||
}
|
||||
|
||||
func (channel *Channel) GetHeaderOverride() map[string]interface{} {
|
||||
headerOverride := make(map[string]interface{})
|
||||
if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
|
||||
err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err))
|
||||
}
|
||||
}
|
||||
return headerOverride
|
||||
}
|
||||
|
||||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -150,10 +151,10 @@ type RecordConsumeLogParams struct {
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
|
||||
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(params.Other)
|
||||
// 判断是否需要记录 IP
|
||||
@@ -236,26 +237,22 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
channelIdsMap := make(map[int]struct{})
|
||||
channelMap := make(map[int]string)
|
||||
channelIds := types.NewSet[int]()
|
||||
for _, log := range logs {
|
||||
if log.ChannelId != 0 {
|
||||
channelIdsMap[log.ChannelId] = struct{}{}
|
||||
channelIds.Add(log.ChannelId)
|
||||
}
|
||||
}
|
||||
|
||||
channelIds := make([]int, 0, len(channelIdsMap))
|
||||
for channelId := range channelIdsMap {
|
||||
channelIds = append(channelIds, channelId)
|
||||
}
|
||||
if len(channelIds) > 0 {
|
||||
if channelIds.Len() > 0 {
|
||||
var channels []struct {
|
||||
Id int `gorm:"column:id"`
|
||||
Name string `gorm:"column:name"`
|
||||
}
|
||||
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
|
||||
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
|
||||
return logs, total, err
|
||||
}
|
||||
channelMap := make(map[int]string, len(channels))
|
||||
for _, channel := range channels {
|
||||
channelMap[channel.Id] = channel.Name
|
||||
}
|
||||
|
||||
@@ -64,22 +64,6 @@ var DB *gorm.DB
|
||||
|
||||
var LOG_DB *gorm.DB
|
||||
|
||||
// dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors
|
||||
func dropIndexIfExists(tableName string, indexName string) {
|
||||
if !common.UsingMySQL {
|
||||
return
|
||||
}
|
||||
var count int64
|
||||
// Check index existence via information_schema
|
||||
err := DB.Raw(
|
||||
"SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
|
||||
tableName, indexName,
|
||||
).Scan(&count).Error
|
||||
if err == nil && count > 0 {
|
||||
_ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
|
||||
}
|
||||
}
|
||||
|
||||
func createRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != common.UserStatusEnabled {
|
||||
@@ -263,16 +247,6 @@ func InitLogDB() (err error) {
|
||||
}
|
||||
|
||||
func migrateDB() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引 (model_name, deleted_at) 冲突
|
||||
dropIndexIfExists("models", "uk_model_name") // 新版复合索引名称(若已存在)
|
||||
dropIndexIfExists("models", "model_name") // 旧版列级唯一索引名称
|
||||
|
||||
dropIndexIfExists("vendors", "uk_vendor_name") // 新版复合索引名称(若已存在)
|
||||
dropIndexIfExists("vendors", "name") // 旧版列级唯一索引名称
|
||||
//if !common.UsingPostgreSQL {
|
||||
// return migrateDBFast()
|
||||
//}
|
||||
err := DB.AutoMigrate(
|
||||
&Channel{},
|
||||
&Token{},
|
||||
@@ -299,13 +273,6 @@ func migrateDB() error {
|
||||
}
|
||||
|
||||
func migrateDBFast() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引冲突
|
||||
dropIndexIfExists("models", "uk_model_name")
|
||||
dropIndexIfExists("models", "model_name")
|
||||
|
||||
dropIndexIfExists("vendors", "uk_vendor_name")
|
||||
dropIndexIfExists("vendors", "name")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
|
||||
@@ -20,17 +20,18 @@ type BoundChannel struct {
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Id int `json:"id"`
|
||||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
|
||||
Id int `json:"id"`
|
||||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
SyncOfficial int `json:"sync_official" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
|
||||
|
||||
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||||
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||||
|
||||
@@ -155,9 +155,12 @@ func updatePricing() {
|
||||
vendorMap[vendors[i].Id] = &vendors[i]
|
||||
}
|
||||
|
||||
// 初始化默认供应商映射
|
||||
initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
|
||||
|
||||
// 构建对前端友好的供应商列表
|
||||
vendorsList = make([]PricingVendor, 0, len(vendors))
|
||||
for _, v := range vendors {
|
||||
vendorsList = make([]PricingVendor, 0, len(vendorMap))
|
||||
for _, v := range vendorMap {
|
||||
vendorsList = append(vendorsList, PricingVendor{
|
||||
ID: v.Id,
|
||||
Name: v.Name,
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 简化的供应商映射规则
|
||||
var defaultVendorRules = map[string]string{
|
||||
"gpt": "OpenAI",
|
||||
"dall-e": "OpenAI",
|
||||
"whisper": "OpenAI",
|
||||
"o1": "OpenAI",
|
||||
"o3": "OpenAI",
|
||||
"claude": "Anthropic",
|
||||
"gemini": "Google",
|
||||
"moonshot": "Moonshot",
|
||||
"kimi": "Moonshot",
|
||||
"chatglm": "智谱",
|
||||
"glm-": "智谱",
|
||||
"qwen": "阿里巴巴",
|
||||
"deepseek": "DeepSeek",
|
||||
"abab": "MiniMax",
|
||||
"ernie": "百度",
|
||||
"spark": "讯飞",
|
||||
"hunyuan": "腾讯",
|
||||
"command": "Cohere",
|
||||
"@cf/": "Cloudflare",
|
||||
"360": "360",
|
||||
"yi": "零一万物",
|
||||
"jina": "Jina",
|
||||
"mistral": "Mistral",
|
||||
"grok": "xAI",
|
||||
"llama": "Meta",
|
||||
"doubao": "字节跳动",
|
||||
"kling": "快手",
|
||||
"jimeng": "即梦",
|
||||
"vidu": "Vidu",
|
||||
}
|
||||
|
||||
// 供应商默认图标映射
|
||||
var defaultVendorIcons = map[string]string{
|
||||
"OpenAI": "OpenAI",
|
||||
"Anthropic": "Claude.Color",
|
||||
"Google": "Gemini.Color",
|
||||
"Moonshot": "Moonshot",
|
||||
"智谱": "Zhipu.Color",
|
||||
"阿里巴巴": "Qwen.Color",
|
||||
"DeepSeek": "DeepSeek.Color",
|
||||
"MiniMax": "Minimax.Color",
|
||||
"百度": "Wenxin.Color",
|
||||
"讯飞": "Spark.Color",
|
||||
"腾讯": "Hunyuan.Color",
|
||||
"Cohere": "Cohere.Color",
|
||||
"Cloudflare": "Cloudflare.Color",
|
||||
"360": "Ai360.Color",
|
||||
"零一万物": "Yi.Color",
|
||||
"Jina": "Jina",
|
||||
"Mistral": "Mistral.Color",
|
||||
"xAI": "XAI",
|
||||
"Meta": "Ollama",
|
||||
"字节跳动": "Doubao.Color",
|
||||
"快手": "Kling.Color",
|
||||
"即梦": "Jimeng.Color",
|
||||
"Vidu": "Vidu",
|
||||
"微软": "AzureAI",
|
||||
"Microsoft": "AzureAI",
|
||||
"Azure": "AzureAI",
|
||||
}
|
||||
|
||||
// initDefaultVendorMapping 简化的默认供应商映射
|
||||
func initDefaultVendorMapping(metaMap map[string]*Model, vendorMap map[int]*Vendor, enableAbilities []AbilityWithChannel) {
|
||||
for _, ability := range enableAbilities {
|
||||
modelName := ability.Model
|
||||
if _, exists := metaMap[modelName]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// 匹配供应商
|
||||
vendorID := 0
|
||||
modelLower := strings.ToLower(modelName)
|
||||
for pattern, vendorName := range defaultVendorRules {
|
||||
if strings.Contains(modelLower, pattern) {
|
||||
vendorID = getOrCreateVendor(vendorName, vendorMap)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 创建模型元数据
|
||||
metaMap[modelName] = &Model{
|
||||
ModelName: modelName,
|
||||
VendorID: vendorID,
|
||||
Status: 1,
|
||||
NameRule: NameRuleExact,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查找或创建供应商
|
||||
func getOrCreateVendor(vendorName string, vendorMap map[int]*Vendor) int {
|
||||
// 查找现有供应商
|
||||
for id, vendor := range vendorMap {
|
||||
if vendor.Name == vendorName {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新供应商
|
||||
newVendor := &Vendor{
|
||||
Name: vendorName,
|
||||
Status: 1,
|
||||
Icon: getDefaultVendorIcon(vendorName),
|
||||
}
|
||||
|
||||
if err := newVendor.Insert(); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
vendorMap[newVendor.Id] = newVendor
|
||||
return newVendor.Id
|
||||
}
|
||||
|
||||
// 获取供应商默认图标
|
||||
func getDefaultVendorIcon(vendorName string) string {
|
||||
if icon, exists := defaultVendorIcons[vendorName]; exists {
|
||||
return icon
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -77,7 +77,7 @@ type SyncTaskQueryParams struct {
|
||||
UserIDs []int
|
||||
}
|
||||
|
||||
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
|
||||
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task {
|
||||
t := &Task{
|
||||
UserId: relayInfo.UserId,
|
||||
SubmitTime: time.Now().Unix(),
|
||||
|
||||
@@ -16,7 +16,7 @@ type TwoFA struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
||||
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
||||
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
||||
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
@@ -30,7 +30,7 @@ type TwoFABackupCode struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"not null;index"`
|
||||
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
||||
IsUsed bool `json:"is_used" gorm:"default:false"`
|
||||
IsUsed bool `json:"is_used"`
|
||||
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
@@ -21,12 +21,6 @@ type QuotaData struct {
|
||||
}
|
||||
|
||||
func UpdateQuotaData() {
|
||||
// recover
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
if common.DataExportEnabled {
|
||||
common.SysLog("正在更新数据看板数据...")
|
||||
|
||||
@@ -91,6 +91,68 @@ func (user *User) SetSetting(setting dto.UserSetting) {
|
||||
user.Setting = string(settingBytes)
|
||||
}
|
||||
|
||||
// 根据用户角色生成默认的边栏配置
|
||||
func generateDefaultSidebarConfigForRole(userRole int) string {
|
||||
defaultConfig := map[string]interface{}{}
|
||||
|
||||
// 聊天区域 - 所有用户都可以访问
|
||||
defaultConfig["chat"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"playground": true,
|
||||
"chat": true,
|
||||
}
|
||||
|
||||
// 控制台区域 - 所有用户都可以访问
|
||||
defaultConfig["console"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"detail": true,
|
||||
"token": true,
|
||||
"log": true,
|
||||
"midjourney": true,
|
||||
"task": true,
|
||||
}
|
||||
|
||||
// 个人中心区域 - 所有用户都可以访问
|
||||
defaultConfig["personal"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"topup": true,
|
||||
"personal": true,
|
||||
}
|
||||
|
||||
// 管理员区域 - 根据角色决定
|
||||
if userRole == common.RoleAdminUser {
|
||||
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
}
|
||||
} else if userRole == common.RoleRootUser {
|
||||
// 超级管理员可以访问所有功能
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": true,
|
||||
}
|
||||
}
|
||||
// 普通用户不包含admin区域
|
||||
|
||||
// 转换为JSON字符串
|
||||
configBytes, err := json.Marshal(defaultConfig)
|
||||
if err != nil {
|
||||
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(configBytes)
|
||||
}
|
||||
|
||||
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
||||
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
var user User
|
||||
@@ -320,10 +382,34 @@ func (user *User) Insert(inviterId int) error {
|
||||
user.Quota = common.QuotaForNewUser
|
||||
//user.SetAccessToken(common.GetUUID())
|
||||
user.AffCode = common.GetRandomString(4)
|
||||
|
||||
// 初始化用户设置,包括默认的边栏配置
|
||||
if user.Setting == "" {
|
||||
defaultSetting := dto.UserSetting{}
|
||||
// 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置
|
||||
user.SetSetting(defaultSetting)
|
||||
}
|
||||
|
||||
result := DB.Create(user)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// 用户创建成功后,根据角色初始化边栏配置
|
||||
// 需要重新获取用户以确保有正确的ID和Role
|
||||
var createdUser User
|
||||
if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil {
|
||||
// 生成基于角色的默认边栏配置
|
||||
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
|
||||
if defaultSidebarConfig != "" {
|
||||
currentSetting := createdUser.GetSetting()
|
||||
currentSetting.SidebarModules = defaultSidebarConfig
|
||||
createdUser.SetSetting(currentSetting)
|
||||
createdUser.Update(false)
|
||||
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
|
||||
}
|
||||
}
|
||||
|
||||
if common.QuotaForNewUser > 0 {
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||
}
|
||||
|
||||
@@ -14,13 +14,13 @@ import (
|
||||
|
||||
type Vendor struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
|
||||
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name_delete_at,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name_delete_at,priority:2"`
|
||||
}
|
||||
|
||||
// Insert 创建新的供应商记录
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
@@ -16,12 +17,17 @@ import (
|
||||
func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
audioRequest, ok := info.Request.(*dto.AudioRequest)
|
||||
audioReq, ok := info.Request.(*dto.AudioRequest)
|
||||
if !ok {
|
||||
return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, audioRequest)
|
||||
request, err := common.DeepCopy(audioReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -32,7 +38,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -30,16 +30,16 @@ type Adaptor interface {
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
Init(info *relaycommon.TaskRelayInfo)
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
|
||||
|
||||
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
|
||||
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
|
||||
|
||||
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
||||
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
|
||||
@@ -3,7 +3,6 @@ package ali
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -14,6 +13,8 @@ import (
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -44,6 +45,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesEdits:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
case constant.RelayModeCompletions:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
|
||||
default:
|
||||
@@ -63,6 +66,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
req.Set("Content-Type", "application/json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,8 +99,30 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
aliRequest := oaiImage2Ali(request)
|
||||
return aliRequest, nil
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||
// 如果用户使用表单,则需要解析表单数据
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
aliRequest, err := oaiFormEdit2AliImageEdit(c, info, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image edit form request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -120,15 +151,24 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
case constant.RelayModeImagesEdits:
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return usage, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -3,10 +3,15 @@ package ali
|
||||
import "one-api/dto"
|
||||
|
||||
type AliMessage struct {
|
||||
Content string `json:"content"`
|
||||
Content any `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type AliMediaContent struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
//History []AliMessage `json:"history,omitempty"`
|
||||
@@ -70,13 +75,14 @@ type TaskResult struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []map[string]any `json:"choices,omitempty"`
|
||||
}
|
||||
|
||||
type AliResponse struct {
|
||||
@@ -86,20 +92,26 @@ type AliResponse struct {
|
||||
}
|
||||
|
||||
type AliImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
} `json:"input"`
|
||||
Parameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type AliImageParameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type AliImageInput struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Messages []AliMessage `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
type AliRerankParameters struct {
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -18,15 +20,135 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||
if request.Extra != nil {
|
||||
if val, ok := request.Extra["parameters"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
||||
}
|
||||
}
|
||||
if val, ok := request.Extra["input"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input field: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Parameters == nil {
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
}
|
||||
}
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Input.Prompt = request.Prompt
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
||||
imageRequest.Parameters.N = int(request.N)
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
return &imageRequest
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
mf = c.Request.MultipartForm
|
||||
}
|
||||
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range mf.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
foundArrayImages = true
|
||||
imageFiles = append(imageFiles, files...)
|
||||
}
|
||||
}
|
||||
|
||||
// If no image fields found at all
|
||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(imageFiles) == 0 {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
|
||||
if len(imageFiles) > 1 {
|
||||
return nil, errors.New("only one image is supported for qwen edit")
|
||||
}
|
||||
|
||||
// 获取base64编码的图片
|
||||
var imageBase64s []string
|
||||
for _, file := range imageFiles {
|
||||
image, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open image file")
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
imageData, err := io.ReadAll(image)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read image file")
|
||||
}
|
||||
|
||||
// 获取MIME类型
|
||||
mimeType := http.DetectContentType(imageData)
|
||||
|
||||
// 编码为base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
||||
|
||||
// 构造data URL格式
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
|
||||
imageBase64s = append(imageBase64s, dataURL)
|
||||
image.Close()
|
||||
}
|
||||
|
||||
//dto.MediaContent{}
|
||||
mediaContents := make([]AliMediaContent, len(imageBase64s))
|
||||
for i, b64 := range imageBase64s {
|
||||
mediaContents[i] = AliMediaContent{
|
||||
Image: b64,
|
||||
}
|
||||
}
|
||||
mediaContents = append(mediaContents, AliMediaContent{
|
||||
Text: request.Prompt,
|
||||
})
|
||||
imageRequest.Input = AliImageInput{
|
||||
Messages: []AliMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: mediaContents,
|
||||
},
|
||||
},
|
||||
}
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
||||
@@ -52,7 +174,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
var response AliResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
err = common.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
common.SysLog("updateTask NewDecoder err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
@@ -61,8 +183,8 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
||||
return &response, nil, responseBody
|
||||
}
|
||||
|
||||
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
||||
waitSeconds := 3
|
||||
func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
||||
waitSeconds := 10
|
||||
step := 0
|
||||
maxStep := 20
|
||||
|
||||
@@ -70,11 +192,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
|
||||
var responseBody []byte
|
||||
|
||||
for {
|
||||
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
||||
step++
|
||||
rsp, err, body := updateTask(info, taskID)
|
||||
responseBody = body
|
||||
if err != nil {
|
||||
return &taskResponse, responseBody, err
|
||||
logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
|
||||
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if rsp.Output.TaskStatus == "" {
|
||||
@@ -100,7 +225,7 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
|
||||
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
|
||||
}
|
||||
|
||||
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
|
||||
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
|
||||
imageResponse := dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
}
|
||||
@@ -124,6 +249,9 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(originBody, &mapResponse)
|
||||
imageResponse.Extra = mapResponse
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
@@ -136,7 +264,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
err = common.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -146,7 +274,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
||||
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
@@ -160,13 +288,52 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
c.Writer.Write(jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Message != "" {
|
||||
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
||||
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var fullTextResponse dto.ImageResponse
|
||||
if len(aliResponse.Output.Choices) > 0 {
|
||||
fullTextResponse = dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
Data: []dto.ImageData{
|
||||
{
|
||||
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
||||
B64Json: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(responseBody, &mapResponse)
|
||||
fullTextResponse.Extra = mapResponse
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -47,7 +48,19 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
headers := req.Header
|
||||
headerOverride := make(map[string]string)
|
||||
for k, v := range info.HeadersOverride {
|
||||
if str, ok := v.(string); ok {
|
||||
headerOverride[k] = str
|
||||
} else {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
}
|
||||
for key, value := range headerOverride {
|
||||
headers.Set(key, value)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, &headers, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -72,8 +85,19 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
}
|
||||
// set form data
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
headers := req.Header
|
||||
headerOverride := make(map[string]string)
|
||||
for k, v := range info.HeadersOverride {
|
||||
if str, ok := v.(string); ok {
|
||||
headerOverride[k] = str
|
||||
} else {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
}
|
||||
for key, value := range headerOverride {
|
||||
headers.Set(key, value)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, &headers, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -253,7 +277,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.BuildRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -270,7 +294,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req, info.RelayInfo)
|
||||
resp, err := doRequest(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -130,7 +130,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
// 复制上游 Content-Type 到客户端响应头
|
||||
if awsResp.ContentType != nil && *awsResp.ContentType != "" {
|
||||
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
|
||||
@@ -81,20 +81,23 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-search") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
|
||||
request.Model = info.UpstreamModelName
|
||||
toMap := request.ToMap()
|
||||
toMap["web_search"] = map[string]any{
|
||||
"enable": true,
|
||||
"enable_citation": true,
|
||||
"enable_trace": true,
|
||||
"enable_status": false,
|
||||
if len(request.WebSearch) == 0 {
|
||||
toMap := request.ToMap()
|
||||
toMap["web_search"] = map[string]any{
|
||||
"enable": true,
|
||||
"enable_citation": true,
|
||||
"enable_trace": true,
|
||||
"enable_status": false,
|
||||
}
|
||||
return toMap, nil
|
||||
}
|
||||
return toMap, nil
|
||||
return request, nil
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
|
||||
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
return RequestOpenAI2ClaudeComplete(*request), nil
|
||||
} else {
|
||||
return RequestOpenAI2ClaudeMessage(*request)
|
||||
return RequestOpenAI2ClaudeMessage(c, *request)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,9 +102,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
return ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = ClaudeHandler(c, resp, info, a.RequestMode)
|
||||
return ClaudeHandler(c, resp, info, a.RequestMode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "max_tokens"
|
||||
return "length"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
default:
|
||||
@@ -71,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]any, 0, len(textRequest.Tools))
|
||||
|
||||
for _, tool := range textRequest.Tools {
|
||||
@@ -274,19 +274,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
|
||||
claudeMessages := make([]dto.ClaudeMessage, 0)
|
||||
isFirstMessage := true
|
||||
// 初始化system消息数组,用于累积多个system消息
|
||||
var systemMessages []dto.ClaudeMediaMessage
|
||||
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
// 根据Claude API规范,system字段使用数组格式更有通用性
|
||||
if message.IsStringContent() {
|
||||
claudeRequest.System = message.StringContent()
|
||||
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](message.StringContent()),
|
||||
})
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
content := ""
|
||||
for _, ctx := range contents {
|
||||
// 支持复合内容的system消息(虽然不常见,但需要考虑完整性)
|
||||
for _, ctx := range message.ParseContent() {
|
||||
if ctx.Type == "text" {
|
||||
content += ctx.Text
|
||||
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](ctx.Text),
|
||||
})
|
||||
}
|
||||
// 未来可以在这里扩展对图片等其他类型的支持
|
||||
}
|
||||
claudeRequest.System = content
|
||||
}
|
||||
} else {
|
||||
if isFirstMessage {
|
||||
@@ -355,7 +364,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
@@ -392,6 +401,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置累积的system消息
|
||||
if len(systemMessages) > 0 {
|
||||
claudeRequest.System = systemMessages
|
||||
}
|
||||
|
||||
claudeRequest.Prompt = ""
|
||||
claudeRequest.Messages = claudeMessages
|
||||
return &claudeRequest, nil
|
||||
@@ -426,7 +441,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
|
||||
// 如果是文本块,尽可能发送首段文本(若存在)
|
||||
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
|
||||
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
|
||||
}
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Index: common.GetPointer(fcIdx),
|
||||
@@ -674,7 +692,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
@@ -691,14 +709,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
|
||||
return nil, claudeInfo.Usage
|
||||
return claudeInfo.Usage, nil
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.Unmarshal(data, &claudeResponse)
|
||||
if err != nil {
|
||||
@@ -736,11 +754,11 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, nil, responseData)
|
||||
service.IOCopyBytesGracefully(c, httpResp, responseData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
@@ -752,16 +770,16 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
}
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
|
||||
if handleErr != nil {
|
||||
return handleErr, nil
|
||||
return nil, handleErr
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
return claudeInfo.Usage, nil
|
||||
}
|
||||
|
||||
func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
|
||||
|
||||
@@ -59,15 +59,22 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not supported model for image generation")
|
||||
}
|
||||
|
||||
// convert size to aspect ratio
|
||||
// convert size to aspect ratio but allow user to specify aspect ratio
|
||||
aspectRatio := "1:1" // default aspect ratio
|
||||
switch request.Size {
|
||||
case "1024x1024":
|
||||
aspectRatio = "1:1"
|
||||
case "1024x1792":
|
||||
aspectRatio = "9:16"
|
||||
case "1792x1024":
|
||||
aspectRatio = "16:9"
|
||||
size := strings.TrimSpace(request.Size)
|
||||
if size != "" {
|
||||
if strings.Contains(size, ":") {
|
||||
aspectRatio = size
|
||||
} else {
|
||||
switch size {
|
||||
case "1024x1024":
|
||||
aspectRatio = "1:1"
|
||||
case "1024x1792":
|
||||
aspectRatio = "9:16"
|
||||
case "1792x1024":
|
||||
aspectRatio = "16:9"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// build gemini imagen request
|
||||
@@ -142,7 +149,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
geminiRequest, err := CovertGemini2OpenAI(*request, info)
|
||||
geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -46,6 +46,32 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputCounts := 0
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||
imageOutputCounts++
|
||||
}
|
||||
}
|
||||
}
|
||||
if imageOutputCounts != 0 {
|
||||
usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290
|
||||
usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290
|
||||
c.Set("gemini_image_tokens", imageOutputCounts*1290)
|
||||
}
|
||||
}
|
||||
|
||||
// if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
// for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
// if detail.Modality == "IMAGE" {
|
||||
// usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
// usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
// c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
@@ -136,6 +162,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
if detail.Modality == "IMAGE" {
|
||||
usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
|
||||
@@ -178,7 +178,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
@@ -390,7 +390,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
|
||||
// 是url,获取文件的类型和base64编码的数据
|
||||
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
|
||||
}
|
||||
@@ -749,7 +749,16 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
var texts []string
|
||||
var toolCalls []dto.ToolCallResponse
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
if part.InlineData != nil {
|
||||
// 媒体内容
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
} else {
|
||||
// 其他媒体类型,直接显示链接
|
||||
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
if call := getResponseToolCall(&part); call != nil {
|
||||
toolCalls = append(toolCalls, *call)
|
||||
@@ -935,7 +944,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
|
||||
@@ -953,6 +962,11 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
if response.IsFinished() {
|
||||
response.Choices[0].FinishReason = nil
|
||||
}
|
||||
} else {
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,4 +6,4 @@ var ModelList = []string{
|
||||
"m3e-small",
|
||||
}
|
||||
|
||||
var ChannelName = "mokaai"
|
||||
var ChannelName = "mokaai"
|
||||
|
||||
@@ -89,17 +89,16 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -31,7 +31,7 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Ollama(request)
|
||||
return requestOpenAI2Ollama(c, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
@@ -24,7 +24,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
// check if not base64
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -68,9 +68,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
|
||||
StreamOptions: request.StreamOptions,
|
||||
Suffix: request.Suffix,
|
||||
}
|
||||
if think, ok := request.Extra["think"]; ok {
|
||||
ollamaRequest.Think = think
|
||||
}
|
||||
ollamaRequest.Think = request.Think
|
||||
return ollamaRequest, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -237,6 +237,8 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
request.Reasoning = marshal
|
||||
}
|
||||
// 清空多余的ReasoningEffort
|
||||
request.ReasoningEffort = ""
|
||||
} else {
|
||||
if len(request.Reasoning) == 0 {
|
||||
// 适配 OpenAI 的 ReasoningEffort 格式
|
||||
@@ -254,7 +256,40 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
}
|
||||
}
|
||||
request.ReasoningEffort = ""
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/en/api/openai-sdk#extended-thinking-support
|
||||
// 没有做排除3.5Haiku等,要出问题再加吧,最佳兼容性(不是
|
||||
if request.THINKING != nil && strings.HasPrefix(info.UpstreamModelName, "anthropic") {
|
||||
var thinking dto.Thinking // Claude标准Thinking格式
|
||||
if err := json.Unmarshal(request.THINKING, &thinking); err != nil {
|
||||
return nil, fmt.Errorf("error Unmarshal thinking: %w", err)
|
||||
}
|
||||
|
||||
// 只有当 thinking.Type 是 "enabled" 时才处理
|
||||
if thinking.Type == "enabled" {
|
||||
// 检查 BudgetTokens 是否为 nil
|
||||
if thinking.BudgetTokens == nil {
|
||||
return nil, fmt.Errorf("BudgetTokens is nil when thinking is enabled")
|
||||
}
|
||||
|
||||
reasoning := openrouter.RequestReasoning{
|
||||
MaxTokens: *thinking.BudgetTokens,
|
||||
}
|
||||
|
||||
marshal, err := common.Marshal(reasoning)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling reasoning: %w", err)
|
||||
}
|
||||
|
||||
request.Reasoning = marshal
|
||||
}
|
||||
|
||||
// 清空 THINKING
|
||||
request.THINKING = nil
|
||||
}
|
||||
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
@@ -503,7 +538,13 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
|
||||
if effort != "" {
|
||||
request.Reasoning.Effort = effort
|
||||
if request.Reasoning == nil {
|
||||
request.Reasoning = &dto.Reasoning{
|
||||
Effort: effort,
|
||||
}
|
||||
} else {
|
||||
request.Reasoning.Effort = effort
|
||||
}
|
||||
request.Model = originModel
|
||||
}
|
||||
return request, nil
|
||||
|
||||
@@ -12,16 +12,25 @@ var ModelList = []string{
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
|
||||
"gpt-4.1", "gpt-4.1-2025-04-14",
|
||||
"gpt-4.1-mini", "gpt-4.1-mini-2025-04-14",
|
||||
"gpt-4.1-nano", "gpt-4.1-nano-2025-04-14",
|
||||
"o1", "o1-2024-12-17",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"o1-pro", "o1-pro-2025-03-19",
|
||||
"o3-mini", "o3-mini-2025-01-31",
|
||||
"o3-mini-high", "o3-mini-2025-01-31-high",
|
||||
"o3-mini-low", "o3-mini-2025-01-31-low",
|
||||
"o3-mini-medium", "o3-mini-2025-01-31-medium",
|
||||
"o3", "o3-2025-04-16",
|
||||
"o3-pro", "o3-pro-2025-06-10",
|
||||
"o3-deep-research", "o3-deep-research-2025-06-26",
|
||||
"o4-mini", "o4-mini-2025-04-16",
|
||||
"o4-mini-deep-research", "o4-mini-deep-research-2025-06-26",
|
||||
"gpt-5", "gpt-5-2025-08-07", "gpt-5-chat-latest",
|
||||
"gpt-5-mini", "gpt-5-mini-2025-08-07",
|
||||
"gpt-5-nano", "gpt-5-nano-2025-08-07",
|
||||
"o1", "o1-2024-12-17",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
|
||||
"gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
|
||||
@@ -30,7 +39,7 @@ var ModelList = []string{
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-3",
|
||||
"dall-e-3", "gpt-image-1",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
||||
|
||||
@@ -2,9 +2,6 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
@@ -15,6 +12,8 @@ import (
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -71,11 +70,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
||||
|
||||
// send gemini format response
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
_ = helper.FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -253,9 +248,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
|
||||
// 发送最终的 Gemini 响应
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
_ = helper.FlushWriter(c)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@@ -197,21 +198,34 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
forceFormat = true
|
||||
}
|
||||
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
usageModified := false
|
||||
if simpleResponse.Usage.PromptTokens == 0 {
|
||||
completionTokens := simpleResponse.Usage.CompletionTokens
|
||||
if completionTokens == 0 {
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
}
|
||||
usageModified = true
|
||||
}
|
||||
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
if usageModified {
|
||||
var bodyMap map[string]interface{}
|
||||
err = common.Unmarshal(responseBody, &bodyMap)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
bodyMap["usage"] = simpleResponse.Usage
|
||||
responseBody, _ = common.Marshal(bodyMap)
|
||||
}
|
||||
if forceFormat {
|
||||
responseBody, err = common.Marshal(simpleResponse)
|
||||
if err != nil {
|
||||
@@ -267,11 +281,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// count tokens by audio file duration
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
@@ -279,6 +288,26 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
var responseData struct {
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
||||
if responseData.Usage.TotalTokens > 0 {
|
||||
usage := responseData.Usage
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
}
|
||||
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||
}
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = audioTokens
|
||||
usage.CompletionTokens = 0
|
||||
|
||||
@@ -46,9 +46,17 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil {
|
||||
return &usage, nil
|
||||
}
|
||||
// 解析 Tools 用量
|
||||
for _, tool := range responsesResponse.Tools {
|
||||
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
|
||||
buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])]
|
||||
if !ok || buildToolinfo == nil {
|
||||
logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"]))
|
||||
continue
|
||||
}
|
||||
buildToolinfo.CallCount++
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
@@ -72,10 +80,16 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response.Usage != nil {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
@@ -92,6 +106,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -106,5 +122,11 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
}
|
||||
}
|
||||
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ type TaskAdaptor struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
|
||||
@@ -87,7 +87,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
@@ -108,19 +108,19 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
@@ -139,12 +139,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
@@ -37,15 +38,46 @@ type SubmitReq struct {
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TrajectoryPoint struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
}
|
||||
|
||||
type DynamicMask struct {
|
||||
Mask string `json:"mask,omitempty"`
|
||||
Trajectories []TrajectoryPoint `json:"trajectories,omitempty"`
|
||||
}
|
||||
|
||||
type CameraConfig struct {
|
||||
Horizontal float64 `json:"horizontal,omitempty"`
|
||||
Vertical float64 `json:"vertical,omitempty"`
|
||||
Pan float64 `json:"pan,omitempty"`
|
||||
Tilt float64 `json:"tilt,omitempty"`
|
||||
Roll float64 `json:"roll,omitempty"`
|
||||
Zoom float64 `json:"zoom,omitempty"`
|
||||
}
|
||||
|
||||
type CameraControl struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Config *CameraConfig `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
ImageTail string `json:"image_tail,omitempty"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
StaticMask string `json:"static_mask,omitempty"`
|
||||
DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"`
|
||||
CameraControl *CameraControl `json:"camera_control,omitempty"`
|
||||
CallbackUrl string `json:"callback_url,omitempty"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
@@ -79,7 +111,7 @@ type TaskAdaptor struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
@@ -88,7 +120,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
@@ -109,13 +141,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
token, err := a.createJWTToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JWT token: %w", err)
|
||||
@@ -129,7 +161,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Kling specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
@@ -140,6 +172,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if body.Image == "" && body.ImageTail == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -148,7 +183,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if action := c.GetString("action"); action != "" {
|
||||
info.Action = action
|
||||
}
|
||||
@@ -156,7 +191,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo,
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
@@ -222,14 +257,19 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
|
||||
CfgScale: 0.5,
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
|
||||
CfgScale: 0.5,
|
||||
StaticMask: "",
|
||||
DynamicMasks: []DynamicMask{},
|
||||
CameraControl: nil,
|
||||
CallbackUrl: "",
|
||||
ExternalTaskId: "",
|
||||
}
|
||||
if r.ModelName == "" {
|
||||
r.ModelName = "kling-v1"
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -16,6 +15,8 @@ import (
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TaskAdaptor struct {
|
||||
@@ -26,11 +27,11 @@ func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
|
||||
return nil, fmt.Errorf("not implement") // todo implement this method if needed
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
action := strings.ToUpper(c.Param("action"))
|
||||
|
||||
var sunoRequest *dto.SunoSubmitReq
|
||||
@@ -58,20 +59,20 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseURL := info.ChannelBaseUrl
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
sunoRequest, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||
@@ -86,11 +87,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -84,12 +84,12 @@ type TaskAdaptor struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
@@ -109,7 +109,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
@@ -132,7 +132,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayI
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var path string
|
||||
switch info.Action {
|
||||
case constant.TaskActionGenerate:
|
||||
@@ -143,21 +143,21 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string,
|
||||
return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if action := c.GetString("action"); action != "" {
|
||||
info.Action = action
|
||||
}
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -66,8 +66,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
geminiAdaptor := gemini.Adaptor{}
|
||||
return geminiAdaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -181,8 +181,62 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
prompt := ""
|
||||
for _, m := range request.Messages {
|
||||
if m.Role == "user" {
|
||||
prompt = m.StringContent()
|
||||
if prompt != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if prompt == "" {
|
||||
if p, ok := request.Prompt.(string); ok {
|
||||
prompt = p
|
||||
}
|
||||
}
|
||||
if prompt == "" {
|
||||
return nil, errors.New("prompt is required for image generation")
|
||||
}
|
||||
|
||||
imgReq := dto.ImageRequest{
|
||||
Model: request.Model,
|
||||
Prompt: prompt,
|
||||
N: 1,
|
||||
Size: "1024x1024",
|
||||
}
|
||||
if request.N > 0 {
|
||||
imgReq.N = uint(request.N)
|
||||
}
|
||||
if request.Size != "" {
|
||||
imgReq.Size = request.Size
|
||||
}
|
||||
if len(request.ExtraBody) > 0 {
|
||||
var extra map[string]any
|
||||
if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
|
||||
if n, ok := extra["n"].(float64); ok && n > 0 {
|
||||
imgReq.N = uint(n)
|
||||
}
|
||||
if size, ok := extra["size"].(string); ok {
|
||||
imgReq.Size = size
|
||||
}
|
||||
// accept aspectRatio in extra body (top-level or under parameters)
|
||||
if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
|
||||
imgReq.Size = ar
|
||||
}
|
||||
if params, ok := extra["parameters"].(map[string]any); ok {
|
||||
if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
|
||||
imgReq.Size = ar
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
return a.ConvertImageRequest(c, info, imgReq)
|
||||
}
|
||||
if a.RequestMode == RequestModeClaude {
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -191,7 +245,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -225,31 +279,31 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = gemini.GeminiChatStreamHandler(c, info, resp)
|
||||
return gemini.GeminiChatStreamHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
return openai.OaiStreamHandler(c, info, resp)
|
||||
}
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
return gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
} else {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return gemini.GeminiImageHandler(c, info, resp)
|
||||
}
|
||||
usage, err = gemini.GeminiChatHandler(c, info, resp)
|
||||
return gemini.GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
return openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@ package volcengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -214,6 +215,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
// 适配 方舟deepseek混合模型 的 thinking 后缀
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
request.Model = info.UpstreamModelName
|
||||
request.THINKING = json.RawMessage(`{"type": "enabled"}`)
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
@@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -43,12 +42,16 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.ChannelBaseUrl)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/embeddings", baseUrl), nil
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/api/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/chat/completions", baseUrl), nil
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/paas/v4/embeddings", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,12 +89,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -21,13 +21,18 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
textRequest, ok := info.Request.(*dto.ClaudeRequest)
|
||||
claudeReq, ok := info.Request.(*dto.ClaudeRequest)
|
||||
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, textRequest)
|
||||
request, err := common.DeepCopy(claudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -38,30 +43,30 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
if request.MaxTokens == 0 {
|
||||
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
if textRequest.Thinking == nil {
|
||||
strings.HasSuffix(request.Model, "-thinking") {
|
||||
if request.Thinking == nil {
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
if textRequest.MaxTokens < 1280 {
|
||||
textRequest.MaxTokens = 1280
|
||||
if request.MaxTokens < 1280 {
|
||||
request.MaxTokens = 1280
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
request.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
textRequest.TopP = 0
|
||||
textRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
request.TopP = 0
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
info.UpstreamModelName = textRequest.Model
|
||||
request.Model = strings.TrimSuffix(request.Model, "-thinking")
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
@@ -72,7 +77,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -83,12 +88,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range info.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,435 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ConditionOperation struct {
|
||||
Path string `json:"path"` // JSON路径
|
||||
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
|
||||
Value interface{} `json:"value"` // 匹配的值
|
||||
Invert bool `json:"invert"` // 反选功能,true表示取反结果
|
||||
PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
|
||||
}
|
||||
|
||||
type ParamOperation struct {
|
||||
Path string `json:"path"`
|
||||
Mode string `json:"mode"` // delete, set, move, prepend, append
|
||||
Value interface{} `json:"value"`
|
||||
KeepOrigin bool `json:"keep_origin"`
|
||||
From string `json:"from,omitempty"`
|
||||
To string `json:"to,omitempty"`
|
||||
Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
|
||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||
}
|
||||
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
// 尝试断言为操作格式
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
// 使用新方法
|
||||
result, err := applyOperations(string(jsonData), operations)
|
||||
return []byte(result), err
|
||||
}
|
||||
|
||||
// 直接使用旧方法
|
||||
return applyOperationsLegacy(jsonData, paramOverride)
|
||||
}
|
||||
|
||||
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
|
||||
// 检查是否包含 "operations" 字段
|
||||
if opsValue, exists := paramOverride["operations"]; exists {
|
||||
if opsSlice, ok := opsValue.([]interface{}); ok {
|
||||
var operations []ParamOperation
|
||||
for _, op := range opsSlice {
|
||||
if opMap, ok := op.(map[string]interface{}); ok {
|
||||
operation := ParamOperation{}
|
||||
|
||||
// 断言必要字段
|
||||
if path, ok := opMap["path"].(string); ok {
|
||||
operation.Path = path
|
||||
}
|
||||
if mode, ok := opMap["mode"].(string); ok {
|
||||
operation.Mode = mode
|
||||
} else {
|
||||
return nil, false // mode 是必需的
|
||||
}
|
||||
|
||||
// 可选字段
|
||||
if value, exists := opMap["value"]; exists {
|
||||
operation.Value = value
|
||||
}
|
||||
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
|
||||
operation.KeepOrigin = keepOrigin
|
||||
}
|
||||
if from, ok := opMap["from"].(string); ok {
|
||||
operation.From = from
|
||||
}
|
||||
if to, ok := opMap["to"].(string); ok {
|
||||
operation.To = to
|
||||
}
|
||||
if logic, ok := opMap["logic"].(string); ok {
|
||||
operation.Logic = logic
|
||||
} else {
|
||||
operation.Logic = "OR" // 默认为OR
|
||||
}
|
||||
|
||||
// 解析条件
|
||||
if conditions, exists := opMap["conditions"]; exists {
|
||||
if condSlice, ok := conditions.([]interface{}); ok {
|
||||
for _, cond := range condSlice {
|
||||
if condMap, ok := cond.(map[string]interface{}); ok {
|
||||
condition := ConditionOperation{}
|
||||
if path, ok := condMap["path"].(string); ok {
|
||||
condition.Path = path
|
||||
}
|
||||
if mode, ok := condMap["mode"].(string); ok {
|
||||
condition.Mode = mode
|
||||
}
|
||||
if value, ok := condMap["value"]; ok {
|
||||
condition.Value = value
|
||||
}
|
||||
if invert, ok := condMap["invert"].(bool); ok {
|
||||
condition.Invert = invert
|
||||
}
|
||||
if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
|
||||
condition.PassMissingKey = passMissingKey
|
||||
}
|
||||
operation.Conditions = append(operation.Conditions, condition)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
operations = append(operations, operation)
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return operations, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
if len(conditions) == 0 {
|
||||
return true, nil // 没有条件,直接通过
|
||||
}
|
||||
results := make([]bool, len(conditions))
|
||||
for i, condition := range conditions {
|
||||
result, err := checkSingleCondition(jsonStr, condition)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
results[i] = result
|
||||
}
|
||||
|
||||
if strings.ToUpper(logic) == "AND" {
|
||||
for _, result := range results {
|
||||
if !result {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
} else {
|
||||
for _, result := range results {
|
||||
if result {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||
// 处理负数索引
|
||||
path := processNegativeIndex(jsonStr, condition.Path)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
if !value.Exists() {
|
||||
if condition.PassMissingKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 利用gjson的类型解析
|
||||
targetBytes, err := json.Marshal(condition.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
||||
}
|
||||
targetValue := gjson.ParseBytes(targetBytes)
|
||||
|
||||
result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
|
||||
}
|
||||
|
||||
if condition.Invert {
|
||||
result = !result
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func processNegativeIndex(jsonStr string, path string) string {
|
||||
re := regexp.MustCompile(`\.(-\d+)`)
|
||||
matches := re.FindAllStringSubmatch(path, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return path
|
||||
}
|
||||
|
||||
result := path
|
||||
for _, match := range matches {
|
||||
negIndex := match[1]
|
||||
index, _ := strconv.Atoi(negIndex)
|
||||
|
||||
arrayPath := strings.Split(path, negIndex)[0]
|
||||
if strings.HasSuffix(arrayPath, ".") {
|
||||
arrayPath = arrayPath[:len(arrayPath)-1]
|
||||
}
|
||||
|
||||
array := gjson.Get(jsonStr, arrayPath)
|
||||
if array.IsArray() {
|
||||
length := len(array.Array())
|
||||
actualIndex := length + index
|
||||
if actualIndex >= 0 && actualIndex < length {
|
||||
result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
|
||||
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
|
||||
switch mode {
|
||||
case "full":
|
||||
return compareEqual(jsonValue, targetValue)
|
||||
case "prefix":
|
||||
return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
|
||||
case "suffix":
|
||||
return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
|
||||
case "contains":
|
||||
return strings.Contains(jsonValue.String(), targetValue.String()), nil
|
||||
case "gt":
|
||||
return compareNumeric(jsonValue, targetValue, "gt")
|
||||
case "gte":
|
||||
return compareNumeric(jsonValue, targetValue, "gte")
|
||||
case "lt":
|
||||
return compareNumeric(jsonValue, targetValue, "lt")
|
||||
case "lte":
|
||||
return compareNumeric(jsonValue, targetValue, "lte")
|
||||
default:
|
||||
return false, fmt.Errorf("unsupported comparison mode: %s", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
|
||||
// 对布尔值特殊处理
|
||||
if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
|
||||
(targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
|
||||
return jsonValue.Bool() == targetValue.Bool(), nil
|
||||
}
|
||||
|
||||
// 如果类型不同,报错
|
||||
if jsonValue.Type != targetValue.Type {
|
||||
return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
|
||||
}
|
||||
|
||||
switch jsonValue.Type {
|
||||
case gjson.True, gjson.False:
|
||||
return jsonValue.Bool() == targetValue.Bool(), nil
|
||||
case gjson.Number:
|
||||
return jsonValue.Num == targetValue.Num, nil
|
||||
case gjson.String:
|
||||
return jsonValue.String() == targetValue.String(), nil
|
||||
default:
|
||||
return jsonValue.String() == targetValue.String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
|
||||
// 只有数字类型才支持数值比较
|
||||
if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
|
||||
return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
|
||||
}
|
||||
|
||||
jsonNum := jsonValue.Num
|
||||
targetNum := targetValue.Num
|
||||
|
||||
switch operator {
|
||||
case "gt":
|
||||
return jsonNum > targetNum, nil
|
||||
case "gte":
|
||||
return jsonNum >= targetNum, nil
|
||||
case "lt":
|
||||
return jsonNum < targetNum, nil
|
||||
case "lte":
|
||||
return jsonNum <= targetNum, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unsupported numeric operator: %s", operator)
|
||||
}
|
||||
}
|
||||
|
||||
// applyOperationsLegacy 原参数覆盖方法
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
reqMap := make(map[string]interface{})
|
||||
err := json.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, value := range paramOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
|
||||
result := jsonStr
|
||||
for _, op := range operations {
|
||||
// 检查条件是否满足
|
||||
ok, err := checkConditions(result, op.Conditions, op.Logic)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
continue // 条件不满足,跳过当前操作
|
||||
}
|
||||
// 处理路径中的负数索引
|
||||
opPath := processNegativeIndex(result, op.Path)
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
|
||||
switch op.Mode {
|
||||
case "delete":
|
||||
result, err = sjson.Delete(result, opPath)
|
||||
case "set":
|
||||
if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
|
||||
continue
|
||||
}
|
||||
result, err = sjson.Set(result, opPath, op.Value)
|
||||
case "move":
|
||||
result, err = moveValue(result, opFrom, opTo)
|
||||
case "prepend":
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
|
||||
case "append":
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||
if !sourceValue.Exists() {
|
||||
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
|
||||
}
|
||||
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sjson.Delete(result, fromPath)
|
||||
}
|
||||
|
||||
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
switch {
|
||||
case current.IsArray():
|
||||
return modifyArray(jsonStr, path, value, isPrepend)
|
||||
case current.Type == gjson.String:
|
||||
return modifyString(jsonStr, path, value, isPrepend)
|
||||
case current.Type == gjson.JSON:
|
||||
return mergeObjects(jsonStr, path, value, keepOrigin)
|
||||
}
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
|
||||
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
var newArray []interface{}
|
||||
// 添加新值
|
||||
addValue := func() {
|
||||
if arr, ok := value.([]interface{}); ok {
|
||||
newArray = append(newArray, arr...)
|
||||
} else {
|
||||
newArray = append(newArray, value)
|
||||
}
|
||||
}
|
||||
// 添加原值
|
||||
addOriginal := func() {
|
||||
current.ForEach(func(_, val gjson.Result) bool {
|
||||
newArray = append(newArray, val.Value())
|
||||
return true
|
||||
})
|
||||
}
|
||||
if isPrepend {
|
||||
addValue()
|
||||
addOriginal()
|
||||
} else {
|
||||
addOriginal()
|
||||
addValue()
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newArray)
|
||||
}
|
||||
|
||||
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
var newStr string
|
||||
if isPrepend {
|
||||
newStr = valueStr + current.String()
|
||||
} else {
|
||||
newStr = current.String() + valueStr
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newStr)
|
||||
}
|
||||
|
||||
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
var currentMap, newMap map[string]interface{}
|
||||
|
||||
// 解析当前值
|
||||
if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 解析新值
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}:
|
||||
newMap = v
|
||||
default:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
// 合并
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range currentMap {
|
||||
result[k] = v
|
||||
}
|
||||
for k, v := range newMap {
|
||||
if !keepOrigin || result[k] == nil {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return sjson.Set(jsonStr, path, result)
|
||||
}
|
||||
@@ -63,6 +63,7 @@ type ChannelMeta struct {
|
||||
Organization string
|
||||
ChannelCreateTime int64
|
||||
ParamOverride map[string]interface{}
|
||||
HeadersOverride map[string]interface{}
|
||||
ChannelSetting dto.ChannelSettings
|
||||
ChannelOtherSettings dto.ChannelOtherSettings
|
||||
UpstreamModelName string
|
||||
@@ -115,11 +116,13 @@ type RelayInfo struct {
|
||||
*RerankerInfo
|
||||
*ResponsesUsageInfo
|
||||
*ChannelMeta
|
||||
*TaskRelayInfo
|
||||
}
|
||||
|
||||
func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
channelMeta := &ChannelMeta{
|
||||
ChannelType: channelType,
|
||||
@@ -133,11 +136,19 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||
ParamOverride: paramOverride,
|
||||
HeadersOverride: headerOverride,
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
IsModelMapped: false,
|
||||
SupportStreamOptions: false,
|
||||
}
|
||||
|
||||
if channelType == constant.ChannelTypeAzure {
|
||||
channelMeta.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if channelType == constant.ChannelTypeVertexAi {
|
||||
channelMeta.ApiVersion = c.GetString("region")
|
||||
}
|
||||
|
||||
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
||||
if ok {
|
||||
channelMeta.ChannelSetting = channelSetting
|
||||
@@ -151,7 +162,14 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||
if streamSupportedChannels[channelMeta.ChannelType] {
|
||||
channelMeta.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
info.ChannelMeta = channelMeta
|
||||
|
||||
// reset some fields based on channel meta
|
||||
// 重置某些字段,例如模型名称等
|
||||
if info.Request != nil {
|
||||
info.Request.SetModelName(info.OriginModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) ToString() string {
|
||||
@@ -296,7 +314,7 @@ func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest)
|
||||
BuiltInTools: make(map[string]*BuildInToolInfo),
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
for _, tool := range request.Tools {
|
||||
for _, tool := range request.GetToolsMap() {
|
||||
toolType := common.Interface2String(tool["type"])
|
||||
info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
|
||||
ToolName: toolType,
|
||||
@@ -383,6 +401,10 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
},
|
||||
}
|
||||
|
||||
if info.RelayMode == relayconstant.RelayModeUnknown {
|
||||
info.RelayMode = c.GetInt("relay_mode")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
|
||||
@@ -448,24 +470,12 @@ func (info *RelayInfo) HasSendResponse() bool {
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
*RelayInfo
|
||||
Action string
|
||||
OriginTaskID string
|
||||
|
||||
ConsumeQuota bool
|
||||
}
|
||||
|
||||
func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) {
|
||||
relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info := &TaskRelayInfo{
|
||||
RelayInfo: relayInfo,
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
|
||||
@@ -2,12 +2,10 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
|
||||
@@ -25,38 +25,40 @@ import (
|
||||
)
|
||||
|
||||
func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||||
|
||||
textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||||
if !ok {
|
||||
//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
request, err := common.DeepCopy(textReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, textRequest)
|
||||
if request.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
includeUsage := true
|
||||
// 判断用户是否需要返回使用情况
|
||||
if textRequest.StreamOptions != nil {
|
||||
includeUsage = textRequest.StreamOptions.IncludeUsage
|
||||
if request.StreamOptions != nil {
|
||||
includeUsage = request.StreamOptions.IncludeUsage
|
||||
}
|
||||
|
||||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||
if !info.SupportStreamOptions || !textRequest.Stream {
|
||||
textRequest.StreamOptions = nil
|
||||
if !info.SupportStreamOptions || !request.Stream {
|
||||
request.StreamOptions = nil
|
||||
} else {
|
||||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||
if constant.ForceStreamOption {
|
||||
textRequest.StreamOptions = &dto.StreamOptions{
|
||||
request.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
@@ -81,7 +83,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -128,17 +130,12 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range info.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -317,11 +314,22 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
} else {
|
||||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
}
|
||||
var dGeminiImageOutputQuota decimal.Decimal
|
||||
var imageOutputPrice float64
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName)
|
||||
if imageOutputPrice > 0 {
|
||||
dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens")))
|
||||
dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
// 添加 responses tools call 调用的配额
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 Gemini image output 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota)
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
@@ -416,6 +424,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dGeminiImageOutputQuota.IsZero() {
|
||||
other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens")
|
||||
other["image_output_price"] = imageOutputPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
@@ -84,6 +84,8 @@ func Path2RelayMode(path string) int {
|
||||
relayMode = RelayModeRealtime
|
||||
} else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
|
||||
relayMode = RelayModeGemini
|
||||
} else if strings.HasPrefix(path, "/mj") {
|
||||
relayMode = Path2RelayModeMidjourney(path)
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -16,15 +16,19 @@ import (
|
||||
)
|
||||
|
||||
func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
|
||||
embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, embeddingRequest)
|
||||
request, err := common.DeepCopy(embeddingReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -35,7 +39,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -53,13 +53,18 @@ func trimModelThinking(modelName string) string {
|
||||
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
request, ok := info.Request.(*dto.GeminiChatRequest)
|
||||
geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
request, err := common.DeepCopy(geminiReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err := helper.ModelMappedHelper(c, info, request)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -123,12 +128,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range info.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
||||
info.IsGeminiBatchEmbedding = isBatch
|
||||
|
||||
var req any
|
||||
var req dto.Request
|
||||
var err error
|
||||
var inputTexts []string
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -14,73 +13,68 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func FlushWriter(c *gin.Context) error {
|
||||
if c.Writer == nil {
|
||||
return nil
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
// 检查是否已经设置过头部
|
||||
if _, exists := c.Get("event_stream_headers_set"); exists {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置标志,表示头部已经设置过
|
||||
c.Set("event_stream_headers_set", true)
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
// 设置标志,表示头部已经设置过
|
||||
c.Set("event_stream_headers_set", true)
|
||||
}
|
||||
|
||||
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
||||
jsonData, err := json.Marshal(resp)
|
||||
jsonData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
} else {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
}
|
||||
|
||||
func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
}
|
||||
|
||||
func StringData(c *gin.Context, str string) error {
|
||||
//str = strings.TrimPrefix(str, "data: ")
|
||||
//str = strings.TrimSuffix(str, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func PingData(c *gin.Context) error {
|
||||
c.Writer.Write([]byte(": PING\n\n"))
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -109,7 +103,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
|
||||
}
|
||||
|
||||
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
||||
jsonData, err := json.Marshal(object)
|
||||
jsonData, err := common.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,15 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/dto"
|
||||
common2 "one-api/logger"
|
||||
"one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/dto"
|
||||
"one-api/relay/common"
|
||||
)
|
||||
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
@@ -54,40 +51,7 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
|
||||
}
|
||||
}
|
||||
if request != nil {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatGemini:
|
||||
// Gemini 模型映射
|
||||
case types.RelayFormatClaude:
|
||||
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
|
||||
claudeRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
openAIResponsesRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIAudio:
|
||||
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
|
||||
openAIAudioRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIImage:
|
||||
if imageRequest, ok := request.(*dto.ImageRequest); ok {
|
||||
imageRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatRerank:
|
||||
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
|
||||
rerankRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatEmbedding:
|
||||
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
|
||||
embeddingRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
default:
|
||||
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||
openAIRequest.Model = info.UpstreamModelName
|
||||
} else {
|
||||
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
|
||||
}
|
||||
}
|
||||
request.SetModelName(info.UpstreamModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||