Compare commits

...

80 Commits

Author SHA1 Message Date
CaIon 2c391a0c6d Merge branch 'alpha' 2025-06-21 01:26:57 +08:00
Calcium-Ion 62bfb15278 Merge pull request #1252 from feitianbubu/pr/fix-playgroud-user-setting
fix: playground write user context to check acceptUnsetRatio
2025-06-21 01:25:22 +08:00
Calcium-Ion ccabfe56cb Merge pull request #1264 from feitianbubu/pr/uniq-channel-models
fix: unique channel models
2025-06-21 01:21:28 +08:00
Calcium-Ion 3523aafeec Merge pull request #1272 from QuantumNous/gemini-stream-completion-count-fix
fix: gemini 原生格式流模式中断请求未计费
2025-06-21 01:21:01 +08:00
CaIon 3164e86278 fix: remove unnecessary error handling in token counting functions 2025-06-21 01:16:54 +08:00
CaIon 2c0fd2915b fix: improve usage calculation in GeminiTextGenerationStreamHandler 2025-06-21 01:08:15 +08:00
CaIon 16997a695d refactor: token counter logic 2025-06-21 00:54:40 +08:00
creamlike1024 4f3024ad63 fix: gemini 原生格式流模式中断请求未计费 2025-06-20 23:01:10 +08:00
CaIon 8ceaaf7819 fix: update response handling in GeminiTextGenerationStreamHandler
- Changed response handling from ObjectData to StringData for improved data processing.
- Ensured proper error logging in case of response handling failure.
2025-06-20 21:55:28 +08:00
CaIon 11c3bff1e2 fix: update payment method handling in topup controller
- Refactored payment method validation to check against available methods.
- Changed payment method types from "zfb" to "alipay" and "wx" to "wxpay" for consistency.
- Updated the purchase request to use the validated payment method directly.
2025-06-20 17:48:55 +08:00
Calcium-Ion 0a78d10388 Merge pull request #1271 from RedwindA/feat/vertex-budgetControl
feat: vertex budget control in model name
2025-06-20 17:24:42 +08:00
RedwindA 4f84fd6a71 fix: update model name logic for vertex 2025-06-20 16:40:51 +08:00
Calcium-Ion c74a86d14d Merge pull request #1267 from t0ng7u/feature/upstream-ratio-sync
🔄 feat(ratio-sync): introduce upstream ratio synchronisation feature #1220
2025-06-20 16:22:00 +08:00
Calcium-Ion e8da9d7dd7 Merge pull request #1270 from QuantumNous/refactor_model_mapping
feat: implement new handlers for relay processing
2025-06-20 16:12:41 +08:00
Calcium-Ion 2e242159a1 Merge pull request #1244 from feitianbubu/feat/video
feat: 支持可灵视频渠道(异步任务)
2025-06-20 16:11:59 +08:00
CaIon 13277cf838 feat: implement new handlers for audio, image, embedding, and responses processing
- Added new handlers: AudioHelper, ImageHelper, EmbeddingHelper, and ResponsesHelper to manage respective requests.
- Updated ModelMappedHelper to accept request parameters for better model mapping.
- Enhanced error handling and validation across new handlers to ensure robust request processing.
- Introduced support for new relay formats in relay_info and updated relevant functions accordingly.
2025-06-20 16:02:23 +08:00
CaIon da067b7f90 refactor: update error handling in ClaudeHelper and GeminiHelper 2025-06-20 14:53:27 +08:00
Calcium-Ion 0b245ff4ee Merge pull request #1268 from QuantumNous/alpha
fix: gemini relay empty response
2025-06-20 02:31:11 +08:00
CaIon 2b2e0a4777 feat: enhance error handling in GeminiHelper and streamline response processing
- Added status code mapping handling in GeminiHelper to reset status codes based on response.
- Removed redundant candidate check in GeminiTextGenerationHandler to simplify response processing.
2025-06-20 01:42:19 +08:00
CaIon 2aae048295 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-20 01:07:52 +08:00
CaIon f9a72212e6 Merge branch 'main' into alpha 2025-06-20 01:07:44 +08:00
Apple\Apple cfeb0c2828 💄 style(TopUp): Optimize payment method buttons layout based on quantity
Enhance the UI of payment method selection area with responsive layouts:
- Use 2-column grid when exactly 2 payment methods are present
- Use 3-column grid for 3 payment methods
- Use compact card layout for more than 3 payment methods
- Full-width button for single payment method

This improves the visual balance across different device sizes and payment provider configurations, ensuring buttons fill their grid cells appropriately with the w-full class.
2025-06-20 00:52:45 +08:00
Apple\Apple cd4fc7a188 🚀 chore(controller, dto): elevate ratio-sync feature to production readiness
WHAT’S NEW
• controller/ratio_sync.go
  – Deleted unused local structs (TestResult, DifferenceItem, SyncableChannel).
  – Centralised config with constants: defaultTimeoutSeconds, defaultEndpoint, maxConcurrentFetches, ratioTypes.
  – Replaced magic numbers; added semaphore-based concurrency limit and shared http.Client (with TLS & Expect-Continue timeouts).
  – Added comprehensive error handling and context-aware logging via common.Log* helpers.
  – Checked DB errors from GetChannelsByIds; early-return on failures or empty upstream list.
  – Removed custom-channel support; logic now relies solely on ChannelIDs.
  – Minor clean-ups: import grouping, string trimming, endpoint normalisation.

• dto/ratio_sync.go
  – Simplified UpstreamRequest: dropped unused CustomChannels field.

WHY
These improvements harden the ratio-sync endpoint for production use by preventing silent failures, controlling resource usage, and making behaviour configurable and observable.

HOW
No business logic change—only structural refactor, logging, and safeguards—so existing API contracts (aside from removed custom_channels) remain intact.
2025-06-19 19:55:51 +08:00
CaIon 94a1aeb5c7 feat: add data presence check before batch update in utils 2025-06-19 19:34:57 +08:00
CaIon 2cd9fd4c45 refactor: streamline JSON response structure in channel API endpoints 2025-06-19 19:03:35 +08:00
Apple\Apple 9d7a571472 🔍 feat(ratio-sync): add fuzzy model search & enhance empty-state UX
Summary
1. Add model name search box
   • Introduce Semi UI `Input` with `IconSearch` prefix next to the “Apply Sync” button.
   • Support case-insensitive fuzzy matching of model names.
   • Real-time filtering, pagination and bulk-select logic now work on filtered data.

2. Improve empty state handling
   • Add `hasSynced` flag to distinguish “not synced yet” from “synced with no differences”.
   • Display messages:
     – “Please select sync channels” when no sync has been performed.
     – “No differences found” when a sync completed with zero discrepancies.
     – “No matching model found” when search yields no results.

3. UI tweaks
   • Replace lucide-react `Search` icon with Semi UI `IconSearch` for visual consistency.
   • Keep responsive width and clearable input for better usability.

Why
These changes allow admins to quickly locate specific models and provide accurate feedback on the sync status, greatly improving the usability of the Upstream Ratio Sync page.
2025-06-19 18:54:46 +08:00
Apple\Apple 53e8aa058b 🛠️ chore(ratio-sync): improve upstream ratio comparison & output cleanliness
Summary
1. Consider “both unset” as identical
   • When both localValue and upstreamValue are nil, mark upstreamValue as "same" to avoid showing “Not set”.

2. Exclude fully-synced upstream channels from result
   • Scan `differences` to detect channels that contain at least one divergent value.
   • Remove channels whose every ratio is either `"same"` or `nil`, so the frontend only receives actionable discrepancies.

Why
These changes reduce visual noise in the Upstream Ratio Sync table, making it easier for admins to focus on models requiring attention. No functional regressions or breaking API changes are introduced.
2025-06-19 18:38:43 +08:00
creamlike1024 4a725f8bb0 fix: 使用日志分组查询 2025-06-19 17:17:32 +08:00
CaIon 05aaf63337 Merge branch 'alpha' 2025-06-19 16:17:56 +08:00
Apple\Apple 73cfa5891d chore(ui): enhance channel selector with status avatars and UI improvements
Add visual status indicators and improve user experience for the upstream ratio sync channel selector modal.

Features:
- Add status-based avatar indicators for channels (enabled/disabled/auto-disabled)
- Implement search functionality with text highlighting
- Add endpoint configuration input for each channel
- Optimize component structure with reusable ChannelInfo component

UI Improvements:
- Custom styling for transfer component items
- Hide scrollbars for cleaner appearance in transfer lists
- Responsive layout adjustments for channel information display
- Color-coded avatars: green (enabled), red (disabled), amber (auto-disabled), grey (unknown)

Code Quality:
- Extract channel status configuration to constants
- Create reusable ChannelInfo component to reduce code duplication
- Implement proper search filtering for both channel names and URLs
- Add consistent styling classes for transfer demo components

Files modified:
- web/src/components/settings/ChannelSelectorModal.js
- web/src/pages/Setting/Ratio/UpstreamRatioSync.js
- web/src/index.css

This enhancement provides better visual feedback for channel status and improves the overall user experience when selecting channels for ratio synchronization.
2025-06-19 16:05:50 +08:00
CaIon e685279207 fix: ratio render 2025-06-19 15:36:06 +08:00
Apple\Apple 9db2886c18 🗑️ chore(custom channel): Remove custom channel support from upstream ratio sync
Remove all custom channel functionality from the upstream ratio sync feature to simplify the codebase and focus on database-stored channels only.

Changes:
- Remove custom channel UI components and related state management
- Remove custom channel testing and validation logic
- Simplify ChannelSelectorModal by removing custom channel input fields
- Update API payload to only include channel_ids, removing custom_channels
- Remove custom channel processing logic from backend controller
- Update import path for DEFAULT_ENDPOINT constant

Files modified:
- web/src/pages/Setting/Ratio/UpstreamRatioSync.js
- web/src/components/settings/ChannelSelectorModal.js
- controller/ratio_sync.go

This change streamlines the ratio synchronization workflow by focusing solely on pre-configured database channels, reducing complexity and potential maintenance overhead.
2025-06-19 15:17:05 +08:00
skynono 9034b2c9b2 fix: unique channel models 2025-06-19 15:16:26 +08:00
creamlike1024 bc1381ea5b Merge branch 'xqx121-main' 2025-06-19 14:51:15 +08:00
creamlike1024 1271f8f648 update relay-gemini-native.go 2025-06-19 14:50:50 +08:00
creamlike1024 24dda1f6fa Merge branch 'main' of github.com:xqx121/new-api into xqx121-main 2025-06-19 14:45:41 +08:00
CaIon fd943659c0 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-19 14:36:55 +08:00
CaIon bf30adcde0 Merge branch 'main' into alpha 2025-06-19 14:36:17 +08:00
Calcium-Ion b33b3c4f85 Merge pull request #1253 from TopickAI/main
Fix Vertex channel global region format for claude models
2025-06-19 14:35:18 +08:00
Calcium-Ion 0d51f07d8f Merge pull request #1260 from tbphp/fix-gemini-empty-content-error
fix: Gemini & Vertex empty content error
2025-06-19 14:34:27 +08:00
Calcium-Ion 93761f9948 Merge pull request #1261 from KamiPasi/new-api-pr
透传thinking参数, 豆包模型用来控制是否思考
2025-06-19 14:33:41 +08:00
Calcium-Ion 4ba5558094 Merge pull request #1262 from wans10/main
fix: 修复渠道界面模型选择下拉框模型重复显示
2025-06-19 14:32:49 +08:00
Calcium-Ion 4d5e79c9d3 Merge pull request #1257 from QuantumNous/pay_custom
feat: 自定义充值方式
2025-06-19 14:31:45 +08:00
wans10 e4a7f0c779 修复渠道界面模型选择下拉框模型重复显示 2025-06-19 13:34:11 +08:00
KamiPasi 3247d7a341 透传thinking参数, 豆包模型用来控制是否思考 2025-06-19 12:06:42 +08:00
skynono 5d9cf11466 feat: add unsupported test case for kling channel 2025-06-19 11:53:47 +08:00
skynono aaf7c88e0b feat: add video channel kling fix 2025-06-19 11:53:47 +08:00
skynono d7ed0214ad feat: add video channel kling 2025-06-19 11:53:42 +08:00
tbphp 173594446e fix: Gemini & Vertex empty content error 2025-06-19 11:25:59 +08:00
Apple\Apple 8748425aa2 🚀 feat(ratio-sync): major refactor & UX overhaul for Upstream Ratio Sync 2025-06-19 08:57:34 +08:00
Calcium-Ion 1ea674f3ff Merge pull request #1258 from feitianbubu/pr/fix-task-cost-time
fix: task cost time
2025-06-18 23:15:28 +08:00
skynono 6b01438ba5 fix: task cost time 2025-06-18 21:57:46 +08:00
creamlike1024 ef6c390a83 feat: 充值方式设置 2025-06-18 21:23:06 +08:00
CaIon 8d05e44f61 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-18 20:51:06 +08:00
CaIon 17f3832487 feat(relay): add debug logging for Gemini request body and introduce flexible speech configuration 2025-06-18 20:50:13 +08:00
Apple\Apple 175ea56fdd 🚚 Refactor(ratio_setting): refactor ratio management into standalone ratio_setting package
Summary
• Migrated all ratio-related sources into `setting/ratio_setting/`
  – `model_ratio.go` (renamed from model-ratio.go)
  – `cache_ratio.go`
  – `group_ratio.go`
• Changed package name to `ratio_setting` and relocated initialization (`ratio_setting.InitRatioSettings()` in main).
• Updated every import & call site:
  – Model / cache / completion / image ratio helpers
  – Group ratio helpers (`GetGroupRatio*`, `ContainsGroupRatio`, `CheckGroupRatio`, etc.)
  – JSON-serialization & update helpers (`*Ratio2JSONString`, `Update*RatioByJSONString`)
• Adjusted controllers, middleware, relay helpers, services and models to reference the new package.
• Removed obsolete `setting` / `operation_setting` imports; added missing `ratio_setting` imports.
• Adopted idiomatic map iteration (`for key := range m`) where value is unused.
• Ran static checks to ensure clean build.

This commit centralises all ratio configuration (model, cache and group) in one cohesive module, simplifying future maintenance and improving code clarity.
2025-06-18 18:00:49 +08:00
xqx121 2541defbf4 Update relay-gemini-native.go 2025-06-18 14:26:23 +08:00
sgyy 302ca0b847 fix: Vertex channel global region format 2025-06-18 11:21:56 +08:00
skynono d93204e8d8 fix: playground write user context to check acceptUnsetRatio 2025-06-18 11:06:15 +08:00
Apple\Apple ee793087de 🐛 fix(detail): explicitly set preventScroll={true} on Tabs to stop page jump
Problem
Semi UI’s Tabs calls `focus()` on the active tab during mount, causing the browser to scroll the page to that element.
Using the bare `preventScroll` shorthand was not picked up reliably, so the page still jumped to the Tabs’ position on first render.

Changes
• Updated both Tabs instances in `web/src/pages/Detail/index.js` to `preventScroll={true}` instead of the shorthand prop.
• Ensures the prop is explicitly interpreted as boolean `true`, converting the internal call to `focus({ preventScroll: true })`.

Result
The `Detail` page now stays at its original scroll position after load, eliminating the unexpected auto-scroll behavior.
2025-06-18 05:10:32 +08:00
Apple\Apple 1d16d3288d 🛠️ fix(detail): disable automatic page scroll caused by Tabs focus
The initial render of the `Detail` page was jumping to the first `Tabs` component because Semi UI calls `focus()` on the active tab, which triggers the browser’s default scroll-into-view behavior.

Changes made
• Added `preventScroll` to the chart-selector `Tabs` (type="button").
• Added `preventScroll` to the uptime-monitor `Tabs` (type="card").

These flags convert the internal `focus()` call to `focus({ preventScroll: true })`, allowing the page to stay at its current position after load.

No functional logic is changed other than disabling the unwanted scroll; UI and user interactions remain the same.
2025-06-18 04:36:12 +08:00
Apple\Apple cd94cc200b 🏷️ chore(ui): Hide Type Tabs in Tag-Aggregation Mode & Refine Query Logic
frontend(ChannelsTable):
• Do not render type-filter Tabs when `enableTagMode` is true, preventing UI/logic conflicts in tag aggregation view.
• Adjust API query construction:
  – Append `type=` param only when NOT in tag mode and selected tab ≠ 'all'.
  – Applies to both `loadChannels` and `searchChannels`.
