Compare commits

...

163 Commits

Author SHA1 Message Date
1808837298@qq.com 8101cd3ce3 feat: Add reasoning content support in OpenAI response handling 2025-02-21 18:52:51 +08:00
1808837298@qq.com 4a49d6c795 refactor: Improve message content parsing with robust type handling 2025-02-21 18:27:43 +08:00
1808837298@qq.com 4194f4bd21 refactor: Improve message content handling and quota error responses 2025-02-21 18:18:21 +08:00
1808837298@qq.com e1784f8981 refactor: Optimize sensitive word detection and text processing 2025-02-21 17:05:35 +08:00
1808837298@qq.com 78f9a30c39 feat: Enhance sensitive word detection with detailed logging 2025-02-21 16:57:30 +08:00
1808837298@qq.com 009333da8b refactor: Improve quota error messages with formatted quota display 2025-02-21 16:42:48 +08:00
1808837298@qq.com 23bfc06fd8 feat: Add base URL input with localized tooltip for channel configuration 2025-02-21 16:17:59 +08:00
1808837298@qq.com f64540cd1c feat: Add localization for notification and webhook settings 2025-02-21 15:36:24 +08:00
Calcium-Ion 5020091714 Merge pull request #775 from Calcium-Ion/model_mappping
refactor: Simplify model mapping and pricing logic across relay modules
2025-02-20 16:42:23 +08:00
1808837298@qq.com 1e0b414fc0 refactor: Simplify model mapping and pricing logic across relay modules 2025-02-20 16:41:46 +08:00
1808837298@qq.com 8dbac87e92 fix: Correct Ollama channel authentication header setting 2025-02-20 01:28:15 +08:00
Calcium-Ion db81b21f06 Merge pull request #773 from wellcoming/patch-1
fix: Fix Ollama channel authentication
2025-02-20 01:26:12 +08:00
Coming 368b5fbaa1 fix: Fix Ollama channel authentication 2025-02-20 00:52:30 +08:00
CalciumIon 1f6cb4bd5d feat: Improve mobile text truncation and sidebar visibility 2025-02-19 23:25:42 +08:00
1808837298@qq.com 5ef44f6898 feat: Improve image handling for Ollama channels 2025-02-19 20:45:42 +08:00
1808837298@qq.com 604a5b1ad5 feat: Enhance Ollama channel support with additional request parameters #771 2025-02-19 19:58:34 +08:00
1808837298@qq.com cc61f85bdd fix: Remove redundant error handling in distributor and relay modules 2025-02-19 18:47:28 +08:00
1808837298@qq.com 4ea2556b29 refactor: Replace manual goroutine creation with gopool.Go 2025-02-19 18:38:29 +08:00
Calcium-Ion ff66890dd2 Merge pull request #770 from Calcium-Ion/refactor_notify
feat: Add user notification settings and multiple notification methods
2025-02-19 14:54:54 +07:00
1808837298@qq.com 8a34d61654 chore: update env name and README 2025-02-19 15:54:33 +08:00
1808837298@qq.com 67f02e0a6a docs: Add proxy usage information note in SystemSetting component 2025-02-19 15:45:09 +08:00
1808837298@qq.com 097d02eb8b feat: Implement comprehensive webhook notification system 2025-02-19 15:40:54 +08:00
1808837298@qq.com 1ae0a38485 refactor: Optimize user caching and token retrieval methods 2025-02-19 15:12:26 +08:00
Calcium-Ion 45dd96e717 Merge pull request #768 from lgphone/main
bugfix: 配置文件 .env.example 示例配置错误
2025-02-18 19:35:08 +07:00
lgphone 10311f60e1 Update .env.example
修复示例配置中MySQL的DSN错误问题
2025-02-18 19:18:54 +08:00
Calcium-Ion 869fe957ae Merge pull request #763 from Sh1n3zZ/support-imagen-3.0-generate-002
feat: add Gemini Imagen image generation support
2025-02-18 15:32:32 +07:00
1808837298@qq.com af9d03140c fix: Extend temperature handling for OpenAI-like models
- Add support for suppressing temperature for o1 models
- Expand model prefix check to include 'o1' alongside 'o3' models
2025-02-18 16:00:56 +08:00
1808837298@qq.com 83e161a1d4 refactor: Simplify root user notification and remove global email variable
- Remove global `RootUserEmail` variable
- Modify channel testing and user notification methods to use `GetRootUser()`
- Update user cache and notification service to use more consistent user base type
- Add new channel test notification type
2025-02-18 15:59:17 +08:00
1808837298@qq.com 9a41c04416 feat: Implement notification rate limiting mechanism
- Add in-memory and Redis-based notification rate limiting
- Create configurable hourly notification limits
- Implement notification limit checking for user notifications
- Add environment variables for customizing notification limits
2025-02-18 15:30:43 +08:00
1808837298@qq.com e0ce8bf2b3 refactor: Improve CompletionRatio handling with thread-safe access and initialization 2025-02-18 15:01:43 +08:00
1808837298@qq.com 0fcd243f56 feat: Add user notification settings with quota warning and multiple notification methods
- Implement user notification settings with email and webhook options
- Add new user settings for quota warning threshold and notification preferences
- Create backend API and database support for user notification configuration
- Enhance frontend personal settings with notification configuration UI
- Support custom notification email and webhook URL
- Add service layer for sending user notifications
2025-02-18 14:54:21 +08:00
Sh1n3zZ 873a79f28f feat: add Gemini Imagen image generation support 2025-02-18 01:41:58 +08:00
1808837298@qq.com a353ea3eee Merge remote-tracking branch 'origin/main' 2025-02-17 18:15:13 +08:00
1808837298@qq.com 1cfe06c3d8 feat: Add support for DeepSeek completions endpoint 2025-02-17 18:15:01 +08:00
Calcium-Ion e661578515 Merge pull request #735 from jyc001/main
feat:Add Supoorts to FIM
2025-02-17 14:37:06 +07:00
1808837298@qq.com 65faa158b6 refactor: Optimize channel testing and model menu generation (fix #761) 2025-02-15 19:12:28 +08:00
1808837298@qq.com d43c5d6003 refactor: Improve channel property update mechanism (fix #761) 2025-02-15 15:30:55 +08:00
Calcium-Ion 27a1e49366 Merge pull request #759 from nightcoffee/patch-1
feat: add 火山引擎 support stream options
2025-02-15 14:22:04 +07:00
nightcoffee 3688d8f968 feat: add 火山引擎 support stream options 2025-02-15 04:55:57 +08:00
1808837298@qq.com 2869dbf60a feat: Enhance VolcEngine channel support with bot model routing (fix #757) 2025-02-15 00:10:58 +08:00
1808837298@qq.com 9e2433e2a6 fix: Improve OpenAI stream data parsing and handling 2025-02-14 23:52:25 +08:00
1808837298@qq.com 025b6d1226 feat: Add automatic channel disabling based on configurable keywords
- Introduce AutomaticDisableKeywords setting to dynamically control channel disabling
- Implement AC search for matching error messages against disable keywords
- Add frontend UI for configuring automatic disable keywords
- Update localization with new keyword-based channel disabling feature
- Refactor sensitive word and AC search logic to support multiple keyword lists
2025-02-13 16:39:17 +08:00
1808837298@qq.com 5665b1a58b refactor: Optimize log retrieval with separate channel name fetching (fix #751)
- Remove inline channel join in log queries
- Implement separate channel name lookup for logs
- Improve performance by fetching channel names in a single query
- Ensure channel names are correctly associated with logs
2025-02-12 19:19:13 +08:00
1808837298@qq.com ed888fb717 feat: Add invite link banner for specific channel type 2025-02-12 17:48:48 +08:00
1808837298@qq.com e84d602e17 refactor: Optimize Dockerfile for Go build process
- Use alpine-based Golang image for smaller build size
- Simplify Go build command by removing static linking flag
- Improve Docker multi-stage build configuration
2025-02-12 17:18:23 +08:00
1808837298@qq.com cc5e683784 docs: Update README with detailed Docker deployment and update instructions 2025-02-12 16:54:53 +08:00
1808837298@qq.com 6fe7b15cd0 fix: Update BaseURL placeholder text and label in channel edit page 2025-02-12 15:39:18 +08:00
1808837298@qq.com 61ad1adbda feat: Improve embedding request handling and support across channels
- Update EmbeddingRequest DTO to support more flexible input types
- Add input parsing method to handle various input formats
- Implement ConvertEmbeddingRequest for multiple channel adaptors
- Remove relayMode parameter from EmbeddingHelper
- Add input validation for embedding requests
- Simplify embedding request conversion for different channels
2025-02-12 14:39:36 +08:00
1808837298@qq.com babbbfb346 feat: Add Baidu Qianfan V2 channel support #725
- Update channel constants to include Baidu V2 channel
- Create new Baidu V2 adaptor for relay
- Add Baidu V2 models and channel configuration
- Update relay adaptor to support Baidu V2 channel
- Modify web channel constants to include Baidu V2 option
2025-02-12 00:07:02 +08:00
1808837298@qq.com 4d6bce1ddd feat: Add support for VolcEngine (Doubao) channel #313 #734 2025-02-11 23:47:15 +08:00
Calcium-Ion e514764816 Merge pull request #714 from NitroRCr/main
feat:  添加 AIaW 的聊天链接
2025-02-11 22:17:49 +07:00
Calcium-Ion b21ba28167 Merge pull request #723 from kuwork/main
Support for MokaAI M3E
2025-02-11 22:16:18 +07:00
1808837298@qq.com 86d93c4694 fix: adjust max tokens configuration in test request builder
- Update max tokens default value to 10
2025-02-11 20:00:05 +08:00
1808837298@qq.com 93511cfbf9 feat: enhance OpenAI request and response DTOs
- Add `Prefix` and `ReasoningContent` fields to Message struct
- Add getter and setter methods for `Prefix`
- Make `ToolCall.ID` field optional (fix #749)
2025-02-11 19:54:54 +08:00
1808837298@qq.com 305be19bb2 chore: disable cgo 2025-02-11 18:51:27 +08:00
1808837298@qq.com cb458caefd chore: disable cgo 2025-02-11 18:51:09 +08:00
1808837298@qq.com 0be6c5ee69 chore: replace sqlite lib with prue go lib 2025-02-11 18:34:34 +08:00
1808837298@qq.com 51b13ce038 chore: update CI 2025-02-11 18:23:20 +08:00
1808837298@qq.com 0486041952 update CI 2025-02-11 17:44:54 +08:00
1808837298@qq.com 8f3752c423 feat: enhance session store security and configuration
- Add 30-day max age for session cookies
- Enable HttpOnly flag
- Set SameSite to strict mode
2025-02-11 17:06:51 +08:00
1808837298@qq.com 94c10d8def fix: update session store configuration
- Change session cookie path from "/api" to "/"
- Remove HttpOnly flag
2025-02-11 15:53:15 +08:00
1808837298@qq.com 6a9c0fd2c0 feat: configure session store options for API routes
- Set session cookie path to "/api"
- Disable secure flag for local development
- Enable HttpOnly flag for improved security
2025-02-11 15:45:24 +08:00
Calcium-Ion 14399d6176 Merge pull request #746 from zjjxwhh/main
fix: always use modelMapping in channel test
2025-02-11 12:21:06 +07:00
1808837298@qq.com d4855da092 chore: update CI 2025-02-11 13:14:38 +08:00
zjjxwhh 5bebd3760a fix: always use modelMapping in channel test 2025-02-10 22:39:56 +08:00
1808837298@qq.com 147d0f211b chore: update CI 2025-02-10 21:59:41 +08:00
1808837298@qq.com 6d021b20b7 fix: replace context-based user ID with session-based retrieval #741
- Update user and wechat controllers to use sessions for user ID
- Modify ID retrieval to use `session.Get("id")` instead of `c.GetInt("id")`
- Cast session ID to int when creating user object
2025-02-10 20:52:33 +08:00
1808837298@qq.com e65238a327 fix: CI #744 2025-02-10 20:39:04 +08:00
1808837298@qq.com 451d23a129 Merge remote-tracking branch 'origin/main' 2025-02-10 20:34:11 +08:00
1808837298@qq.com bb1358ca47 refactor: improve SSE response handling in Playground
- Simplify event listener logic for streaming responses
- Add null-safe checks for payload content
- Optimize message generation and completion flow
2025-02-10 20:24:14 +08:00
Calcium-Ion 220946cfc3 Merge pull request #736 from xy3xy3/main
更正硅基流动的SenseVoiceSmall模型名字
2025-02-09 12:23:34 +07:00
Calcium-Ion b7a947ee78 Merge pull request #742 from HynoR/chore/ds
chore: 同步deepseek价格
2025-02-09 12:23:10 +07:00
HynoR efe01aee42 chore: 同步deepseek价格 2025-02-09 12:35:37 +08:00
xy3 30cbcac15a 更正硅基流动的SenseVoiceSmall模型名字 2025-02-08 11:54:08 +08:00
e. 2c255b6598 Merge pull request #2 from jyc001/dev
fix: correct JSON tags for `Prompt` and `Suffix` in `GeneralOpenAIReq…
2025-02-08 00:37:37 +08:00
e. bec1b752d6 fix: correct JSON tags for Prompt and Suffix in GeneralOpenAIRequest 2025-02-08 00:36:42 +08:00
e. ca10ec1a0c Merge pull request #1 from jyc001/dev
Dev
2025-02-08 00:25:49 +08:00
e. cbca378e61 feat: add Suffix to GeneralOpenAIRequest in order to support FIM 2025-02-08 00:25:08 +08:00
e. 02bb1c6580 feat add FIM support for siliconflow 2025-02-08 00:23:35 +08:00
1808837298@qq.com 276f9991c5 fix: channels model_mapping 2025-02-06 19:51:33 +08:00
1808837298@qq.com df848c3a1b fix: update logs table total count display
- Replace `logs.length` with `logCount` in pagination information
- Ensure accurate total log count is displayed in the logs table
2025-02-06 14:56:23 +08:00
Calcium-Ion 4eefb778cf Merge pull request #727 from HynoR/feat/autogemini
chore: 同步gemini模型
2025-02-06 13:43:13 +07:00
1808837298@qq.com c8a8526ff3 feat: modify channel model_mapping column type to TEXT
- Change `ModelMapping` column type from varchar(1024) to TEXT in channels table
- Add MySQL migration script to alter column type during database initialization
- Improve database schema flexibility for storing complex model mappings
2025-02-06 14:35:14 +08:00
HynoR f062c3596d chore: sync gemini aistudio model 2025-02-06 13:32:19 +08:00
kuwork fdebb6e6e8 Merge branch 'main' into main 2025-02-04 22:52:37 +08:00
1808837298@qq.com 906516fb90 feat: add SOCKS5 proxy authentication support
- Enhance `NewProxyHttpClient` to handle SOCKS5 proxy authentication
- Extract username and password from proxy URL for SOCKS5 proxy configuration
- Provide optional authentication for SOCKS5 proxy connections
2025-02-04 18:10:25 +08:00
1808837298@qq.com 881be9a3ec feat: add demo site configuration flag
- Introduce `DemoSiteEnabled` variable in operation settings
- Provide a configurable flag to enable/disable demo site functionality
2025-02-04 14:15:01 +08:00
1808837298@qq.com 80f60109cf feat: add Azure default API version configuration
- Introduce `AZURE_DEFAULT_API_VERSION` environment variable
- Set default Azure API version to `2024-12-01-preview`
- Update README documentation for new environment configuration
- Modify Azure channel relay to use default API version when not specified
2025-02-03 22:38:23 +08:00
1808837298@qq.com c2702a7125 feat: enhance model name handling and logging
- Add `RecodeModelName` to `RelayInfo` struct for more flexible model name tracking
- Update text relay and quota consumption to use `RecodeModelName`
- Move reasoning effort from admin info to other info in log generation
- Ensure consistent model name handling across relay components
2025-02-03 15:06:46 +08:00
1808837298@qq.com e641fb346e feat: add reasoning effort logging and display
- Add `ReasoningEffort` field to `RelayInfo` struct
- Update log generation to include reasoning effort in admin info
- Modify logs table component to display reasoning effort when available
- Preserve reasoning effort information during request processing
2025-02-03 14:44:40 +08:00
1808837298@qq.com 587ed2afae fix: improve reasoning effort model suffix handling
- Remove model name suffixes after extracting reasoning effort
- Update upstream model name to reflect the base model
- Ensure clean model name is passed to the upstream service
2025-02-03 14:34:00 +08:00
1808837298@qq.com b010700391 fix: update reasoning effort model suffix parsing
- Modify model suffix parsing to use hyphen-separated suffixes
- Ensure consistent parsing of `-high`, `-medium`, and `-low` reasoning effort indicators
2025-02-03 14:23:26 +08:00
1808837298@qq.com 0b2585ed52 feat: add reasoning effort configuration for models
- Support setting reasoning effort via model name suffix
- Add `-high`, `-medium`, and `-low` suffixes to control reasoning effort
- Update README with new model configuration option
- Modify OpenAI adaptor to handle reasoning effort settings
2025-02-03 14:22:34 +08:00
1808837298@qq.com 14f6ba91eb feat: add other_setting docs link 2025-02-02 22:18:37 +08:00
1808837298@qq.com 5c8af189eb feat: support channel request proxy 2025-02-02 22:15:06 +08:00
1808837298@qq.com c81e23b0e4 f*** o3-mini 2025-02-01 14:11:34 +08:00
1808837298@qq.com 709dd8fc40 Merge remote-tracking branch 'origin/main' 2025-02-01 13:41:38 +08:00
1808837298@qq.com 4131304c1f feat: add support for o3-mini models in model ratio and request handling 2025-02-01 13:41:25 +08:00
Calcium-Ion 4d862f10af Merge pull request #694 from yinuan-i/main
feat: 新增渠道管理与模型列表获取
2025-01-27 12:34:57 +07:00
1808837298@qq.com 4470ef97f2 fix: clear channel name in user logs 2025-01-27 13:31:24 +08:00
1808837298@qq.com b6bae9dc3a feat: enhance model ratio lookup with case-insensitive and direct matching 2025-01-26 16:07:41 +08:00
1808837298@qq.com 2993ce37a9 fix: update DeepSeek reasoner model ratio check 2025-01-25 23:09:14 +08:00
Calcium-Ion 56c39ef135 Merge pull request #715 from seefs001/main 2025-01-25 13:00:23 +07:00
Seefs ae94a74e87 Merge remote-tracking branch 'origin/main' 2025-01-25 12:59:28 +07:00
Seefs 68bdb7002d fix: display docker build error 2025-01-25 12:58:08 +07:00
Seefs 761e2fb7cc Merge branch 'Calcium-Ion:main' into main 2025-01-25 12:56:08 +07:00
Seefs aab25415ab fix: remove ffmpeg-tools 2025-01-25 12:55:40 +07:00
Calcium-Ion 01f4a5ff9b Merge pull request #713 from seefs001/main 2025-01-25 12:55:01 +07:00
NitroRCr 259e1ba3e2 feat: add chat link for AIaW 2025-01-25 11:57:54 +08:00
Seefs b96133eaf7 fix: log filename format 2025-01-24 21:09:54 +07:00
Jerry fcc32ffbc9 Fix M3E not working 2025-01-23 05:54:39 +08:00
1808837298@qq.com 7d402cff8c chore: update Node.js version in CI workflows from 16 to 18 2025-01-22 13:47:41 +08:00
Calcium-Ion f2e1447cc6 Merge pull request #710 from hubutui/main
Fix temperature not being set to 0 due to json omitempty
2025-01-22 12:44:48 +07:00
1808837298@qq.com a2867cceb8 chore: add ffmpeg-tools to Dockerfile for enhanced multimedia processing 2025-01-22 13:41:46 +08:00
1808837298@qq.com 1ae20bdd89 refactor: update log queries to explicitly reference 'logs' table for clarity and consistency 2025-01-22 13:37:32 +08:00
Jerry 215846bf73 fix : chanel test did not refresh 2025-01-22 13:16:06 +08:00
H. 675cb9b7f3 Merge branch 'Calcium-Ion:main' into main 2025-01-22 13:12:14 +08:00
Jerry 8e2059b898 Support for MokaAI M3E 2025-01-22 04:21:08 +08:00
1808837298@qq.com 71c75b4ea2 CI: update workflows 2025-01-21 16:24:07 +08:00
Butui Hu 6e710f3210 Fix temperature not being set to 0 due to json omitempty
The issue was caused by the `omitempty` tag in the Go struct, which prevented the `temperature` field from being included in the JSON output when it was set to 0.

Signed-off-by: Butui Hu <hot123tea123@gmail.com>
2025-01-21 12:54:09 +08:00
Calcium-Ion 7df2e3c80b Merge pull request #705 from maranello-o/main
fix: incorrect whisper audio usage
2025-01-21 11:21:04 +07:00
Calcium-Ion dd9ddbe75a Merge pull request #699 from detecti1/feat/show-log-with-channel-name
Feat: 日志查询增加渠道名称显示
2025-01-21 11:17:13 +07:00
Calcium-Ion 57d867f0b7 Merge pull request #709 from HynoR/feat/update-ratio
feat: 更新模型和模型倍率
2025-01-21 11:16:04 +07:00
HynoR 09bd33f0a5 feat: 更新模型和模型倍率 2025-01-21 00:53:10 +08:00
沈浩 dab21d5263 fix: incorrect whisper audio usage 2025-01-17 18:12:05 +08:00
H. 67da837763 Merge branch 'Calcium-Ion:main' into main 2025-01-13 13:42:30 +08:00
Lilo 37226d6589 Fix JSON parsing error when record.other is empty string 2025-01-09 17:07:28 +08:00
Lilo a335d9e1f4 Add channel name (tooltip / detail) to logs 2025-01-09 17:07:28 +08:00
1808837298@qq.com b604cab599 Update IP restriction messages for clarity in English localization and placeholder text in EditToken component. Enhanced user guidance by specifying that leaving the IP field blank means no restrictions. 2025-01-08 16:52:31 +08:00
mango 37aa5ec3b4 fix(batch add model list): fix the issue of fetching model list failure in batch add channel 2025-01-07 12:42:37 +08:00
mango 6fddeb6d2f feat(channel model list): modify fetching model list in add channel to fetch by type 2025-01-07 12:40:36 +08:00
mango d5e5b856dc feat(channel balance): add channel balance for siliconflow and deepseek 2025-01-07 12:15:55 +08:00
Calcium-Ion 6f47c4d47f Merge pull request #693 from Calcium-Ion/refactor-auth
refactor: access_token auth
2025-01-06 17:55:20 +08:00
1808837298@qq.com a3fe6c03b8 Adjust streaming timeout for OpenAI models in OaiStreamHandler
- Implemented conditional logic to double the streaming timeout for models starting with "o1" or "o3".
- Improved handling of streaming timeout configuration to enhance performance based on model type.
2025-01-06 17:52:33 +08:00
1808837298@qq.com e10531670e Update Dockerfile to use Bun for package management and build process
- Changed base image from Node.js to Bun for improved performance.
- Replaced npm install with bun install for dependency management.
- Updated build command to use Bun for building the application.
- Added new bun.lockb file to track Bun dependencies.
2025-01-06 16:37:21 +08:00
1808837298@qq.com aa2ac4766e Enhance user search functionality to support ID and keyword searches. Updated query conditions to allow searching by user ID alongside username, email, and display name. Improved handling of numeric and string keywords in search queries. 2025-01-06 15:20:38 +08:00
Calcium-Ion f599aede05 Merge pull request #692 from Calcium-Ion/fix-channel-model-length
Fix channel model length issue
2025-01-05 22:13:04 +08:00
1808837298@qq.com f11148bccf revert cache.go 2025-01-05 22:12:39 +08:00
1808837298@qq.com 530e846ac1 refactor: access_token auth 2025-01-05 22:08:23 +08:00
Calcium-Ion 7baa204e9c Fix model name length validation limit 2025-01-05 22:02:46 +08:00
Calcium-Ion 3dff0d498f 2025-01-05 22:01:36 +08:00
Calcium-Ion e275229597 Fix channel model length issue
Fixes #691

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/Calcium-Ion/new-api/issues/691?shareId=XXXX-XXXX-XXXX-XXXX).
2025-01-05 21:55:25 +08:00
1808837298@qq.com cb5a13920f fix: update linux-release workflow to install gcc-aarch64-linux-gnu non-interactively 2025-01-05 17:59:29 +08:00
1808837298@qq.com 4aaa99a3d4 chore: change workflow runners to self-hosted for Docker and release jobs 2025-01-05 17:17:57 +08:00
1808837298@qq.com aea201ca77 fix: update iframe styling and permissions in ChatPage component 2025-01-04 22:24:47 +08:00
1808837298@qq.com e954913f19 refactor: realtime log render 2025-01-04 17:54:02 +08:00
1808837298@qq.com 44a097243e refactor: realtime i18n 2025-01-04 17:46:06 +08:00
1808837298@qq.com 1798595afc refactor: realtime quota 2025-01-04 15:46:35 +08:00
Calcium-Ion c073debcf5 Merge pull request #689 from iszcz/new512 2025-01-03 20:47:56 +08:00
iszcz 4fc2e49518 Update model-ratio.go 2025-01-03 20:42:46 +08:00
1808837298@qq.com d953b98417 feat: support gpt-4o-mini-realtime-preview 2025-01-03 18:51:09 +08:00
1808837298@qq.com 16afc5b3d6 fix: retry prompt tokens 2025-01-02 16:33:00 +08:00
Calcium-Ion a835920f83 Merge pull request #686 from delph1s/main
fix: try to fix pgsql #685
2025-01-02 00:17:02 +08:00
delph1s 7244adada5 fix: try to fix pgsql #685 2025-01-02 00:14:16 +08:00
Calcium-Ion eec3aa3bbd Update README.md 2024-12-31 22:19:37 +08:00
Calcium-Ion b664a2208b Merge pull request #683 from iszcz/new512
Update channel-test.go
2024-12-31 20:22:57 +08:00
CalciumIon d15074c860 fix: error page size opts 2024-12-31 15:51:15 +08:00
CalciumIon 97d0e6d2cd feat: implement pagination and total count for redemptions API #386
- Updated GetAllRedemptions and SearchRedemptions functions to return total count along with paginated results.
- Modified API endpoints to accept page size as a parameter, enhancing flexibility in data retrieval.
- Adjusted RedemptionsTable component to support pagination and display total count, improving user experience.
- Ensured consistent handling of pagination across related components, including LogsTable and UsersTable.
2024-12-31 15:28:25 +08:00
CalciumIon 7fec5fa1b3 feat: enhance user search functionality with pagination support
- Updated SearchUsers function to include pagination parameters (startIdx and num) for improved user search results.
- Modified API response structure to return paginated data, including total user count and current page information.
- Adjusted UsersTable component to handle pagination and search parameters, ensuring a seamless user experience.
- Added internationalization support for new search functionality in the UI.
2024-12-31 15:02:59 +08:00
CalciumIon 5a39d2e171 feat: enhance user management and pagination features #518
- Updated GetAllUsers function to return total user count along with paginated results, improving data handling in user retrieval.
- Modified GetAllUsers API endpoint to accept page size as a parameter, allowing for dynamic pagination.
- Enhanced UsersTable component to support customizable page sizes and improved pagination logic.
- Added error handling for empty username and password in AddUser component.
- Updated LogsTable component to display pagination information in a user-friendly format.
2024-12-31 14:52:55 +08:00
iszcz 60a3d1b066 Update channel-test.go 2024-12-31 12:49:13 +08:00
CalciumIon 0a845ae69f fix: try to fix pgsql #682 2024-12-31 02:10:19 +08:00
CalciumIon e0b7300239 fix: try to fix pgsql #682 2024-12-31 02:06:30 +08:00
149 changed files with 6609 additions and 5723 deletions
+2 -1
View File
@@ -3,4 +3,5 @@
*.md
.vscode
.gitignore
Makefile
Makefile
docs
+2 -2
View File
@@ -10,9 +10,9 @@
# 数据库相关配置
# 数据库连接字符串
# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# 日志数据库连接字符串
# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# SQLite数据库路径
# SQLITE_PATH=/path/to/sqlite.db
# 数据库最大空闲连接数
+2 -2
View File
@@ -17,7 +17,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-node@v3
with:
node-version: 16
node-version: 18
- name: Build Frontend
env:
CI: ""
@@ -38,7 +38,7 @@ jobs:
- name: Build Backend (arm64)
run: |
sudo apt-get update
sudo apt-get install gcc-aarch64-linux-gnu
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64
- name: Release
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-node@v3
with:
node-version: 16
node-version: 18
- name: Build Frontend
env:
CI: ""
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-node@v3
with:
node-version: 16
node-version: 18
- name: Build Frontend
env:
CI: ""
+10 -8
View File
@@ -1,31 +1,33 @@
FROM node:16 as builder
FROM oven/bun:latest AS builder
WORKDIR /build
COPY web/package.json .
RUN npm install
RUN bun install
COPY ./web .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
FROM golang AS builder2
FROM golang:alpine AS builder2
ENV GO111MODULE=on \
CGO_ENABLED=1 \
CGO_ENABLED=0 \
GOOS=linux
WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/dist ./web/dist
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
FROM alpine
RUN apk update \
&& apk upgrade \
&& apk add --no-cache ca-certificates tzdata \
&& update-ca-certificates 2>/dev/null || true
&& apk add --no-cache ca-certificates tzdata ffmpeg \
&& update-ca-certificates
COPY --from=builder2 /build/one-api /
EXPOSE 3000
+41 -5
View File
@@ -59,6 +59,10 @@
13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md)
14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md)
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels
16. 🧠 Support for setting reasoning effort through model name suffix:
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
## Model Support
This version additionally supports:
@@ -84,15 +88,15 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
## Deployment
> [!TIP]
> Latest Docker image: `calciumion/new-api:latest`
> Default account: root, password: 123456
> Update command:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
> Default account: root, password: 123456
### Multi-Server Deployment
- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
@@ -102,26 +106,58 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- Local database (default): SQLite (Docker deployment must mount `/data` directory)
- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6
### Deployment with BT Panel
Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install.
After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation.
After installation, find **New-API** in the app store, click install, configure basic options to complete installation.
[Pictorial Guide](BT.md)
### Docker Deployment
### Using Docker Compose (Recommended)
```shell
# Clone project
git clone https://github.com/Calcium-Ion/new-api.git
cd new-api
# Edit docker-compose.yml as needed
# nano docker-compose.yml
# vim docker-compose.yml
# Start
docker-compose up -d
```
#### Update Version
```shell
docker-compose pull
docker-compose up -d
```
### Direct Docker Image Usage
```shell
# SQLite deployment:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed
# Example:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
#### Update Version
```shell
# Pull the latest image
docker pull calciumion/new-api:latest
# Stop and remove the old container
docker stop new-api
docker rm new-api
# Run the new container with the same parameters as before
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## Channel Retry
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
First retry uses same priority, second retry uses next priority, and so on.
+46 -13
View File
@@ -65,6 +65,10 @@
14. 🔄 支持Rerank模型,目前兼容Cohere和Jina,可接入Dify[对接文档](Rerank.md)
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API,支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
## 模型支持
此版本额外支持以下模型:
@@ -85,19 +89,20 @@
- `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 部署
> [!TIP]
> 最新版Docker镜像:`calciumion/new-api:latest`
> 默认账号root 密码123456
> 更新指令:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
> 默认账号root 密码123456
### 多机部署
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
@@ -114,25 +119,54 @@
[图文教程](BT.md)
### 基于 Docker 进行部署
> [!TIP]
> 默认管理员账号root 密码123456
### 使用 Docker Compose 部署(推荐)
```shell
# 下载项目
git clone https://github.com/Calcium-Ion/new-api.git
cd new-api
# 按需编辑 docker-compose.yml
# nano docker-compose.yml
# vim docker-compose.yml
# 启动
docker-compose up -d
```
#### 更新版本
```shell
docker-compose pull
docker-compose up -d
```
### 直接使用 Docker 镜像
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
#### 更新版本
```shell
# 拉取最新镜像
docker pull calciumion/new-api:latest
# 停止并删除旧容器
docker stop new-api
docker rm new-api
# 使用相同参数运行新容器
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
@@ -159,15 +193,14 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
[对接文档](Suno.md)
## 界面截图
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a)
![image](https://github.com/user-attachments/assets/a0dcd349-5df8-4dc8-9acf-ca272b239919)
![image](https://github.com/user-attachments/assets/c7d0f7e1-729c-43e2-ac7c-2cb73b0afc8e)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/user-attachments/assets/29f81de5-33fc-4fc5-a5ff-f9b54b653c7c)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
![image](https://github.com/user-attachments/assets/4fa53e18-d2c5-477a-9b26-b86e44c71e35)
## 交流群
<img src="https://github.com/user-attachments/assets/9ca0bc82-e057-4230-a28d-9f198fa022e3" width="200">
+8 -3
View File
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
var RetryTimes = 0
var RootUserEmail = ""
//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
@@ -231,8 +231,10 @@ const (
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeDummy // this one is only for count, do not add any channel after this
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -281,4 +283,7 @@ var ChannelBaseURLs = []string{
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
}
+1
View File
@@ -3,5 +3,6 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var UsingClickHouse = false
var SQLitePath = "one-api.db?_busy_timeout=5000"
-13
View File
@@ -1,22 +1,9 @@
package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
+12 -3
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"log"
@@ -36,7 +37,7 @@ func SetupLogger() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
gopool.Go(func() {
SetupLogger()
}()
})
}
}
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
}
}
func FormatQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d", quota)
}
}
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)
+86 -57
View File
@@ -32,30 +32,42 @@ var defaultModelRatio = map[string]float64{
"gpt-4-0613": 15,
"gpt-4-32k": 30,
//"gpt-4-32k-0314": 30, //deprecated
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $10 / 1M tokens
"gpt-4-0125-preview": 5, // $10 / 1M tokens
"gpt-4-turbo-preview": 5, // $10 / 1M tokens
"gpt-4-vision-preview": 5, // $10 / 1M tokens
"gpt-4-1106-vision-preview": 5, // $10 / 1M tokens
"chatgpt-4o-latest": 2.5, // $5 / 1M tokens
"gpt-4o": 1.25, // $2.5 / 1M tokens
"gpt-4o-audio-preview": 1.25, // $2.5 / 1M tokens
"gpt-4o-audio-preview-2024-10-01": 1.25, // $2.5 / 1M tokens
"gpt-4o-2024-05-13": 2.5, // $5 / 1M tokens
"gpt-4o-2024-08-06": 1.25, // $2.5 / 1M tokens
"gpt-4o-2024-11-20": 1.25, // $2.5 / 1M tokens
"gpt-4o-realtime-preview": 2.5,
"o1": 7.5,
"o1-2024-12-17": 7.5,
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5,
"o1-mini-2024-09-12": 1.5,
"gpt-4o-mini": 0.075,
"gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $10 / 1M tokens
"gpt-4-0125-preview": 5, // $10 / 1M tokens
"gpt-4-turbo-preview": 5, // $10 / 1M tokens
"gpt-4-vision-preview": 5, // $10 / 1M tokens
"gpt-4-1106-vision-preview": 5, // $10 / 1M tokens
"chatgpt-4o-latest": 2.5, // $5 / 1M tokens
"gpt-4o": 1.25, // $2.5 / 1M tokens
"gpt-4o-audio-preview": 1.25, // $2.5 / 1M tokens
"gpt-4o-audio-preview-2024-10-01": 1.25, // $2.5 / 1M tokens
"gpt-4o-2024-05-13": 2.5, // $5 / 1M tokens
"gpt-4o-2024-08-06": 1.25, // $2.5 / 1M tokens
"gpt-4o-2024-11-20": 1.25, // $2.5 / 1M tokens
"gpt-4o-realtime-preview": 2.5,
"gpt-4o-realtime-preview-2024-10-01": 2.5,
"gpt-4o-realtime-preview-2024-12-17": 2.5,
"gpt-4o-mini-realtime-preview": 0.3,
"gpt-4o-mini-realtime-preview-2024-12-17": 0.3,
"o1": 7.5,
"o1-2024-12-17": 7.5,
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 0.55,
"o1-mini-2024-09-12": 0.55,
"o3-mini": 0.55,
"o3-mini-2025-01-31": 0.55,
"o3-mini-high": 0.55,
"o3-mini-2025-01-31-high": 0.55,
"o3-mini-low": 0.55,
"o3-mini-2025-01-31-low": 0.55,
"o3-mini-medium": 0.55,
"o3-mini-2025-01-31-medium": 0.55,
"gpt-4o-mini": 0.075,
"gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
//"gpt-3.5-turbo-0301": 0.75, //deprecated
"gpt-3.5-turbo": 0.25,
"gpt-3.5-turbo-0613": 0.75,
@@ -179,8 +191,9 @@ var defaultModelRatio = map[string]float64{
"command-r-plus": 1.5,
"command-r-08-2024": 0.075,
"command-r-plus-08-2024": 1.25,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
"deepseek-chat": 0.27 / 2,
"deepseek-coder": 0.27 / 2,
"deepseek-reasoner": 0.55 / 2, // 0.55 / 1k tokens
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
@@ -220,7 +233,11 @@ var (
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var (
CompletionRatio map[string]float64 = nil
CompletionRatioMutex = sync.RWMutex{}
)
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3,
@@ -321,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func CompletionRatio2JSONString() string {
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}
func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
@@ -333,11 +357,21 @@ func CompletionRatio2JSONString() string {
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap()
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
}
lowercaseName := strings.ToLower(name)
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
@@ -356,7 +390,7 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
if strings.HasPrefix(name, "o1") {
if strings.HasPrefix(name, "o1") || strings.HasPrefix(name, "o3") {
return 4
}
if name == "chatgpt-4o-latest" {
@@ -397,11 +431,12 @@ func GetCompletionRatio(name string) float64 {
case "command-r-plus-08-2024":
return 4
default:
return 2
return 4
}
}
if strings.HasPrefix(name, "deepseek") {
return 2
// hint 只给官方上4倍率,由于开源模型供应商自行定价,不对其进行补全倍率进行强制对齐
if lowercaseName == "deepseek-chat" || lowercaseName == "deepseek-reasoner" {
return 4
}
if strings.HasPrefix(name, "ERNIE-Speed-") {
return 2
@@ -427,10 +462,23 @@ func GetCompletionRatio(name string) float64 {
}
func GetAudioRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 20
} else if strings.HasPrefix(name, "gpt-4o-audio") {
return 40
if strings.Contains(name, "-realtime") {
if strings.HasSuffix(name, "gpt-4o-realtime-preview-2024-12-17") {
return 8
} else if strings.Contains(name, "mini") {
return 10 / 0.6
} else {
return 20
}
}
if strings.Contains(name, "-audio") {
if strings.HasSuffix(name, "gpt-4o-audio-preview-2024-12-17") {
return 16
} else if strings.Contains(name, "mini") {
return 10 / 0.15
} else {
return 40
}
}
return 20
}
@@ -438,27 +486,8 @@ func GetAudioRatio(name string) float64 {
func GetAudioCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 2
} else if strings.HasPrefix(name, "gpt-4o-mini-realtime") {
return 2
}
return 2
}
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}
+33
View File
@@ -1,14 +1,19 @@
package common
import (
"bytes"
"context"
crand "crypto/rand"
"encoding/base64"
"fmt"
"github.com/pkg/errors"
"html/template"
"io"
"log"
"math/big"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
@@ -207,3 +212,31 @@ func RandomSleep() {
// Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename)
if err != nil {
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
}
defer f.Close()
_, err = io.Copy(f, data)
if err != nil {
return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
}
return f.Name(), nil
}
// GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
}
+2 -1
View File
@@ -1,5 +1,6 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
)
+6 -1
View File
@@ -21,12 +21,17 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
var GeminiModelMap = map[string]string{
"gemini-1.0-pro": "v1",
}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
@@ -42,5 +47,5 @@ func InitEnv() {
}
}
// 是否生成初始令牌,默认关闭。
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
+14
View File
@@ -0,0 +1,14 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)
+88 -3
View File
@@ -78,6 +78,36 @@ type APGC2DGPTUsageResponse struct {
TotalUsed float64 `json:"total_used"`
}
type SiliconFlowUsageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Status bool `json:"status"`
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image"`
Email string `json:"email"`
IsAdmin bool `json:"isAdmin"`
Balance string `json:"balance"`
Status string `json:"status"`
Introduction string `json:"introduction"`
Role string `json:"role"`
ChargeBalance string `json:"chargeBalance"`
TotalBalance string `json:"totalBalance"`
Category string `json:"category"`
} `json:"data"`
}
type DeepSeekUsageResponse struct {
IsAvailable bool `json:"is_available"`
BalanceInfos []struct {
Currency string `json:"currency"`
TotalBalance string `json:"total_balance"`
GrantedBalance string `json:"granted_balance"`
ToppedUpBalance string `json:"topped_up_balance"`
} `json:"balance_infos"`
}
// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
@@ -185,6 +215,57 @@ func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
return response.TotalRemaining, nil
}
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
url := "https://api.siliconflow.cn/v1/user/info"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := SiliconFlowUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
}
balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
url := "https://api.deepseek.com/user/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := DeepSeekUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
index := -1
for i, balanceInfo := range response.BalanceInfos {
if balanceInfo.Currency == "CNY" {
index = i
break
}
}
if index == -1 {
return 0, errors.New("currency CNY not found")
}
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
@@ -222,6 +303,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D:
return updateChannelAIGC2DBalance(channel)
case common.ChannelTypeSiliconFlow:
return updateChannelSiliconFlowBalance(channel)
case common.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel)
default:
return 0, errors.New("尚未实现")
}
@@ -300,9 +385,9 @@ func updateAllChannelsBalance() error {
continue
}
// TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
continue
}
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
// continue
//}
balance, err := updateChannelBalance(channel)
if err != nil {
continue
+46 -26
View File
@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -24,6 +23,8 @@ import (
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
)
@@ -32,14 +33,29 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported!!!"), nil
}
if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
@@ -51,20 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
testModel = "gpt-4o-mini"
}
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
@@ -84,7 +100,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
adaptor.Init(meta)
@@ -152,12 +168,21 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later
Stream: false,
}
if strings.HasPrefix(model, "o1") {
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
}
// 并非Embedding 模型
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
testRequest.MaxCompletionTokens = 10
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
testRequest.MaxTokens = 2
} else {
testRequest.MaxTokens = 1
testRequest.MaxTokens = 10
}
content, _ := json.Marshal("hi")
testMessage := dto.Message{
@@ -213,9 +238,7 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
@@ -270,10 +293,7 @@ func testAllChannels(notify bool) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试完成")
}
})
return nil
+18 -2
View File
@@ -274,6 +274,17 @@ func AddChannel(c *gin.Context) {
}
localChannel := channel
localChannel.Key = key
// Validate the length of the model name
models := strings.Split(localChannel.Models, ",")
for _, model := range models {
if len(model) > 255 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("模型名称过长: %s", model),
})
return
}
}
channels = append(channels, localChannel)
}
err = model.BatchInsertChannels(channels)
@@ -499,6 +510,7 @@ func UpdateChannel(c *gin.Context) {
func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Type int `json:"type"`
Key string `json:"key"`
}
@@ -512,7 +524,7 @@ func FetchModels(c *gin.Context) {
baseURL := req.BaseURL
if baseURL == "" {
baseURL = "https://api.openai.com"
baseURL = common.ChannelBaseURLs[req.Type]
}
client := &http.Client{}
@@ -527,7 +539,11 @@ func FetchModels(c *gin.Context) {
return
}
request.Header.Set("Authorization", "Bearer "+req.Key)
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
// If the key contains a line break, only take the first part.
key = strings.Split(key, "\n")[0]
request.Header.Set("Authorization", "Bearer "+key)
response, err := client.Do(request)
if err != nil {
+1
View File
@@ -66,6 +66,7 @@ func GetStatus(c *gin.Context) {
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": setting.DemoSiteEnabled,
},
})
return
+1 -1
View File
@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
}
var group string
if exists {
user, err := model.GetUserById(userId.(int), false)
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
}
+28 -5
View File
@@ -1,19 +1,24 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-gonic/gin"
)
func GetAllRedemptions(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -24,14 +29,27 @@ func GetAllRedemptions(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": redemptions,
"data": gin.H{
"items": redemptions,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
func SearchRedemptions(c *gin.Context) {
keyword := c.Query("keyword")
redemptions, err := model.SearchRedemptions(keyword)
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -42,7 +60,12 @@ func SearchRedemptions(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": redemptions,
"data": gin.H{
"items": redemptions,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
+3 -1
View File
@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode)
err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
@@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
err = relay.AudioHelper(c)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c)
default:
err = relay.TextHelper(c)
}
+148 -13
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
@@ -11,9 +12,10 @@ import (
"strings"
"sync"
"one-api/constant"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"one-api/constant"
)
type LoginRequest struct {
@@ -242,10 +244,14 @@ func Register(c *gin.Context) {
func GetAllUsers(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -256,7 +262,12 @@ func GetAllUsers(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
"data": gin.H{
"items": users,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
@@ -264,7 +275,16 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
users, err := model.SearchUsers(keyword, group)
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
startIdx := (p - 1) * pageSize
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -275,7 +295,12 @@ func SearchUsers(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
"data": gin.H{
"items": users,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
@@ -447,7 +472,7 @@ func GetUserModels(c *gin.Context) {
if err != nil {
id = c.GetInt("id")
}
user, err := model.GetUserById(id, true)
user, err := model.GetUserCache(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -822,9 +847,10 @@ func EmailBind(c *gin.Context) {
})
return
}
id := c.GetInt("id")
session := sessions.Default(c)
id := session.Get("id")
user := model.User{
Id: id,
Id: id.(int),
}
err := user.FillUserById()
if err != nil {
@@ -844,9 +870,6 @@ func EmailBind(c *gin.Context) {
})
return
}
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -888,3 +911,115 @@ func TopUp(c *gin.Context) {
})
return
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold int `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
}
func UpdateUserSetting(c *gin.Context) {
var req UpdateUserSettingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
// 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
})
return
}
// 验证预警阈值
if req.QuotaWarningThreshold <= 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "预警阈值必须大于0",
})
return
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Webhook地址不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Webhook地址",
})
return
}
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
}
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 构建设置
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
}
// 更新用户设置
user.SetSetting(settings)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "更新设置失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "设置已更新",
})
}
+4 -2
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
@@ -142,9 +143,10 @@ func WeChatBind(c *gin.Context) {
})
return
}
id := c.GetInt("id")
session := sessions.Default(c)
id := session.Get("id")
user := model.User{
Id: id,
Id: id.(int),
}
err = user.FillUserById()
if err != nil {
+1 -1
View File
@@ -24,7 +24,7 @@ services:
- redis
- mysql
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
interval: 30s
timeout: 10s
retries: 3
+53
View File
@@ -0,0 +1,53 @@
# API 鉴权文档
## 认证方式
### Access Token
对于需要鉴权的 API 接口,必须同时提供以下两个请求头来进行 Access Token 认证:
1. **请求头中的 `Authorization` 字段**
将 Access Token 放置于 HTTP 请求头部的 `Authorization` 字段中,格式如下:
```
Authorization: <your_access_token>
```
其中 `<your_access_token>` 需要替换为实际的 Access Token 值。
2. **请求头中的 `New-Api-User` 字段**
将用户 ID 放置于 HTTP 请求头部的 `New-Api-User` 字段中,格式如下:
```
New-Api-User: <your_user_id>
```
其中 `<your_user_id>` 需要替换为实际的用户 ID。
**注意:**
* **必须同时提供 `Authorization` 和 `New-Api-User` 两个请求头才能通过鉴权。**
* 如果只提供其中一个请求头,或者两个请求头都未提供,则会返回 `401 Unauthorized` 错误。
* 如果 `Authorization` 中的 Access Token 无效,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,access token 无效”。
* 如果 `New-Api-User` 中的用户 ID 与 Access Token 不匹配,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,与登录用户不匹配,请重新登录”。
* 如果没有提供 `New-Api-User` 请求头,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,未提供 New-Api-User”。
* 如果 `New-Api-User` 请求头格式错误,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,New-Api-User 格式错误”。
* 如果用户已被禁用,则会返回 `403 Forbidden` 错误,并提示“用户已被封禁”。
* 如果用户权限不足,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,权限不足”。
* 如果用户信息无效,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,用户信息无效”。
## Curl 示例
假设您的 Access Token 为 `access_token`,用户 ID 为 `123`,要访问的 API 接口为 `/api/user/self`,则可以使用以下 curl 命令:
```bash
curl -X GET \
-H "Authorization: access_token" \
-H "New-Api-User: 123" \
https://your-domain.com/api/user/self
```
请将 `access_token`、`123` 和 `https://your-domain.com` 替换为实际的值。
View File
+28
View File
@@ -0,0 +1,28 @@
# 渠道而外设置说明
该配置用于设置一些额外的渠道参数,可以通过 JSON 对象进行配置。主要包含以下两个设置项:
1. force_format
- 用于标识是否对数据进行强制格式化为 OpenAI 格式
- 类型为布尔值,设置为 true 时启用强制格式化
2. proxy
- 用于配置网络代理
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
--------------------------------------------------------------
## JSON 格式示例
以下是一个示例配置,启用强制格式化并设置了代理地址:
```json
{
"force_format": true,
"proxy": "socks5://xxxxxxx"
}
```
--------------------------------------------------------------
通过调整上述 JSON 配置中的值,可以灵活控制渠道的额外行为,比如是否进行格式化以及使用特定的网络代理。
+57
View File
@@ -0,0 +1,57 @@
package dto
type EmbeddingOptions struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input any `json:"input"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
User string `json:"user,omitempty"`
Seed float64 `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
func (r EmbeddingRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type EmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type EmbeddingResponse struct {
Object string `json:"object"`
Data []EmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
+25
View File
@@ -0,0 +1,25 @@
package dto
type Notify struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values"`
}
const ContentValueParam = "{{value}}"
const (
NotifyTypeQuotaExceed = "quota_exceed"
NotifyTypeChannelUpdate = "channel_update"
NotifyTypeChannelTest = "channel_test"
)
func NewNotify(t string, title string, content string, values []interface{}) Notify {
return Notify{
Type: t,
Title: title,
Content: content,
Values: values,
}
}
+95 -40
View File
@@ -18,12 +18,14 @@ type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
@@ -86,11 +88,15 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent
parsedStringContent *string
}
type MediaContent struct {
@@ -116,6 +122,17 @@ const (
ContentTypeInputAudio = "input_audio"
)
func (m *Message) GetPrefix() bool {
if m.Prefix == nil {
return false
}
return *m.Prefix
}
func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix
}
func (m *Message) ParseToolCalls() []ToolCall {
if m.ToolCalls == nil {
return nil
@@ -133,6 +150,9 @@ func (m *Message) SetToolCalls(toolCalls any) {
}
func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
@@ -143,78 +163,113 @@ func (m *Message) StringContent() string {
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
m.parsedStringContent = &content
m.parsedContent = nil
}
func (m *Message) SetMediaContent(content []MediaContent) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
m.parsedContent = nil
m.parsedStringContent = nil
}
func (m *Message) IsStringContent() bool {
if m.parsedStringContent != nil {
return true
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return true
}
return false
}
func (m *Message) ParseContent() []MediaContent {
if m.parsedContent != nil {
return m.parsedContent
}
var contentList []MediaContent
// 先尝试解析为字符串
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaContent{
contentList = []MediaContent{{
Type: ContentTypeText,
Text: stringContent,
})
}}
m.parsedContent = contentList
return contentList
}
var arrayContent []json.RawMessage
// 尝试解析为数组
var arrayContent []map[string]interface{}
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
contentType, ok := contentItem["type"].(string)
if !ok {
continue
}
switch contentMap["type"] {
switch contentType {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: subStr,
Text: text,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "high"
}
imageUrl := contentItem["image_url"]
switch v := imageUrl.(type) {
case string:
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Url: v,
Detail: "high",
},
})
case map[string]interface{}:
url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string)
if !ok2 {
detail = "high"
}
if ok1 {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: detail,
},
})
}
}
case ContentTypeInputAudio:
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: subObj["data"].(string),
Format: subObj["format"].(string),
},
})
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: data,
Format: format,
},
})
}
}
}
}
return contentList
}
return nil
if len(contentList) > 0 {
m.parsedContent = contentList
}
return contentList
}
+12 -4
View File
@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -78,10 +79,17 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
return *c.Content
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
if c.ReasoningContent == nil {
return ""
}
return *c.ReasoningContent
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
}
+11 -5
View File
@@ -16,6 +16,7 @@ require (
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/glebarez/sqlite v1.9.0
github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
@@ -29,10 +30,10 @@ require (
github.com/shirou/gopsutil v3.21.11+incompatible
golang.org/x/crypto v0.27.0
golang.org/x/image v0.23.0
golang.org/x/net v0.28.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
gorm.io/gorm v1.25.2
)
require (
@@ -48,12 +49,14 @@ require (
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
@@ -69,11 +72,11 @@ require (
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -81,10 +84,13 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
)
+23 -9
View File
@@ -40,6 +40,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -58,6 +60,10 @@ github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwv
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -77,8 +83,9 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
@@ -90,6 +97,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -140,9 +149,6 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -167,6 +173,9 @@ github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQ
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -263,11 +272,16 @@ gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
+9 -2
View File
@@ -119,9 +119,9 @@ func main() {
}
if os.Getenv("ENABLE_PPROF") == "true" {
go func() {
gopool.Go(func() {
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
}()
})
go common.Monitor()
common.SysLog("pprof enabled")
}
@@ -145,6 +145,13 @@ func main() {
middleware.SetUpLogger(server)
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
store.Options(sessions.Options{
Path: "/",
MaxAge: 2592000, // 30 days
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteStrictMode,
})
server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS, indexPage)
+26 -28
View File
@@ -64,35 +64,33 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if !useAccessToken {
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,请刷新页面或清空缓存后重试",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,登录信息无效,请重新登录",
})
c.Abort()
return
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,未提供 New-Api-User",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,New-Api-User 格式错误",
})
c.Abort()
return
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,与登录用户不匹配,请重新登录",
})
c.Abort()
return
}
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,New-Api-User 与登录用户不匹配",
})
c.Abort()
return
}
if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{
+2 -4
View File
@@ -135,17 +135,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
@@ -170,7 +167,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
@@ -239,5 +235,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
case common.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
}
}
+1 -10
View File
@@ -12,7 +12,7 @@ import (
type Ability struct {
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
@@ -23,10 +23,6 @@ type Ability struct {
func GetGroupModels(group string) []string {
var models []string
// Find distinct models
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
return models
}
@@ -45,10 +41,8 @@ func GetAllEnableAbilities() []Ability {
}
func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
@@ -81,10 +75,8 @@ func getPriority(group string, model string, retry int) (int, error) {
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
@@ -286,7 +278,6 @@ func FixAbility() (int, error) {
return 0, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
-100
View File
@@ -11,106 +11,6 @@ import (
"time"
)
//func CacheGetUserGroup(id int) (group string, err error) {
// if !common.RedisEnabled {
// return GetUserGroup(id)
// }
// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
// if err != nil {
// group, err = GetUserGroup(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return group, err
//}
//
//func CacheGetUsername(id int) (username string, err error) {
// if !common.RedisEnabled {
// return GetUsernameById(id)
// }
// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
// if err != nil {
// username, err = GetUsernameById(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return username, err
//}
//
//func CacheGetUserQuota(id int) (quota int, err error) {
// if !common.RedisEnabled {
// return GetUserQuota(id)
// }
// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
// if err != nil {
// quota, err = GetUserQuota(id)
// if err != nil {
// return 0, err
// }
// return quota, nil
// }
// quota, err = strconv.Atoi(quotaString)
// return quota, nil
//}
//
//func CacheUpdateUserQuota(id int) error {
// if !common.RedisEnabled {
// return nil
// }
// quota, err := GetUserQuota(id)
// if err != nil {
// return err
// }
// return cacheSetUserQuota(id, quota)
//}
//
//func cacheSetUserQuota(id int, quota int) error {
// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
// return err
//}
//
//func CacheDecreaseUserQuota(id int, quota int) error {
// if !common.RedisEnabled {
// return nil
// }
// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
// return err
//}
//
//func CacheIsUserEnabled(userId int) (bool, error) {
// if !common.RedisEnabled {
// return IsUserEnabled(userId)
// }
// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
// if err == nil {
// return enabled == "1", nil
// }
//
// userEnabled, err := IsUserEnabled(userId)
// if err != nil {
// return false, err
// }
// enabled = "0"
// if userEnabled {
// enabled = "1"
// }
// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user enabled error: " + err.Error())
// }
// return userEnabled, err
//}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
+1 -8
View File
@@ -28,7 +28,7 @@ type Channel struct {
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
ModelMapping *string `json:"model_mapping" gorm:"type:text"`
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
@@ -114,14 +114,11 @@ func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) {
func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
var channels []*Channel
keyCol := "`key`"
groupCol := "`group`"
modelsCol := "`models`"
// 如果是 PostgreSQL,使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
groupCol = `"group"`
modelsCol = `"models"`
}
@@ -437,14 +434,10 @@ func GetPaginatedTags(offset int, limit int) ([]*string, error) {
func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) {
var tags []*string
keyCol := "`key`"
groupCol := "`group`"
modelsCol := "`models`"
// 如果是 PostgreSQL,使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
groupCol = `"group"`
modelsCol = `"models"`
}
+50 -18
View File
@@ -27,6 +27,7 @@ type Log struct {
UseTime int `json:"use_time" gorm:"default:0"`
IsStream bool `json:"is_stream" gorm:"default:false"`
ChannelId int `json:"channel" gorm:"index"`
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Other string `json:"other"`
@@ -42,6 +43,7 @@ const (
func formatUserLogs(logs []*Log) {
for i := range logs {
logs[i].ChannelName = ""
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
if otherMap != nil {
@@ -56,7 +58,7 @@ func formatUserLogs(logs []*Log) {
func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
if err = DB.Model(&Token{}).Where("`key`=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -128,67 +130,97 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if logType == LogTypeUnknown {
tx = LOG_DB
} else {
tx = LOG_DB.Where("type = ?", logType)
tx = LOG_DB.Where("logs.type = ?", logType)
}
if modelName != "" {
tx = tx.Where("model_name like ?", modelName)
tx = tx.Where("logs.model_name like ?", modelName)
}
if username != "" {
tx = tx.Where("username = ?", username)
tx = tx.Where("logs.username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
tx = tx.Where("logs.token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
tx = tx.Where("logs.channel_id = ?", channel)
}
if group != "" {
tx = tx.Where(groupCol+" = ?", group)
tx = tx.Where("logs."+groupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
channelIds := make([]int, 0)
channelMap := make(map[int]string)
for _, log := range logs {
if log.ChannelId != 0 {
channelIds = append(channelIds, log.ChannelId)
}
}
if len(channelIds) > 0 {
var channels []struct {
Id int `gorm:"column:id"`
Name string `gorm:"column:name"`
}
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
return logs, total, err
}
for _, channel := range channels {
channelMap[channel.Id] = channel.Name
}
for i := range logs {
logs[i].ChannelName = channelMap[logs[i].ChannelId]
}
}
return logs, total, err
}
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB.Where("user_id = ?", userId)
tx = LOG_DB.Where("logs.user_id = ?", userId)
} else {
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
}
if modelName != "" {
tx = tx.Where("model_name like ?", modelName)
tx = tx.Where("logs.model_name like ?", modelName)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
tx = tx.Where("logs.token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if group != "" {
tx = tx.Where(groupCol+" = ?", group)
tx = tx.Where("logs."+groupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
formatUserLogs(logs)
return logs, total, err
}
+8 -8
View File
@@ -1,9 +1,9 @@
package model
import (
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"log"
"one-api/common"
@@ -16,7 +16,7 @@ import (
var groupCol string
var keyCol string
func init() {
func initCol() {
if common.UsingPostgreSQL {
groupCol = `"group"`
keyCol = `"key"`
@@ -55,6 +55,9 @@ func createRootAccountIfNeed() error {
}
func chooseDB(envName string) (*gorm.DB, error) {
defer func() {
initCol()
}()
dsn := os.Getenv(envName)
if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") {
@@ -116,12 +119,9 @@ func InitDB() (err error) {
if !common.IsMasterNode {
return nil
}
//if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
if common.UsingMySQL {
_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = migrateDB()
return err
+8 -2
View File
@@ -84,7 +84,7 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -104,11 +104,13 @@ func InitOptionMap() {
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -220,6 +222,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled":
setting.CheckSensitiveEnabled = boolValue
case "DemoSiteEnabled":
setting.DemoSiteEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
@@ -302,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
@@ -332,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
+75 -9
View File
@@ -3,8 +3,10 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strconv"
"gorm.io/gorm"
)
type Redemption struct {
@@ -21,16 +23,80 @@ type Redemption struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
var redemptions []*Redemption
var err error
err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
return redemptions, err
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
// 开始事务
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 获取总数
err = tx.Model(&Redemption{}).Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 获取分页数据
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 提交事务
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return redemptions, total, nil
}
func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) {
err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error
return redemptions, err
func SearchRedemptions(keyword string, startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Build query based on keyword type
query := tx.Model(&Redemption{})
// Only try to convert to ID if the string represents a valid integer
if id, err := strconv.Atoi(keyword); err == nil {
query = query.Where("id = ? OR name LIKE ?", id, keyword+"%")
} else {
query = query.Where("name LIKE ?", keyword+"%")
}
// Get total count
err = query.Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Get paginated data
err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return redemptions, total, nil
}
func GetRedemptionById(id int) (*Redemption, error) {
+4 -83
View File
@@ -3,13 +3,11 @@ package model
import (
"errors"
"fmt"
"one-api/common"
"strings"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
"strconv"
"strings"
)
type Token struct {
@@ -68,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
if token != "" {
token = strings.Trim(token, "sk-")
}
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where("`key` LIKE ?", "%"+token+"%").Find(&tokens).Error
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
}
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(relayInfo.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("failed to send email" + err.Error())
}
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
}
}()
}
}
}
return nil
}
+1 -1
View File
@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
return nil, nil
return nil, fmt.Errorf("redis is not enabled")
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
+163 -35
View File
@@ -1,6 +1,7 @@
package model
import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
@@ -38,6 +39,20 @@ type User struct {
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
}
func (user *User) ToBaseUser() *UserBase {
cache := &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return cache
}
func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *User) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
@@ -81,41 +112,105 @@ func GetMaxUserId() int {
return user.Id
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
return users, err
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
// Start transaction
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Get total count within transaction
err = tx.Unscoped().Model(&User{}).Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Get paginated users within same transaction
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Commit transaction
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func SearchUsers(keyword string, group string) ([]*User, error) {
func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
var users []*User
var total int64
var err error
// 开始事务
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 构建基础查询
query := tx.Unscoped().Model(&User{})
// 构建搜索条件
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果转换成功,按照ID和可选的组别搜索用户
query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt)
// 如果是数字,同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
if group != "" {
query = query.Where(groupCol+" = ?", group) // 使用反引号包围group
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
err = query.Find(&users).Error
if err != nil || len(users) > 0 {
return users, err
}
}
err = nil
query := DB.Unscoped().Omit("password")
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
if group != "" {
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
// 非数字关键字,只搜索字符串字段
if group != "" {
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
}
err = query.Find(&users).Error
return users, err
// 获取总数
err = query.Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 获取分页数据
err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 提交事务
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func GetUserById(id int, selectAll bool) (*User, error) {
@@ -251,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Edit(updatePassword bool) error {
@@ -280,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
return err
}
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
// Update cache
return updateUserCache(*user)
}
func (user *User) Delete() error {
@@ -307,8 +402,8 @@ func (user *User) HardDelete() error {
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
// that means if your field's value is 0, '', false or other zero values,
// it won't be used to build query conditions
password := user.Password
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
@@ -467,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
return quota, nil
}
// Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
@@ -516,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
return group, nil
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
common.SysError("failed to update user setting cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
setting, err := getUserSettingCache(id)
if err == nil {
return setting, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
}
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
@@ -577,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
}
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
//func GetRootUserEmail() (email string) {
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
// return email
//}
func GetRootUser() (user *User) {
DB.Where("role = ?", common.RoleRootUser).First(&user)
return user
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -661,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
return !errors.Is(err, gorm.ErrRecordNotFound)
}
func (u *User) FillUserByLinuxDOId() error {
if u.LinuxDOId == "" {
func (user *User) FillUserByLinuxDOId() error {
if user.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}
+167 -160
View File
@@ -1,206 +1,213 @@
package model
import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/constant"
"strconv"
"time"
"github.com/bytedance/gopkg/util/gopool"
)
// Change UserCache struct to userCache
type userCache struct {
// UserBase struct remains the same as it represents the cached data structure
type UserBase struct {
Id int `json:"id"`
Group string `json:"group"`
Email string `json:"email"`
Quota int `json:"quota"`
Status int `json:"status"`
Role int `json:"role"`
Username string `json:"username"`
Setting string `json:"setting"`
}
// Rename all exported functions to private ones
// invalidateUserCache clears all user related cache
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// getUserCacheKey returns the key for user cache
func getUserCacheKey(userId int) string {
return fmt.Sprintf("user:%d", userId)
}
// invalidateUserCache clears user cache
func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHDelObj(getUserCacheKey(userId))
}
keys := []string{
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
// updateUserCache updates all user cache fields using hash
func updateUserCache(user User) error {
if !common.RedisEnabled {
return nil
}
for _, key := range keys {
if err := common.RedisDel(key); err != nil {
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// GetUserCache gets complete user cache from hash
func GetUserCache(userId int) (userCache *UserBase, err error) {
var user *User
var fromDB bool
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}
return nil
}
}()
// updateUserGroupCache updates user group cache
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
group,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserQuotaCache updates user quota cache
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf("%d", quota),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserStatusCache updates user status cache
func updateUserStatusCache(userId int, userEnabled bool) error {
if !common.RedisEnabled {
return nil
}
enabled := "0"
if userEnabled {
enabled = "1"
}
return common.RedisSet(
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
enabled,
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
)
}
// updateUserNameCache updates username cache
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
username,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserCache updates all user cache fields
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
if !common.RedisEnabled {
return nil
// Try getting from Redis first
userCache, err = cacheGetUserBase(userId)
if err == nil {
return userCache, nil
}
if err := updateUserGroupCache(userId, userGroup); err != nil {
return fmt.Errorf("update group cache: %w", err)
}
if err := updateUserQuotaCache(userId, quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(userId, username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
return nil
}
// getUserGroupCache gets user group from cache
func getUserGroupCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
}
// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
// If Redis fails, get from DB
fromDB = true
user, err = GetUserById(userId, false)
if err != nil {
return 0, err
return nil, err // Return nil and error if DB lookup fails
}
return strconv.Atoi(quotaStr)
// Create cache object from user data
userCache = &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return userCache, nil
}
// getUserStatusCache gets user status from cache
func getUserStatusCache(userId int) (int, error) {
func cacheGetUserBase(userId int) (*UserBase, error) {
if !common.RedisEnabled {
return 0, nil
return nil, fmt.Errorf("redis is not enabled")
}
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
var userCache UserBase
// Try getting from Redis first
err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
if err != nil {
return 0, err
return nil, err
}
return strconv.Atoi(statusStr)
return &userCache, nil
}
// getUserNameCache gets username from cache
func getUserNameCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
}
// getUserCache gets complete user cache
func getUserCache(userId int) (*userCache, error) {
if !common.RedisEnabled {
return nil, nil
}
group, err := getUserGroupCache(userId)
if err != nil {
return nil, fmt.Errorf("get group cache: %w", err)
}
quota, err := getUserQuotaCache(userId)
if err != nil {
return nil, fmt.Errorf("get quota cache: %w", err)
}
status, err := getUserStatusCache(userId)
if err != nil {
return nil, fmt.Errorf("get status cache: %w", err)
}
username, err := getUserNameCache(userId)
if err != nil {
return nil, fmt.Errorf("get username cache: %w", err)
}
return &userCache{
Id: userId,
Group: group,
Quota: quota,
Status: status,
Username: username,
}, nil
}
// Add atomic quota operations
// Add atomic quota operations using hash fields
func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled {
return nil
}
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
return common.RedisIncr(key, delta)
return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
}
func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta)
}
// Helper functions to get individual fields if needed
func getUserGroupCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Group, nil
}
func getUserQuotaCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Quota, nil
}
func getUserStatusCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Status, nil
}
func getUserNameCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
}
return cache.GetSetting(), nil
}
// New functions for individual field updates
func updateUserStatusCache(userId int, status bool) error {
if !common.RedisEnabled {
return nil
}
statusInt := common.UserStatusEnabled
if !status {
statusInt = common.UserStatusDisabled
}
return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
}
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
}
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
}
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
}
func updateUserSettingCache(userId int, setting string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
}
+1
View File
@@ -15,6 +15,7 @@ type Adaptor interface {
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
+4 -3
View File
@@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default:
aliReq := requestOpenAI2Ali(*request)
return aliReq, nil
@@ -67,6 +64,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return embeddingRequestOpenAI2Ali(request), nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
+5 -2
View File
@@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
return &request
}
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
if request.Model == "" {
request.Model = "text-embedding-v1"
}
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Model: request.Model,
Input: struct {
Texts []string `json:"texts"`
}{
+15 -5
View File
@@ -39,7 +39,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
resp, err := doRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
@@ -62,7 +62,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
resp, err := doRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
@@ -90,8 +90,18 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return targetConn, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req)
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
@@ -120,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
resp, err := doRequest(c, req, info.ToRelayInfo())
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
+6
View File
@@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}
+1 -1
View File
@@ -10,7 +10,7 @@ type AwsClaudeRequest struct {
System string `json:"system,omitempty"`
Messages []claude.ClaudeMessage `json:"messages"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
+5 -3
View File
@@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := requestOpenAI2Baidu(*request)
return baiduRequest, nil
@@ -122,6 +119,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
return baiduEmbeddingRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -1
View File
@@ -12,7 +12,7 @@ type BaiduMessage struct {
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
+1 -1
View File
@@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha
return &response
}
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
+76
View File
@@ -0,0 +1,76 @@
package baidu_v2
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}
+29
View File
@@ -0,0 +1,29 @@
package baidu_v2
var ModelList = []string{
"ernie-4.0-8k-latest",
"ernie-4.0-8k-preview",
"ernie-4.0-8k",
"ernie-4.0-turbo-8k-latest",
"ernie-4.0-turbo-8k-preview",
"ernie-4.0-turbo-8k",
"ernie-4.0-turbo-128k",
"ernie-3.5-8k-preview",
"ernie-3.5-8k",
"ernie-3.5-128k",
"ernie-speed-8k",
"ernie-speed-128k",
"ernie-speed-pro-128k",
"ernie-lite-8k",
"ernie-lite-pro-128k",
"ernie-tiny-8k",
"ernie-char-8k",
"ernie-char-fiction-8k",
"ernie-novel-8k",
"deepseek-v3",
"deepseek-r1",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-qwen-14b",
}
var ChannelName = "volcengine"
+5
View File
@@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -1
View File
@@ -50,7 +50,7 @@ type ClaudeRequest struct {
MaxTokens uint `json:"max_tokens,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
+6 -1
View File
@@ -4,13 +4,14 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -56,6 +57,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
// 添加文件字段
file, _, err := c.Request.FormFile("file")
+1 -1
View File
@@ -9,7 +9,7 @@ type CfRequest struct {
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
}
type CfAudioResponse struct {
+6
View File
@@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return requestConvertRerank2Cohere(request), nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
+12 -1
View File
@@ -10,6 +10,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -49,6 +55,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -1
View File
@@ -1,7 +1,7 @@
package deepseek
var ModelList = []string{
"deepseek-chat", "deepseek-coder",
"deepseek-chat", "deepseek-reasoner",
}
var ChannelName = "deepseek"
+6
View File
@@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+104 -3
View File
@@ -1,15 +1,21 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not supported model for image generation")
}
// convert size to aspect ratio
aspectRatio := "1:1" // default aspect ratio
switch request.Size {
case "1024x1024":
aspectRatio = "1:1"
case "1024x1792":
aspectRatio = "9:16"
case "1792x1024":
aspectRatio = "16:9"
}
// build gemini imagen request
geminiRequest := GeminiImageRequest{
Instances: []GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
}
return geminiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
@@ -68,11 +106,20 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
}
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {
@@ -81,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return
}
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage = &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
+9 -5
View File
@@ -3,17 +3,21 @@ package gemini
var ModelList = []string{
// stable version
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
"gemini-2.0-flash",
// latest version
"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
// legacy version
"gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
// exp
"gemini-exp-1114", "gemini-exp-1121", "gemini-exp-1206",
// preview version
"gemini-2.0-flash-lite-preview",
// gemini exp
"gemini-exp-1206",
// flash exp
"gemini-2.0-flash-exp",
// pro exp
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
"gemini-2.0-flash-thinking-exp-1219",
// imagen models
"imagen-3.0-generate-002",
}
var ChannelName = "google gemini"
+28 -1
View File
@@ -71,7 +71,7 @@ type GeminiChatTool struct {
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
// Imagen related structs
type GeminiImageRequest struct {
Instances []GeminiImageInstance `json:"instances"`
Parameters GeminiImageParameters `json:"parameters"`
}
type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
type GeminiImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type GeminiImageResponse struct {
Predictions []GeminiImagePrediction `json:"predictions"`
}
type GeminiImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}
+4
View File
@@ -55,6 +55,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
+6
View File
@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+93
View File
@@ -0,0 +1,93 @@
package mokaai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
suffix = "embeddings"
}
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
return baiduEmbeddingRequest, nil
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = mokaEmbeddingHandler(c, resp)
default:
// err, usage = mokaHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}
+9
View File
@@ -0,0 +1,9 @@
package mokaai
var ModelList = []string{
"m3e-large",
"m3e-base",
"m3e-small",
}
var ChannelName = "mokaai"
+83
View File
@@ -0,0 +1,83 @@
package mokaai
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
var input []string // Change input to []string
switch v := request.Input.(type) {
case string:
input = []string{v} // Convert string to []string
case []string:
input = v // Already a []string, no conversion needed
case []interface{}:
for _, part := range v {
if str, ok := part.(string); ok {
input = append(input, str) // Append each string to the slice
}
}
}
return &dto.EmbeddingRequest{
Input: input,
Model: request.Model,
}
}
func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var baiduResponse dto.EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// if baiduResponse.ErrorMsg != "" {
// return &dto.OpenAIErrorWithStatusCode{
// Error: dto.OpenAIError{
// Type: "baidu_error",
// Param: "",
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
+6 -6
View File
@@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -46,18 +47,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
return requestOpenAI2Ollama(*request), nil
}
return requestOpenAI2Ollama(*request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return requestOpenAI2Embeddings(request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+25 -22
View File
@@ -3,29 +3,32 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type OllamaEmbeddingRequest struct {
@@ -35,7 +38,7 @@ type OllamaEmbeddingRequest struct {
}
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}
+35 -10
View File
@@ -9,14 +9,36 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, err
}
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
}
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
@@ -39,10 +61,13 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}, nil
}
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Input: request.ParseInput(),
@@ -123,9 +148,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
}
func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
}
return flattened
}
return flattened
}
+28 -4
View File
@@ -10,6 +10,7 @@ import (
"mime/multipart"
"net/http"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
@@ -44,16 +45,20 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
switch info.ChannelType {
case common.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
apiVersion = constant2.AzureDefaultAPIVersion
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax:
@@ -109,13 +114,28 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
request.StreamOptions = nil
}
if strings.HasPrefix(request.Model, "o1") {
if strings.HasPrefix(request.Model, "o1") || strings.HasPrefix(request.Model, "o3") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.ReasoningEffort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.ReasoningEffort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
}
if request.Model == "o1" || request.Model == "o1-2024-12-17" {
if request.Model == "o1" || request.Model == "o1-2024-12-17" || strings.HasPrefix(request.Model, "o3") {
//修改第一个Message的内容,将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
@@ -129,6 +149,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
+6 -1
View File
@@ -13,9 +13,14 @@ var ModelList = []string{
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"o3-mini", "o3-mini-2025-01-31",
"o3-mini-high", "o3-mini-2025-01-31-high",
"o3-mini-low", "o3-mini-2025-01-31-low",
"o3-mini-medium", "o3-mini-2025-01-31-medium",
"o1", "o1-2024-12-17",
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01",
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001",
"text-moderation-latest", "text-moderation-stable",
+59 -62
View File
@@ -5,7 +5,13 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io"
"math"
"mime/multipart"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,13 +19,10 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"os"
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
@@ -65,8 +68,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
// twice timeout for o1 model
streamingTimeout *= 2
}
ticker := time.NewTicker(streamingTimeout)
defer ticker.Stop()
stopChan := make(chan bool)
@@ -83,11 +90,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
mu.Lock()
data = data[6:]
data = data[5:]
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat)
@@ -154,6 +162,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -174,6 +183,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -265,7 +275,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -312,6 +322,11 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -336,70 +351,52 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
resp.Body.Close()
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
usage.PromptTokens = audioTokens
usage.CompletionTokens = 0
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func getTextFromVTT(body []byte) (string, error) {
return getTextFromSRT(body)
}
func getTextFromVerboseJSON(body []byte) (string, error) {
var whisperResponse dto.WhisperVerboseJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
func countAudioTokens(c *gin.Context) (int, error) {
body, err := common.GetRequestBody(c)
if err != nil {
return 0, errors.WithStack(err)
}
return whisperResponse.Text, nil
}
func getTextFromSRT(body []byte) (string, error) {
scanner := bufio.NewScanner(strings.NewReader(string(body)))
var builder strings.Builder
var textLine bool
for scanner.Scan() {
line := scanner.Text()
if textLine {
builder.WriteString(line)
textLine = false
continue
} else if strings.Contains(line, "-->") {
textLine = true
continue
}
var reqBody struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
if err := scanner.Err(); err != nil {
return "", err
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
return builder.String(), nil
}
func getTextFromText(body []byte) (string, error) {
return strings.TrimSuffix(string(body), "\n"), nil
}
func getTextFromJSON(body []byte) (string, error) {
var whisperResponse dto.AudioResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
return whisperResponse.Text, nil
tmpFp, err := os.CreateTemp("", "audio-*")
if err != nil {
return 0, errors.WithStack(err)
}
defer os.Remove(tmpFp.Name())
_, err = io.Copy(tmpFp, reqFp)
if err != nil {
return 0, errors.WithStack(err)
}
if err = tmpFp.Close(); err != nil {
return 0, errors.WithStack(err)
}
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
if err != nil {
return 0, errors.WithStack(err)
}
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
+6
View File
@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -1
View File
@@ -18,7 +18,7 @@ type PaLMPrompt struct {
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK uint `json:"topK,omitempty"`
+6
View File
@@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+12
View File
@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
@@ -58,6 +60,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRerank:
@@ -68,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
+1 -1
View File
@@ -40,7 +40,7 @@ var ModelList = []string{
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
"black-forest-labs/FLUX.1-schnell",
"iic/SenseVoiceSmall",
"FunAudioLLM/SenseVoiceSmall",
"netease-youdao/bce-embedding-base_v1",
"BAAI/bge-m3",
"internlm/internlm2_5-20b-chat",
+6
View File
@@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -3
View File
@@ -39,9 +39,7 @@ func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *Tencen
if request.TopP != 0 {
req.TopP = &request.TopP
}
if request.Temperature != 0 {
req.Temperature = &request.Temperature
}
req.Temperature = request.Temperature
return &req
}
+6
View File
@@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -1
View File
@@ -9,7 +9,7 @@ type VertexAIClaudeRequest struct {
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
+92
View File
@@ -0,0 +1,92 @@
package volcengine
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}
+13
View File
@@ -0,0 +1,13 @@
package volcengine
var ModelList = []string{
"Doubao-pro-128k",
"Doubao-pro-32k",
"Doubao-pro-4k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
}
var ChannelName = "volcengine"
+6
View File
@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}
+5 -5
View File
@@ -13,11 +13,11 @@ type XunfeiChatRequest struct {
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
Domain string `json:"domain,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
+6
View File
@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -1
View File
@@ -12,7 +12,7 @@ type ZhipuMessage struct {
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
+6
View File
@@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -2
View File
@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages[j] = mediaMessage
}
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,

Some files were not shown because too many files have changed in this diff Show More