Compare commits

...

74 Commits

Author SHA1 Message Date
1808837298@qq.com 0a369cc193 feat: Add Gemini safety settings configuration support (close #703) 2025-02-26 16:54:43 +08:00
1808837298@qq.com 5ba44f5ad5 feat: Update Claude relay temperature setting 2025-02-25 22:01:05 +08:00
1808837298@qq.com d04d78a116 refactor: Enhance user context and quota management
- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
2025-02-25 20:56:16 +08:00
1808837298@qq.com 8c2323d74d feat: redis poolsize 2025-02-25 19:39:29 +08:00
1808837298@qq.com 583678d9ff fix: Adjust Claude thinking mode request parameters 2025-02-25 16:52:45 +08:00
1808837298@qq.com fd38e59f78 docs: Update README 2025-02-25 16:31:42 +08:00
Calcium-Ion f5cbab77cf Merge pull request #788 from MartialBE/main
feat: Add Claude 3.7 Sonnet thinking mode support
2025-02-25 15:21:39 +08:00
1808837298@qq.com d4706d6b8e Merge branch 'main' into thinking
# Conflicts:
#	relay/channel/claude/dto.go
2025-02-25 15:21:22 +08:00
1808837298@qq.com 6c8016e5f8 feat: Add support for Claude thinking parameter in request 2025-02-25 14:37:03 +08:00
MartialBE 7160012fe2 feat: Add Claude 3.7 Sonnet thinking mode support 2025-02-25 14:10:43 +08:00
1808837298@qq.com c62276fcc4 feat: Add Claude 3.7 Sonnet model to AWS channel mapping 2025-02-25 02:55:23 +08:00
1808837298@qq.com 15a3b44689 feat: Add support for Claude 3.7 Sonnet model 2025-02-25 02:51:31 +08:00
1808837298@qq.com 8f3c7280cf feat: Support max_tokens parameter for Ollama channel #782 2025-02-24 17:35:49 +08:00
Calcium-Ion e5e73a33f0 Merge pull request #781 from zeyugao/main
feat: Pass extra_body in OpenAI request to the backend
2025-02-24 16:29:48 +08:00
Calcium-Ion 2d15f63eaa Merge pull request #783 from Calcium-Ion/rate-limit
feat: Add model request rate limiting functionality
2025-02-24 16:29:23 +08:00
1808837298@qq.com 6f3072895a feat: Add model rate limit settings in system configuration 2025-02-24 16:27:20 +08:00
1808837298@qq.com 1763145fea feat: Add model request rate limiting functionality 2025-02-24 16:20:55 +08:00
1808837298@qq.com 66831a1bde feat: Add support for different Dify bot types and request URLs 2025-02-24 14:18:30 +08:00
1808837298@qq.com fd44ac7c0c feat: Enhance token counting and content parsing for messages 2025-02-24 14:18:15 +08:00
Elsa f5bf67c636 Pass extra_body to the backend 2025-02-24 10:52:55 +08:00
1808837298@qq.com 7becf62a7a fix: Improve 429 error logging with detailed message 2025-02-23 21:26:31 +08:00
1808837298@qq.com 40c0333eaa fix typo 2025-02-23 17:27:33 +08:00
1808837298@qq.com 65021e2e0e feat: Add thinking-to-content option in channel extra settings #780 2025-02-23 17:13:08 +08:00
1808837298@qq.com 4597816a14 feat: Add thinking-to-content conversion for stream responses 2025-02-23 17:05:57 +08:00
1808837298@qq.com 991b6f8bb0 fix: mistral 2025-02-22 16:29:48 +08:00
1808837298@qq.com 0333576bee fix: fix image ratio calculation 2025-02-22 15:50:18 +08:00
Calcium-Ion 1c35a2fd0b Merge pull request #778 from utopeadia/main
美化日志界面刷新图标
2025-02-22 15:21:28 +08:00
1808837298@qq.com 7a13f9c99a fix: Ensure correct quota warning threshold type conversion 2025-02-22 15:19:55 +08:00
1808837298@qq.com 1c20d16c49 chore: update rerank.md 2025-02-22 15:13:26 +08:00
HowieWood 4f2d44187d 进一步美化刷新图标 2025-02-22 14:18:25 +08:00
HowieWood d3b8647019 优化日志刷新图标显示 2025-02-22 14:12:49 +08:00
1808837298@qq.com 3eb186825d fix: ShouldDisableChannel 2025-02-22 02:02:03 +08:00
1808837298@qq.com d71315cfa5 fix: mistral adaptor (close #774) 2025-02-21 22:21:19 +08:00
1808837298@qq.com 8101cd3ce3 feat: Add reasoning content support in OpenAI response handling 2025-02-21 18:52:51 +08:00
1808837298@qq.com 4a49d6c795 refactor: Improve message content parsing with robust type handling 2025-02-21 18:27:43 +08:00
1808837298@qq.com 4194f4bd21 refactor: Improve message content handling and quota error responses 2025-02-21 18:18:21 +08:00
1808837298@qq.com e1784f8981 refactor: Optimize sensitive word detection and text processing 2025-02-21 17:05:35 +08:00
1808837298@qq.com 78f9a30c39 feat: Enhance sensitive word detection with detailed logging 2025-02-21 16:57:30 +08:00
1808837298@qq.com 009333da8b refactor: Improve quota error messages with formatted quota display 2025-02-21 16:42:48 +08:00
1808837298@qq.com 23bfc06fd8 feat: Add base URL input with localized tooltip for channel configuration 2025-02-21 16:17:59 +08:00
1808837298@qq.com f64540cd1c feat: Add localization for notification and webhook settings 2025-02-21 15:36:24 +08:00
Calcium-Ion 5020091714 Merge pull request #775 from Calcium-Ion/model_mappping
refactor: Simplify model mapping and pricing logic across relay modules
2025-02-20 16:42:23 +08:00
1808837298@qq.com 1e0b414fc0 refactor: Simplify model mapping and pricing logic across relay modules 2025-02-20 16:41:46 +08:00
1808837298@qq.com 8dbac87e92 fix: Correct Ollama channel authentication header setting 2025-02-20 01:28:15 +08:00
Calcium-Ion db81b21f06 Merge pull request #773 from wellcoming/patch-1
fix: Fix Ollama channel authentication
2025-02-20 01:26:12 +08:00
Coming 368b5fbaa1 fix: Fix Ollama channel authentication 2025-02-20 00:52:30 +08:00
CalciumIon 1f6cb4bd5d feat: Improve mobile text truncation and sidebar visibility 2025-02-19 23:25:42 +08:00
1808837298@qq.com 5ef44f6898 feat: Improve image handling for Ollama channels 2025-02-19 20:45:42 +08:00
1808837298@qq.com 604a5b1ad5 feat: Enhance Ollama channel support with additional request parameters #771 2025-02-19 19:58:34 +08:00
1808837298@qq.com cc61f85bdd fix: Remove redundant error handling in distributor and relay modules 2025-02-19 18:47:28 +08:00
1808837298@qq.com 4ea2556b29 refactor: Replace manual goroutine creation with gopool.Go 2025-02-19 18:38:29 +08:00
Calcium-Ion ff66890dd2 Merge pull request #770 from Calcium-Ion/refactor_notify
feat: Add user notification settings and multiple notification methods
2025-02-19 14:54:54 +07:00
1808837298@qq.com 8a34d61654 chore: update env name and README 2025-02-19 15:54:33 +08:00
1808837298@qq.com 67f02e0a6a docs: Add proxy usage information note in SystemSetting component 2025-02-19 15:45:09 +08:00
1808837298@qq.com 097d02eb8b feat: Implement comprehensive webhook notification system 2025-02-19 15:40:54 +08:00
1808837298@qq.com 1ae0a38485 refactor: Optimize user caching and token retrieval methods 2025-02-19 15:12:26 +08:00
Calcium-Ion 45dd96e717 Merge pull request #768 from lgphone/main
bugfix: 配置文件 .env.example 示例配置错误
2025-02-18 19:35:08 +07:00
lgphone 10311f60e1 Update .env.example
修复示例配置中MySQL的DSN错误问题
2025-02-18 19:18:54 +08:00
Calcium-Ion 869fe957ae Merge pull request #763 from Sh1n3zZ/support-imagen-3.0-generate-002
feat: add Gemini Imagen image generation support
2025-02-18 15:32:32 +07:00
1808837298@qq.com af9d03140c fix: Extend temperature handling for OpenAI-like models
- Add support for suppressing temperature for o1 models
- Expand model prefix check to include 'o1' alongside 'o3' models
2025-02-18 16:00:56 +08:00
1808837298@qq.com 83e161a1d4 refactor: Simplify root user notification and remove global email variable
- Remove global `RootUserEmail` variable
- Modify channel testing and user notification methods to use `GetRootUser()`
- Update user cache and notification service to use more consistent user base type
- Add new channel test notification type
2025-02-18 15:59:17 +08:00
1808837298@qq.com 9a41c04416 feat: Implement notification rate limiting mechanism
- Add in-memory and Redis-based notification rate limiting
- Create configurable hourly notification limits
- Implement notification limit checking for user notifications
- Add environment variables for customizing notification limits
2025-02-18 15:30:43 +08:00
1808837298@qq.com e0ce8bf2b3 refactor: Improve CompletionRatio handling with thread-safe access and initialization 2025-02-18 15:01:43 +08:00
1808837298@qq.com 0fcd243f56 feat: Add user notification settings with quota warning and multiple notification methods
- Implement user notification settings with email and webhook options
- Add new user settings for quota warning threshold and notification preferences
- Create backend API and database support for user notification configuration
- Enhance frontend personal settings with notification configuration UI
- Support custom notification email and webhook URL
- Add service layer for sending user notifications
2025-02-18 14:54:21 +08:00
Sh1n3zZ 873a79f28f feat: add Gemini Imagen image generation support 2025-02-18 01:41:58 +08:00
1808837298@qq.com a353ea3eee Merge remote-tracking branch 'origin/main' 2025-02-17 18:15:13 +08:00
1808837298@qq.com 1cfe06c3d8 feat: Add support for DeepSeek completions endpoint 2025-02-17 18:15:01 +08:00
Calcium-Ion e661578515 Merge pull request #735 from jyc001/main
feat:Add Supoorts to FIM
2025-02-17 14:37:06 +07:00
1808837298@qq.com 65faa158b6 refactor: Optimize channel testing and model menu generation (fix #761) 2025-02-15 19:12:28 +08:00
e. 2c255b6598 Merge pull request #2 from jyc001/dev
fix: correct JSON tags for `Prompt` and `Suffix` in `GeneralOpenAIReq…
2025-02-08 00:37:37 +08:00
e. bec1b752d6 fix: correct JSON tags for Prompt and Suffix in GeneralOpenAIRequest 2025-02-08 00:36:42 +08:00
e. ca10ec1a0c Merge pull request #1 from jyc001/dev
Dev
2025-02-08 00:25:49 +08:00
e. cbca378e61 feat: add Suffix to GeneralOpenAIRequest in order to support FIM 2025-02-08 00:25:08 +08:00
e. 02bb1c6580 feat add FIM support for siliconflow 2025-02-08 00:23:35 +08:00
97 changed files with 3022 additions and 1035 deletions
+2 -2
View File
@@ -10,9 +10,9 @@
# 数据库相关配置
# 数据库连接字符串
# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# 日志数据库连接字符串
# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# SQLite数据库路径
# SQLITE_PATH=/path/to/sqlite.db
# 数据库最大空闲连接数
+4
View File
@@ -63,6 +63,8 @@
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
## Model Support
This version additionally supports:
@@ -89,6 +91,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
## Deployment
+11 -3
View File
@@ -66,9 +66,14 @@
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API,支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
1. OpenAI o系列模型
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
2. Claude 思考模型
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
## 模型支持
此版本额外支持以下模型:
@@ -95,6 +100,9 @@
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 部署
> [!TIP]
+1 -1
View File
@@ -13,7 +13,7 @@ Request:
```json
{
"model": "rerank-multilingual-v3.0",
"model": "jina-reranker-v2-base-multilingual",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
+2 -2
View File
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
var RetryTimes = 0
var RootUserEmail = ""
//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
-13
View File
@@ -1,22 +1,9 @@
package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
+11 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"log"
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
gopool.Go(func() {
SetupLogger()
}()
})
}
}
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
}
}
func FormatQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d", quota)
}
}
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)
+105 -109
View File
@@ -83,92 +83,94 @@ var defaultModelRatio = map[string]float64{
"text-curie-001": 1,
//"text-davinci-002": 10,
//"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-7-sonnet-20250219": 1.5,
"claude-3-7-sonnet-20250219-thinking": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // 0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18,
@@ -233,7 +235,11 @@ var (
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var (
CompletionRatio map[string]float64 = nil
CompletionRatioMutex = sync.RWMutex{}
)
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3,
@@ -334,10 +340,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func CompletionRatio2JSONString() string {
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}
func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
@@ -346,11 +359,15 @@ func CompletionRatio2JSONString() string {
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap()
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -476,24 +493,3 @@ func GetAudioCompletionRatio(name string) float64 {
}
return 2
}
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}
+34 -1
View File
@@ -32,6 +32,7 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
}
opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
RDB = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -41,6 +42,10 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
}
if DebugEnabled {
SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
}
return err
}
@@ -53,13 +58,20 @@ func ParseRedisOption() *redis.Options {
}
func RedisSet(key string, value string, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
}
ctx := context.Background()
return RDB.Set(ctx, key, value, expiration).Err()
}
func RedisGet(key string) (string, error) {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis GET: key=%s", key))
}
ctx := context.Background()
return RDB.Get(ctx, key).Result()
val, err := RDB.Get(ctx, key).Result()
return val, err
}
//func RedisExpire(key string, expiration time.Duration) error {
@@ -73,16 +85,25 @@ func RedisGet(key string) (string, error) {
//}
func RedisDel(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
}
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisHDelObj(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
}
ctx := context.Background()
return RDB.HDel(ctx, key).Err()
}
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
}
ctx := context.Background()
data := make(map[string]interface{})
@@ -130,6 +151,9 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
}
func RedisHGetObj(key string, obj interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
}
ctx := context.Background()
result, err := RDB.HGetAll(ctx, key).Result()
@@ -208,6 +232,9 @@ func RedisHGetObj(key string, obj interface{}) error {
// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
}
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
@@ -238,6 +265,9 @@ func RedisIncr(key string, delta int64) error {
}
func RedisHIncrBy(key, field string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
@@ -262,6 +292,9 @@ func RedisHIncrBy(key, field string, delta int64) error {
}
func RedisHSetField(key, field string, value interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
+19
View File
@@ -5,6 +5,7 @@ import (
"context"
crand "crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"html/template"
@@ -213,6 +214,24 @@ func RandomSleep() {
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}
func GetPointer[T any](v T) *T {
return &v
}
func Any2Type[T any](data any) (T, error) {
var zero T
bytes, err := json.Marshal(data)
if err != nil {
return zero, err
}
var res T
err = json.Unmarshal(bytes, &res)
if err != nil {
return zero, err
}
return res, nil
}
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename)
+3 -2
View File
@@ -1,6 +1,7 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
)
+5
View File
@@ -2,4 +2,9 @@ package constant
const (
ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
)
+4 -1
View File
@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
@@ -44,5 +47,5 @@ func InitEnv() {
}
}
// 是否生成初始令牌,默认关闭。
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
+14
View File
@@ -0,0 +1,14 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)
+2 -7
View File
@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试完成")
}
})
return nil
+1 -1
View File
@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
+1 -1
View File
@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
}
var group string
if exists {
user, err := model.GetUserById(userId.(int), false)
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
}
+2 -1
View File
@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode)
err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
@@ -85,6 +85,7 @@ func Relay(c *gin.Context) {
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
+1 -1
View File
@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
+1 -1
View File
@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
+114 -4
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
if err != nil {
id = c.GetInt("id")
}
user, err := model.GetUserById(id, true)
user, err := model.GetUserCache(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
})
return
}
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
})
return
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
}
func UpdateUserSetting(c *gin.Context) {
var req UpdateUserSettingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
// 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
})
return
}
// 验证预警阈值
if req.QuotaWarningThreshold <= 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "预警阈值必须大于0",
})
return
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Webhook地址不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Webhook地址",
})
return
}
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
}
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 构建设置
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
}
// 更新用户设置
user.SetSetting(settings)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "更新设置失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "设置已更新",
})
}
+1 -1
View File
@@ -24,7 +24,7 @@ services:
- redis
- mysql
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
interval: 30s
timeout: 10s
retries: 3
+5
View File
@@ -10,6 +10,10 @@
- 用于配置网络代理
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换
--------------------------------------------------------------
## JSON 格式示例
@@ -19,6 +23,7 @@
```json
{
"force_format": true,
"thinking_to_content": true,
"proxy": "socks5://xxxxxxx"
}
```
+25
View File
@@ -0,0 +1,25 @@
package dto
type Notify struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values"`
}
const ContentValueParam = "{{value}}"
const (
NotifyTypeQuotaExceed = "quota_exceed"
NotifyTypeChannelUpdate = "channel_update"
NotifyTypeChannelTest = "channel_test"
)
func NewNotify(t string, title string, content string, values []interface{}) Notify {
return Notify{
Type: t,
Title: title,
Content: content,
Values: values,
}
}
+103 -44
View File
@@ -1,6 +1,9 @@
package dto
import "encoding/json"
import (
"encoding/json"
"strings"
)
type ResponseFormat struct {
Type string `json:"type,omitempty"`
@@ -18,6 +21,8 @@ type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
@@ -45,6 +50,7 @@ type GeneralOpenAIRequest struct {
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
}
type OpenAITools struct {
@@ -86,18 +92,20 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent
parsedStringContent *string
}
type MediaContent struct {
Type string `json:"type"`
Text string `json:"text"`
Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
}
@@ -146,88 +154,139 @@ func (m *Message) SetToolCalls(toolCalls any) {
}
func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent
}
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
}
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
m.parsedStringContent = &content
m.parsedContent = nil
}
func (m *Message) SetMediaContent(content []MediaContent) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
m.parsedContent = nil
m.parsedStringContent = nil
}
func (m *Message) IsStringContent() bool {
if m.parsedStringContent != nil {
return true
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return true
}
return false
}
func (m *Message) ParseContent() []MediaContent {
if m.parsedContent != nil {
return m.parsedContent
}
var contentList []MediaContent
// 先尝试解析为字符串
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaContent{
contentList = []MediaContent{{
Type: ContentTypeText,
Text: stringContent,
})
}}
m.parsedContent = contentList
return contentList
}
var arrayContent []json.RawMessage
// 尝试解析为数组
var arrayContent []map[string]interface{}
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
contentType, ok := contentItem["type"].(string)
if !ok {
continue
}
switch contentMap["type"] {
switch contentType {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: subStr,
Text: text,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "high"
}
imageUrl := contentItem["image_url"]
switch v := imageUrl.(type) {
case string:
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Url: v,
Detail: "high",
},
})
case map[string]interface{}:
url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string)
if !ok2 {
detail = "high"
}
if ok1 {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: detail,
},
})
}
}
case ContentTypeInputAudio:
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: subObj["data"].(string),
Format: subObj["format"].(string),
},
})
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: data,
Format: format,
},
})
}
}
}
}
return contentList
}
return nil
if len(contentList) > 0 {
m.parsedContent = contentList
}
return contentList
}
+29 -3
View File
@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -78,6 +79,17 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
return *c.Content
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
if c.ReasoningContent == nil {
return ""
}
return *c.ReasoningContent
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
c.ReasoningContent = &s
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
@@ -108,6 +120,20 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices)
return &ChatCompletionsStreamResponse{
Id: c.Id,
Object: c.Object,
Created: c.Created,
Model: c.Model,
SystemFingerprint: c.SystemFingerprint,
Choices: choices,
Usage: c.Usage,
}
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
+2 -2
View File
@@ -119,9 +119,9 @@ func main() {
}
if os.Getenv("ENABLE_PPROF") == "true" {
go func() {
gopool.Go(func() {
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
}()
})
go common.Monitor()
common.SysLog("pprof enabled")
}
+5 -1
View File
@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.IsUserEnabled(token.UserId, false)
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
userCache.WriteContext(c)
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
+1 -6
View File
@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
return
}
}
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.GetUserGroup(userId, false)
userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
@@ -135,17 +134,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
@@ -170,7 +166,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
+172
View File
@@ -0,0 +1,172 @@
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0,表示不限制
if maxCount == 0 {
return true, nil
}
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
// 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
return false, err
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
return false, err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
return false, nil
}
return true, nil
}
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0,不记录请求
if maxCount == 0 {
return
}
now := time.Now().Format(timeFormat)
rdb.LPush(ctx, key, now)
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
// 4. 处理请求
c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制(当totalMaxCount为0时跳过)
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制,这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 3. 处理请求
c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
// 如果未启用限流,直接放行
if !setting.ModelRequestRateLimitEnabled {
return defNext
}
// 计算限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 根据存储类型选择限流处理器
if common.RedisEnabled {
return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
} else {
return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
}
}
+5 -5
View File
@@ -1,8 +1,8 @@
package model
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"os"
"strings"
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username, _ := GetUsernameById(userId, false)
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
common.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
+18 -2
View File
@@ -84,7 +84,10 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -105,12 +108,14 @@ func InitOptionMap() {
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMap["GeminiSafetySettings"] = setting.GeminiSafetySettingsJsonString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -226,6 +231,9 @@ func updateOptionMap(key string, value string) (err error) {
setting.DemoSiteEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
case "ModelRequestRateLimitEnabled":
setting.ModelRequestRateLimitEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
@@ -306,8 +314,14 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRequestRateLimitCount":
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitDurationMinutes":
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
@@ -338,6 +352,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "GeminiSafetySettings":
setting.GeminiSafetySettingFromJsonString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
+3 -82
View File
@@ -3,13 +3,11 @@ package model
import (
"errors"
"fmt"
"one-api/common"
"strings"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
"strconv"
"strings"
)
type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(relayInfo.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("failed to send email" + err.Error())
}
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
}
}()
}
}
}
return nil
}
+1 -1
View File
@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
return nil, nil
return nil, fmt.Errorf("redis is not enabled")
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
+110 -46
View File
@@ -1,6 +1,7 @@
package model
import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
@@ -38,6 +39,20 @@ type User struct {
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
}
func (user *User) ToBaseUser() *UserBase {
cache := &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return cache
}
func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *User) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
@@ -289,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
// that means if your field's value is 0, '', false or other zero values,
// it won't be used to build query conditions
password := user.Password
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
@@ -471,35 +502,35 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
status, err := getUserStatusCache(id)
if err == nil {
return status == common.UserStatusEnabled, nil
}
// Don't return error - fall through to DB
}
fromDB = true
var user User
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil {
return false, err
}
return user.Status == common.UserStatusEnabled, nil
}
//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
// defer func() {
// // Update Redis cache asynchronously on successful DB read
// if shouldUpdateRedis(fromDB, err) {
// gopool.Go(func() {
// if err := updateUserStatusCache(id, status); err != nil {
// common.SysError("failed to update user status cache: " + err.Error())
// }
// })
// }
// }()
// if !fromDB && common.RedisEnabled {
// // Try Redis first
// status, err := getUserStatusCache(id)
// if err == nil {
// return status == common.UserStatusEnabled, nil
// }
// // Don't return error - fall through to DB
// }
// fromDB = true
// var user User
// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
// if err != nil {
// return false, err
// }
//
// return user.Status == common.UserStatusEnabled, nil
//}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
return quota, nil
}
// Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
@@ -580,7 +610,36 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
return group, nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
common.SysError("failed to update user setting cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
setting, err := getUserSettingCache(id)
if err == nil {
return setting, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
}
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -590,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
common.SysError("failed to increase user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -635,15 +694,20 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
return nil
}
if delta > 0 {
return IncreaseUserQuota(id, delta)
return IncreaseUserQuota(id, delta, false)
} else {
return DecreaseUserQuota(id, -delta)
}
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
//func GetRootUserEmail() (email string) {
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
// return email
//}
func GetRootUser() (user *User) {
DB.Where("role = ?", common.RoleRootUser).First(&user)
return user
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
return !errors.Is(err, gorm.ErrRecordNotFound)
}
func (u *User) FillUserByLinuxDOId() error {
if u.LinuxDOId == "" {
func (user *User) FillUserByLinuxDOId() error {
if user.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}
+177 -160
View File
@@ -1,206 +1,223 @@
package model
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"strconv"
"time"
"github.com/bytedance/gopkg/util/gopool"
)
// Change UserCache struct to userCache
type userCache struct {
// UserBase struct remains the same as it represents the cached data structure
type UserBase struct {
Id int `json:"id"`
Group string `json:"group"`
Email string `json:"email"`
Quota int `json:"quota"`
Status int `json:"status"`
Role int `json:"role"`
Username string `json:"username"`
Setting string `json:"setting"`
}
// Rename all exported functions to private ones
// invalidateUserCache clears all user related cache
func (user *UserBase) WriteContext(c *gin.Context) {
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// getUserCacheKey returns the key for user cache
func getUserCacheKey(userId int) string {
return fmt.Sprintf("user:%d", userId)
}
// invalidateUserCache clears user cache
func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHDelObj(getUserCacheKey(userId))
}
keys := []string{
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
// updateUserCache updates all user cache fields using hash
func updateUserCache(user User) error {
if !common.RedisEnabled {
return nil
}
for _, key := range keys {
if err := common.RedisDel(key); err != nil {
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// GetUserCache gets complete user cache from hash
func GetUserCache(userId int) (userCache *UserBase, err error) {
var user *User
var fromDB bool
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}
return nil
}
}()
// updateUserGroupCache updates user group cache
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
group,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserQuotaCache updates user quota cache
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf("%d", quota),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserStatusCache updates user status cache
func updateUserStatusCache(userId int, userEnabled bool) error {
if !common.RedisEnabled {
return nil
}
enabled := "0"
if userEnabled {
enabled = "1"
}
return common.RedisSet(
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
enabled,
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
)
}
// updateUserNameCache updates username cache
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
username,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserCache updates all user cache fields
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
if !common.RedisEnabled {
return nil
// Try getting from Redis first
userCache, err = cacheGetUserBase(userId)
if err == nil {
return userCache, nil
}
if err := updateUserGroupCache(userId, userGroup); err != nil {
return fmt.Errorf("update group cache: %w", err)
}
if err := updateUserQuotaCache(userId, quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(userId, username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
return nil
}
// getUserGroupCache gets user group from cache
func getUserGroupCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
}
// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
// If Redis fails, get from DB
fromDB = true
user, err = GetUserById(userId, false)
if err != nil {
return 0, err
return nil, err // Return nil and error if DB lookup fails
}
return strconv.Atoi(quotaStr)
// Create cache object from user data
userCache = &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return userCache, nil
}
// getUserStatusCache gets user status from cache
func getUserStatusCache(userId int) (int, error) {
func cacheGetUserBase(userId int) (*UserBase, error) {
if !common.RedisEnabled {
return 0, nil
return nil, fmt.Errorf("redis is not enabled")
}
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
var userCache UserBase
// Try getting from Redis first
err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
if err != nil {
return 0, err
return nil, err
}
return strconv.Atoi(statusStr)
return &userCache, nil
}
// getUserNameCache gets username from cache
func getUserNameCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
}
// getUserCache gets complete user cache
func getUserCache(userId int) (*userCache, error) {
if !common.RedisEnabled {
return nil, nil
}
group, err := getUserGroupCache(userId)
if err != nil {
return nil, fmt.Errorf("get group cache: %w", err)
}
quota, err := getUserQuotaCache(userId)
if err != nil {
return nil, fmt.Errorf("get quota cache: %w", err)
}
status, err := getUserStatusCache(userId)
if err != nil {
return nil, fmt.Errorf("get status cache: %w", err)
}
username, err := getUserNameCache(userId)
if err != nil {
return nil, fmt.Errorf("get username cache: %w", err)
}
return &userCache{
Id: userId,
Group: group,
Quota: quota,
Status: status,
Username: username,
}, nil
}
// Add atomic quota operations
// Add atomic quota operations using hash fields
func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled {
return nil
}
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
return common.RedisIncr(key, delta)
return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
}
func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta)
}
// Helper functions to get individual fields if needed
func getUserGroupCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Group, nil
}
func getUserQuotaCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Quota, nil
}
func getUserStatusCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Status, nil
}
func getUserNameCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
}
return cache.GetSetting(), nil
}
// New functions for individual field updates
func updateUserStatusCache(userId int, status bool) error {
if !common.RedisEnabled {
return nil
}
statusInt := common.UserStatusEnabled
if !status {
statusInt = common.UserStatusDisabled
}
return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
}
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
}
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
}
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
}
func updateUserSettingCache(userId int, setting string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
}
+1 -1
View File
@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req, info.ToRelayInfo())
resp, err := doRequest(c, req, info.RelayInfo)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
+2 -1
View File
@@ -9,7 +9,8 @@ var awsModelIDMap = map[string]string{
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
}
var ChannelName = "aws"
+2
View File
@@ -16,6 +16,7 @@ type AwsClaudeRequest struct {
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
@@ -30,5 +31,6 @@ func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}
+2
View File
@@ -11,6 +11,8 @@ var ModelList = []string{
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
}
var ChannelName = "claude"
+12 -3
View File
@@ -11,6 +11,9 @@ type ClaudeMediaMessage struct {
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
@@ -54,9 +57,15 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
type ClaudeError struct {
+42 -2
View File
@@ -92,6 +92,29 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Stream: textRequest.Stream,
Tools: claudeTools,
}
if strings.HasSuffix(textRequest.Model, "-thinking") {
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 8192
}
// 因为BudgetTokens 必须大于1024
if claudeRequest.MaxTokens < 1280 {
claudeRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{
Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * 0.8),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
@@ -308,12 +331,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" {
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
Arguments: claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
thinkingContent := claudeResponse.Delta.Thinking
choice.Delta.ReasoningContent = &thinkingContent
}
}
} else if claudeResponse.Type == "message_delta" {
@@ -352,6 +383,8 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
thinkingContent := ""
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
@@ -367,7 +400,8 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
if message.Type == "tool_use" {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{
ID: message.Id,
@@ -377,6 +411,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.Text
}
}
}
@@ -391,6 +430,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 {
choice.Message.SetToolCalls(tools)
}
choice.Message.ReasoningContent = thinkingContent
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)
fullTextResponse.Choices = choices
+2 -1
View File
@@ -4,13 +4,14 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
+7 -1
View File
@@ -10,6 +10,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+28 -2
View File
@@ -9,9 +9,18 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
const (
BotTypeChatFlow = 1 // chatflow default
BotTypeAgent = 2
BotTypeWorkFlow = 3
BotTypeCompletion = 4
)
type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -25,10 +34,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion
} else {
a.BotType = BotTypeChatFlow
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -53,7 +80,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+99 -4
View File
@@ -1,15 +1,21 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not supported model for image generation")
}
// convert size to aspect ratio
aspectRatio := "1:1" // default aspect ratio
switch request.Size {
case "1024x1024":
aspectRatio = "1:1"
case "1024x1792":
aspectRatio = "9:16"
case "1792x1024":
aspectRatio = "16:9"
}
// build gemini imagen request
geminiRequest := GeminiImageRequest{
Instances: []GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
}
return geminiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
@@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
}
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {
@@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return
}
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage = &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
+10
View File
@@ -16,6 +16,16 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
// imagen models
"imagen-3.0-generate-002",
}
var SafetySettingList = []string{
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_VIOLENCE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
}
var ChannelName = "google gemini"
+27
View File
@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
// Imagen related structs
type GeminiImageRequest struct {
Instances []GeminiImageInstance `json:"instances"`
Parameters GeminiImageParameters `json:"parameters"`
}
type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
type GeminiImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type GeminiImageResponse struct {
Predictions []GeminiImagePrediction `json:"predictions"`
}
type GeminiImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}
+11 -22
View File
@@ -11,6 +11,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
"strings"
"unicode/utf8"
@@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: common.GeminiSafetySetting,
},
},
//SafetySettings: []GeminiChatSafetySettings{},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -52,6 +32,15 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
Category: category,
Threshold: setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
+1 -4
View File
@@ -41,9 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
mistralReq := requestOpenAI2Mistral(*request)
//common.LogJson(c, "body", mistralReq)
return mistralReq, nil
return requestOpenAI2Mistral(request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -55,7 +53,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+8 -12
View File
@@ -1,25 +1,21 @@
package mistral
import (
"encoding/json"
"one-api/dto"
)
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
}
message.SetMediaContent(mediaMessages)
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
+2 -1
View File
@@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -46,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Ollama(*request), nil
return requestOpenAI2Ollama(*request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+18 -14
View File
@@ -3,18 +3,22 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {
@@ -35,7 +39,7 @@ type OllamaEmbeddingRequest struct {
}
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}
+30 -4
View File
@@ -9,14 +9,36 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, err
}
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
}
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
@@ -36,10 +58,14 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.MaxTokens,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}, nil
}
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
+1 -1
View File
@@ -119,7 +119,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasPrefix(request.Model, "o3") {
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") {
+69 -19
View File
@@ -5,10 +5,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io"
"math"
"mime/multipart"
@@ -23,21 +19,66 @@ import (
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
if data == "" {
return nil
}
if forceFormat {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
if !forceFormat && !thinkToContent {
return service.StringData(c, data)
}
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
if !thinkToContent {
return service.ObjectData(c, lastStreamResponse)
}
return service.StringData(c, data)
// Handle think to content conversion
if info.IsFirstResponse {
response := lastStreamResponse.Copy()
for i := range response.Choices {
response.Choices[i].Delta.SetContentString("<think>\n")
response.Choices[i].Delta.SetReasoningContent("")
}
service.ObjectData(c, response)
}
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return service.ObjectData(c, lastStreamResponse)
}
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>")
response.Choices[j].Delta.SetReasoningContent("")
}
info.SendLastReasoningResponse = true
service.ObjectData(c, response)
}
// Convert reasoning content to regular content
if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.SetReasoningContent("")
}
}
return service.ObjectData(c, lastStreamResponse)
}
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -56,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var usage = &dto.Usage{}
var streamItems []string // store stream items
var forceFormat bool
var thinkToContent bool
if info.ChannelType == common.ChannelTypeCustom {
if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok {
forceFormat = forceFmt
}
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
}
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
thinkToContent = think2Content
}
toolCount := 0
@@ -84,9 +128,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
)
gopool.Go(func() {
for scanner.Scan() {
info.SetFirstResponseTime()
//info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 { // ignore blank line or wrong format
continue
}
@@ -98,10 +145,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat)
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
@@ -141,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
if shouldSendLastResp {
sendStreamData(c, lastStreamData, forceFormat)
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 计算token
@@ -162,6 +210,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -182,6 +231,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -273,7 +323,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
+8
View File
@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
@@ -72,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
+1 -2
View File
@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages[j] = mediaMessage
}
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
+51 -84
View File
@@ -13,24 +13,25 @@ import (
)
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
TokenKey string
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
setFirstResponse bool
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
RecodeModelName string
ChannelType int
ChannelId int
TokenId int
TokenKey string
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
IsFirstResponse bool
SendLastReasoningResponse bool
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
//RecodeModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
@@ -39,6 +40,7 @@ type RelayInfo struct {
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
IsModelMapped bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
@@ -48,6 +50,21 @@ type RelayInfo struct {
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
}
// 定义支持流式选项的通道类型
var streamSupportedChannels = map[int]bool{
common.ChannelTypeOpenAI: true,
common.ChannelTypeAnthropic: true,
common.ChannelTypeAws: true,
common.ChannelTypeGemini: true,
common.ChannelCloudflare: true,
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -75,6 +92,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail),
IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
@@ -89,12 +110,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
RecodeModelName: c.GetString("recode_model"),
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
//RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
@@ -110,10 +132,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
info.ChannelType == common.ChannelTypeVolcEngine {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
return info
@@ -128,26 +147,14 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
}
func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse {
if info.IsFirstResponse {
info.FirstResponseTime = time.Now()
info.setFirstResponse = true
info.IsFirstResponse = false
}
}
type TaskRelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
StartTime time.Time
ApiType int
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiKey string
BaseUrl string
*RelayInfo
Action string
OriginTaskID string
@@ -155,48 +162,8 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
StartTime: startTime,
ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
RelayInfo: GenRelayInfo(c),
}
return info
}
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}
+25
View File
@@ -0,0 +1,25 @@
package helper
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/relay/common"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return fmt.Errorf("unmarshal_model_mapping_failed")
}
if modelMap[info.OriginModelName] != "" {
info.UpstreamModelName = modelMap[info.OriginModelName]
info.IsModelMapped = true
}
}
return nil
}
+41
View File
@@ -0,0 +1,41 @@
package helper
import (
"github.com/gin-gonic/gin"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
)
type PriceData struct {
ModelPrice float64
ModelRatio float64
GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int
var modelRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
preConsumedTokens = promptTokens + maxTokens
}
modelRatio = common.GetModelRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
return PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
UsePrice: usePrice,
ShouldPreConsumedQuota: preConsumedQuota,
}
}
+13 -20
View File
@@ -1,7 +1,6 @@
package relay
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
@@ -11,8 +10,10 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return nil, errors.New("model is required")
}
if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(audioRequest.Input)
words, err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
@@ -73,15 +75,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo.PromptTokens = promptTokens
}
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -91,19 +91,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}()
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = audioRequest.Model
audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
@@ -140,7 +133,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
+16 -24
View File
@@ -12,6 +12,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -60,15 +61,16 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(imageRequest.Prompt)
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
return imageRequest, nil
}
func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,29 +79,20 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = imageRequest.Model
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success {
modelRatio := common.GetModelRatio(imageRequest.Model)
imageRequest.Model = relayInfo.UpstreamModelName
priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if !priceData.UsePrice {
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
modelPrice = 0.0025 * modelRatio
priceData.ModelPrice = 0.0025 * priceData.ModelRatio
}
groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
sizeRatio := 1.0
@@ -122,11 +115,11 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
}
imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest)
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
}
adaptor := GetAdaptor(relayInfo.ApiType)
@@ -184,7 +177,6 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
return nil
}
+8 -9
View File
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -192,9 +191,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return &mjResp.Response
}
defer func(ctx context.Context) {
defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
@@ -498,9 +497,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
+41 -61
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -14,6 +15,7 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -75,40 +77,21 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
// map model name
//isModelMapped := false
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
//isModelMapped = true
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = textRequest.Model
relayInfo.RecodeModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo)
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens,如果上下文中已经存在,则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
@@ -123,20 +106,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -219,10 +192,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} else {
postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
}
return nil
}
@@ -247,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
return promptTokens, err
}
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
var err error
var words []string
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
err = service.CheckSensitiveMessages(textRequest.Messages)
words, err = service.CheckSensitiveMessages(textRequest.Messages)
case relayconstant.RelayModeCompletions:
err = service.CheckSensitiveInput(textRequest.Prompt)
words, err = service.CheckSensitiveInput(textRequest.Prompt)
case relayconstant.RelayModeModerations:
err = service.CheckSensitiveInput(textRequest.Input)
words, err = service.CheckSensitiveInput(textRequest.Input)
case relayconstant.RelayModeEmbeddings:
err = service.CheckSensitiveInput(textRequest.Input)
words, err = service.CheckSensitiveInput(textRequest.Input)
}
return err
return words, err
}
// 预扣费并返回用户剩余配额
@@ -272,8 +246,9 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
@@ -282,18 +257,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
}
}
if preConsumedQuota > 0 {
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -307,20 +282,19 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 {
go func() {
gopool.Go(func() {
relayInfoCopy := *relayInfo
err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
}()
})
}
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
@@ -332,12 +306,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
ratio := priceData.ModelRatio * priceData.GroupRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quota := 0
if !usePrice {
if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
@@ -368,7 +348,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
+9 -32
View File
@@ -10,8 +10,8 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[embeddingRequest.Model] != "" {
embeddingRequest.Model = modelMap[embeddingRequest.Model]
// set upstream model name
//isModelMapped = true
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = embeddingRequest.Model
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(embeddingRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
+9 -32
View File
@@ -9,8 +9,8 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[rerankRequest.Model] != "" {
rerankRequest.Model = modelMap[rerankRequest.Model]
// set upstream model name
//isModelMapped = true
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(rerankRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
+4 -5
View File
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return
}
defer func(ctx context.Context) {
defer func() {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
}
}(c.Request.Context())
}()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
if taskErr != nil {
+1
View File
@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
}
adminRoute := userRoute.Group("/")
+1
View File
@@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) {
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
{
// WebSocket 路由
wsRouter := relayV1Router.Group("")
+37 -9
View File
@@ -2,6 +2,7 @@ package service
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"one-api/common"
@@ -9,19 +10,46 @@ import (
"strings"
)
// WorkerRequest Worker请求的数据结构
type WorkerRequest struct {
URL string `json:"url"`
Key string `json:"key"`
Method string `json:"method,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
Body json.RawMessage `json:"body,omitempty"`
}
// DoWorkerRequest 通过Worker发送请求
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
if !setting.EnableWorker() {
return nil, fmt.Errorf("worker not enabled")
}
if !strings.HasPrefix(req.URL, "https") {
return nil, fmt.Errorf("only support https url")
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// 序列化worker请求数据
workerPayload, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
}
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
}
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
if !strings.HasPrefix(originUrl, "https") {
return nil, fmt.Errorf("only support https url")
req := &WorkerRequest{
URL: originUrl,
Key: setting.WorkerValidKey,
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
return DoWorkerRequest(req)
} else {
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)
+5 -5
View File
@@ -4,7 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
relaymodel "one-api/dto"
"one-api/dto"
"one-api/model"
"one-api/setting"
"strings"
@@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
notifyRootUser(subject, content)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
notifyRootUser(subject, content)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
@@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
return false
}
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
+4
View File
@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort
}
if relayInfo.IsModelMapped {
other["is_model_mapped"] = true
other["upstream_model_name"] = relayInfo.UpstreamModelName
}
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo
+117
View File
@@ -0,0 +1,117 @@
package service
import (
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"one-api/common"
"one-api/constant"
"strconv"
"sync"
"time"
)
// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
var (
notifyLimitStore sync.Map
cleanupOnce sync.Once
)
type limitCount struct {
Count int
Timestamp time.Time
}
func getDuration() time.Duration {
minute := constant.NotificationLimitDurationMinute
return time.Duration(minute) * time.Minute
}
// startCleanupTask starts a background task to clean up expired entries
func startCleanupTask() {
gopool.Go(func() {
for {
time.Sleep(time.Hour)
now := time.Now()
notifyLimitStore.Range(func(key, value interface{}) bool {
if limit, ok := value.(limitCount); ok {
if now.Sub(limit.Timestamp) >= getDuration() {
notifyLimitStore.Delete(key)
}
}
return true
})
}
})
}
// CheckNotificationLimit checks if the user has exceeded their notification limit
// Returns true if the user can send notification, false if limit exceeded
func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
if common.RedisEnabled {
return checkRedisLimit(userId, notifyType)
}
return checkMemoryLimit(userId, notifyType)
}
func checkRedisLimit(userId int, notifyType string) (bool, error) {
key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
// Get current count
count, err := common.RedisGet(key)
if err != nil && err.Error() != "redis: nil" {
return false, fmt.Errorf("failed to get notification count: %w", err)
}
// If key doesn't exist, initialize it
if count == "" {
err = common.RedisSet(key, "1", getDuration())
return true, err
}
currentCount, _ := strconv.Atoi(count)
limit := constant.NotifyLimitCount
// Check if limit is already reached
if currentCount >= limit {
return false, nil
}
// Only increment if under limit
err = common.RedisIncr(key, 1)
if err != nil {
return false, fmt.Errorf("failed to increment notification count: %w", err)
}
return true, nil
}
func checkMemoryLimit(userId int, notifyType string) (bool, error) {
// Ensure cleanup task is started
cleanupOnce.Do(startCleanupTask)
key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
now := time.Now()
// Get current limit count or initialize new one
var currentLimit limitCount
if value, ok := notifyLimitStore.Load(key); ok {
currentLimit = value.(limitCount)
// Check if the entry has expired
if now.Sub(currentLimit.Timestamp) >= getDuration() {
currentLimit = limitCount{Count: 0, Timestamp: now}
}
} else {
currentLimit = limitCount{Count: 0, Timestamp: now}
}
// Increment count
currentLimit.Count++
// Check against limits
limit := constant.NotifyLimitCount
// Store updated count
notifyLimitStore.Store(key, currentLimit)
return currentLimit.Count <= limit, nil
}
+102 -14
View File
@@ -3,11 +3,14 @@ package service
import (
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"math"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
"strings"
"time"
@@ -66,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return err
}
modelName := relayInfo.UpstreamModelName
modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
@@ -92,14 +95,14 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
}
err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
@@ -120,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quotaInfo := QuotaInfo{
@@ -171,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -182,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -195,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: relayInfo.RecodeModelName,
ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
@@ -218,11 +225,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
@@ -231,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := relayInfo.RecodeModelName
logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent
}
@@ -239,3 +246,84 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() {
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
}
})
}
+38 -29
View File
@@ -8,48 +8,47 @@ import (
"strings"
)
func CheckSensitiveMessages(messages []dto.Message) error {
func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
if len(messages) == 0 {
return nil, nil
}
for _, message := range messages {
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
if ok, words := SensitiveWordContains(stringContent); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
// TODO: check image url
continue
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
// TODO: check image url
} else {
if ok, words := SensitiveWordContains(m.Text); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
}
}
// 检查 text 是否为空
if m.Text == "" {
continue
}
if ok, words := SensitiveWordContains(m.Text); ok {
return words, errors.New("sensitive words detected")
}
}
}
return nil
return nil, nil
}
func CheckSensitiveText(text string) error {
func CheckSensitiveText(text string) ([]string, error) {
if ok, words := SensitiveWordContains(text); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
return words, errors.New("sensitive words detected")
}
return nil
return nil, nil
}
func CheckSensitiveInput(input any) error {
func CheckSensitiveInput(input any) ([]string, error) {
switch v := input.(type) {
case string:
return CheckSensitiveText(v)
case []string:
text := ""
var builder strings.Builder
for _, s := range v {
text += s
builder.WriteString(s)
}
return CheckSensitiveText(text)
return CheckSensitiveText(builder.String())
}
return CheckSensitiveText(fmt.Sprintf("%v", input))
}
@@ -59,8 +58,11 @@ func SensitiveWordContains(text string) (bool, []string) {
if len(setting.SensitiveWords) == 0 {
return false, nil
}
if len(text) == 0 {
return false, nil
}
checkText := strings.ToLower(text)
return AcSearch(checkText, setting.SensitiveWords, false)
return AcSearch(checkText, setting.SensitiveWords, true)
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -72,14 +74,21 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
m := InitAc(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)
words := make([]string, 0, len(hits))
var builder strings.Builder
builder.Grow(len(text))
lastPos := 0
for _, hit := range hits {
pos := hit.Pos
word := string(hit.Word)
text = text[:pos] + "**###**" + text[pos+len(word):]
builder.WriteString(text[lastPos:pos])
builder.WriteString("**###**")
lastPos = pos + len(word)
words = append(words, word)
}
return true, words, text
builder.WriteString(text[lastPos:])
return true, words, builder.String()
}
return false, nil, text
}
+21 -23
View File
@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if text == "" {
return 0
}
return len(tokenEncoder.Encode(text, nil, nil))
}
@@ -282,30 +285,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
}
}
+66 -8
View File
@@ -3,15 +3,73 @@ package service
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"strings"
)
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
_ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
}
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
}
// Check notification limit
canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
}
switch notifyType {
case constant.NotifyTypeEmail:
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil
}
// 获取 webhook secret
var webhookSecret string
if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
webhookSecret, _ = secret.(string)
}
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
}
return nil
}
func sendEmailNotify(userEmail string, data dto.Notify) error {
// make email content
content := data.Content
// 处理占位符
for _, value := range data.Values {
content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
}
return common.SendEmail(data.Title, userEmail, content)
}
+118
View File
@@ -0,0 +1,118 @@
package service
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"one-api/dto"
"one-api/setting"
"time"
)
// WebhookPayload webhook 通知的负载数据
type WebhookPayload struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values,omitempty"`
Timestamp int64 `json:"timestamp"`
}
// generateSignature 生成 webhook 签名
func generateSignature(secret string, payload []byte) string {
h := hmac.New(sha256.New, []byte(secret))
h.Write(payload)
return hex.EncodeToString(h.Sum(nil))
}
// SendWebhookNotify 发送 webhook 通知
func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
// 处理占位符
content := data.Content
for _, value := range data.Values {
content = fmt.Sprintf(content, value)
}
// 构建 webhook 负载
payload := WebhookPayload{
Type: data.Type,
Title: data.Title,
Content: content,
Values: data.Values,
Timestamp: time.Now().Unix(),
}
// 序列化负载
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal webhook payload: %v", err)
}
// 创建 HTTP 请求
var req *http.Request
var resp *http.Response
if setting.EnableWorker() {
// 构建worker请求数据
workerReq := &WorkerRequest{
URL: webhookURL,
Key: setting.WorkerValidKey,
Method: http.MethodPost,
Headers: map[string]string{
"Content-Type": "application/json",
},
Body: payloadBytes,
}
// 如果有secret,添加签名到headers
if secret != "" {
signature := generateSignature(secret, payloadBytes)
workerReq.Headers["X-Webhook-Signature"] = signature
workerReq.Headers["Authorization"] = "Bearer " + secret
}
resp, err = DoWorkerRequest(workerReq)
if err != nil {
return fmt.Errorf("failed to send webhook request through worker: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
} else {
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
if err != nil {
return fmt.Errorf("failed to create webhook request: %v", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
// 如果有 secret,生成签名
if secret != "" {
signature := generateSignature(secret, payloadBytes)
req.Header.Set("X-Webhook-Signature", signature)
}
// 发送请求
client := GetImpatientHttpClient()
resp, err = client.Do(req)
if err != nil {
return fmt.Errorf("failed to send webhook request: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
}
return nil
}
+45
View File
@@ -0,0 +1,45 @@
package setting
import (
"encoding/json"
"one-api/common"
)
var geminiSafetySettings = map[string]string{
"default": "OFF",
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
}
func GetGeminiSafetySetting(key string) string {
if value, ok := geminiSafetySettings[key]; ok {
return value
}
return geminiSafetySettings["default"]
}
func GeminiSafetySettingFromJsonString(jsonString string) {
geminiSafetySettings = map[string]string{}
err := json.Unmarshal([]byte(jsonString), &geminiSafetySettings)
if err != nil {
geminiSafetySettings = map[string]string{
"default": "OFF",
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
}
}
// check must have default
if _, ok := geminiSafetySettings["default"]; !ok {
geminiSafetySettings["default"] = common.GeminiSafetySetting
}
}
func GeminiSafetySettingsJsonString() string {
// check must have default
if _, ok := geminiSafetySettings["default"]; !ok {
geminiSafetySettings["default"] = common.GeminiSafetySetting
}
jsonString, err := json.Marshal(geminiSafetySettings)
if err != nil {
return "{}"
}
return string(jsonString)
}
+1
View File
@@ -23,6 +23,7 @@ func AutomaticDisableKeywordsFromString(s string) {
ak := strings.Split(s, "\n")
for _, k := range ak {
k = strings.TrimSpace(k)
k = strings.ToLower(k)
if k != "" {
AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
}
+6
View File
@@ -0,0 +1,6 @@
package setting
var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
+19 -16
View File
@@ -357,6 +357,13 @@ const ChannelsTable = () => {
dataIndex: 'operate',
render: (text, record, index) => {
if (record.children === undefined) {
// 构建模型测试菜单
const modelMenuItems = record.models.split(',').map(model => ({
node: 'item',
name: model,
onClick: () => testChannel(record, model)
}));
return (
<div>
<SplitButtonGroup
@@ -374,7 +381,7 @@ const ChannelsTable = () => {
<Dropdown
trigger="click"
position="bottomRight"
menu={record.test_models}
menu={modelMenuItems} // 使用即时生成的菜单项
>
<Button
style={{ padding: '8px 4px' }}
@@ -545,17 +552,6 @@ const ChannelsTable = () => {
let channelTags = {};
for (let i = 0; i < channels.length; i++) {
channels[i].key = '' + channels[i].id;
let test_models = [];
channels[i].models.split(',').forEach((item, index) => {
test_models.push({
node: 'item',
name: item,
onClick: () => {
testChannel(channels[i], item);
}
});
});
channels[i].test_models = test_models;
if (!enableTagMode) {
channelDates.push(channels[i]);
} else {
@@ -801,7 +797,8 @@ const ChannelsTable = () => {
const updateChannelProperty = (channelId, updateFn) => {
// Create a new copy of channels array
const newChannels = [...channels];
let updated = false;
// Find and update the correct channel
newChannels.forEach(channel => {
if (channel.children !== undefined) {
@@ -809,26 +806,32 @@ const ChannelsTable = () => {
channel.children.forEach(child => {
if (child.id === channelId) {
updateFn(child);
updated = true;
}
});
} else if (channel.id === channelId) {
// Direct channel match
updateFn(channel);
updated = true;
}
});
// Update state with new array to trigger re-render
setChannels(newChannels);
// Only update state if we actually modified a channel
if (updated) {
setChannels(newChannels);
}
};
const testChannel = async (record, model) => {
const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
const { success, message, time } = res.data;
if (success) {
// Also update the channels state to persist the change
updateChannelProperty(record.id, (channel) => {
channel.response_time = time * 1000;
channel.test_time = Date.now() / 1000;
});
showInfo(t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。').replace('${name}', record.name).replace('${time.toFixed(2)}', time.toFixed(2)));
} else {
showError(message);
+86 -14
View File
@@ -15,7 +15,7 @@ import {
Button, Descriptions,
Form,
Layout,
Modal,
Modal, Popover,
Select,
Space,
Spin,
@@ -34,6 +34,7 @@ import {
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js';
import { StyleContext } from '../context/Style/index.js';
import { IconInherit, IconRefresh } from '@douyinfe/semi-icons';
const { Header } = Layout;
@@ -141,7 +142,78 @@ const LogsTable = () => {
</Tag>
);
}
}
}
function renderModelName(record) {
let other = getLogOther(record.other);
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (!modelMapped) {
return <Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{' '}{record.model_name}{' '}
</Tag>;
} else {
return (
<>
<Space vertical align={'start'}>
<Popover content={
<div style={{padding: 10}}>
<Space vertical align={'start'}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{t('请求并计费模型')}{' '}{record.model_name}{' '}
</Tag>
<Tag
color={stringToColor(other.upstream_model_name)}
size='large'
onClick={(event) => {
copyText(event, other.upstream_model_name).then(r => {});
}}
>
{t('实际模型')}{' '}{other.upstream_model_name}{' '}
</Tag>
</Space>
</div>
}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
suffixIcon={<IconRefresh style={{width: '0.8em', height: '0.8em', opacity: 0.6}} />}
>
{' '}{record.model_name}{' '}
</Tag>
</Popover>
{/*<Tooltip content={t('实际模型')}>*/}
{/* <Tag*/}
{/* color={stringToColor(other.upstream_model_name)}*/}
{/* size='large'*/}
{/* onClick={(event) => {*/}
{/* copyText(event, other.upstream_model_name).then(r => {});*/}
{/* }}*/}
{/* >*/}
{/* {' '}{other.upstream_model_name}{' '}*/}
{/* </Tag>*/}
{/*</Tooltip>*/}
</Space>
</>
);
}
}
const columns = [
{
@@ -272,18 +344,7 @@ const LogsTable = () => {
dataIndex: 'model_name',
render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? (
<>
<Tag
color={stringToColor(text)}
size='large'
onClick={(event) => {
copyText(event, text);
}}
>
{' '}
{text}{' '}
</Tag>
</>
<>{renderModelName(record)}</>
) : (
<></>
);
@@ -580,6 +641,17 @@ const LogsTable = () => {
value: logs[i].content,
});
if (logs[i].type === 2) {
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (modelMapped) {
expandDataLocal.push({
key: t('请求并计费模型'),
value: logs[i].model_name,
});
expandDataLocal.push({
key: t('实际模型'),
value: other.upstream_model_name,
});
}
let content = '';
if (other?.ws || other?.audio) {
content = renderAudioModelPrice(
+82
View File
@@ -0,0 +1,82 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import SettingGeminiModel from '../pages/Setting/Model/SettingGeminiModel.js';
const ModelSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
GeminiSafetySettings: '',
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key === 'GeminiSafetySettings'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
if (
item.key.endsWith('Enabled')
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* Gemini */}
<Card style={{ marginTop: '10px' }}>
<SettingGeminiModel options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default ModelSetting;
+177 -25
View File
@@ -26,6 +26,10 @@ import {
Tag,
Typography,
Collapsible,
Select,
Radio,
RadioGroup,
AutoComplete,
} from '@douyinfe/semi-ui';
import {
getQuotaPerUnit,
@@ -67,14 +71,16 @@ const PersonalSetting = () => {
const [transferAmount, setTransferAmount] = useState(0);
const [isModelsExpanded, setIsModelsExpanded] = useState(false);
const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量
const [notificationSettings, setNotificationSettings] = useState({
warningType: 'email',
warningThreshold: 100000,
webhookUrl: '',
webhookSecret: '',
notificationEmail: ''
});
const [showWebhookDocs, setShowWebhookDocs] = useState(false);
useEffect(() => {
// let user = localStorage.getItem('user');
// if (user) {
// userDispatch({ type: 'login', payload: user });
// }
// console.log(localStorage.getItem('user'))
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
@@ -105,6 +111,19 @@ const PersonalSetting = () => {
return () => clearInterval(countdownInterval); // Clean up on unmount
}, [disableButton, countdown]);
useEffect(() => {
if (userState?.user?.setting) {
const settings = JSON.parse(userState.user.setting);
setNotificationSettings({
warningType: settings.notify_type || 'email',
warningThreshold: settings.quota_warning_threshold || 500000,
webhookUrl: settings.webhook_url || '',
webhookSecret: settings.webhook_secret || '',
notificationEmail: settings.notification_email || ''
});
}
}, [userState?.user?.setting]);
const handleInputChange = (name, value) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
@@ -300,7 +319,36 @@ const PersonalSetting = () => {
}
};
const handleNotificationSettingChange = (type, value) => {
setNotificationSettings(prev => ({
...prev,
[type]: value.target ? value.target.value : value // 处理 Radio 事件对象
}));
};
const saveNotificationSettings = async () => {
try {
const res = await API.put('/api/user/setting', {
notify_type: notificationSettings.warningType,
quota_warning_threshold: parseFloat(notificationSettings.warningThreshold),
webhook_url: notificationSettings.webhookUrl,
webhook_secret: notificationSettings.webhookSecret,
notification_email: notificationSettings.notificationEmail
});
if (res.data.success) {
showSuccess(t('通知设置已更新'));
await getUserData();
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('更新通知设置失败'));
}
};
return (
<div>
<Layout>
<Layout.Content>
@@ -526,9 +574,7 @@ const PersonalSetting = () => {
</div>
<div style={{marginTop: 10}}>
<Typography.Text strong>{t('微信')}</Typography.Text>
<div
style={{display: 'flex', justifyContent: 'space-between'}}
>
<div style={{display: 'flex', justifyContent: 'space-between'}}>
<div>
<Input
value={
@@ -541,12 +587,16 @@ const PersonalSetting = () => {
</div>
<div>
<Button
disabled={
(userState.user && userState.user.wechat_id !== '') ||
!status.wechat_login
}
disabled={!status.wechat_login}
onClick={() => {
setShowWeChatBindModal(true);
}}
>
{status.wechat_login ? t('绑定') : t('未启用')}
{userState.user && userState.user.wechat_id !== ''
? t('修改绑定')
: status.wechat_login
? t('绑定')
: t('未启用')}
</Button>
</div>
</div>
@@ -672,18 +722,8 @@ const PersonalSetting = () => {
style={{marginTop: '10px'}}
/>
)}
{status.wechat_login && (
<Button
onClick={() => {
setShowWeChatBindModal(true);
}}
>
{t('绑定微信账号')}
</Button>
)}
<Modal
onCancel={() => setShowWeChatBindModal(false)}
// onOpen={() => setShowWeChatBindModal(true)}
visible={showWeChatBindModal}
size={'small'}
>
@@ -707,9 +747,121 @@ const PersonalSetting = () => {
</Modal>
</div>
</Card>
<Card style={{marginTop: 10}}>
<Typography.Title heading={6}>{t('通知设置')}</Typography.Title>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('通知方式')}</Typography.Text>
<div style={{marginTop: 10}}>
<RadioGroup
value={notificationSettings.warningType}
onChange={value => handleNotificationSettingChange('warningType', value)}
>
<Radio value="email">{t('邮件通知')}</Radio>
<Radio value="webhook">{t('Webhook通知')}</Radio>
</RadioGroup>
</div>
</div>
{notificationSettings.warningType === 'webhook' && (
<>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('Webhook地址')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.webhookUrl}
onChange={val => handleNotificationSettingChange('webhookUrl', val)}
placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')}
</Typography.Text>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
<div style={{cursor: 'pointer'}} onClick={() => setShowWebhookDocs(!showWebhookDocs)}>
{t('Webhook请求结构')} {showWebhookDocs ? '▼' : '▶'}
</div>
<Collapsible isOpen={showWebhookDocs}>
<pre style={{marginTop: 4, background: 'var(--semi-color-fill-0)', padding: 8, borderRadius: 4}}>
{`{
"type": "quota_exceed", // 通知类型
"title": "标题", // 通知标题
"content": "通知内容", // 通知内容,支持 {{value}} 变量占位符
"values": ["值1", "值2"], // 按顺序替换content中的 {{value}} 占位符
"timestamp": 1739950503 // 时间戳
}
示例
{
"type": "quota_exceed",
"title": "额度预警通知",
"content": "您的额度即将用尽,当前剩余额度为 {{value}}",
"values": ["$0.99"],
"timestamp": 1739950503
}`}
</pre>
</Collapsible>
</Typography.Text>
</div>
</div>
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('接口凭证(可选)')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.webhookSecret}
onChange={val => handleNotificationSettingChange('webhookSecret', val)}
placeholder={t('请输入密钥')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')}
</Typography.Text>
<Typography.Text type="secondary" style={{marginTop: 4, display: 'block'}}>
{t('Authorization: Bearer your-secret-key')}
</Typography.Text>
</div>
</div>
</>
)}
{notificationSettings.warningType === 'email' && (
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('通知邮箱')}</Typography.Text>
<div style={{marginTop: 10}}>
<Input
value={notificationSettings.notificationEmail}
onChange={val => handleNotificationSettingChange('notificationEmail', val)}
placeholder={t('留空则使用账号绑定的邮箱')}
/>
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
{t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')}
</Typography.Text>
</div>
</div>
)}
<div style={{marginTop: 20}}>
<Typography.Text strong>{t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)}</Typography.Text>
<div style={{marginTop: 10}}>
<AutoComplete
value={notificationSettings.warningThreshold}
onChange={val => handleNotificationSettingChange('warningThreshold', val)}
style={{width: 200}}
placeholder={t('请输入预警额度')}
data={[
{ value: 100000, label: '0.2$' },
{ value: 500000, label: '1$' },
{ value: 1000000, label: '5$' },
{ value: 5000000, label: '10$' }
]}
/>
</div>
<Typography.Text type="secondary" style={{marginTop: 10, display: 'block'}}>
{t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')}
</Typography.Text>
</div>
<div style={{marginTop: 20}}>
<Button type="primary" onClick={saveNotificationSettings}>
{t('保存设置')}
</Button>
</div>
</Card>
<Modal
onCancel={() => setShowEmailBindModal(false)}
// onOpen={() => setShowEmailBindModal(true)}
onOk={bindEmail}
visible={showEmailBindModal}
size={'small'}
+80
View File
@@ -0,0 +1,80 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js';
const RateLimitSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key.endsWith('Enabled')
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* AI请求速率限制 */}
<Card style={{ marginTop: '10px' }}>
<RequestRateLimit options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default RateLimitSetting;
+6 -6
View File
@@ -80,7 +80,7 @@ const SiderBar = () => {
itemKey: 'channel',
to: '/channel',
icon: <IconLayers />,
className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('聊天'),
@@ -101,7 +101,7 @@ const SiderBar = () => {
icon: <IconCalendarClock />,
className:
localStorage.getItem('enable_data_export') === 'true'
? 'semi-navigation-item-normal'
? ''
: 'tableHiddle',
},
{
@@ -109,7 +109,7 @@ const SiderBar = () => {
itemKey: 'redemption',
to: '/redemption',
icon: <IconGift />,
className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('钱包'),
@@ -122,7 +122,7 @@ const SiderBar = () => {
itemKey: 'user',
to: '/user',
icon: <IconUser />,
className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('日志'),
@@ -137,7 +137,7 @@ const SiderBar = () => {
icon: <IconImage />,
className:
localStorage.getItem('enable_drawing') === 'true'
? 'semi-navigation-item-normal'
? ''
: 'tableHiddle',
},
{
@@ -147,7 +147,7 @@ const SiderBar = () => {
icon: <IconChecklistStroked />,
className:
localStorage.getItem('enable_task') === 'true'
? 'semi-navigation-item-normal'
? ''
: 'tableHiddle',
},
{
+11
View File
@@ -368,6 +368,17 @@ const SystemSetting = () => {
</a>
</Header>
<Message info>
注意代理功能仅对图片请求和 Webhook 请求生效不会影响其他 API 请求如需配置 API 请求代理请参考
<a
href='https://github.com/Calcium-Ion/new-api/blob/main/docs/channel/other_setting.md'
target='_blank'
rel='noreferrer'
>
{' '}API 代理设置文档
</a>
</Message>
<Form.Group widths='equal'>
<Form.Input
label='Worker地址,不填写则不启用代理'
+69 -2
View File
@@ -1,6 +1,6 @@
import i18next from 'i18next';
import { Modal, Tag, Typography } from '@douyinfe/semi-ui';
import { copy, showSuccess } from './utils.js';
import { copy, isMobile, showSuccess } from './utils.js';
export function renderText(text, limit) {
if (text.length > limit) {
@@ -67,6 +67,73 @@ export function renderRatio(ratio) {
return <Tag color={color}>{ratio}x {i18next.t('倍率')}</Tag>;
}
const measureTextWidth = (text, style = {
fontSize: '14px',
fontFamily: '-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif'
}, containerWidth) => {
const span = document.createElement('span');
span.style.visibility = 'hidden';
span.style.position = 'absolute';
span.style.whiteSpace = 'nowrap';
span.style.fontSize = style.fontSize;
span.style.fontFamily = style.fontFamily;
span.textContent = text;
document.body.appendChild(span);
const width = span.offsetWidth;
document.body.removeChild(span);
return width;
};
export function truncateText(text, maxWidth = 200) {
if (!isMobile()) {
return text;
}
if (!text) return text;
try {
// Handle percentage-based maxWidth
let actualMaxWidth = maxWidth;
if (typeof maxWidth === 'string' && maxWidth.endsWith('%')) {
const percentage = parseFloat(maxWidth) / 100;
// Use window width as fallback container width
actualMaxWidth = window.innerWidth * percentage;
}
const width = measureTextWidth(text);
if (width <= actualMaxWidth) return text;
let left = 0;
let right = text.length;
let result = text;
while (left <= right) {
const mid = Math.floor((left + right) / 2);
const truncated = text.slice(0, mid) + '...';
const currentWidth = measureTextWidth(truncated);
if (currentWidth <= actualMaxWidth) {
result = truncated;
left = mid + 1;
} else {
right = mid - 1;
}
}
return result;
} catch (error) {
console.warn('Text measurement failed, falling back to character count', error);
if (text.length > 20) {
return text.slice(0, 17) + '...';
}
return text;
}
}
export const renderGroupOption = (item) => {
const {
disabled,
@@ -386,7 +453,7 @@ export function renderQuotaWithPrompt(quota, digits) {
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
}
return '';
}
+35 -3
View File
@@ -856,7 +856,7 @@
"IP黑名单": "IP blacklist",
"不允许的IP,一行一个": "IPs not allowed, one per line",
"请选择该渠道所支持的模型": "Please select the model supported by this channel",
"次": "Second-rate",
"次": "times",
"达到限速报错内容": "Error content when the speed limit is reached",
"不填则使用默认报错": "If not filled in, the default error will be reported.",
"Midjouney 设置 (可选)": "Midjouney settings (optional)",
@@ -1249,5 +1249,37 @@
"已注销": "Logged out",
"自动禁用关键词": "Automatic disable keywords",
"一行一个,不区分大小写": "One line per keyword, not case-sensitive",
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
}
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel",
"请求并计费模型": "Request and charge model",
"实际模型": "Actual model",
"渠道信息": "Channel information",
"通知设置": "Notification settings",
"Webhook地址": "Webhook URL",
"请输入Webhook地址,例如: https://example.com/webhook": "Please enter the Webhook URL, e.g.: https://example.com/webhook",
"邮件通知": "Email notification",
"Webhook通知": "Webhook notification",
"接口凭证(可选)": "Interface credentials (optional)",
"密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "The secret will be added to the request header as a Bearer token to verify the legitimacy of the webhook request",
"Authorization: Bearer your-secret-key": "Authorization: Bearer your-secret-key",
"额度预警阈值": "Quota warning threshold",
"当剩余额度低于此数值时,系统将通过选择的方式发送通知": "When the remaining quota is lower than this value, the system will send a notification through the selected method",
"Webhook请求结构": "Webhook request structure",
"只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "Only https is supported, the system will send a notification through POST, please ensure the address can receive POST requests",
"保存设置": "Save settings",
"通知邮箱": "Notification email",
"设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "Set the email address for receiving quota warning notifications, if not set, the email address bound to the account will be used",
"留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used",
"代理站地址": "Base URL",
"对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in",
"渠道额外设置": "Channel extra settings",
"模型请求速率限制": "Model request rate limit",
"启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)",
"限制周期": "Limit period",
"用户每周期最多请求次数": "User max request times per period",
"用户每周期最多请求完成次数": "User max successful request times per period",
"包括失败请求的次数,0代表不限制": "Including failed request times, 0 means no limit",
"频率限制的周期(分钟)": "Rate limit period (minutes)",
"只包括请求成功的次数": "Only include successful request times",
"保存模型速率限制": "Save model rate limit settings",
"速率限制设置": "Rate limit settings"
}
+14 -12
View File
@@ -540,21 +540,23 @@ const EditChannel = (props) => {
value={inputs.name}
autoComplete="new-password"
/>
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>{t('BaseURL')}</Typography.Text>
<Typography.Text strong>{t('代理站地址')}</Typography.Text>
</div>
<Input
label={t('BaseURL')}
name="base_url"
placeholder={t('此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/')}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete="new-password"
/>
<Tooltip content={t('对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写')}>
<Input
label={t('代理站地址')}
name="base_url"
placeholder={t('此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/')}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete="new-password"
/>
</Tooltip>
</>
)}
<div style={{ marginTop: 10 }}>
+4 -3
View File
@@ -7,7 +7,7 @@ import { SSE } from 'sse';
import { IconSetting } from '@douyinfe/semi-icons';
import { StyleContext } from '../../context/Style/index.js';
import { useTranslation } from 'react-i18next';
import { renderGroupOption } from '../../helpers/render.js';
import { renderGroupOption, truncateText } from '../../helpers/render.js';
const roleInfo = {
user: {
@@ -99,9 +99,10 @@ const Playground = () => {
const { success, message, data } = res.data;
if (success) {
let localGroupOptions = Object.entries(data).map(([group, info]) => ({
label: info.desc,
label: truncateText(info.desc, "50%"),
value: group,
ratio: info.ratio
ratio: info.ratio,
fullLabel: info.desc // 保存完整文本用于tooltip
}));
if (localGroupOptions.length === 0) {
@@ -0,0 +1,112 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning, verifyJSON
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
const GEMINI_SETTING_EXAMPLE = {
'default': 'OFF',
'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE',
};
export default function SettingGeminiModel(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
GeminiSafetySettings: '',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('Gemini设置')}>
<Row>
<Col span={16}>
<Form.TextArea
label={t('Gemini安全设置')}
placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_SETTING_EXAMPLE, null, 2)}
field={'GeminiSafetySettings'}
extraText={t('default为默认设置,可单独设置每个分类的安全等级')}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: t('不是合法的 JSON 字符串')
}
]}
onChange={(value) => setInputs({ ...inputs, GeminiSafetySettings: value })}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}
@@ -0,0 +1,159 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning,
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
export default function RequestRateLimit(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('模型请求速率限制')}>
<Row gutter={16}>
<Col span={8}>
<Form.Switch
field={'ModelRequestRateLimitEnabled'}
label={t('启用用户模型请求速率限制(可能会影响高并发性能)')}
size='default'
checkedText=''
uncheckedText=''
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitEnabled: value,
});
}}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('限制周期')}
step={1}
min={0}
suffix={t('分钟')}
extraText={t('频率限制的周期(分钟)')}
field={'ModelRequestRateLimitDurationMinutes'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitDurationMinutes: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求次数')}
step={1}
min={0}
suffix={t('次')}
extraText={t('包括失败请求的次数,0代表不限制')}
field={'ModelRequestRateLimitCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitCount: String(value),
})
}
/>
</Col>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求完成次数')}
step={1}
min={1}
suffix={t('次')}
extraText={t('只包括请求成功的次数')}
field={'ModelRequestRateLimitSuccessCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitSuccessCount: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}
+12
View File
@@ -8,6 +8,8 @@ import { isRoot } from '../../helpers';
import OtherSetting from '../../components/OtherSetting';
import PersonalSetting from '../../components/PersonalSetting';
import OperationSetting from '../../components/OperationSetting';
import RateLimitSetting from '../../components/RateLimitSetting.js';
import ModelSetting from '../../components/ModelSetting.js';
const Setting = () => {
const { t } = useTranslation();
@@ -28,6 +30,16 @@ const Setting = () => {
content: <OperationSetting />,
itemKey: 'operation',
});
panes.push({
tab: t('速率限制设置'),
content: <RateLimitSetting />,
itemKey: 'ratelimit',
});
panes.push({
tab: t('模型相关设置'),
content: <ModelSetting />,
itemKey: 'models',
});
panes.push({
tab: t('系统设置'),
content: <SystemSetting />,