Compare commits

...

30 Commits

Author SHA1 Message Date
1808837298@qq.com ac00e9bbb3 init openrouter adaptor 2025-02-27 00:01:21 +08:00
1808837298@qq.com 0646fa1892 fix: gemini&claude tool call format #795 #766 2025-02-26 23:56:10 +08:00
1808837298@qq.com 23de62ec0d fix: claude tool call format #795 #766 2025-02-26 23:40:16 +08:00
1808837298@qq.com c3b0e57ea4 feat: Add Jina reranking support for OpenAI adaptor 2025-02-26 21:46:06 +08:00
1808837298@qq.com ce03e77906 fix: Update Gemini safety settings to use 'OFF' as default 2025-02-26 19:20:17 +08:00
1808837298@qq.com 832f4b2b1a fix: Update Gemini safety settings category 2025-02-26 19:18:00 +08:00
1808837298@qq.com 7100c787d4 fix: Update Gemini safety settings default value 2025-02-26 19:01:45 +08:00
1808837298@qq.com 8a30d64a75 feat: Add Gemini version settings configuration support (close #568) 2025-02-26 18:19:09 +08:00
1808837298@qq.com 0a369cc193 feat: Add Gemini safety settings configuration support (close #703) 2025-02-26 16:54:43 +08:00
1808837298@qq.com 5ba44f5ad5 feat: Update Claude relay temperature setting 2025-02-25 22:01:05 +08:00
1808837298@qq.com d04d78a116 refactor: Enhance user context and quota management
- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
2025-02-25 20:56:16 +08:00
1808837298@qq.com 8c2323d74d feat: redis poolsize 2025-02-25 19:39:29 +08:00
1808837298@qq.com 583678d9ff fix: Adjust Claude thinking mode request parameters 2025-02-25 16:52:45 +08:00
1808837298@qq.com fd38e59f78 docs: Update README 2025-02-25 16:31:42 +08:00
Calcium-Ion f5cbab77cf Merge pull request #788 from MartialBE/main
feat: Add Claude 3.7 Sonnet thinking mode support
2025-02-25 15:21:39 +08:00
1808837298@qq.com d4706d6b8e Merge branch 'main' into thinking
# Conflicts:
#	relay/channel/claude/dto.go
2025-02-25 15:21:22 +08:00
1808837298@qq.com 6c8016e5f8 feat: Add support for Claude thinking parameter in request 2025-02-25 14:37:03 +08:00
MartialBE 7160012fe2 feat: Add Claude 3.7 Sonnet thinking mode support 2025-02-25 14:10:43 +08:00
1808837298@qq.com c62276fcc4 feat: Add Claude 3.7 Sonnet model to AWS channel mapping 2025-02-25 02:55:23 +08:00
1808837298@qq.com 15a3b44689 feat: Add support for Claude 3.7 Sonnet model 2025-02-25 02:51:31 +08:00
1808837298@qq.com 8f3c7280cf feat: Support max_tokens parameter for Ollama channel #782 2025-02-24 17:35:49 +08:00
Calcium-Ion e5e73a33f0 Merge pull request #781 from zeyugao/main
feat: Pass extra_body in OpenAI request to the backend
2025-02-24 16:29:48 +08:00
Calcium-Ion 2d15f63eaa Merge pull request #783 from Calcium-Ion/rate-limit
feat: Add model request rate limiting functionality
2025-02-24 16:29:23 +08:00
1808837298@qq.com 6f3072895a feat: Add model rate limit settings in system configuration 2025-02-24 16:27:20 +08:00
1808837298@qq.com 1763145fea feat: Add model request rate limiting functionality 2025-02-24 16:20:55 +08:00
1808837298@qq.com 66831a1bde feat: Add support for different Dify bot types and request URLs 2025-02-24 14:18:30 +08:00
1808837298@qq.com fd44ac7c0c feat: Enhance token counting and content parsing for messages 2025-02-24 14:18:15 +08:00
Elsa f5bf67c636 Pass extra_body to the backend 2025-02-24 10:52:55 +08:00
1808837298@qq.com 7becf62a7a fix: Improve 429 error logging with detailed message 2025-02-23 21:26:31 +08:00
1808837298@qq.com 40c0333eaa fix typo 2025-02-23 17:27:33 +08:00
58 changed files with 1359 additions and 412 deletions
-4
View File
@@ -50,10 +50,6 @@
# CHANNEL_TEST_FREQUENCY=10
# 生成默认token
# GENERATE_DEFAULT_TOKEN=false
# Gemini 安全设置
# GEMINI_SAFETY_SETTING=BLOCK_NONE
# Gemini版本设置
# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
# Cohere 安全设置
# COHERE_SAFETY_SETTING=NONE
# 是否统计图片token
+2 -1
View File
@@ -63,7 +63,8 @@
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_conetnt` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
## Model Support
This version additionally supports:
+12 -5
View File
@@ -66,10 +66,14 @@
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API,支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回。
1. OpenAI o系列模型
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
2. Claude 思考模型
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
## 模型支持
此版本额外支持以下模型:
@@ -90,7 +94,6 @@
- `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`
@@ -99,6 +102,10 @@
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 已废弃的环境变量
- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
## 部署
> [!TIP]
+1 -1
View File
@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
+88 -86
View File
@@ -83,92 +83,94 @@ var defaultModelRatio = map[string]float64{
"text-curie-001": 1,
//"text-davinci-002": 10,
//"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-7-sonnet-20250219": 1.5,
"claude-3-7-sonnet-20250219-thinking": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // 0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18,
+34 -1
View File
@@ -32,6 +32,7 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
}
opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
RDB = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -41,6 +42,10 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
}
if DebugEnabled {
SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
}
return err
}
@@ -53,13 +58,20 @@ func ParseRedisOption() *redis.Options {
}
func RedisSet(key string, value string, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
}
ctx := context.Background()
return RDB.Set(ctx, key, value, expiration).Err()
}
func RedisGet(key string) (string, error) {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis GET: key=%s", key))
}
ctx := context.Background()
return RDB.Get(ctx, key).Result()
val, err := RDB.Get(ctx, key).Result()
return val, err
}
//func RedisExpire(key string, expiration time.Duration) error {
@@ -73,16 +85,25 @@ func RedisGet(key string) (string, error) {
//}
func RedisDel(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
}
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisHDelObj(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
}
ctx := context.Background()
return RDB.HDel(ctx, key).Err()
}
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
}
ctx := context.Background()
data := make(map[string]interface{})
@@ -130,6 +151,9 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
}
func RedisHGetObj(key string, obj interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
}
ctx := context.Background()
result, err := RDB.HGetAll(ctx, key).Result()
@@ -208,6 +232,9 @@ func RedisHGetObj(key string, obj interface{}) error {
// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
}
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
@@ -238,6 +265,9 @@ func RedisIncr(key string, delta int64) error {
}
func RedisHIncrBy(key, field string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
@@ -262,6 +292,9 @@ func RedisHIncrBy(key, field string, delta int64) error {
}
func RedisHSetField(key, field string, value interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
+19
View File
@@ -5,6 +5,7 @@ import (
"context"
crand "crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"html/template"
@@ -213,6 +214,24 @@ func RandomSleep() {
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}
func GetPointer[T any](v T) *T {
return &v
}
func Any2Type[T any](data any) (T, error) {
var zero T
bytes, err := json.Marshal(data)
if err != nil {
return zero, err
}
var res T
err = json.Unmarshal(bytes, &res)
if err != nil {
return zero, err
}
return res, nil
}
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename)
+5
View File
@@ -2,4 +2,9 @@ package constant
const (
ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
)
+15 -18
View File
@@ -1,10 +1,7 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
@@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
var GeminiModelMap = map[string]string{
"gemini-1.0-pro": "v1",
}
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
@@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
//}
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
// parts := strings.Split(pair, ":")
// if len(parts) == 2 {
// GeminiModelMap[parts[0]] = parts[1]
// } else {
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
// }
//}
}
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
+1 -1
View File
@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
+1
View File
@@ -85,6 +85,7 @@ func Relay(c *gin.Context) {
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
+1 -1
View File
@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
+1 -1
View File
@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
+59 -40
View File
@@ -1,6 +1,9 @@
package dto
import "encoding/json"
import (
"encoding/json"
"strings"
)
type ResponseFormat struct {
Type string `json:"type,omitempty"`
@@ -15,49 +18,52 @@ type FormatJsonSchema struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
}
type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Function FunctionRequest `json:"function"`
}
type OpenAIFunction struct {
type FunctionRequest struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type StreamOptions struct {
@@ -133,11 +139,11 @@ func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix
}
func (m *Message) ParseToolCalls() []ToolCall {
func (m *Message) ParseToolCalls() []ToolCallRequest {
if m.ToolCalls == nil {
return nil
}
var toolCalls []ToolCall
var toolCalls []ToolCallRequest
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls
}
@@ -153,11 +159,24 @@ func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent
}
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
}
func (m *Message) SetStringContent(content string) {
+12 -12
View File
@@ -62,10 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -90,24 +90,24 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string)
c.ReasoningContent = &s
}
type ToolCall struct {
type ToolCallResponse struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionResponse `json:"function"`
}
func (c *ToolCall) SetIndex(i int) {
func (c *ToolCallResponse) SetIndex(i int) {
c.Index = &i
}
type FunctionCall struct {
type FunctionResponse struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
Arguments string `json:"arguments"` // response
}
type ChatCompletionsStreamResponse struct {
+5 -1
View File
@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.IsUserEnabled(token.UserId, false)
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
userCache.WriteContext(c)
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
+1 -2
View File
@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
return
}
}
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.GetUserGroup(userId, false)
userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
+172
View File
@@ -0,0 +1,172 @@
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0,表示不限制
if maxCount == 0 {
return true, nil
}
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
// 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
return false, err
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
return false, err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
return false, nil
}
return true, nil
}
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0,不记录请求
if maxCount == 0 {
return
}
now := time.Now().Format(timeFormat)
rdb.LPush(ctx, key, now)
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
// 4. 处理请求
c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制(当totalMaxCount为0时跳过)
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制,这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 3. 处理请求
c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
// 如果未启用限流,直接放行
if !setting.ModelRequestRateLimitEnabled {
return defNext
}
// 计算限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 根据存储类型选择限流处理器
if common.RedisEnabled {
return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
} else {
return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
}
}
+5 -5
View File
@@ -1,8 +1,8 @@
package model
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"os"
"strings"
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username, _ := GetUsernameById(userId, false)
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
common.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
+20
View File
@@ -3,6 +3,7 @@ package model
import (
"one-api/common"
"one-api/setting"
"one-api/setting/model_setting"
"strconv"
"strings"
"time"
@@ -85,6 +86,9 @@ func InitOptionMap() {
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -105,12 +109,15 @@ func InitOptionMap() {
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMap["GeminiSafetySettings"] = model_setting.GeminiSafetySettingsJsonString()
common.OptionMap["GeminiVersionSettings"] = model_setting.GeminiVersionSettingsJsonString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -226,6 +233,9 @@ func updateOptionMap(key string, value string) (err error) {
setting.DemoSiteEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
case "ModelRequestRateLimitEnabled":
setting.ModelRequestRateLimitEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
@@ -308,6 +318,12 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRequestRateLimitCount":
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitDurationMinutes":
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
@@ -338,6 +354,10 @@ func updateOptionMap(key string, value string) (err error) {
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "GeminiSafetySettings":
model_setting.GeminiSafetySettingFromJsonString(value)
case "GeminiVersionSettings":
model_setting.GeminiVersionSettingFromJsonString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
+33 -33
View File
@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
@@ -502,35 +502,35 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
status, err := getUserStatusCache(id)
if err == nil {
return status == common.UserStatusEnabled, nil
}
// Don't return error - fall through to DB
}
fromDB = true
var user User
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil {
return false, err
}
return user.Status == common.UserStatusEnabled, nil
}
//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
// defer func() {
// // Update Redis cache asynchronously on successful DB read
// if shouldUpdateRedis(fromDB, err) {
// gopool.Go(func() {
// if err := updateUserStatusCache(id, status); err != nil {
// common.SysError("failed to update user status cache: " + err.Error())
// }
// })
// }
// }()
// if !fromDB && common.RedisEnabled {
// // Try Redis first
// status, err := getUserStatusCache(id)
// if err == nil {
// return status == common.UserStatusEnabled, nil
// }
// // Don't return error - fall through to DB
// }
// fromDB = true
// var user User
// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
// if err != nil {
// return false, err
// }
//
// return user.Status == common.UserStatusEnabled, nil
//}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
@@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
common.SysError("failed to increase user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
return nil
}
if delta > 0 {
return IncreaseUserQuota(id, delta)
return IncreaseUserQuota(id, delta, false)
} else {
return DecreaseUserQuota(id, -delta)
}
+10
View File
@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"time"
@@ -21,6 +22,15 @@ type UserBase struct {
Setting string `json:"setting"`
}
func (user *UserBase) WriteContext(c *gin.Context) {
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
+1 -1
View File
@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req, info.ToRelayInfo())
resp, err := doRequest(c, req, info.RelayInfo)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
+2 -1
View File
@@ -9,7 +9,8 @@ var awsModelIDMap = map[string]string{
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
}
var ChannelName = "aws"
+2
View File
@@ -16,6 +16,7 @@ type AwsClaudeRequest struct {
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
@@ -30,5 +31,6 @@ func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}
+2
View File
@@ -11,6 +11,8 @@ var ModelList = []string{
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
}
var ChannelName = "claude"
+12 -3
View File
@@ -11,6 +11,9 @@ type ClaudeMediaMessage struct {
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
@@ -54,9 +57,15 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
type ClaudeError struct {
+50 -10
View File
@@ -92,6 +92,29 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Stream: textRequest.Stream,
Tools: claudeTools,
}
if strings.HasSuffix(textRequest.Model, "-thinking") {
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 8192
}
// 因为BudgetTokens 必须大于1024
if claudeRequest.MaxTokens < 1280 {
claudeRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{
Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * 0.8),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
@@ -273,7 +296,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
@@ -292,10 +315,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
@@ -308,12 +331,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
thinkingContent := claudeResponse.Delta.Thinking
choice.Delta.ReasoningContent = &thinkingContent
}
}
} else if claudeResponse.Type == "message_delta" {
@@ -351,7 +382,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
@@ -367,16 +400,22 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
if message.Type == "tool_use" {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.Text
}
}
}
@@ -391,6 +430,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 {
choice.Message.SetToolCalls(tools)
}
choice.Message.ReasoningContent = thinkingContent
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)
fullTextResponse.Choices = choices
+28 -2
View File
@@ -9,9 +9,18 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
const (
BotTypeChatFlow = 1 // chatflow default
BotTypeAgent = 2
BotTypeWorkFlow = 3
BotTypeCompletion = 4
)
type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -25,10 +34,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion
} else {
a.BotType = BotTypeChatFlow
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -53,7 +80,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+2 -10
View File
@@ -7,11 +7,11 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
@@ -64,15 +64,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta"
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1beta"
}
}
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
+8
View File
@@ -20,4 +20,12 @@ var ModelList = []string{
"imagen-3.0-generate-002",
}
var SafetySettingList = []string{
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
}
var ChannelName = "google gemini"
+24 -35
View File
@@ -11,6 +11,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
"unicode/utf8"
@@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: common.GeminiSafetySetting,
},
},
//SafetySettings: []GeminiChatSafetySettings{},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -52,9 +32,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
@@ -349,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data
}
func getToolCall(item *GeminiPart) *dto.ToolCall {
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -361,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
if err != nil {
return nil
}
return &dto.ToolCall{
return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
@@ -379,7 +368,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
is_tool_call := false
isToolCall := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index),
@@ -391,12 +380,12 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
if len(candidate.Content.Parts) > 0 {
var texts []string
var tool_calls []dto.ToolCall
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else {
if part.ExecutableCode != nil {
@@ -411,9 +400,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
}
}
if len(tool_calls) > 0 {
choice.Message.SetToolCalls(tool_calls)
is_tool_call = true
if len(toolCalls) > 0 {
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
@@ -429,7 +418,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if is_tool_call {
if isToolCall {
choice.FinishReason = constant.FinishReasonToolCalls
}
@@ -468,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
isTools = true
if call := getToolCall(&part); call != nil {
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
+1 -1
View File
@@ -61,7 +61,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
err, usage = JinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp)
}
+1 -1
View File
@@ -9,7 +9,7 @@ import (
"one-api/service"
)
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+16 -15
View File
@@ -3,21 +3,22 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {
+1
View File
@@ -58,6 +58,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.MaxTokens,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
+4 -1
View File
@@ -14,6 +14,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/jina"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
@@ -146,7 +147,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -228,6 +229,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = jina.JinaRerankHandler(c, resp)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)
+74
View File
@@ -0,0 +1,74 @@
package openrouter
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
req.Set("X-Title", "New API")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}
+5
View File
@@ -0,0 +1,5 @@
package openrouter
var ModelList = []string{}
var ChannelName = "openrouter"
+8 -54
View File
@@ -50,6 +50,9 @@ type RelayInfo struct {
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
}
// 定义支持流式选项的通道类型
@@ -89,6 +92,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail),
IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
@@ -148,19 +154,7 @@ func (info *RelayInfo) SetFirstResponseTime() {
}
type TaskRelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
StartTime time.Time
ApiType int
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiKey string
BaseUrl string
*RelayInfo
Action string
OriginTaskID string
@@ -168,48 +162,8 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
StartTime: startTime,
ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
RelayInfo: GenRelayInfo(c),
}
return info
}
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}
+3
View File
@@ -30,6 +30,7 @@ const (
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -86,6 +87,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeVolcEngine
case common.ChannelTypeBaiduV2:
apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
}
if apiType == -1 {
return APITypeOpenAI, false
+6 -7
View File
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return &mjResp.Response
}
defer func(ctx context.Context) {
defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
@@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
+2 -1
View File
@@ -248,6 +248,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
@@ -267,7 +268,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
}
if preConsumedQuota > 0 {
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
+3
View File
@@ -18,6 +18,7 @@ import (
"one-api/relay/channel/mokaai"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
@@ -83,6 +84,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &volcengine.Adaptor{}
case constant.APITypeBaiduV2:
return &baidu_v2.Adaptor{}
case constant.APITypeOpenRouter:
return &openrouter.Adaptor{}
}
return nil
}
+4 -5
View File
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return
}
defer func(ctx context.Context) {
defer func() {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
}
}(c.Request.Context())
}()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
if taskErr != nil {
+1
View File
@@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) {
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
{
// WebSocket 路由
wsRouter := relayV1Router.Group("")
+7 -11
View File
@@ -276,7 +276,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
}
if err != nil {
return err
@@ -295,20 +295,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() {
userCache, err := model.GetUserCache(userId)
if err != nil {
common.SysError("failed to get user cache: " + err.Error())
}
userSetting := userCache.GetSetting()
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
@@ -317,16 +313,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if userCache.Quota-consumeQuota < threshold {
if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
}
})
+22 -30
View File
@@ -1,7 +1,6 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"image"
@@ -78,6 +77,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if text == "" {
return 0
}
return len(tokenEncoder.Encode(text, nil, nil))
}
@@ -167,12 +169,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
}
tkm += msgTokens
if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
openaiTools := request.Tools
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
@@ -282,30 +279,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
}
}
+7 -9
View File
@@ -11,47 +11,45 @@ import (
func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
_ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
}
func NotifyUser(user *model.UserBase, data dto.Notify) error {
userSetting := user.GetSetting()
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
}
// Check notification limit
canSend, err := CheckNotificationLimit(user.Id, data.Type)
canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
}
switch notifyType {
case constant.NotifyTypeEmail:
userEmail := user.Email
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil
}
+1
View File
@@ -0,0 +1 @@
package model_setting
+83
View File
@@ -0,0 +1,83 @@
package model_setting
import (
"encoding/json"
"one-api/common"
)
var geminiSafetySettings = map[string]string{
"default": "OFF",
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
}
func GetGeminiSafetySetting(key string) string {
if value, ok := geminiSafetySettings[key]; ok {
return value
}
return geminiSafetySettings["default"]
}
func GeminiSafetySettingFromJsonString(jsonString string) {
geminiSafetySettings = map[string]string{}
err := json.Unmarshal([]byte(jsonString), &geminiSafetySettings)
if err != nil {
geminiSafetySettings = map[string]string{
"default": "OFF",
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
}
}
// check must have default
if _, ok := geminiSafetySettings["default"]; !ok {
geminiSafetySettings["default"] = common.GeminiSafetySetting
}
}
func GeminiSafetySettingsJsonString() string {
// check must have default
if _, ok := geminiSafetySettings["default"]; !ok {
geminiSafetySettings["default"] = common.GeminiSafetySetting
}
jsonString, err := json.Marshal(geminiSafetySettings)
if err != nil {
return "{}"
}
return string(jsonString)
}
var geminiVersionSettings = map[string]string{
"default": "v1beta",
"gemini-1.0-pro": "v1",
}
func GetGeminiVersionSetting(key string) string {
if value, ok := geminiVersionSettings[key]; ok {
return value
}
return geminiVersionSettings["default"]
}
func GeminiVersionSettingFromJsonString(jsonString string) {
geminiVersionSettings = map[string]string{}
err := json.Unmarshal([]byte(jsonString), &geminiVersionSettings)
if err != nil {
geminiVersionSettings = map[string]string{
"default": "v1beta",
}
}
// check must have default
if _, ok := geminiVersionSettings["default"]; !ok {
geminiVersionSettings["default"] = "v1beta"
}
}
func GeminiVersionSettingsJsonString() string {
// check must have default
if _, ok := geminiVersionSettings["default"]; !ok {
geminiVersionSettings["default"] = "v1beta"
}
jsonString, err := json.Marshal(geminiVersionSettings)
if err != nil {
return "{}"
}
return string(jsonString)
}
+6
View File
@@ -0,0 +1,6 @@
package setting
var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
+73
View File
@@ -0,0 +1,73 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import SettingGeminiModel from '../pages/Setting/Model/SettingGeminiModel.js';
const ModelSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
GeminiSafetySettings: '',
GeminiVersionSettings: '',
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key === 'GeminiSafetySettings' ||
item.key === 'GeminiVersionSettings'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
if (
item.key.endsWith('Enabled')
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* Gemini */}
<Card style={{ marginTop: '10px' }}>
<SettingGeminiModel options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default ModelSetting;
+80
View File
@@ -0,0 +1,80 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js';
const RateLimitSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key.endsWith('Enabled')
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* AI请求速率限制 */}
<Card style={{ marginTop: '10px' }}>
<RequestRateLimit options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default RateLimitSetting;
+13 -3
View File
@@ -856,7 +856,7 @@
"IP黑名单": "IP blacklist",
"不允许的IP,一行一个": "IPs not allowed, one per line",
"请选择该渠道所支持的模型": "Please select the model supported by this channel",
"次": "Second-rate",
"次": "times",
"达到限速报错内容": "Error content when the speed limit is reached",
"不填则使用默认报错": "If not filled in, the default error will be reported.",
"Midjouney 设置 (可选)": "Midjouney settings (optional)",
@@ -1271,5 +1271,15 @@
"留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used",
"代理站地址": "Base URL",
"对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in",
"渠道额外设置": "Channel extra settings"
}
"渠道额外设置": "Channel extra settings",
"模型请求速率限制": "Model request rate limit",
"启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)",
"限制周期": "Limit period",
"用户每周期最多请求次数": "User max request times per period",
"用户每周期最多请求完成次数": "User max successful request times per period",
"包括失败请求的次数,0代表不限制": "Including failed request times, 0 means no limit",
"频率限制的周期(分钟)": "Rate limit period (minutes)",
"只包括请求成功的次数": "Only include successful request times",
"保存模型速率限制": "Save model rate limit settings",
"速率限制设置": "Rate limit settings"
}
@@ -0,0 +1,139 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning, verifyJSON
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
const GEMINI_SETTING_EXAMPLE = {
'default': 'OFF',
'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE',
};
const GEMINI_VERSION_EXAMPLE = {
'default': 'v1beta',
};
export default function SettingGeminiModel(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
GeminiSafetySettings: '',
GeminiVersionSettings: '',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('Gemini设置')}>
<Row>
<Col span={16}>
<Form.TextArea
label={t('Gemini安全设置')}
placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_SETTING_EXAMPLE, null, 2)}
field={'GeminiSafetySettings'}
extraText={t('default为默认设置,可单独设置每个分类的安全等级')}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: t('不是合法的 JSON 字符串')
}
]}
onChange={(value) => setInputs({ ...inputs, GeminiSafetySettings: value })}
/>
</Col>
</Row>
<Row>
<Col span={16}>
<Form.TextArea
label={t('Gemini版本设置')}
placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_VERSION_EXAMPLE, null, 2)}
field={'GeminiVersionSettings'}
extraText={t('default为默认设置,可单独设置每个模型的版本')}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: t('不是合法的 JSON 字符串')
}
]}
onChange={(value) => setInputs({ ...inputs, GeminiVersionSettings: value })}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}
@@ -0,0 +1,159 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning,
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
export default function RequestRateLimit(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('模型请求速率限制')}>
<Row gutter={16}>
<Col span={8}>
<Form.Switch
field={'ModelRequestRateLimitEnabled'}
label={t('启用用户模型请求速率限制(可能会影响高并发性能)')}
size='default'
checkedText=''
uncheckedText=''
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitEnabled: value,
});
}}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('限制周期')}
step={1}
min={0}
suffix={t('分钟')}
extraText={t('频率限制的周期(分钟)')}
field={'ModelRequestRateLimitDurationMinutes'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitDurationMinutes: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求次数')}
step={1}
min={0}
suffix={t('次')}
extraText={t('包括失败请求的次数,0代表不限制')}
field={'ModelRequestRateLimitCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitCount: String(value),
})
}
/>
</Col>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求完成次数')}
step={1}
min={1}
suffix={t('次')}
extraText={t('只包括请求成功的次数')}
field={'ModelRequestRateLimitSuccessCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitSuccessCount: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}
+12
View File
@@ -8,6 +8,8 @@ import { isRoot } from '../../helpers';
import OtherSetting from '../../components/OtherSetting';
import PersonalSetting from '../../components/PersonalSetting';
import OperationSetting from '../../components/OperationSetting';
import RateLimitSetting from '../../components/RateLimitSetting.js';
import ModelSetting from '../../components/ModelSetting.js';
const Setting = () => {
const { t } = useTranslation();
@@ -28,6 +30,16 @@ const Setting = () => {
content: <OperationSetting />,
itemKey: 'operation',
});
panes.push({
tab: t('速率限制设置'),
content: <RateLimitSetting />,
itemKey: 'ratelimit',
});
panes.push({
tab: t('模型相关设置'),
content: <ModelSetting />,
itemKey: 'models',
});
panes.push({
tab: t('系统设置'),
content: <SystemSetting />,