• Result: UI stays clean in tag view, and backend receives correct parameters across modes.

No other functionality affected.
2025-06-18 02:59:34 +08:00
Apple\Apple 1800cfd1f6 🎨 style(EditChannel): replace fixed-height TextArea with autosize to eliminate unwanted scrollbar
The “Batch Create” secret-key input in Channel Edit previously used a TextArea
with a hard-coded `minHeight`, which caused an extra scrollbar and blank space
on the right side of the field.
This change:

• Removes the fixed `minHeight` in favour of `autosize={{ minRows: 6, maxRows: 6 }}`
• Keeps the field’s rounded appearance while letting it grow/shrink with
  content, improving usability on both desktop and mobile

No other components or global styles are affected.
2025-06-18 02:41:06 +08:00
Apple\Apple 44688fe6cc 🚀 feat(Channels): Enhance Channel Filtering & Performance
feat(api):
• Add optional `type` query param to `/api/channel` endpoint for type-specific pagination
• Return `type_counts` map with counts for each channel type
• Implement `GetChannelsByType`, `CountChannelsByType`, `CountChannelsGroupByType` in `model/channel.go`

feat(frontend):
• Introduce type Tabs in `ChannelsTable` to switch between channel types
• Tabs show dynamic counts using backend `type_counts`; “All” is computed from sum
• Persist active type, reload data on tab change (with proper query params)

perf(frontend):
• Use a request counter (`useRef`) to discard stale responses when tabs switch quickly
• Move all `useMemo` hooks to top level to satisfy React Hook rules
• Remove redundant local type counting fallback when backend data present

ui:
• Remove icons from response-time tags for cleaner look
• Use Semi-UI native arrow controls for Tabs; custom arrow code deleted

chore:
• Minor refactor & comments for clarity
• Ensure ESLint Hook rules pass

Result: Channel list now supports fast, accurate type filtering with correct counts, improved concurrency safety, and cleaner UI.
2025-06-18 02:33:18 +08:00
Apple\Apple 997d9901aa Merge remote-tracking branch 'origin/main' into alpha 2025-06-18 01:30:12 +08:00
Apple\Apple 19c3cb1248 🚀 feat(ui): isolate ratio configurations into dedicated “Ratio” tab and refactor settings components
Summary
• Added new Ratio tab in Settings for managing all ratio-related configurations (group & model multipliers).
• Created `RatioSetting` component to host GroupRatio, ModelRatio, Visual Editor and Unset-Models panels.
• Moved ratio components to `web/src/pages/Setting/Ratio/` directory:
  – `GroupRatioSettings.js`
  – `ModelRatioSettings.js`
  – `ModelSettingsVisualEditor.js`
  – `ModelRationNotSetEditor.js`
• Updated imports in `RatioSetting.js` to use the new path.
• Updated main Settings router (`web/src/pages/Setting/index.js`) to include the new “Ratio Settings” tab.
• Pruned `OperationSetting.js`:
  – Removed ratio-specific cards, tabs and unused imports.
  – Reduced state to only the keys required by its child components.
  – Deleted obsolete fields (`StreamCacheQueueLength`, `CheckSensitiveOnCompletionEnabled`, `StopOnSensitiveEnabled`).
• Added boolean handling simplification in `OperationSetting.js`.
• Adjusted helper import list and removed unused translation hook.

Why
Separating ratio-related settings improves UX clarity, reduces cognitive load in the Operation Settings panel and keeps the codebase modular and easier to maintain.

BREAKING CHANGE
The file paths for ratio components have changed. Any external imports referencing the old `Operation` directory must update to the new `Ratio` path.
2025-06-18 01:29:35 +08:00
Calcium-Ion 423796e790 Merge pull request #1247 from RedwindA/feat/25lite-thinking
feat: improve gemini thinking budget adaption
2025-06-18 01:00:08 +08:00
RedwindA a004db93c9 feat(Gemini): enhance budget clamping logic for Gemini models 2025-06-18 00:49:35 +08:00
CaIon 36a0d4d7ae fix(relay): ensure consistent setting of web search context size in TextHelper function 2025-06-18 00:37:22 +08:00
CaIon d4e20df7a6 fix(relay): refine error message for unsupported MIME types and enhance error handling in OpenAI wrapper 2025-06-17 22:44:57 +08:00
CaIon 8d0b54bf6e fix(relay): improve error handling for unsupported MIME types by sanitizing URLs 2025-06-17 22:40:41 +08:00
CaIon 5506e4feed feat(file_decoder): expand MIME type detection to include additional file extensions 2025-06-17 22:20:19 +08:00
CaIon 16bffe05d4 feat(file_decoder): add debug logging for MIME type detection when handling application/octet-stream 2025-06-17 22:18:51 +08:00
CaIon 680ff8e8eb feat(file_decoder): enhance MIME type detection based on URL and Content-Disposition header 2025-06-17 21:49:13 +08:00
Calcium-Ion 129d1b081f Merge pull request #1239 from QuantumNous/auto_group
feat: auto分组
2025-06-17 21:14:09 +08:00
CaIon 4ec7012974 feat: enhance group ratio handling in pricing calculations 2025-06-17 21:05:35 +08:00
CaIon 26c6087d80 feat(GroupRatioSettings): enhance JSON validation for group ratios 2025-06-17 21:05:24 +08:00
CaIon 4238068c65 feat(channel): enhance Claude response handling with new Done flag and improved usage tracking 2025-06-17 20:08:25 +08:00
CaIon 822ed681de feat(channel): add handling for pre_consume_token_quota_failed error type 2025-06-17 16:46:52 +08:00
creamlike1024 c6a9df67b1 feat: auto分组 2025-06-16 22:15:12 +08:00
101 changed files with 3868 additions and 925 deletions
+2
View File
@@ -241,6 +241,7 @@ const (
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
"https://api.klingai.com", //50
}
+18
View File
@@ -13,6 +13,7 @@ import (
"math/big"
"math/rand"
"net"
"net/url"
"os"
"os/exec"
"runtime"
@@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
}
return strconv.ParseFloat(durationStr, 64)
}
// BuildURL concatenates base and endpoint, returns the complete url string
func BuildURL(base string, endpoint string) string {
u, err := url.Parse(base)
if err != nil {
return base + endpoint
}
end := endpoint
if end == "" {
end = "/"
}
ref, err := url.Parse(end)
if err != nil {
return base + endpoint
}
return u.ResolveReference(ref).String()
}
+1
View File
@@ -5,6 +5,7 @@ type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
TaskPlatformKling TaskPlatform = "kling"
)
const (
+7 -4
View File
@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info)
err = helper.ModelMappedHelper(c, info, nil)
if err != nil {
return err, nil
}
@@ -165,8 +168,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio)
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
@@ -312,7 +315,7 @@ func testAllChannels(notify bool) error {
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
+35 -5
View File
@@ -52,6 +52,14 @@ func GetAllChannels(c *gin.Context) {
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
var total int64
@@ -72,6 +80,14 @@ func GetAllChannels(c *gin.Context) {
}
// 计算 tag 总数用于分页
total, _ = model.CountAllTags()
} else if typeFilter >= 0 {
channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
channelData = channels
total, _ = model.CountChannelsByType(typeFilter)
} else {
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
if err != nil {
@@ -82,14 +98,18 @@ func GetAllChannels(c *gin.Context) {
total, _ = model.CountAllChannels()
}
// calculate type counts
typeCounts, _ := model.CountChannelsGroupByType()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": channelData,
"total": total,
"page": p,
"page_size": pageSize,
"items": channelData,
"total": total,
"page": p,
"page_size": pageSize,
"type_counts": typeCounts,
},
})
return
@@ -217,10 +237,20 @@ func SearchChannels(c *gin.Context) {
}
channelData = channels
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelData,
"data": gin.H{
"items": channelData,
"type_counts": typeCounts,
},
})
return
}
+11 -3
View File
@@ -1,15 +1,17 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range setting.GetGroupRatioCopy() {
for groupName := range ratio_setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.GetUserGroup(userId, false)
for groupName, ratio := range setting.GetGroupRatioCopy() {
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok {
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
}
}
}
if setting.GroupInUserUsableGroups("auto") {
usableGroups["auto"] = map[string]interface{}{
"ratio": "自动",
"desc": setting.GetUsableGroupDescription("auto"),
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
+41 -39
View File
@@ -9,9 +9,9 @@ import (
"one-api/middleware"
"one-api/model"
"one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"one-api/setting/console_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -41,46 +41,48 @@ func GetStatus(c *gin.Context) {
cs := console_setting.GetConsoleSetting()
data := gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup,
"pay_methods": setting.PayMethods,
// 面板启用开关
"api_info_enabled": cs.ApiInfoEnabled,
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
"announcements_enabled": cs.AnnouncementsEnabled,
"faq_enabled": cs.FAQEnabled,
"api_info_enabled": cs.ApiInfoEnabled,
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
"announcements_enabled": cs.AnnouncementsEnabled,
"faq_enabled": cs.FAQEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
+20 -2
View File
@@ -2,7 +2,7 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"net/http"
"one-api/common"
"one-api/constant"
@@ -15,6 +15,9 @@ import (
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/setting"
"github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -134,6 +137,9 @@ func init() {
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
return m.Id
})
}
func ListModels(c *gin.Context) {
@@ -179,7 +185,19 @@ func ListModels(c *gin.Context) {
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
var models []string
if tokenGroup == "auto" {
for _, autoGroup := range setting.AutoGroups {
groupModels := model.GetGroupModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {
models = append(models, g)
}
}
}
} else {
models = model.GetGroupModels(group)
}
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
+2 -1
View File
@@ -7,6 +7,7 @@ import (
"one-api/model"
"one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"strings"
@@ -103,7 +104,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
err = setting.CheckGroupRatio(option.Value)
err = ratio_setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
+13 -3
View File
@@ -3,7 +3,6 @@ package controller
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,6 +12,8 @@ import (
"one-api/service"
"one-api/setting"
"time"
"github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
c.Set(constant.ContextKeyRequestStartTime, time.Now())
// Write user context to ensure acceptUnsetRatio is available
userId := c.GetInt("id")
userCache, err := model.GetUserCache(userId)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
return
}
userCache.WriteContext(c)
Relay(c)
}
+6 -6
View File
@@ -3,7 +3,7 @@ package controller
import (
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
@@ -13,7 +13,7 @@ func GetPricing(c *gin.Context) {
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := map[string]float64{}
for s, f := range setting.GetGroupRatioCopy() {
for s, f := range ratio_setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
@@ -22,7 +22,7 @@ func GetPricing(c *gin.Context) {
if err == nil {
group = user.Group
for g := range groupRatio {
ratio, ok := setting.GetGroupGroupRatio(group, g)
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
if ok {
groupRatio[g] = ratio
}
@@ -32,7 +32,7 @@ func GetPricing(c *gin.Context) {
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
for group := range setting.GetGroupRatioCopy() {
for group := range ratio_setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}
@@ -47,7 +47,7 @@ func GetPricing(c *gin.Context) {
}
func ResetModelRatio(c *gin.Context) {
defaultStr := operation_setting.DefaultModelRatio2JSONString()
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
@@ -56,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
})
return
}
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
+24
View File
@@ -0,0 +1,24 @@
package controller
import (
"net/http"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetRatioConfig(c *gin.Context) {
if !ratio_setting.IsExposeRatioEnabled() {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "倍率配置接口未启用",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ratio_setting.GetExposedData(),
})
}
+322
View File
@@ -0,0 +1,322 @@
package controller
import (
"context"
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
const (
defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config"
maxConcurrentFetches = 8
)
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct {
Name string `json:"name"`
Data map[string]any `json:"data,omitempty"`
Err string `json:"err,omitempty"`
}
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
if req.Timeout <= 0 {
req.Timeout = defaultTimeoutSeconds
}
var upstreams []dto.UpstreamDTO
if len(req.ChannelIDs) > 0 {
intIds := make([]int, 0, len(req.ChannelIDs))
for _, id64 := range req.ChannelIDs {
intIds = append(intIds, int(id64))
}
dbChannels, err := model.GetChannelsByIds(intIds)
if err != nil {
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
return
}
for _, ch := range dbChannels {
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
upstreams = append(upstreams, dto.UpstreamDTO{
Name: ch.Name,
BaseURL: strings.TrimRight(base, "/"),
Endpoint: "",
})
}
}
}
if len(upstreams) == 0 {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
return
}
var wg sync.WaitGroup
ch := make(chan upstreamResult, len(upstreams))
sem := make(chan struct{}, maxConcurrentFetches)
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
for _, chn := range upstreams {
wg.Add(1)
go func(chItem dto.UpstreamDTO) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
endpoint := chItem.Endpoint
if endpoint == "" {
endpoint = defaultEndpoint
} else if !strings.HasPrefix(endpoint, "/") {
endpoint = "/" + endpoint
}
fullURL := chItem.BaseURL + endpoint
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
resp, err := client.Do(httpReq)
if err != nil {
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
return
}
var body struct {
Success bool `json:"success"`
Data map[string]any `json:"data"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
if !body.Success {
ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
return
}
ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
}(chn)
}
wg.Wait()
close(ch)
localData := ratio_setting.GetExposedData()
var testResults []dto.TestResult
var successfulChannels []struct {
name string
data map[string]any
}
for r := range ch {
if r.Err != "" {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "error",
Error: r.Err,
})
} else {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "success",
})
successfulChannels = append(successfulChannels, struct {
name string
data map[string]any
}{name: r.Name, data: r.Data})
}
}
differences := buildDifferences(localData, successfulChannels)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"differences": differences,
"test_results": testResults,
},
})
}
func buildDifferences(localData map[string]any, successfulChannels []struct {
name string
data map[string]any
}) map[string]map[string]dto.DifferenceItem {
differences := make(map[string]map[string]dto.DifferenceItem)
allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for modelName := range allModels {
for _, ratioType := range ratioTypes {
var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists {
localValue = val
}
}
}
upstreamValues := make(map[string]interface{})
hasUpstreamValue := false
hasDifference := false
for _, channel := range successfulChannels {
var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val
hasUpstreamValue = true
if localValue != nil && localValue != val {
hasDifference = true
} else if localValue == val {
upstreamValue = "same"
}
}
}
if upstreamValue == nil && localValue == nil {
upstreamValue = "same"
}
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
hasDifference = true
}
upstreamValues[channel.name] = upstreamValue
}
shouldInclude := false
if localValue != nil {
if hasDifference {
shouldInclude = true
}
} else {
if hasUpstreamValue {
shouldInclude = true
}
}
if shouldInclude {
if differences[modelName] == nil {
differences[modelName] = make(map[string]dto.DifferenceItem)
}
differences[modelName][ratioType] = dto.DifferenceItem{
Current: localValue,
Upstreams: upstreamValues,
}
}
}
}
channelHasDiff := make(map[string]bool)
for _, ratioMap := range differences {
for _, item := range ratioMap {
for chName, val := range item.Upstreams {
if val != nil && val != "same" {
channelHasDiff[chName] = true
}
}
}
}
for modelName, ratioMap := range differences {
for ratioType, item := range ratioMap {
for chName := range item.Upstreams {
if !channelHasDiff[chName] {
delete(item.Upstreams, chName)
}
}
differences[modelName][ratioType] = item
}
}
return differences
}
func GetSyncableChannels(c *gin.Context) {
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var syncableChannels []dto.SyncableChannel
for _, channel := range channels {
if channel.GetBaseURL() != "" {
syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: channel.Id,
Name: channel.Name,
BaseURL: channel.GetBaseURL(),
Status: channel.Status,
})
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": syncableChannels,
})
}
+3 -3
View File
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
+2
View File
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
case constant.TaskPlatformKling:
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
default:
common.SysLog("未知平台")
}
+140
View File
@@ -0,0 +1,140 @@
package controller
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
)
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
})
if err != nil {
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
}
var responseItem map[string]interface{}
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
}
code, _ := responseItem["code"].(float64)
if code != 0 {
return fmt.Errorf("video task fetch failed for task %s", taskId)
}
data, ok := responseItem["data"].(map[string]interface{})
if !ok {
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
return fmt.Errorf("video task data format error for task %s", taskId)
}
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
if status, ok := data["task_status"].(string); ok {
switch status {
case "submitted", "queued":
task.Status = model.TaskStatusSubmitted
case "processing":
task.Status = model.TaskStatusInProgress
case "succeed":
task.Status = model.TaskStatusSuccess
task.Progress = "100%"
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
task.FailReason = url
} else {
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
}
case "failed":
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if reason, ok := data["fail_reason"].(string); ok {
task.FailReason = reason
}
}
}
// If task failed, refund quota
if task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
task.Data = responseBody
if err := task.Update(); err != nil {
common.SysError("UpdateVideoTask task error: " + err.Error())
}
return nil
}
+6 -8
View File
@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
payType := "wxpay"
if req.PaymentMethod == "zfb" {
payType = "alipay"
}
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = "wxpay"
if !setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
return
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
Type: req.PaymentMethod,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
+3
View File
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if setting.DefaultUseAutoGroup {
token.Group = "auto"
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
+1
View File
@@ -53,6 +53,7 @@ type GeneralOpenAIRequest struct {
Modalities json.RawMessage `json:"modalities,omitempty"`
Audio json.RawMessage `json:"audio,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"` // ali
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
// OpenRouter Params
+49
View File
@@ -0,0 +1,49 @@
package dto
// UpstreamDTO 提交到后端同步倍率的上游渠道信息
// Endpoint 可以为空,后端会默认使用 /api/ratio_config
// BaseURL 必须以 http/https 开头,不要以 / 结尾
// 例如: https://api.example.com
// Endpoint: /api/ratio_config
// 提交示例:
// {
// "name": "openai",
// "base_url": "https://api.openai.com",
// "endpoint": "/ratio_config"
// }
type UpstreamDTO struct {
Name string `json:"name" binding:"required"`
BaseURL string `json:"base_url" binding:"required"`
Endpoint string `json:"endpoint"`
}
type UpstreamRequest struct {
ChannelIDs []int64 `json:"channel_ids"`
Timeout int `json:"timeout"`
}
// TestResult 上游测试连通性结果
type TestResult struct {
Name string `json:"name"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
}
// DifferenceItem 差异项
// Current 为本地值,可能为 nil
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
type DifferenceItem struct {
Current interface{} `json:"current"`
Upstreams map[string]interface{} `json:"upstreams"`
}
// SyncableChannel 可同步的渠道信息(base_url 不为空)
type SyncableChannel struct {
ID int `json:"id"`
Name string `json:"name"`
BaseURL string `json:"base_url"`
Status int `json:"status"`
}
+47
View File
@@ -0,0 +1,47 @@
package dto
type VideoRequest struct {
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
Width int `json:"width" example:"512"` // Video width
Height int `json:"height" example:"512"` // Video height
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
User string `json:"user,omitempty" example:"user-1234"` // User identifier
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
}
// VideoResponse 视频生成提交任务后的响应
type VideoResponse struct {
TaskId string `json:"task_id"`
Status string `json:"status"`
}
// VideoTaskResponse 查询视频生成任务状态的响应
type VideoTaskResponse struct {
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
Status string `json:"status" example:"succeeded"` // 任务状态
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
}
// VideoTaskMetadata 视频任务元数据
type VideoTaskMetadata struct {
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
Fps int `json:"fps" example:"30"` // 实际帧率
Width int `json:"width" example:"512"` // 实际宽度
Height int `json:"height" example:"512"` // 实际高度
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
}
// VideoTaskError 视频任务错误信息
type VideoTaskError struct {
Code int `json:"code"`
Message string `json:"message"`
}
+2 -2
View File
@@ -12,7 +12,7 @@ import (
"one-api/model"
"one-api/router"
"one-api/service"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"os"
"strconv"
@@ -74,7 +74,7 @@ func main() {
}
// Initialize model settings
operation_setting.InitRatioSettings()
ratio_setting.InitRatioSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
+22 -5
View File
@@ -11,6 +11,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -48,9 +49,11 @@ func Distribute() func(c *gin.Context) {
return
}
// check group in common.GroupRatio
if !setting.ContainsGroupRatio(tokenGroup) {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
var selectGroup string
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
showGroup := userGroup
if userGroup == "auto" {
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
}
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
@@ -162,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeKlingFetchByID {
shouldSelectChannel = false
} else {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
c.Set("platform", string(constant.TaskPlatformKling))
c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini
+40 -1
View File
@@ -5,10 +5,13 @@ import (
"fmt"
"math/rand"
"one-api/common"
"one-api/setting"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
var channel *Channel
var err error
selectGroup := group
if group == "auto" {
if len(setting.AutoGroups) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
for _, autoGroup := range setting.AutoGroups {
if common.DebugEnabled {
println("autoGroup:", autoGroup)
}
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
if channel == nil {
continue
} else {
c.Set("auto_group", autoGroup)
selectGroup = autoGroup
if common.DebugEnabled {
println("selectGroup:", selectGroup)
}
break
}
}
} else {
channel, err = getRandomSatisfiedChannel(group, model, retry)
if err != nil {
return nil, group, err
}
}
if channel == nil {
return nil, group, errors.New("channel not found")
}
return channel, selectGroup, nil
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
+36
View File
@@ -597,3 +597,39 @@ func CountAllTags() (int64, error) {
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
return total, err
}
// Get channels of specified type with pagination
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
var channels []*Channel
order := "priority desc"
if idSort {
order = "id desc"
}
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
return channels, err
}
// Count channels of specific type
func CountChannelsByType(channelType int) (int64, error) {
var count int64
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
return count, err
}
// Return map[type]count for all channels
func CountChannelsGroupByType() (map[int64]int64, error) {
type result struct {
Type int64 `gorm:"column:type"`
Count int64 `gorm:"column:count"`
}
var results []result
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
if err != nil {
return nil, err
}
counts := make(map[int64]int64)
for _, r := range results {
counts[r.Type] = r.Count
}
return counts, nil
}
+9
View File
@@ -46,6 +46,15 @@ func initCol() {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
} else {
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
if common.UsingPostgreSQL {
logGroupCol = `"group"`
logKeyCol = `"key"`
} else {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
}
// log sql type and database type
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
+26 -13
View File
@@ -5,6 +5,7 @@ import (
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -76,6 +77,9 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -94,13 +98,13 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -123,6 +127,7 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -192,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -261,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
case "DefaultUseAutoGroup":
setting.DefaultUseAutoGroup = boolValue
case "ExposeRatioEnabled":
ratio_setting.SetExposeRatioEnabled(boolValue)
}
}
switch key {
@@ -287,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
case "AutoGroups":
err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
@@ -352,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = operation_setting.UpdateModelRatioByJSONString(value)
err = ratio_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = setting.UpdateGroupRatioByJSONString(value)
err = ratio_setting.UpdateGroupRatioByJSONString(value)
case "GroupGroupRatio":
err = setting.UpdateGroupGroupRatioByJSONString(value)
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = operation_setting.UpdateCompletionRatioByJSONString(value)
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = operation_setting.UpdateModelPriceByJSONString(value)
err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = operation_setting.UpdateCacheRatioByJSONString(value)
err = ratio_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
@@ -381,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case "PayMethods":
err = setting.UpdatePayMethodsByJsonString(value)
}
return err
}
+4 -4
View File
@@ -2,7 +2,7 @@ package model
import (
"one-api/common"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"sync"
"time"
)
@@ -65,14 +65,14 @@ func updatePricing() {
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
modelRatio, _ := operation_setting.GetModelRatio(model)
modelRatio, _ := ratio_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)
+19 -2
View File
@@ -2,11 +2,12 @@ package model
import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
const (
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
// check if there's any data to update
hasData := false
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
if len(batchUpdateStores[i]) > 0 {
hasData = true
batchUpdateLocks[i].Unlock()
break
}
batchUpdateLocks[i].Unlock()
}
if !hasData {
return
}
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
}
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
}
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}()
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+2
View File
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseResultUrl(resp map[string]any) (string, error)
}
+36 -40
View File
@@ -454,6 +454,7 @@ type ClaudeResponseInfo struct {
Model string
ResponseText strings.Builder
Usage *dto.Usage
Done bool
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
@@ -461,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
// message_start, 获取usage
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != "" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
}
} else if claudeResponse.Type == "message_delta" {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
// 最终的usage获取
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
// 判断是否完整
claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
@@ -506,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
@@ -544,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysError("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 说明流模式建立失败,可能为官方出错
if claudeInfo.Usage.PromptTokens == 0 {
//usage.PromptTokens = info.PromptTokens
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
//
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
@@ -619,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
if requestMode == RequestModeCompletion {
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
}
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
+3 -3
View File
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
+1 -1
View File
@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
}
+5 -7
View File
@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
var currentEvent string
var currentData string
var usage dto.Usage
var usage = &dto.Usage{}
for scanner.Scan() {
line := scanner.Text()
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
if line == "" {
if currentEvent != "" && currentData != "" {
// handle last event
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
currentEvent = ""
currentData = ""
}
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
// Last event
if currentEvent != "" && currentData != "" {
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
}
if err := scanner.Err(); err != nil {
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
helper.Done(c)
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
return nil, &usage
return nil, usage
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
+1 -8
View File
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
return true
})
helper.Done(c)
err := resp.Body.Close()
if err != nil {
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
usage.CompletionTokens += nodeToken
return nil, usage
+3 -3
View File
@@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.OriginModelName, "-thinking-") {
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
+1
View File
@@ -140,6 +140,7 @@ type GeminiChatGenerationConfig struct {
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
}
type GeminiChatCandidate struct {
+19 -18
View File
@@ -9,6 +9,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
@@ -35,23 +36,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// 检查是否有候选响应
if len(geminiResponse.Candidates) == 0 {
return nil, &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}
}
// 计算使用量(基于 UsageMetadata
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
@@ -88,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
helper.SetEventStreamHeaders(c)
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse)
@@ -102,13 +92,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
if part.Text != "" {
responseText.WriteString(part.Text)
}
}
}
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
@@ -121,7 +114,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
}
// 直接发送 GeminiChatResponse 响应
err = helper.ObjectData(c, geminiResponse)
err = helper.StringData(c, data)
if err != nil {
common.LogError(c, err.Error())
}
@@ -135,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
}
}
// 计算最终使用量
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
}
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
//helper.Done(c)
+58 -64
View File
@@ -39,11 +39,45 @@ var geminiSupportedMimeTypes = map[string]bool{
// Gemini 允许的思考预算范围
const (
pro25MinBudget = 128
pro25MaxBudget = 32768
flash25MaxBudget = 24576
pro25MinBudget = 128
pro25MaxBudget = 32768
flash25MaxBudget = 24576
flash25LiteMinBudget = 512
flash25LiteMaxBudget = 24576
)
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
func clampThinkingBudget(modelName string, budget int) int {
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
if is25FlashLite {
if budget < flash25LiteMinBudget {
return flash25LiteMinBudget
}
if budget > flash25LiteMaxBudget {
return flash25LiteMaxBudget
}
} else if isNew25Pro {
if budget < pro25MinBudget {
return pro25MinBudget
}
if budget > pro25MaxBudget {
return pro25MaxBudget
}
} else { // 其他模型
if budget < 0 {
return 0
}
if budget > flash25MaxBudget {
return flash25MaxBudget
}
}
return budget
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
@@ -65,49 +99,31 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.OriginModelName, "-thinking-") {
parts := strings.SplitN(info.OriginModelName, "-thinking-", 2)
modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
if strings.Contains(modelName, "-thinking-") {
parts := strings.SplitN(modelName, "-thinking-", 2)
if len(parts) == 2 && parts[1] != "" {
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
// 从模型名称成功解析预算
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
if isNew25Pro {
// 新的2.5pro模型:ThinkingBudget范围为128-32768
if budgetTokens < pro25MinBudget {
budgetTokens = pro25MinBudget
} else if budgetTokens > pro25MaxBudget {
budgetTokens = pro25MaxBudget
}
} else {
// 其他模型:ThinkingBudget范围为0-24576
if budgetTokens < 0 {
budgetTokens = 0
} else if budgetTokens > flash25MaxBudget {
budgetTokens = flash25MaxBudget
}
}
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(budgetTokens),
ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true,
}
}
// 如果解析失败,则不设置ThinkingConfig,静默处理
}
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 保留旧逻辑以兼容
// 硬编码不支持 ThinkingBudget 的旧模型
} else if strings.HasSuffix(modelName, "-thinking") {
unsupportedModels := []string{
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-preview-03-25",
}
isUnsupported := false
for _, unsupportedModel := range unsupportedModels {
if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
if strings.HasPrefix(modelName, unsupportedModel) {
isUnsupported = true
break
}
@@ -119,39 +135,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
} else {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
// 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围)
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
if isNew25Pro {
// 新的2.5pro模型:ThinkingBudget范围为128-32768
if budgetTokens == 0 || budgetTokens < 128 {
budgetTokens = 128
} else if budgetTokens > 32768 {
budgetTokens = 32768
}
} else {
// 其他模型:ThinkingBudget范围为0-24576
if budgetTokens == 0 || budgetTokens > 24576 {
budgetTokens = 24576
}
}
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(int(budgetTokens)),
ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true,
}
}
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
// 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
if !isNew25Pro {
// 只有非新2.5pro模型才支持-nothinking
} else if strings.HasSuffix(modelName, "-nothinking") {
if !isNew25Pro && !is25FlashLite {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
@@ -324,7 +315,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
url := part.GetImageMedia().Url
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
}
parts = append(parts, GeminiPart{
@@ -382,7 +374,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if content.Role == "assistant" {
content.Role = "model"
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
if len(content.Parts) > 0 {
geminiRequest.Contents = append(geminiRequest.Contents, content)
}
}
if len(system_content) > 0 {
+9 -9
View File
@@ -8,7 +8,6 @@ import (
"math"
"mime/multipart"
"net/http"
"path/filepath"
"one-api/common"
"one-api/constant"
"one-api/dto"
@@ -16,6 +15,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"os"
"path/filepath"
"strings"
"github.com/bytedance/gopkg/util/gopool"
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == common.ChannelTypeDeepSeek {
@@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
StatusCode: resp.StatusCode,
}, nil
}
forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer resp.Body.Close()
usage := &dto.Usage{}
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
defer reqFp.Close()
defer reqFp.Close()
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
+1 -1
View File
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
+1 -1
View File
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
+1 -1
View File
@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
+312
View File
@@ -0,0 +1,312 @@
package kling
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
// ============================
// Request / Response structures
// ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct {
Prompt string `json:"prompt,omitempty"`
Image string `json:"image,omitempty"`
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Model string `json:"model,omitempty"`
ModelName string `json:"model_name,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
TaskID string `json:"task_id"`
} `json:"data"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
accessKey string
secretKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
// apiKey format: "access_key,secret_key"
keyParts := strings.Split(info.ApiKey, ",")
if len(keyParts) == 2 {
a.accessKey = strings.TrimSpace(keyParts[0])
a.secretKey = strings.TrimSpace(keyParts[1])
}
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := "generate"
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("kling_request", req)
return nil
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
token, err := a.createJWTToken()
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return nil
}
// BuildRequestBody converts request into Kling specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("kling_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
body := a.convertToRequestPayload(&req)
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
// Attempt Kling response parse first.
var kResp responsePayload
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
return kResp.Data.TaskID, responseBody, nil
}
// Fallback generic task response.
var generic dto.TaskResponse[string]
if err := json.Unmarshal(responseBody, &generic); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if !generic.IsSuccess() {
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
return generic.Data, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
token, err := a.createJWTTokenWithKey(key)
if err != nil {
token = key
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req = req.WithContext(ctx)
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "kling"
}
// ============================
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
r := &requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
Model: req.Model,
ModelName: req.Model,
CfgScale: 0.5,
}
if r.Model == "" {
r.Model = "kling-v1"
r.ModelName = "kling-v1"
}
return r
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
switch size {
case "1024x1024", "512x512":
return "1:1"
case "1280x720", "1920x1080":
return "16:9"
case "720x1280", "1080x1920":
return "9:16"
default:
return "1:1"
}
}
func defaultString(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
func defaultInt(v int, def int) int {
if v == 0 {
return def
}
return v
}
// ============================
// JWT helpers
// ============================
func (a *TaskAdaptor) createJWTToken() (string, error) {
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
parts := strings.Split(apiKey, ",")
if len(parts) != 2 {
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
}
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
if accessKey == "" || secretKey == "" {
return "", fmt.Errorf("access key and secret key are required")
}
now := time.Now().Unix()
claims := jwt.MapClaims{
"iss": accessKey,
"exp": now + 1800, // 30 minutes
"nbf": now - 5,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["typ"] = "JWT"
return token.SignedString([]byte(secretKey))
}
// ParseResultUrl 提取视频任务结果的 url
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
data, ok := resp["data"].(map[string]any)
if !ok {
return "", fmt.Errorf("data field not found or invalid")
}
taskResult, ok := data["task_result"].(map[string]any)
if !ok {
return "", fmt.Errorf("task_result field not found or invalid")
}
videos, ok := taskResult["videos"].([]interface{})
if !ok || len(videos) == 0 {
return "", fmt.Errorf("videos field not found or empty")
}
video, ok := videos[0].(map[string]interface{})
if !ok {
return "", fmt.Errorf("video item invalid")
}
url, ok := video["url"].(string)
if !ok || url == "" {
return "", fmt.Errorf("url field not found or invalid")
}
return url, nil
}
+4
View File
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
ChannelType int
}
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
return "", nil // todo implement this method if needed
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
}
+1 -1
View File
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = tencentHandler(c, resp)
}
+23 -11
View File
@@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// suffix -thinking and -nothinking
if strings.HasSuffix(info.OriginModelName, "-thinking") {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
@@ -123,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
model = v
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
model,
suffix,
), nil
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
adc.ProjectID,
model,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
model,
suffix,
), nil
}
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
+1 -1
View File
@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
})
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
+2 -4
View File
@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
relayInfo.IsStream = true
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
@@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
+38 -7
View File
@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
}
const (
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
RelayFormatGemini = "gemini"
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
RelayFormatGemini = "gemini"
RelayFormatOpenAIResponses = "openai_responses"
RelayFormatOpenAIAudio = "openai_audio"
RelayFormatOpenAIImage = "openai_image"
RelayFormatRerank = "rerank"
RelayFormatEmbedding = "embedding"
)
type RerankerInfo struct {
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
info.RelayFormat = RelayFormatRerank
info.RerankerInfo = &RerankerInfo{
Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(),
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
return info
}
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIAudio
return info
}
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatEmbedding
return info
}
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeResponses
info.RelayFormat = RelayFormatOpenAIResponses
info.SupportStreamOptions = false
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
return info
}
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatGemini
info.ShouldIncludeUsage = false
return info
}
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIImage
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
@@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
// responses 模式不支持 StreamOptions
if relayconstant.RelayModeResponses == info.RelayMode {
info.SupportStreamOptions = false
}
return info
}
+13
View File
@@ -38,6 +38,9 @@ const (
RelayModeSunoFetchByID
RelayModeSunoSubmit
RelayModeKlingFetchByID
RelayModeKlingSubmit
RelayModeRerank
RelayModeResponses
@@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
}
return relayMode
}
func Path2RelayKling(method, path string) int {
relayMode := RelayModeUnknown
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
relayMode = RelayModeKlingSubmit
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
relayMode = RelayModeKlingFetchByID
}
return relayMode
}
@@ -15,7 +15,7 @@ import (
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
return token
}
@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
}
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest)
relayInfo.PromptTokens = promptToken
@@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
return sensitiveWords, err
}
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
@@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
}
inputText := strings.Join(inputTexts, "\n")
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
info.PromptTokens = inputTokens
return inputTokens, err
return inputTokens
}
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
}
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoGemini(c)
// 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo)
@@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
}
@@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens, err := getGeminiInputTokens(req, relayInfo)
promptTokens := getGeminiInputTokens(req, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
}
@@ -155,14 +155,33 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
if common.DebugEnabled {
println("Gemini request body: %s", string(requestBody))
}
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
+39 -1
View File
@@ -4,12 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
common2 "one-api/common"
"one-api/dto"
"one-api/relay/common"
"github.com/gin-gonic/gin"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
info.UpstreamModelName = currentModel
}
}
if request != nil {
switch info.RelayFormat {
case common.RelayFormatGemini:
// Gemini 模型映射
case common.RelayFormatClaude:
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
claudeRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIResponses:
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
openAIResponsesRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIAudio:
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
openAIAudioRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIImage:
if imageRequest, ok := request.(*dto.ImageRequest); ok {
imageRequest.Model = info.UpstreamModelName
}
case common.RelayFormatRerank:
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
rerankRequest.Model = info.UpstreamModelName
}
case common.RelayFormatEmbedding:
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
embeddingRequest.Model = info.UpstreamModelName
}
default:
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
openAIRequest.Model = info.UpstreamModelName
} else {
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
}
}
}
return nil
}
+52 -22
View File
@@ -5,12 +5,16 @@ import (
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
type GroupRatioInfo struct {
GroupRatio float64
GroupSpecialRatio float64
}
type PriceData struct {
ModelPrice float64
ModelRatio float64
@@ -18,23 +22,50 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
GroupRatio float64
UserGroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
}
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
groupRatioInfo := GroupRatioInfo{
GroupRatio: 1.0, // default ratio
GroupSpecialRatio: -1,
}
// check auto group
autoGroup, exists := ctx.Get("auto_group")
if exists {
if common.DebugEnabled {
println(fmt.Sprintf("final group: %s", autoGroup))
}
relayInfo.Group = autoGroup.(string)
}
// check user group special ratio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
// user group special ratio
groupRatioInfo.GroupSpecialRatio = userGroupRatio
groupRatioInfo.GroupRatio = userGroupRatio
} else {
// normal group ratio
groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.Group)
}
return groupRatioInfo
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
if ok {
groupRatio = userGroupRatio
}
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
groupRatioInfo := HandleGroupRatio(c, info)
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -47,7 +78,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
preConsumedTokens = promptTokens + maxTokens
}
var success bool
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
if !success {
acceptUnsetRatio := false
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
@@ -60,22 +91,21 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
}
}
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatio: groupRatio,
UserGroupRatio: userGroupRatio,
GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
@@ -91,11 +121,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
}
func ContainPriceOrRatio(modelName string) bool {
_, ok := operation_setting.GetModelPrice(modelName, false)
_, ok := ratio_setting.GetModelPrice(modelName, false)
if ok {
return true
}
_, ok = operation_setting.GetModelRatio(modelName)
_, ok = ratio_setting.GetModelRatio(modelName)
if ok {
return true
}
@@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
}
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoImage(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
@@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
imageRequest.Model = relayInfo.UpstreamModelName
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
@@ -162,7 +160,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
+7 -7
View File
@@ -15,7 +15,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -174,17 +174,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := setting.GetGroupRatio(group)
groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
@@ -480,17 +480,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := setting.GetGroupRatio(group)
groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
+10 -12
View File
@@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if textRequest.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
}
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
if textRequest.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
}
if setting.ShouldCheckPromptSensitive() {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
@@ -107,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens,如果上下文中已经存在,则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
@@ -252,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
default:
err = errors.New("unknown relay mode")
promptTokens = 0
@@ -361,9 +360,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
userGroupRatio := priceData.UserGroupRatio
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -511,7 +509,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
+3
View File
@@ -22,6 +22,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
// return &aiproxy.Adaptor{}
case commonconstant.TaskPlatformSuno:
return &suno.TaskAdaptor{}
case commonconstant.TaskPlatformKling:
return &kling.TaskAdaptor{}
}
return nil
}
+33 -8
View File
@@ -15,8 +15,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
)
/*
@@ -38,9 +37,12 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
if platform == constant.TaskPlatformKling {
modelName = relayInfo.OriginModelName
}
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -49,7 +51,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
groupRatio := setting.GetGroupRatio(relayInfo.Group)
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
@@ -137,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
relayInfo.ConsumeQuota = true
// insert task
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
task := model.InitTask(platform, relayInfo)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
task.Action = relayInfo.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
@@ -150,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
@@ -226,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return
}
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
return
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,
@@ -14,12 +14,10 @@ import (
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
for _, document := range rerankRequest.Documents {
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
if err == nil {
token += tkm
}
tkm := service.CountTokenInput(document, rerankRequest.Model)
token += tkm
}
return token
}
@@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest)
relayInfo.PromptTokens = promptToken
@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
return sensitiveWords, err
}
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
inputTokens := service.CountTokenInput(req.Input, req.Model)
info.PromptTokens = inputTokens
return inputTokens, err
return inputTokens
}
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
}
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
}
req.Model = relayInfo.UpstreamModelName
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens, err := getInputTokens(req, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
}
promptTokens := getInputTokens(req, relayInfo)
c.Set("prompt_tokens", promptTokens)
}
+6 -37
View File
@@ -6,12 +6,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//isModelMapped = true
}
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
//if constant.ShouldCheckPromptSensitive() {
// err = checkRequestSensitive(textRequest, relayInfo)
// if err != nil {
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
// }
//}
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
//// count messages token error 计算promptTokens错误
//if err != nil {
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
//}
//
if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
//}
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
relayInfo.UsePrice = true
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
userQuota, priceData, "")
return nil
}
+7
View File
@@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
userRoute := apiRouter.Group("/user")
{
@@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
}
ratioSyncRoute := apiRouter.Group("/ratio_sync")
ratioSyncRoute.Use(middleware.RootAuth())
{
ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
}
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
{
+1
View File
@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router)
SetDashboardRouter(router)
SetRelayRouter(router)
SetVideoRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
+17
View File
@@ -0,0 +1,17 @@
package router
import (
"one-api/controller"
"one-api/middleware"
"github.com/gin-gonic/gin"
)
func SetVideoRouter(router *gin.Engine) {
videoV1Router := router.Group("/v1")
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
videoV1Router.POST("/video/generations", controller.RelayTask)
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
}
}
+2
View File
@@ -59,6 +59,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
return true
case "billing_not_active":
return true
case "pre_consume_token_quota_failed":
return true
}
switch err.Error.Type {
case "insufficient_quota":
+10 -6
View File
@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
}
openAIError := dto.OpenAIError{
Message: text,
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
if !strings.HasPrefix(lowerText, "get file base64 from url") {
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
}
claudeError := dto.ClaudeError{
Message: text,
+98 -1
View File
@@ -4,8 +4,10 @@ import (
"encoding/base64"
"fmt"
"io"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
)
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
@@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
mimeType := resp.Header.Get("Content-Type")
if len(strings.Split(mimeType, ";")) > 1 {
// If Content-Type has parameters, take the first part
mimeType = strings.Split(mimeType, ";")[0]
}
if mimeType == "application/octet-stream" {
if common.DebugEnabled {
println("MIME type is application/octet-stream, trying to guess from URL or filename")
}
// try to guess the MIME type from the url last segment
urlParts := strings.Split(url, "/")
if len(urlParts) > 0 {
lastSegment := urlParts[len(urlParts)-1]
if strings.Contains(lastSegment, ".") {
// Extract the file extension
filename := strings.Split(lastSegment, ".")
if len(filename) > 1 {
ext := strings.ToLower(filename[len(filename)-1])
// Guess MIME type based on file extension
mimeType = GetMimeTypeByExtension(ext)
}
}
} else {
// try to guess the MIME type from the file extension
fileName := resp.Header.Get("Content-Disposition")
if fileName != "" {
// Extract the filename from the Content-Disposition header
parts := strings.Split(fileName, ";")
for _, part := range parts {
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
// Remove quotes if present
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
fileName = fileName[1 : len(fileName)-1]
}
// Guess MIME type based on file extension
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
mimeType = GetMimeTypeByExtension(ext)
}
break
}
}
}
}
}
return &dto.LocalFileData{
Base64Data: base64Data,
MimeType: resp.Header.Get("Content-Type"),
MimeType: mimeType,
Size: int64(len(fileBytes)),
}, nil
}
func GetMimeTypeByExtension(ext string) string {
// Convert to lowercase for case-insensitive comparison
ext = strings.ToLower(ext)
switch ext {
// Text files
case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
return "text/plain"
// Image files
case "jpg", "jpeg":
return "image/jpeg"
case "png":
return "image/png"
case "gif":
return "image/gif"
// Audio files
case "mp3":
return "audio/mp3"
case "wav":
return "audio/wav"
case "mpeg":
return "audio/mpeg"
// Video files
case "mp4":
return "video/mp4"
case "wmv":
return "video/wmv"
case "flv":
return "video/flv"
case "mov":
return "video/mov"
case "mpg":
return "video/mpg"
case "avi":
return "video/avi"
case "mpegps":
return "video/mpegps"
// Document files
case "pdf":
return "application/pdf"
default:
return "application/octet-stream" // Default for unknown types
}
}
+39 -37
View File
@@ -3,6 +3,7 @@ package service
import (
"errors"
"fmt"
"log"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -10,7 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strings"
"time"
@@ -45,9 +46,9 @@ func calculateAudioQuota(info QuotaInfo) int {
return int(quota.IntPart())
}
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
groupRatio := decimal.NewFromFloat(info.GroupRatio)
modelRatio := decimal.NewFromFloat(info.ModelRatio)
@@ -93,12 +94,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := setting.GetGroupRatio(relayInfo.Group)
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
groupRatio = userGroupRatio
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
modelRatio, _ := ratio_setting.GetModelRatio(modelName)
autoGroup, exists := ctx.Get("auto_group")
if exists {
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
log.Printf("final group ratio: %f", groupRatio)
relayInfo.Group = autoGroup.(string)
}
actualGroupRatio := groupRatio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
actualGroupRatio = userGroupRatio
}
modelRatio, _ := operation_setting.GetModelRatio(modelName)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -112,7 +122,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -134,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -145,15 +154,15 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
actualGroupRatio = userGroupRatio
}
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -166,7 +175,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
ModelName: modelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: actualGroupRatio,
GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -198,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -214,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
userGroupRatio := priceData.UserGroupRatio
cacheRatio := priceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
@@ -265,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio)
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -281,21 +289,15 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
actualGroupRatio = userGroupRatio
}
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -308,7 +310,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: actualGroupRatio,
GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -348,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
+17 -27
View File
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err := CountTokenInput(countStr, request.Model)
toolTokens := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
// Count tokens in system message
if request.System != "" {
systemTokens, err := CountTokenInput(request.System, model)
systemTokens := CountTokenInput(request.System, model)
if err != nil {
return 0, err
}
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
msgTokens := CountTextToken(request.Session.Instructions, model)
textToken += msgTokens
}
case dto.RealtimeEventResponseAudioDelta:
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
tkm, err := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
tkm := CountTextToken(request.Delta, model)
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
case "message":
for _, content := range request.Item.Content {
if content.Type == "input_text" {
tokens, err := CountTextToken(content.Text, model)
if err != nil {
return 0, 0, err
}
tokens := CountTextToken(content.Text, model)
textToken += tokens
}
}
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
toolTokens := CountTokenInput(tool, model)
textToken += 8
textToken += toolTokens
}
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
return tokenNum, nil
}
func CountTokenInput(input any, model string) (int, error) {
func CountTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
return CountTextToken(v, model)
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
tkm := CountTokenInput(message.Delta.GetContentString(), model)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
tkm, _ := CountTokenInput(tool.Function.Name, model)
tkm := CountTokenInput(tool.Function.Name, model)
tokens += tkm
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
tkm = CountTokenInput(tool.Function.Arguments, model)
tokens += tkm
}
}
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
func CountTTSToken(text string, model string) (int, error) {
func CountTTSToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil
return utf8.RuneCountInString(text)
} else {
return CountTextToken(text, model)
}
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
//}
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
func CountTextToken(text string, model string) (int, error) {
var err error
func CountTextToken(text string, model string) int {
if text == "" {
return 0
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err
return getTokenNum(tokenEncoder, text)
}
+3 -3
View File
@@ -16,13 +16,13 @@ import (
// return 0, errors.New("unknown relay mode")
//}
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
ctkm, err := CountTextToken(responseText, modeName)
ctkm := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err
return usage
}
func ValidUsage(usage *dto.Usage) bool {
+31
View File
@@ -0,0 +1,31 @@
package setting
import "encoding/json"
var AutoGroups = []string{
"default",
}
var DefaultUseAutoGroup = false
func ContainsAutoGroup(group string) bool {
for _, autoGroup := range AutoGroups {
if autoGroup == group {
return true
}
}
return false
}
func UpdateAutoGroupsByJsonString(jsonString string) error {
AutoGroups = make([]string, 0)
return json.Unmarshal([]byte(jsonString), &AutoGroups)
}
func AutoGroups2JsonString() string {
jsonBytes, err := json.Marshal(AutoGroups)
if err != nil {
return "[]"
}
return string(jsonBytes)
}
+37
View File
@@ -1,8 +1,45 @@
package setting
import "encoding/json"
var PayAddress = ""
var CustomCallbackAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var MinTopUp = 1
var PayMethods = []map[string]string{
{
"name": "支付宝",
"color": "rgba(var(--semi-blue-5), 1)",
"type": "alipay",
},
{
"name": "微信",
"color": "rgba(var(--semi-green-5), 1)",
"type": "wxpay",
},
}
func UpdatePayMethodsByJsonString(jsonString string) error {
PayMethods = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &PayMethods)
}
func PayMethods2JsonString() string {
jsonBytes, err := json.Marshal(PayMethods)
if err != nil {
return "[]"
}
return string(jsonBytes)
}
func ContainsPayMethod(method string) bool {
for _, payMethod := range PayMethods {
if payMethod["type"] == method {
return true
}
}
return false
}
@@ -1,4 +1,4 @@
package operation_setting
package ratio_setting
import (
"encoding/json"
@@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
cacheRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
// GetCacheRatio returns the cache ratio for a model
@@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) {
}
return ratio, true
}
func GetCacheRatioCopy() map[string]float64 {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(cacheRatioMap))
for k, v := range cacheRatioMap {
copyMap[k] = v
}
return copyMap
}
+17
View File
@@ -0,0 +1,17 @@
package ratio_setting
import "sync/atomic"
var exposeRatioEnabled atomic.Bool
func init() {
exposeRatioEnabled.Store(false)
}
func SetExposeRatioEnabled(enabled bool) {
exposeRatioEnabled.Store(enabled)
}
func IsExposeRatioEnabled() bool {
return exposeRatioEnabled.Load()
}
+55
View File
@@ -0,0 +1,55 @@
package ratio_setting
import (
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
)
const exposedDataTTL = 30 * time.Second
type exposedCache struct {
data gin.H
expiresAt time.Time
}
var (
exposedData atomic.Value
rebuildMu sync.Mutex
)
func InvalidateExposedDataCache() {
exposedData.Store((*exposedCache)(nil))
}
func cloneGinH(src gin.H) gin.H {
dst := make(gin.H, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func GetExposedData() gin.H {
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
return cloneGinH(c.data)
}
rebuildMu.Lock()
defer rebuildMu.Unlock()
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
return cloneGinH(c.data)
}
newData := gin.H{
"model_ratio": GetModelRatioCopy(),
"completion_ratio": GetCompletionRatioCopy(),
"cache_ratio": GetCacheRatioCopy(),
"model_price": GetModelPriceCopy(),
}
exposedData.Store(&exposedCache{
data: newData,
expiresAt: time.Now().Add(exposedDataTTL),
})
return cloneGinH(newData)
}
@@ -1,4 +1,4 @@
package setting
package ratio_setting
import (
"encoding/json"
@@ -1,8 +1,9 @@
package operation_setting
package ratio_setting
import (
"encoding/json"
"one-api/common"
"one-api/setting/operation_setting"
"strings"
"sync"
)
@@ -316,7 +317,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
@@ -344,7 +349,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
err := json.Unmarshal([]byte(jsonStr), &modelRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
// 处理带有思考预算的模型名称,方便统一定价
@@ -366,7 +375,7 @@ func GetModelRatio(name string) (float64, bool) {
}
ratio, ok := modelRatioMap[name]
if !ok {
return 37.5, SelfUseModeEnabled
return 37.5, operation_setting.SelfUseModeEnabled
}
return ratio, true
}
@@ -404,7 +413,11 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
err := json.Unmarshal([]byte(jsonStr), &CompletionRatio)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
func GetCompletionRatio(name string) float64 {
@@ -608,3 +621,33 @@ func GetImageRatio(name string) (float64, bool) {
}
return ratio, true
}
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelRatioMap))
for k, v := range modelRatioMap {
copyMap[k] = v
}
return copyMap
}
func GetModelPriceCopy() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelPriceMap))
for k, v := range modelPriceMap {
copyMap[k] = v
}
return copyMap
}
func GetCompletionRatioCopy() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
copyMap := make(map[string]float64, len(CompletionRatio))
for k, v := range CompletionRatio {
copyMap[k] = v
}
return copyMap
}
+7
View File
@@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool {
_, ok := userUsableGroups[groupName]
return ok
}
func GetUsableGroupDescription(groupName string) string {
if desc, ok := userUsableGroups[groupName]; ok {
return desc
}
return groupName
}
@@ -0,0 +1,143 @@
import React, { useState } from 'react';
import {
Modal,
Transfer,
Input,
Space,
Checkbox,
Avatar,
Highlight,
} from '@douyinfe/semi-ui';
import { IconClose } from '@douyinfe/semi-icons';
const CHANNEL_STATUS_CONFIG = {
1: { color: 'green', text: '启用' },
2: { color: 'red', text: '禁用' },
3: { color: 'amber', text: '自禁' },
default: { color: 'grey', text: '未知' }
};
const getChannelStatusConfig = (status) => {
return CHANNEL_STATUS_CONFIG[status] || CHANNEL_STATUS_CONFIG.default;
};
export default function ChannelSelectorModal({
t,
visible,
onCancel,
onOk,
allChannels = [],
selectedChannelIds = [],
setSelectedChannelIds,
channelEndpoints,
updateChannelEndpoint,
}) {
const [searchText, setSearchText] = useState('');
const ChannelInfo = ({ item, showEndpoint = false, isSelected = false }) => {
const channelId = item.key || item.value;
const currentEndpoint = channelEndpoints[channelId];
const baseUrl = item._originalData?.base_url || '';
const status = item._originalData?.status || 0;
const statusConfig = getChannelStatusConfig(status);
return (
<>
<Avatar color={statusConfig.color} size="small">
{statusConfig.text}
</Avatar>
<div className="info">
<div className="name">
{isSelected ? (
item.label
) : (
<Highlight sourceString={item.label} searchWords={[searchText]} />
)}
</div>
<div className="email" style={showEndpoint ? { display: 'flex', alignItems: 'center', gap: '4px' } : {}}>
<span className="text-xs text-gray-500 truncate max-w-[200px]" title={baseUrl}>
{isSelected ? (
baseUrl
) : (
<Highlight sourceString={baseUrl} searchWords={[searchText]} />
)}
</span>
{showEndpoint && (
<Input
size="small"
value={currentEndpoint}
onChange={(value) => updateChannelEndpoint(channelId, value)}
placeholder="/api/ratio_config"
className="flex-1 text-xs"
style={{ fontSize: '12px' }}
/>
)}
{isSelected && !showEndpoint && (
<span className="text-xs text-gray-700 font-mono bg-gray-100 px-2 py-1 rounded ml-2">
{currentEndpoint}
</span>
)}
</div>
</div>
</>
);
};
const renderSourceItem = (item) => {
return (
<div className="components-transfer-source-item" key={item.key}>
<Checkbox
onChange={item.onChange}
checked={item.checked}
style={{ height: 52, alignItems: 'center' }}
>
<ChannelInfo item={item} showEndpoint={true} />
</Checkbox>
</div>
);
};
const renderSelectedItem = (item) => {
return (
<div className="components-transfer-selected-item" key={item.key}>
<ChannelInfo item={item} isSelected={true} />
<IconClose style={{ cursor: 'pointer' }} onClick={item.onRemove} />
</div>
);
};
const channelFilter = (input, item) => {
const searchLower = input.toLowerCase();
return item.label.toLowerCase().includes(searchLower) ||
(item._originalData?.base_url || '').toLowerCase().includes(searchLower);
};
return (
<Modal
visible={visible}
onCancel={onCancel}
onOk={onOk}
title={<span className="text-lg font-semibold">{t('选择同步渠道')}</span>}
width={1000}
>
<Space vertical style={{ width: '100%' }}>
<Transfer
style={{ width: '100%' }}
dataSource={allChannels}
value={selectedChannelIds}
onChange={setSelectedChannelIds}
renderSourceItem={renderSourceItem}
renderSelectedItem={renderSelectedItem}
filter={channelFilter}
inputProps={{ placeholder: t('搜索渠道名称或地址') }}
onSearch={setSearchText}
emptyContent={{
left: t('暂无渠道'),
right: t('暂无选择'),
search: t('无搜索结果'),
}}
/>
</Space>
</Modal>
);
}
+33 -65
View File
@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import { Card, Spin } from '@douyinfe/semi-ui';
import SettingsGeneral from '../../pages/Setting/Operation/SettingsGeneral.js';
import SettingsDrawing from '../../pages/Setting/Operation/SettingsDrawing.js';
import SettingsSensitiveWords from '../../pages/Setting/Operation/SettingsSensitiveWords.js';
@@ -7,61 +7,58 @@ import SettingsLog from '../../pages/Setting/Operation/SettingsLog.js';
import SettingsDataDashboard from '../../pages/Setting/Operation/SettingsDataDashboard.js';
import SettingsMonitoring from '../../pages/Setting/Operation/SettingsMonitoring.js';
import SettingsCreditLimit from '../../pages/Setting/Operation/SettingsCreditLimit.js';
import ModelSettingsVisualEditor from '../../pages/Setting/Operation/ModelSettingsVisualEditor.js';
import GroupRatioSettings from '../../pages/Setting/Operation/GroupRatioSettings.js';
import ModelRatioSettings from '../../pages/Setting/Operation/ModelRatioSettings.js';
import { API, showError, showSuccess } from '../../helpers';
import SettingsChats from '../../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import ModelRatioNotSetEditor from '../../pages/Setting/Operation/ModelRationNotSetEditor.js';
import { API, showError } from '../../helpers';
const OperationSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
/* 额度相关 */
QuotaForNewUser: 0,
PreConsumedQuota: 0,
QuotaForInviter: 0,
QuotaForInvitee: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
StreamCacheQueueLength: 0,
ModelRatio: '',
CacheRatio: '',
CompletionRatio: '',
ModelPrice: '',
GroupRatio: '',
GroupGroupRatio: '',
UserUsableGroups: '',
/* 通用设置 */
TopUpLink: '',
'general_setting.docs_link': '',
// ChatLink2: '', // 添加的新状态变量
QuotaPerUnit: 0,
AutomaticDisableChannelEnabled: false,
AutomaticEnableChannelEnabled: false,
ChannelDisableThreshold: 0,
LogConsumeEnabled: false,
RetryTimes: 0,
DisplayInCurrencyEnabled: false,
DisplayTokenStatEnabled: false,
CheckSensitiveEnabled: false,
CheckSensitiveOnPromptEnabled: false,
CheckSensitiveOnCompletionEnabled: '',
StopOnSensitiveEnabled: '',
SensitiveWords: '',
DefaultCollapseSidebar: false,
DemoSiteEnabled: false,
SelfUseModeEnabled: false,
/* 绘图设置 */
DrawingEnabled: false,
MjNotifyEnabled: false,
MjAccountFilterEnabled: false,
MjModeClearEnabled: false,
MjForwardUrlEnabled: false,
MjModeClearEnabled: false,
MjActionCheckSuccessEnabled: false,
DrawingEnabled: false,
/* 敏感词设置 */
CheckSensitiveEnabled: false,
CheckSensitiveOnPromptEnabled: false,
SensitiveWords: '',
/* 日志设置 */
LogConsumeEnabled: false,
/* 数据看板 */
DataExportEnabled: false,
DataExportDefaultTime: 'hour',
DataExportInterval: 5,
DefaultCollapseSidebar: false, // 默认折叠侧边栏
RetryTimes: 0,
Chats: '[]',
DemoSiteEnabled: false,
SelfUseModeEnabled: false,
/* 监控设置 */
ChannelDisableThreshold: 0,
QuotaRemindThreshold: 0,
AutomaticDisableChannelEnabled: false,
AutomaticEnableChannelEnabled: false,
AutomaticDisableKeywords: '',
/* 聊天设置 */
Chats: '[]',
});
let [loading, setLoading] = useState(false);
@@ -72,17 +69,6 @@ const OperationSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key === 'ModelRatio' ||
item.key === 'GroupRatio' ||
item.key === 'GroupGroupRatio' ||
item.key === 'UserUsableGroups' ||
item.key === 'CompletionRatio' ||
item.key === 'ModelPrice' ||
item.key === 'CacheRatio'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
if (
item.key.endsWith('Enabled') ||
['DefaultCollapseSidebar'].includes(item.key)
@@ -149,24 +135,6 @@ const OperationSetting = () => {
<Card style={{ marginTop: '10px' }}>
<SettingsChats options={inputs} refresh={onRefresh} />
</Card>
{/* 分组倍率设置 */}
<Card style={{ marginTop: '10px' }}>
<GroupRatioSettings options={inputs} refresh={onRefresh} />
</Card>
{/* 合并模型倍率设置和可视化倍率设置 */}
<Card style={{ marginTop: '10px' }}>
<Tabs type='line'>
<Tabs.TabPane tab={t('模型倍率设置')} itemKey='model'>
<ModelRatioSettings options={inputs} refresh={onRefresh} />
</Tabs.TabPane>
<Tabs.TabPane tab={t('可视化倍率设置')} itemKey='visual'>
<ModelSettingsVisualEditor options={inputs} refresh={onRefresh} />
</Tabs.TabPane>
<Tabs.TabPane tab={t('未设置倍率模型')} itemKey='unset_models'>
<ModelRatioNotSetEditor options={inputs} refresh={onRefresh} />
</Tabs.TabPane>
</Tabs>
</Card>
</Spin>
</>
);
+117
View File
@@ -0,0 +1,117 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import { useTranslation } from 'react-i18next';
import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js';
import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js';
import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js';
import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js';
import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync.js';
import { API, showError } from '../../helpers';
const RatioSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
ModelPrice: '',
ModelRatio: '',
CacheRatio: '',
CompletionRatio: '',
GroupRatio: '',
GroupGroupRatio: '',
AutoGroups: '',
DefaultUseAutoGroup: false,
ExposeRatioEnabled: false,
UserUsableGroups: '',
});
const [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key === 'ModelRatio' ||
item.key === 'GroupRatio' ||
item.key === 'GroupGroupRatio' ||
item.key === 'AutoGroups' ||
item.key === 'UserUsableGroups' ||
item.key === 'CompletionRatio' ||
item.key === 'ModelPrice' ||
item.key === 'CacheRatio'
) {
try {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
} catch (e) {
// 如果后端返回的不是合法 JSON,直接展示
}
}
if (['DefaultUseAutoGroup', 'ExposeRatioEnabled'].includes(item.key)) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
const onRefresh = async () => {
try {
setLoading(true);
await getOptions();
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
};
useEffect(() => {
onRefresh();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
return (
<Spin spinning={loading} size='large'>
{/* 模型倍率设置以及可视化编辑器 */}
<Card style={{ marginTop: '10px' }}>
<Tabs type='line'>
<Tabs.TabPane tab={t('模型倍率设置')} itemKey='model'>
<ModelRatioSettings options={inputs} refresh={onRefresh} />
</Tabs.TabPane>
<Tabs.TabPane tab={t('可视化倍率设置')} itemKey='visual'>
<ModelSettingsVisualEditor
options={inputs}
refresh={onRefresh}
/>
</Tabs.TabPane>
<Tabs.TabPane tab={t('未设置倍率模型')} itemKey='unset_models'>
<ModelRatioNotSetEditor
options={inputs}
refresh={onRefresh}
/>
</Tabs.TabPane>
<Tabs.TabPane tab={t('上游倍率同步')} itemKey='upstream_sync'>
<UpstreamRatioSync
options={inputs}
refresh={onRefresh}
/>
</Tabs.TabPane>
</Tabs>
</Card>
{/* 分组倍率设置 */}
<Card style={{ marginTop: '10px' }}>
<GroupRatioSettings options={inputs} refresh={onRefresh} />
</Card>
</Spin>
);
};
export default RatioSetting;
+17 -1
View File
@@ -17,7 +17,7 @@ import {
removeTrailingSlash,
showError,
showSuccess,
verifyJSON
verifyJSON,
} from '../../helpers';
import axios from 'axios';
@@ -73,6 +73,7 @@ const SystemSetting = () => {
LinuxDOOAuthEnabled: '',
LinuxDOClientId: '',
LinuxDOClientSecret: '',
PayMethods: '',
});
const [originInputs, setOriginInputs] = useState({});
@@ -230,6 +231,12 @@ const SystemSetting = () => {
return;
}
}
if (originInputs['PayMethods'] !== inputs.PayMethods) {
if (!verifyJSON(inputs.PayMethods)) {
showError('充值方式设置不是合法的 JSON 字符串');
return;
}
}
const options = [
{ key: 'PayAddress', value: removeTrailingSlash(inputs.PayAddress) },
@@ -256,6 +263,9 @@ const SystemSetting = () => {
if (originInputs['TopupGroupRatio'] !== inputs.TopupGroupRatio) {
options.push({ key: 'TopupGroupRatio', value: inputs.TopupGroupRatio });
}
if (originInputs['PayMethods'] !== inputs.PayMethods) {
options.push({ key: 'PayMethods', value: inputs.PayMethods });
}
await updateOptions(options);
};
@@ -658,6 +668,12 @@ const SystemSetting = () => {
placeholder='为一个 JSON 文本,键为组名称,值为倍率'
autosize
/>
<Form.TextArea
field='PayMethods'
label='充值方式设置'
placeholder='为一个 JSON 文本'
autosize
/>
<Button onClick={submitPayAddress}>更新支付设置</Button>
</Form.Section>
</Card>
+116 -22
View File
@@ -1,4 +1,4 @@
import React, { useEffect, useState } from 'react';
import React, { useEffect, useState, useMemo, useRef } from 'react';
import {
API,
showError,
@@ -16,11 +16,6 @@ import {
XCircle,
AlertCircle,
HelpCircle,
TestTube,
Zap,
Timer,
Clock,
AlertTriangle,
Coins,
Tags
} from 'lucide-react';
@@ -43,7 +38,9 @@ import {
Typography,
Checkbox,
Card,
Form
Form,
Tabs,
TabPane
} from '@douyinfe/semi-ui';
import {
IllustrationNoResult,
@@ -141,31 +138,31 @@ const ChannelsTable = () => {
time = time.toFixed(2) + t(' 秒');
if (responseTime === 0) {
return (
<Tag size='large' color='grey' shape='circle' prefixIcon={<TestTube size={14} />}>
<Tag size='large' color='grey' shape='circle'>
{t('未测试')}
</Tag>
);
} else if (responseTime <= 1000) {
return (
<Tag size='large' color='green' shape='circle' prefixIcon={<Zap size={14} />}>
<Tag size='large' color='green' shape='circle'>
{time}
</Tag>
);
} else if (responseTime <= 3000) {
return (
<Tag size='large' color='lime' shape='circle' prefixIcon={<Timer size={14} />}>
<Tag size='large' color='lime' shape='circle'>
{time}
</Tag>
);
} else if (responseTime <= 5000) {
return (
<Tag size='large' color='yellow' shape='circle' prefixIcon={<Clock size={14} />}>
<Tag size='large' color='yellow' shape='circle'>
{time}
</Tag>
);
} else {
return (
<Tag size='large' color='red' shape='circle' prefixIcon={<AlertTriangle size={14} />}>
<Tag size='large' color='red' shape='circle'>
{time}
</Tag>
);
@@ -682,11 +679,10 @@ const ChannelsTable = () => {
const [isBatchTesting, setIsBatchTesting] = useState(false);
const [testQueue, setTestQueue] = useState([]);
const [isProcessingQueue, setIsProcessingQueue] = useState(false);
// Form API 引用
const [activeTypeKey, setActiveTypeKey] = useState('all');
const [typeCounts, setTypeCounts] = useState({});
const requestCounter = useRef(0);
const [formApi, setFormApi] = useState(null);
// Form 初始值
const formInitValues = {
searchKeyword: '',
searchGroup: '',
@@ -868,17 +864,23 @@ const ChannelsTable = () => {
setChannels(channelDates);
};
const loadChannels = async (page, pageSize, idSort, enableTagMode) => {
const loadChannels = async (page, pageSize, idSort, enableTagMode, typeKey = activeTypeKey) => {
const reqId = ++requestCounter.current; // 记录当前请求序号
setLoading(true);
const typeParam = (!enableTagMode && typeKey !== 'all') ? `&type=${typeKey}` : '';
const res = await API.get(
`/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}`,
`/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}`,
);
if (res === undefined) {
if (res === undefined || reqId !== requestCounter.current) {
return;
}
const { success, message, data } = res.data;
if (success) {
const { items, total } = data;
const { items, total, type_counts } = data;
if (type_counts) {
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
setTypeCounts({ ...type_counts, all: sumAll });
}
setChannelFormat(items, enableTagMode);
setChannelCount(total);
} else {
@@ -1044,12 +1046,16 @@ const ChannelsTable = () => {
return;
}
const typeParam = (!enableTagMode && activeTypeKey !== 'all') ? `&type=${activeTypeKey}` : '';
const res = await API.get(
`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${idSort}&tag_mode=${enableTagMode}`,
`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}`,
);
const { success, message, data } = res.data;
if (success) {
setChannelFormat(data, enableTagMode);
const { items = [], type_counts = {} } = data;
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
setTypeCounts({ ...type_counts, all: sumAll });
setChannelFormat(items, enableTagMode);
setActivePage(1);
} else {
showError(message);
@@ -1179,7 +1185,94 @@ const ChannelsTable = () => {
}
};
const channelTypeCounts = useMemo(() => {
if (Object.keys(typeCounts).length > 0) return typeCounts;
// fallback 本地计算
const counts = { all: channels.length };
channels.forEach((channel) => {
const collect = (ch) => {
const type = ch.type;
counts[type] = (counts[type] || 0) + 1;
};
if (channel.children !== undefined) {
channel.children.forEach(collect);
} else {
collect(channel);
}
});
return counts;
}, [typeCounts, channels]);
const availableTypeKeys = useMemo(() => {
const keys = ['all'];
Object.entries(channelTypeCounts).forEach(([k, v]) => {
if (k !== 'all' && v > 0) keys.push(String(k));
});
return keys;
}, [channelTypeCounts]);
const renderTypeTabs = () => {
if (enableTagMode) return null;
return (
<Tabs
activeKey={activeTypeKey}
type="card"
collapsible
onChange={(key) => {
setActiveTypeKey(key);
setActivePage(1);
loadChannels(1, pageSize, idSort, enableTagMode, key);
}}
className="mb-4"
>
<TabPane
itemKey="all"
tab={
<span className="flex items-center gap-2">
{t('全部')}
<Tag color={activeTypeKey === 'all' ? 'red' : 'grey'} size='small' shape='circle'>
{channelTypeCounts['all'] || 0}
</Tag>
</span>
}
/>
{CHANNEL_OPTIONS.filter((opt) => availableTypeKeys.includes(String(opt.value))).map((option) => {
const key = String(option.value);
const count = channelTypeCounts[option.value] || 0;
return (
<TabPane
key={key}
itemKey={key}
tab={
<span className="flex items-center gap-2">
{getChannelIcon(option.value)}
{option.label}
<Tag color={activeTypeKey === key ? 'red' : 'grey'} size='small' shape='circle'>
{count}
</Tag>
</span>
}
/>
);
})}
</Tabs>
);
};
let pageData = channels;
if (activeTypeKey !== 'all') {
const typeVal = parseInt(activeTypeKey);
if (!isNaN(typeVal)) {
pageData = pageData.filter((ch) => {
if (ch.children !== undefined) {
return ch.children.some((c) => c.type === typeVal);
}
return ch.type === typeVal;
});
}
}
const handlePageChange = (page) => {
setActivePage(page);
@@ -1371,6 +1464,7 @@ const ChannelsTable = () => {
const renderHeader = () => (
<div className="flex flex-col w-full">
{renderTypeTabs()}
<div className="flex flex-col md:flex-row justify-between gap-4">
<div className="flex flex-wrap md:flex-nowrap items-center gap-2 w-full md:w-auto order-2 md:order-1">
<Button
+30 -15
View File
@@ -11,7 +11,9 @@ import {
XCircle,
Loader,
List,
Hash
Hash,
Video,
Sparkles
} from 'lucide-react';
import {
API,
@@ -80,6 +82,7 @@ const COLUMN_KEYS = {
TASK_STATUS: 'task_status',
PROGRESS: 'progress',
FAIL_REASON: 'fail_reason',
RESULT_URL: 'result_url',
};
const renderTimestamp = (timestampInSeconds) => {
@@ -96,20 +99,8 @@ const renderTimestamp = (timestampInSeconds) => {
};
function renderDuration(submit_time, finishTime) {
// 确保startTime和finishTime都是有效的时间戳
if (!submit_time || !finishTime) return 'N/A';
// 将时间戳转换为Date对象
const start = new Date(submit_time);
const finish = new Date(finishTime);
// 计算时间差(毫秒)
const durationMs = finish - start;
// 将时间差转换为秒,并保留一位小数
const durationSec = (durationMs / 1000).toFixed(1);
// 设置颜色:大于60秒则为红色,小于等于60秒则为绿色
const durationSec = finishTime - submit_time;
const color = durationSec > 60 ? 'red' : 'green';
// 返回带有样式的颜色标签
@@ -162,6 +153,7 @@ const LogsTable = () => {
[COLUMN_KEYS.TASK_STATUS]: true,
[COLUMN_KEYS.PROGRESS]: true,
[COLUMN_KEYS.FAIL_REASON]: true,
[COLUMN_KEYS.RESULT_URL]: true,
};
};
@@ -215,6 +207,12 @@ const LogsTable = () => {
{t('生成歌词')}
</Tag>
);
case 'generate':
return (
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
{t('生成视频')}
</Tag>
);
default:
return (
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
@@ -232,6 +230,12 @@ const LogsTable = () => {
Suno
</Tag>
);
case 'kling':
return (
<Tag color='blue' size='large' shape='circle' prefixIcon={<Video size={14} />}>
Kling
</Tag>
);
default:
return (
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
@@ -423,10 +427,21 @@ const LogsTable = () => {
},
{
key: COLUMN_KEYS.FAIL_REASON,
title: t('失败原因'),
title: t('详情'),
dataIndex: 'fail_reason',
fixed: 'right',
render: (text, record, index) => {
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
const isVideoTask = record.action === 'generate';
const isSuccess = record.status === 'SUCCESS';
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
if (isSuccess && isVideoTask && isUrl) {
return (
<a href={text} target="_blank" rel="noopener noreferrer">
{t('点击预览视频')}
</a>
);
}
if (!text) {
return t('无');
}
+5
View File
@@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [
color: 'blue',
label: 'Coze',
},
{
value: 50,
color: 'green',
label: '可灵',
},
];
+2
View File
@@ -1 +1,3 @@
export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend!
export const DEFAULT_ENDPOINT = '/api/ratio_config';
+1
View File
@@ -9,6 +9,7 @@ export function setStatusData(data) {
localStorage.setItem('enable_task', data.enable_task);
localStorage.setItem('enable_data_export', data.enable_data_export);
localStorage.setItem('chats', JSON.stringify(data.chats));
localStorage.setItem('pay_methods', JSON.stringify(data.pay_methods));
localStorage.setItem(
'data_export_default_time',
data.data_export_default_time,
+25 -2
View File
@@ -1588,7 +1588,7 @@
"性能指标": "Performance Indicators",
"模型数据分析": "Model Data Analysis",
"搜索无结果": "No results found",
"仪表盘置": "Dashboard Configuration",
"仪表盘置": "Dashboard Settings",
"API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)": "API information management, you can configure multiple API addresses for status display and load balancing (maximum 50)",
"线路描述": "Route description",
"颜色": "Color",
@@ -1665,5 +1665,28 @@
"确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?",
"将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.",
"选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)",
"请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)"
"请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)",
"上游倍率同步": "Upstream ratio synchronization",
"获取渠道失败:": "Failed to get channels: ",
"请至少选择一个渠道": "Please select at least one channel",
"获取倍率失败:": "Failed to get ratios: ",
"后端请求失败": "Backend request failed",
"部分渠道测试失败:": "Some channels failed to test: ",
"已与上游倍率完全一致,无需同步": "The upstream ratio is completely consistent, no synchronization is required",
"请求后端接口失败:": "Failed to request the backend interface: ",
"同步成功": "Synchronization successful",
"部分保存失败": "Some settings failed to save",
"保存失败": "Save failed",
"选择同步渠道": "Select synchronization channel",
"应用同步": "Apply synchronization",
"倍率类型": "Ratio type",
"当前值": "Current value",
"上游值": "Upstream value",
"差异": "Difference",
"搜索渠道名称或地址": "Search channel name or address",
"缓存倍率": "Cache ratio",
"暂无差异化倍率显示": "No differential ratio display",
"请先选择同步渠道": "Please select the synchronization channel first",
"与本地相同": "Same as local",
"未找到匹配的模型": "No matching model found"
}
+68
View File
@@ -432,4 +432,72 @@ code {
.semi-table-tbody>.semi-table-row {
border-bottom: 1px solid rgba(0, 0, 0, 0.1);
}
}
/* ==================== 同步倍率 - 渠道选择器 ==================== */
.components-transfer-source-item,
.components-transfer-selected-item {
display: flex;
align-items: center;
padding: 8px;
}
.semi-transfer-left-list,
.semi-transfer-right-list {
-ms-overflow-style: none;
scrollbar-width: none;
}
.semi-transfer-left-list::-webkit-scrollbar,
.semi-transfer-right-list::-webkit-scrollbar {
display: none;
}
.components-transfer-source-item .semi-checkbox,
.components-transfer-selected-item .semi-checkbox {
display: flex;
align-items: center;
width: 100%;
}
.components-transfer-source-item .semi-avatar,
.components-transfer-selected-item .semi-avatar {
margin-right: 12px;
flex-shrink: 0;
}
.components-transfer-source-item .info,
.components-transfer-selected-item .info {
flex: 1;
overflow: hidden;
display: flex;
flex-direction: column;
justify-content: center;
}
.components-transfer-source-item .name,
.components-transfer-selected-item .name {
font-weight: 500;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.components-transfer-source-item .email,
.components-transfer-selected-item .email {
font-size: 12px;
color: var(--semi-color-text-2);
display: flex;
align-items: center;
}
.components-transfer-selected-item .semi-icon-close {
margin-left: 8px;
cursor: pointer;
color: var(--semi-color-text-2);
}
.components-transfer-selected-item .semi-icon-close:hover {
color: var(--semi-color-text-0);
}
+22 -13
View File
@@ -298,18 +298,27 @@ const EditChannel = (props) => {
}
};
useEffect(() => {
let localModelOptions = [...originModelOptions];
inputs.models.forEach((model) => {
if (!localModelOptions.find((option) => option.label === model)) {
localModelOptions.push({
label: model,
value: model,
});
}
});
setModelOptions(localModelOptions);
}, [originModelOptions, inputs.models]);
useEffect(() => {
// 使用 Map 来避免重复,以 value 为键
const modelMap = new Map();
// 先添加原始模型选项
originModelOptions.forEach(option => {
modelMap.set(option.value, option);
});
// 再添加当前选中的模型(如果不存在)
inputs.models.forEach(model => {
if (!modelMap.has(model)) {
modelMap.set(model, {
label: model,
value: model,
});
}
});
setModelOptions(Array.from(modelMap.values()));
}, [originModelOptions, inputs.models]);
useEffect(() => {
fetchModels().then();
@@ -530,7 +539,7 @@ const EditChannel = (props) => {
handleInputChange('key', value);
}}
value={inputs.key}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autosize={{ minRows: 6, maxRows: 6 }}
autoComplete='new-password'
className="!rounded-lg"
/>
@@ -17,6 +17,8 @@ export default function GroupRatioSettings(props) {
GroupRatio: '',
UserUsableGroups: '',
GroupGroupRatio: '',
AutoGroups: '',
DefaultUseAutoGroup: false,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -167,6 +169,59 @@ export default function GroupRatioSettings(props) {
/>
</Col>
</Row>
<Row gutter={16}>
<Col xs={24} sm={16}>
<Form.TextArea
label={t('自动分组auto,从第一个开始选择')}
placeholder={t('为一个 JSON 文本')}
field={'AutoGroups'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => {
if (!value || value.trim() === '') {
return true; // Allow empty values
}
// First check if it's valid JSON
try {
const parsed = JSON.parse(value);
// Check if it's an array
if (!Array.isArray(parsed)) {
return false;
}
// Check if every element is a string
return parsed.every(item => typeof item === 'string');
} catch (error) {
return false;
}
},
message: t('必须是有效的 JSON 字符串数组,例如:["g1","g2"]'),
},
]}
onChange={(value) =>
setInputs({ ...inputs, AutoGroups: value })
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.Switch
label={t(
'创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)',
)}
field={'DefaultUseAutoGroup'}
onChange={(value) =>
setInputs({ ...inputs, DefaultUseAutoGroup: value })
}
/>
</Col>
</Row>
</Form.Section>
</Form>
<Button onClick={onSubmit}>{t('保存分组倍率设置')}</Button>
@@ -25,6 +25,7 @@ export default function ModelRatioSettings(props) {
ModelRatio: '',
CacheRatio: '',
CompletionRatio: '',
ExposeRatioEnabled: false,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -206,6 +207,17 @@ export default function ModelRatioSettings(props) {
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.Switch
label={t('暴露倍率接口')}
field={'ExposeRatioEnabled'}
onChange={(value) =>
setInputs({ ...inputs, ExposeRatioEnabled: value })
}
/>
</Col>
</Row>
</Form.Section>
</Form>
<Space>
@@ -0,0 +1,503 @@
import React, { useState, useCallback, useMemo } from 'react';
import {
Button,
Table,
Tag,
Empty,
Checkbox,
Form,
Input,
} from '@douyinfe/semi-ui';
import { IconSearch } from '@douyinfe/semi-icons';
import {
RefreshCcw,
CheckSquare,
} from 'lucide-react';
import { API, showError, showSuccess, showWarning, stringToColor } from '../../../helpers';
import { DEFAULT_ENDPOINT } from '../../../constants';
import { useTranslation } from 'react-i18next';
import {
IllustrationNoResult,
IllustrationNoResultDark
} from '@douyinfe/semi-illustrations';
import ChannelSelectorModal from '../../../components/settings/ChannelSelectorModal';
export default function UpstreamRatioSync(props) {
const { t } = useTranslation();
const [modalVisible, setModalVisible] = useState(false);
const [loading, setLoading] = useState(false);
const [syncLoading, setSyncLoading] = useState(false);
// 渠道选择相关
const [allChannels, setAllChannels] = useState([]);
const [selectedChannelIds, setSelectedChannelIds] = useState([]);
// 渠道端点配置
const [channelEndpoints, setChannelEndpoints] = useState({}); // { channelId: endpoint }
// 差异数据和测试结果
const [differences, setDifferences] = useState({});
const [resolutions, setResolutions] = useState({});
// 是否已经执行过同步
const [hasSynced, setHasSynced] = useState(false);
// 分页相关状态
const [currentPage, setCurrentPage] = useState(1);
const [pageSize, setPageSize] = useState(10);
// 搜索相关状态
const [searchKeyword, setSearchKeyword] = useState('');
const fetchAllChannels = async () => {
setLoading(true);
try {
const res = await API.get('/api/ratio_sync/channels');
if (res.data.success) {
const channels = res.data.data || [];
const transferData = channels.map(channel => ({
key: channel.id,
label: channel.name,
value: channel.id,
disabled: false,
_originalData: channel,
}));
setAllChannels(transferData);
const initialEndpoints = {};
transferData.forEach(channel => {
initialEndpoints[channel.key] = DEFAULT_ENDPOINT;
});
setChannelEndpoints(initialEndpoints);
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('获取渠道失败:') + error.message);
} finally {
setLoading(false);
}
};
const confirmChannelSelection = () => {
const selected = allChannels
.filter(ch => selectedChannelIds.includes(ch.value))
.map(ch => ch._originalData);
if (selected.length === 0) {
showWarning(t('请至少选择一个渠道'));
return;
}
setModalVisible(false);
fetchRatiosFromChannels(selected);
};
const fetchRatiosFromChannels = async (channelList) => {
setSyncLoading(true);
const payload = {
channel_ids: channelList.map(ch => parseInt(ch.id)),
timeout: 10,
};
try {
const res = await API.post('/api/ratio_sync/fetch', payload);
if (!res.data.success) {
showError(res.data.message || t('后端请求失败'));
setSyncLoading(false);
return;
}
const { differences = {}, test_results = [] } = res.data.data;
const errorResults = test_results.filter(r => r.status === 'error');
if (errorResults.length > 0) {
showWarning(t('部分渠道测试失败:') + errorResults.map(r => `${r.name}: ${r.error}`).join(', '));
}
setDifferences(differences);
setResolutions({});
setHasSynced(true);
if (Object.keys(differences).length === 0) {
showSuccess(t('已与上游倍率完全一致,无需同步'));
}
} catch (e) {
showError(t('请求后端接口失败:') + e.message);
} finally {
setSyncLoading(false);
}
};
const selectValue = (model, ratioType, value) => {
setResolutions(prev => ({
...prev,
[model]: {
...prev[model],
[ratioType]: value,
},
}));
};
const applySync = async () => {
const currentRatios = {
ModelRatio: JSON.parse(props.options.ModelRatio || '{}'),
CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'),
CacheRatio: JSON.parse(props.options.CacheRatio || '{}'),
ModelPrice: JSON.parse(props.options.ModelPrice || '{}'),
};
Object.entries(resolutions).forEach(([model, ratios]) => {
Object.entries(ratios).forEach(([ratioType, value]) => {
const optionKey = ratioType
.split('_')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join('');
currentRatios[optionKey][model] = parseFloat(value);
});
});
setLoading(true);
try {
const updates = Object.entries(currentRatios).map(([key, value]) =>
API.put('/api/option/', {
key,
value: JSON.stringify(value, null, 2),
})
);
const results = await Promise.all(updates);
if (results.every(res => res.data.success)) {
showSuccess(t('同步成功'));
props.refresh();
setDifferences(prevDifferences => {
const newDifferences = { ...prevDifferences };
Object.entries(resolutions).forEach(([model, ratios]) => {
Object.keys(ratios).forEach(ratioType => {
if (newDifferences[model] && newDifferences[model][ratioType]) {
delete newDifferences[model][ratioType];
if (Object.keys(newDifferences[model]).length === 0) {
delete newDifferences[model];
}
}
});
});
return newDifferences;
});
setResolutions({});
} else {
showError(t('部分保存失败'));
}
} catch (error) {
showError(t('保存失败'));
} finally {
setLoading(false);
}
};
const getCurrentPageData = (dataSource) => {
const startIndex = (currentPage - 1) * pageSize;
const endIndex = startIndex + pageSize;
return dataSource.slice(startIndex, endIndex);
};
const renderHeader = () => (
<div className="flex flex-col w-full">
<div className="flex flex-col md:flex-row justify-between items-center gap-4 w-full">
<div className="flex gap-2 w-full md:w-auto order-2 md:order-1">
<Button
icon={<RefreshCcw size={14} />}
className="!rounded-full w-full md:w-auto mt-2"
onClick={() => {
setModalVisible(true);
fetchAllChannels();
}}
>
{t('选择同步渠道')}
</Button>
{(() => {
const hasSelections = Object.keys(resolutions).length > 0;
return (
<Button
icon={<CheckSquare size={14} />}
type='secondary'
onClick={applySync}
disabled={!hasSelections}
className="!rounded-full w-full md:w-auto mt-2"
>
{t('应用同步')}
</Button>
);
})()}
<Input
prefix={<IconSearch size={14} />}
placeholder={t('搜索模型名称')}
value={searchKeyword}
onChange={setSearchKeyword}
className="!rounded-full w-full md:w-64 mt-2"
showClear
/>
</div>
</div>
</div>
);
const renderDifferenceTable = () => {
const dataSource = useMemo(() => {
const tmp = [];
Object.entries(differences).forEach(([model, ratioTypes]) => {
Object.entries(ratioTypes).forEach(([ratioType, diff]) => {
tmp.push({
key: `${model}_${ratioType}`,
model,
ratioType,
current: diff.current,
upstreams: diff.upstreams,
});
});
});
return tmp;
}, [differences]);
const filteredDataSource = useMemo(() => {
if (!searchKeyword.trim()) {
return dataSource;
}
const keyword = searchKeyword.toLowerCase().trim();
return dataSource.filter(item =>
item.model.toLowerCase().includes(keyword)
);
}, [dataSource, searchKeyword]);
const upstreamNames = useMemo(() => {
const set = new Set();
filteredDataSource.forEach((row) => {
Object.keys(row.upstreams || {}).forEach((name) => set.add(name));
});
return Array.from(set);
}, [filteredDataSource]);
if (filteredDataSource.length === 0) {
return (
<Empty
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
darkModeImage={<IllustrationNoResultDark style={{ width: 150, height: 150 }} />}
description={
searchKeyword.trim()
? t('未找到匹配的模型')
: (Object.keys(differences).length === 0 ?
(hasSynced ? t('暂无差异化倍率显示') : t('请先选择同步渠道'))
: t('请先选择同步渠道'))
}
style={{ padding: 30 }}
/>
);
}
const columns = [
{
title: t('模型'),
dataIndex: 'model',
fixed: 'left',
},
{
title: t('倍率类型'),
dataIndex: 'ratioType',
render: (text) => {
const typeMap = {
model_ratio: t('模型倍率'),
completion_ratio: t('补全倍率'),
cache_ratio: t('缓存倍率'),
model_price: t('固定价格'),
};
return <Tag color={stringToColor(text)} shape="circle">{typeMap[text] || text}</Tag>;
},
},
{
title: t('当前值'),
dataIndex: 'current',
render: (text) => (
<Tag color={text !== null && text !== undefined ? 'blue' : 'default'} shape="circle">
{text !== null && text !== undefined ? text : t('未设置')}
</Tag>
),
},
...upstreamNames.map((upName) => {
const channelStats = (() => {
let selectableCount = 0;
let selectedCount = 0;
filteredDataSource.forEach((row) => {
const upstreamVal = row.upstreams?.[upName];
if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') {
selectableCount++;
const isSelected = resolutions[row.model]?.[row.ratioType] === upstreamVal;
if (isSelected) {
selectedCount++;
}
}
});
return {
selectableCount,
selectedCount,
allSelected: selectableCount > 0 && selectedCount === selectableCount,
partiallySelected: selectedCount > 0 && selectedCount < selectableCount,
hasSelectableItems: selectableCount > 0
};
})();
const handleBulkSelect = (checked) => {
setResolutions((prev) => {
const newRes = { ...prev };
filteredDataSource.forEach((row) => {
const upstreamVal = row.upstreams?.[upName];
if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') {
if (checked) {
if (!newRes[row.model]) newRes[row.model] = {};
newRes[row.model][row.ratioType] = upstreamVal;
} else {
if (newRes[row.model]) {
delete newRes[row.model][row.ratioType];
if (Object.keys(newRes[row.model]).length === 0) {
delete newRes[row.model];
}
}
}
}
});
return newRes;
});
};
return {
title: channelStats.hasSelectableItems ? (
<Checkbox
checked={channelStats.allSelected}
indeterminate={channelStats.partiallySelected}
onChange={(e) => handleBulkSelect(e.target.checked)}
>
{upName}
</Checkbox>
) : (
<span>{upName}</span>
),
dataIndex: upName,
render: (_, record) => {
const upstreamVal = record.upstreams?.[upName];
if (upstreamVal === null || upstreamVal === undefined) {
return <Tag color="default" shape="circle">{t('未设置')}</Tag>;
}
if (upstreamVal === 'same') {
return <Tag color="blue" shape="circle">{t('与本地相同')}</Tag>;
}
const isSelected = resolutions[record.model]?.[record.ratioType] === upstreamVal;
return (
<Checkbox
checked={isSelected}
onChange={(e) => {
const isChecked = e.target.checked;
if (isChecked) {
selectValue(record.model, record.ratioType, upstreamVal);
} else {
setResolutions((prev) => {
const newRes = { ...prev };
if (newRes[record.model]) {
delete newRes[record.model][record.ratioType];
if (Object.keys(newRes[record.model]).length === 0) {
delete newRes[record.model];
}
}
return newRes;
});
}
}}
>
{upstreamVal}
</Checkbox>
);
},
};
}),
];
return (
<Table
columns={columns}
dataSource={getCurrentPageData(filteredDataSource)}
pagination={{
currentPage: currentPage,
pageSize: pageSize,
total: filteredDataSource.length,
showSizeChanger: true,
showQuickJumper: true,
formatPageText: (page) => t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
start: page.currentStart,
end: page.currentEnd,
total: filteredDataSource.length,
}),
pageSizeOptions: ['5', '10', '20', '50'],
onChange: (page, size) => {
setCurrentPage(page);
setPageSize(size);
},
onShowSizeChange: (current, size) => {
setCurrentPage(1);
setPageSize(size);
}
}}
scroll={{ x: 'max-content' }}
size='middle'
loading={loading || syncLoading}
className="rounded-xl overflow-hidden"
/>
);
};
const updateChannelEndpoint = useCallback((channelId, endpoint) => {
setChannelEndpoints(prev => ({ ...prev, [channelId]: endpoint }));
}, []);
return (
<>
<Form.Section text={renderHeader()}>
{renderDifferenceTable()}
</Form.Section>
<ChannelSelectorModal
t={t}
visible={modalVisible}
onCancel={() => setModalVisible(false)}
onOk={confirmChannelSelection}
allChannels={allChannels}
selectedChannelIds={selectedChannelIds}
setSelectedChannelIds={setSelectedChannelIds}
channelEndpoints={channelEndpoints}
updateChannelEndpoint={updateChannelEndpoint}
/>
</>
);
}
+7 -1
View File
@@ -10,6 +10,7 @@ import OperationSetting from '../../components/settings/OperationSetting.js';
import RateLimitSetting from '../../components/settings/RateLimitSetting.js';
import ModelSetting from '../../components/settings/ModelSetting.js';
import DashboardSetting from '../../components/settings/DashboardSetting.js';
import RatioSetting from '../../components/settings/RatioSetting.js';
const Setting = () => {
const { t } = useTranslation();
@@ -24,6 +25,11 @@ const Setting = () => {
content: <OperationSetting />,
itemKey: 'operation',
});
panes.push({
tab: t('倍率设置'),
content: <RatioSetting />,
itemKey: 'ratio',
});
panes.push({
tab: t('速率限制设置'),
content: <RateLimitSetting />,
@@ -45,7 +51,7 @@ const Setting = () => {
itemKey: 'other',
});
panes.push({
tab: t('仪表盘置'),
tab: t('仪表盘置'),
content: <DashboardSetting />,
itemKey: 'dashboard',
});
+233 -129
View File
@@ -1,4 +1,4 @@
import React, { useEffect, useState } from 'react';
import React, { useEffect, useState, useContext } from 'react';
import { useNavigate } from 'react-router-dom';
import {
API,
@@ -7,7 +7,7 @@ import {
showSuccess,
timestamp2string,
renderGroupOption,
renderQuotaWithPrompt
renderQuotaWithPrompt,
} from '../../helpers';
import {
AutoComplete,
@@ -37,11 +37,13 @@ import {
IconPlusCircle,
} from '@douyinfe/semi-icons';
import { useTranslation } from 'react-i18next';
import { StatusContext } from '../../context/Status';
const { Text, Title } = Typography;
const EditToken = (props) => {
const { t } = useTranslation();
const [statusState, statusDispatch] = useContext(StatusContext);
const [isEdit, setIsEdit] = useState(false);
const [loading, setLoading] = useState(isEdit);
const originInputs = {
@@ -119,7 +121,19 @@ const EditToken = (props) => {
value: group,
ratio: info.ratio,
}));
if (statusState?.status?.default_use_auto_group) {
// if contain auto, add it to the first position
if (localGroupOptions.some((group) => group.value === 'auto')) {
// 排序
localGroupOptions.sort((a, b) => (a.value === 'auto' ? -1 : 1));
} else {
localGroupOptions.unshift({ label: t('自动选择'), value: 'auto' });
}
}
setGroups(localGroupOptions);
if (statusState?.status?.default_use_auto_group) {
setInputs({ ...inputs, group: 'auto' });
}
} else {
showError(t(message));
}
@@ -268,32 +282,37 @@ const EditToken = (props) => {
placement={isEdit ? 'right' : 'left'}
title={
<Space>
{isEdit ?
<Tag color="blue" shape="circle">{t('更新')}</Tag> :
<Tag color="green" shape="circle">{t('新')}</Tag>
}
<Title heading={4} className="m-0">
{isEdit ? (
<Tag color='blue' shape='circle'>
{t('新')}
</Tag>
) : (
<Tag color='green' shape='circle'>
{t('新建')}
</Tag>
)}
<Title heading={4} className='m-0'>
{isEdit ? t('更新令牌信息') : t('创建新的令牌')}
</Title>
</Space>
}
headerStyle={{
borderBottom: '1px solid var(--semi-color-border)',
padding: '24px'
padding: '24px',
}}
bodyStyle={{
backgroundColor: 'var(--semi-color-bg-0)',
padding: '0'
padding: '0',
}}
visible={props.visiable}
width={isMobile() ? '100%' : 600}
footer={
<div className="flex justify-end bg-white">
<div className='flex justify-end bg-white'>
<Space>
<Button
theme="solid"
size="large"
className="!rounded-full"
theme='solid'
size='large'
className='!rounded-full'
onClick={submit}
icon={<IconSave />}
loading={loading}
@@ -301,10 +320,10 @@ const EditToken = (props) => {
{t('提交')}
</Button>
<Button
theme="light"
size="large"
className="!rounded-full"
type="primary"
theme='light'
size='large'
className='!rounded-full'
type='primary'
onClick={handleCancel}
icon={<IconClose />}
>
@@ -317,87 +336,107 @@ const EditToken = (props) => {
onCancel={() => handleCancel()}
>
<Spin spinning={loading}>
<div className="p-6">
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
<div className="flex items-center mb-4 p-6 rounded-xl" style={{
background: 'linear-gradient(135deg, #1e3a8a 0%, #2563eb 50%, #3b82f6 100%)',
position: 'relative'
}}>
<div className="absolute inset-0 overflow-hidden">
<div className="absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full"></div>
<div className="absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full"></div>
<div className='p-6'>
<Card className='!rounded-2xl shadow-sm border-0 mb-6'>
<div
className='flex items-center mb-4 p-6 rounded-xl'
style={{
background:
'linear-gradient(135deg, #1e3a8a 0%, #2563eb 50%, #3b82f6 100%)',
position: 'relative',
}}
>
<div className='absolute inset-0 overflow-hidden'>
<div className='absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full'></div>
<div className='absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full'></div>
</div>
<div className="w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative">
<IconPlusCircle size="large" style={{ color: '#ffffff' }} />
<div className='w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative'>
<IconPlusCircle size='large' style={{ color: '#ffffff' }} />
</div>
<div className="relative">
<Text style={{ color: '#ffffff' }} className="text-lg font-medium">{t('基本信息')}</Text>
<div style={{ color: '#ffffff' }} className="text-sm opacity-80">{t('设置令牌的基本信息')}</div>
<div className='relative'>
<Text
style={{ color: '#ffffff' }}
className='text-lg font-medium'
>
{t('基本信息')}
</Text>
<div
style={{ color: '#ffffff' }}
className='text-sm opacity-80'
>
{t('设置令牌的基本信息')}
</div>
</div>
</div>
<div className="space-y-4">
<div className='space-y-4'>
<div>
<Text strong className="block mb-2">{t('名称')}</Text>
<Text strong className='block mb-2'>
{t('名称')}
</Text>
<Input
placeholder={t('请输入名称')}
onChange={(value) => handleInputChange('name', value)}
value={name}
autoComplete="new-password"
size="large"
className="!rounded-lg"
autoComplete='new-password'
size='large'
className='!rounded-lg'
showClear
required
/>
</div>
<div>
<Text strong className="block mb-2">{t('过期时间')}</Text>
<div className="mb-2">
<Text strong className='block mb-2'>
{t('过期时间')}
</Text>
<div className='mb-2'>
<DatePicker
placeholder={t('请选择过期时间')}
onChange={(value) => handleInputChange('expired_time', value)}
onChange={(value) =>
handleInputChange('expired_time', value)
}
value={expired_time}
autoComplete="new-password"
type="dateTime"
className="w-full !rounded-lg"
size="large"
autoComplete='new-password'
type='dateTime'
className='w-full !rounded-lg'
size='large'
prefix={<IconCalendar />}
/>
</div>
<div className="flex flex-wrap gap-2">
<div className='flex flex-wrap gap-2'>
<Button
theme="light"
type="primary"
theme='light'
type='primary'
onClick={() => setExpiredTime(0, 0, 0, 0)}
className="!rounded-full"
className='!rounded-full'
>
{t('永不过期')}
</Button>
<Button
theme="light"
type="tertiary"
theme='light'
type='tertiary'
onClick={() => setExpiredTime(0, 0, 1, 0)}
className="!rounded-full"
className='!rounded-full'
icon={<IconClock />}
>
{t('一小时')}
</Button>
<Button
theme="light"
type="tertiary"
theme='light'
type='tertiary'
onClick={() => setExpiredTime(0, 1, 0, 0)}
className="!rounded-full"
className='!rounded-full'
icon={<IconCalendar />}
>
{t('一天')}
</Button>
<Button
theme="light"
type="tertiary"
theme='light'
type='tertiary'
onClick={() => setExpiredTime(1, 0, 0, 0)}
className="!rounded-full"
className='!rounded-full'
icon={<IconCalendar />}
>
{t('一个月')}
@@ -407,44 +446,62 @@ const EditToken = (props) => {
</div>
</Card>
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
<div className="flex items-center mb-4 p-6 rounded-xl" style={{
background: 'linear-gradient(135deg, #065f46 0%, #059669 50%, #10b981 100%)',
position: 'relative'
}}>
<div className="absolute inset-0 overflow-hidden">
<div className="absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full"></div>
<div className="absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full"></div>
<Card className='!rounded-2xl shadow-sm border-0 mb-6'>
<div
className='flex items-center mb-4 p-6 rounded-xl'
style={{
background:
'linear-gradient(135deg, #065f46 0%, #059669 50%, #10b981 100%)',
position: 'relative',
}}
>
<div className='absolute inset-0 overflow-hidden'>
<div className='absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full'></div>
<div className='absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full'></div>
</div>
<div className="w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative">
<IconCreditCard size="large" style={{ color: '#ffffff' }} />
<div className='w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative'>
<IconCreditCard size='large' style={{ color: '#ffffff' }} />
</div>
<div className="relative">
<Text style={{ color: '#ffffff' }} className="text-lg font-medium">{t('额度设置')}</Text>
<div style={{ color: '#ffffff' }} className="text-sm opacity-80">{t('设置令牌可用额度和数量')}</div>
<div className='relative'>
<Text
style={{ color: '#ffffff' }}
className='text-lg font-medium'
>
{t('额度设置')}
</Text>
<div
style={{ color: '#ffffff' }}
className='text-sm opacity-80'
>
{t('设置令牌可用额度和数量')}
</div>
</div>
</div>
<Banner
type="warning"
description={t('注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。')}
className="mb-4 !rounded-lg"
type='warning'
description={t(
'注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。',
)}
className='mb-4 !rounded-lg'
/>
<div className="space-y-4">
<div className='space-y-4'>
<div>
<div className="flex justify-between mb-2">
<div className='flex justify-between mb-2'>
<Text strong>{t('额度')}</Text>
<Text type="tertiary">{renderQuotaWithPrompt(remain_quota)}</Text>
<Text type='tertiary'>
{renderQuotaWithPrompt(remain_quota)}
</Text>
</div>
<AutoComplete
placeholder={t('请输入额度')}
onChange={(value) => handleInputChange('remain_quota', value)}
value={remain_quota}
autoComplete="new-password"
type="number"
size="large"
className="w-full !rounded-lg"
autoComplete='new-password'
type='number'
size='large'
className='w-full !rounded-lg'
prefix={<IconCreditCard />}
data={[
{ value: 500000, label: '1$' },
@@ -460,16 +517,18 @@ const EditToken = (props) => {
{!isEdit && (
<div>
<Text strong className="block mb-2">{t('新建数量')}</Text>
<Text strong className='block mb-2'>
{t('新建数量')}
</Text>
<AutoComplete
placeholder={t('请选择或输入创建令牌的数量')}
onChange={(value) => handleTokenCountChange(value)}
onSelect={(value) => handleTokenCountChange(value)}
value={tokenCount.toString()}
autoComplete="off"
type="number"
className="w-full !rounded-lg"
size="large"
autoComplete='off'
type='number'
className='w-full !rounded-lg'
size='large'
prefix={<IconPlusCircle />}
data={[
{ value: 10, label: t('10个') },
@@ -482,12 +541,12 @@ const EditToken = (props) => {
</div>
)}
<div className="flex justify-end">
<div className='flex justify-end'>
<Button
theme="light"
type={unlimited_quota ? "danger" : "warning"}
theme='light'
type={unlimited_quota ? 'danger' : 'warning'}
onClick={setUnlimitedQuota}
className="!rounded-full"
className='!rounded-full'
>
{unlimited_quota ? t('取消无限额度') : t('设为无限额度')}
</Button>
@@ -495,92 +554,137 @@ const EditToken = (props) => {
</div>
</Card>
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
<div className="flex items-center mb-4 p-6 rounded-xl" style={{
background: 'linear-gradient(135deg, #4c1d95 0%, #6d28d9 50%, #7c3aed 100%)',
position: 'relative'
}}>
<div className="absolute inset-0 overflow-hidden">
<div className="absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full"></div>
<div className="absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full"></div>
<Card className='!rounded-2xl shadow-sm border-0 mb-6'>
<div
className='flex items-center mb-4 p-6 rounded-xl'
style={{
background:
'linear-gradient(135deg, #4c1d95 0%, #6d28d9 50%, #7c3aed 100%)',
position: 'relative',
}}
>
<div className='absolute inset-0 overflow-hidden'>
<div className='absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full'></div>
<div className='absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full'></div>
</div>
<div className="w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative">
<IconLink size="large" style={{ color: '#ffffff' }} />
<div className='w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative'>
<IconLink size='large' style={{ color: '#ffffff' }} />
</div>
<div className="relative">
<Text style={{ color: '#ffffff' }} className="text-lg font-medium">{t('访问限制')}</Text>
<div style={{ color: '#ffffff' }} className="text-sm opacity-80">{t('设置令牌的访问限制')}</div>
<div className='relative'>
<Text
style={{ color: '#ffffff' }}
className='text-lg font-medium'
>
{t('访问限制')}
</Text>
<div
style={{ color: '#ffffff' }}
className='text-sm opacity-80'
>
{t('设置令牌的访问限制')}
</div>
</div>
</div>
<div className="space-y-4">
<div className='space-y-4'>
<div>
<Text strong className="block mb-2">{t('IP白名单')}</Text>
<Text strong className='block mb-2'>
{t('IP白名单')}
</Text>
<TextArea
placeholder={t('允许的IP,一行一个,不填写则不限制')}
onChange={(value) => handleInputChange('allow_ips', value)}
value={inputs.allow_ips}
style={{ fontFamily: 'JetBrains Mono, Consolas' }}
className="!rounded-lg"
className='!rounded-lg'
rows={4}
/>
<Text type="tertiary" className="mt-1 block text-xs">{t('请勿过度信任此功能,IP可能被伪造')}</Text>
<Text type='tertiary' className='mt-1 block text-xs'>
{t('请勿过度信任此功能,IP可能被伪造')}
</Text>
</div>
<div>
<div className="flex items-center mb-2">
<div className='flex items-center mb-2'>
<Checkbox
checked={model_limits_enabled}
onChange={(e) => handleInputChange('model_limits_enabled', e.target.checked)}
onChange={(e) =>
handleInputChange(
'model_limits_enabled',
e.target.checked,
)
}
>
<Text strong>{t('模型限制')}</Text>
</Checkbox>
</div>
<Select
placeholder={model_limits_enabled ? t('请选择该渠道所支持的模型') : t('勾选启用模型限制后可选择')}
placeholder={
model_limits_enabled
? t('请选择该渠道所支持的模型')
: t('勾选启用模型限制后可选择')
}
onChange={(value) => handleInputChange('model_limits', value)}
value={inputs.model_limits}
multiple
size="large"
className="w-full !rounded-lg"
size='large'
className='w-full !rounded-lg'
prefix={<IconServer />}
optionList={models}
disabled={!model_limits_enabled}
maxTagCount={3}
/>
<Text type="tertiary" className="mt-1 block text-xs">{t('非必要,不建议启用模型限制')}</Text>
<Text type='tertiary' className='mt-1 block text-xs'>
{t('非必要,不建议启用模型限制')}
</Text>
</div>
</div>
</Card>
<Card className="!rounded-2xl shadow-sm border-0">
<div className="flex items-center mb-4 p-6 rounded-xl" style={{
background: 'linear-gradient(135deg, #92400e 0%, #d97706 50%, #f59e0b 100%)',
position: 'relative'
}}>
<div className="absolute inset-0 overflow-hidden">
<div className="absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full"></div>
<div className="absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full"></div>
<Card className='!rounded-2xl shadow-sm border-0'>
<div
className='flex items-center mb-4 p-6 rounded-xl'
style={{
background:
'linear-gradient(135deg, #92400e 0%, #d97706 50%, #f59e0b 100%)',
position: 'relative',
}}
>
<div className='absolute inset-0 overflow-hidden'>
<div className='absolute -top-10 -right-10 w-40 h-40 bg-white opacity-5 rounded-full'></div>
<div className='absolute -bottom-8 -left-8 w-24 h-24 bg-white opacity-10 rounded-full'></div>
</div>
<div className="w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative">
<IconUserGroup size="large" style={{ color: '#ffffff' }} />
<div className='w-10 h-10 rounded-full bg-white/20 flex items-center justify-center mr-4 relative'>
<IconUserGroup size='large' style={{ color: '#ffffff' }} />
</div>
<div className="relative">
<Text style={{ color: '#ffffff' }} className="text-lg font-medium">{t('分组信息')}</Text>
<div style={{ color: '#ffffff' }} className="text-sm opacity-80">{t('设置令牌的分组')}</div>
<div className='relative'>
<Text
style={{ color: '#ffffff' }}
className='text-lg font-medium'
>
{t('分组信息')}
</Text>
<div
style={{ color: '#ffffff' }}
className='text-sm opacity-80'
>
{t('设置令牌的分组')}
</div>
</div>
</div>
<div>
<Text strong className="block mb-2">{t('令牌分组')}</Text>
<Text strong className='block mb-2'>
{t('令牌分组')}
</Text>
{groups.length > 0 ? (
<Select
placeholder={t('令牌分组,默认为用户的分组')}
onChange={(value) => handleInputChange('group', value)}
renderOptionItem={renderGroupOption}
value={inputs.group}
size="large"
className="w-full !rounded-lg"
size='large'
className='w-full !rounded-lg'
prefix={<IconUserGroup />}
optionList={groups}
/>
@@ -588,8 +692,8 @@ const EditToken = (props) => {
<Select
placeholder={t('管理员未设置用户可选分组')}
disabled={true}
size="large"
className="w-full !rounded-lg"
size='large'
className='w-full !rounded-lg'
prefix={<IconUserGroup />}
/>
)}

Some files were not shown because too many files have changed in this diff Show More