Compare commits

...

38 Commits

Author SHA1 Message Date
Calcium-Ion 705e5edd80 Merge pull request #1522 from QuantumNous/support-deepseek-claude
feat: support deepseek claude format (convert)
2025-08-07 19:04:05 +08:00
Calcium-Ion 65b12d7755 Merge pull request #1521 from QuantumNous/support-qwen-claude
feat: support qwen claude format
2025-08-07 19:03:40 +08:00
CaIon b8b59a134e feat: support deepseek claude format (convert) 2025-08-07 19:01:49 +08:00
CaIon d37af13b33 feat: support qwen claude format 2025-08-07 18:32:31 +08:00
IcedTangerine 76753cea7d Merge pull request #1519 from feitianbubu/pr/fix-qwen3-thinking-test
feat: enable thinking mode on ali thinking model
2025-08-07 17:39:27 +08:00
CaIon c4666934be Revert "feat: update Usage struct to support dynamic token handling with ceil function #1503"
This reverts commit 97b8d7de9e.
2025-08-07 16:22:40 +08:00
CaIon a4b02107dd feat: update MaxTokens handling 2025-08-07 16:15:59 +08:00
CaIon 97b8d7de9e feat: update Usage struct to support dynamic token handling with ceil function #1503 2025-08-07 15:40:12 +08:00
feitianbubu ea7fd9875b feat: enable thinking mode on ali thinking model 2025-08-07 11:59:54 +08:00
Xyfacai 15c11bfe51 refactor: 调整模型匹配 2025-08-06 20:09:22 +08:00
Xyfacai 423ceae515 fix: error code 显示问题 2025-08-06 19:40:26 +08:00
CaIon 2e41362f2e fix: update budget calculation logic in relay-gemini to use clamping function 2025-08-06 16:25:48 +08:00
CaIon 6960a06322 feat: enhance ThinkingAdaptor with effort-based budget clamping and extra body handling 2025-08-06 16:20:38 +08:00
CaIon 4f6d16e365 feat: add reasoning support for Openrouter requests with "-thinking" suffix 2025-08-06 12:50:26 +08:00
Calcium-Ion f1faa08c1e Merge pull request #1508 from wzxjohn/feature/aws_new_apikey_support
feat: support aws bedrock apikey
2025-08-06 12:04:28 +08:00
Calcium-Ion 76072de685 Merge pull request #1510 from RedwindA/fix/manual-price-edit-modelName-check
fix:修复添加模型倍率时的输入框锁定
2025-08-06 12:03:44 +08:00
Calcium-Ion e7c657ef87 Merge pull request #1511 from neotf/feat-05
feat: add support for claude-opus-4-1 model and update ratios
2025-08-06 12:03:33 +08:00
Calcium-Ion 421752497a Merge pull request #1509 from QuantumNous/responses-input-cache-token
fix: responses cache token 未计费
2025-08-06 11:22:14 +08:00
neotf c9bcdc89f0 feat: add support for claude-opus-4-1 model and update ratios 2025-08-06 00:58:46 +08:00
RedwindA b0dc31c414 fix(web): 修复模型倍率设置中添加新模型时输入框锁定的问题 2025-08-05 23:18:42 +08:00
creamlike1024 02fccf0330 fix: responses 流 cache token 未计费 2025-08-05 23:08:08 +08:00
wzxjohn d31027d5c7 feat: support aws bedrock apikey 2025-08-05 23:01:30 +08:00
creamlike1024 2d226a813e fix: responses cache token 未计费 2025-08-05 22:56:27 +08:00
Calcium-Ion 2286ec0641 Merge pull request #1507 from QuantumNous/multi-key-manage
feat: implement channel-specific locking for thread-safe polling
2025-08-05 20:40:26 +08:00
CaIon f3a961f071 fix: reorder request URL handling for relay formats in Adaptor 2025-08-05 20:40:00 +08:00
CaIon 755acc6191 feat: implement channel-specific locking for thread-safe polling 2025-08-04 20:44:19 +08:00
Calcium-Ion 3c5128a671 Merge pull request #1499 from QuantumNous/multi-key-manage
feat: improve layout and pagination handling in MultiKeyManageModal
2025-08-04 20:17:22 +08:00
CaIon 5e47da1a8e feat: improve layout and pagination handling in MultiKeyManageModal 2025-08-04 20:16:51 +08:00
Calcium-Ion 0914d5ec36 Merge pull request #1487 from seefs001/feature/2fa
feat: implement two-factor authentication (2FA) support with user login and settings integration
2025-08-04 19:54:31 +08:00
Calcium-Ion 711a7e34a6 Merge pull request #1498 from QuantumNous/multi-key-manage
feat: add multi-key management
2025-08-04 19:53:52 +08:00
CaIon 808c8c5bcb feat: add status filtering and bulk enable/disable functionality in multi-key management 2025-08-04 19:51:58 +08:00
CaIon 8fe43be3eb fix: correct option value for pagination in MultiKeyManageModal 2025-08-04 19:33:24 +08:00
CaIon 1ecd8b41d8 feat: enhance multi-key management with pagination and statistics 2025-08-04 17:15:32 +08:00
CaIon aeec39a198 feat: add multi-key management 2025-08-04 16:52:31 +08:00
Xyfacai bb08de0b11 fix: 修复gemini2openai 没有返回 usage 2025-08-04 09:06:57 +08:00
Seefs 2f2546dffb refactor: improve error handling and database transactions in 2FA model methods 2025-08-03 10:49:55 +08:00
Seefs b56a75cbe3 fix: coderabbit review 2025-08-03 10:41:00 +08:00
Seefs b48d3a6b40 feat: implement two-factor authentication (2FA) support with user login and settings integration 2025-08-02 14:53:28 +08:00
57 changed files with 3350 additions and 269 deletions
+150
View File
@@ -0,0 +1,150 @@
package common
import (
"crypto/rand"
"fmt"
"os"
"strconv"
"strings"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
)
const (
// 备用码配置
BackupCodeLength = 8 // 备用码长度
BackupCodeCount = 4 // 生成备用码数量
// 限制配置
MaxFailAttempts = 5 // 最大失败尝试次数
LockoutDuration = 300 // 锁定时间(秒)
)
// GenerateTOTPSecret 生成TOTP密钥和配置
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
issuer := Get2FAIssuer()
return totp.Generate(totp.GenerateOpts{
Issuer: issuer,
AccountName: accountName,
Period: 30,
Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1,
})
}
// ValidateTOTPCode 验证TOTP验证码
func ValidateTOTPCode(secret, code string) bool {
// 清理验证码格式
cleanCode := strings.ReplaceAll(code, " ", "")
if len(cleanCode) != 6 {
return false
}
// 验证验证码
return totp.Validate(cleanCode, secret)
}
// GenerateBackupCodes 生成备用恢复码
func GenerateBackupCodes() ([]string, error) {
codes := make([]string, BackupCodeCount)
for i := 0; i < BackupCodeCount; i++ {
code, err := generateRandomBackupCode()
if err != nil {
return nil, err
}
codes[i] = code
}
return codes, nil
}
// generateRandomBackupCode 生成单个备用码
func generateRandomBackupCode() (string, error) {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
code := make([]byte, BackupCodeLength)
for i := range code {
randomBytes := make([]byte, 1)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
code[i] = charset[int(randomBytes[0])%len(charset)]
}
// 格式化为 XXXX-XXXX 格式
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
}
// ValidateBackupCode 验证备用码格式
func ValidateBackupCode(code string) bool {
// 移除所有分隔符并转为大写
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
if len(cleanCode) != BackupCodeLength {
return false
}
// 检查字符是否合法
for _, char := range cleanCode {
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
return false
}
}
return true
}
// NormalizeBackupCode 标准化备用码格式
func NormalizeBackupCode(code string) string {
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
if len(cleanCode) == BackupCodeLength {
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
}
return code
}
// HashBackupCode 对备用码进行哈希
func HashBackupCode(code string) (string, error) {
normalizedCode := NormalizeBackupCode(code)
return Password2Hash(normalizedCode)
}
// Get2FAIssuer 获取2FA发行者名称
func Get2FAIssuer() string {
return SystemName
}
// getEnvOrDefault 获取环境变量或默认值
func getEnvOrDefault(key, defaultValue string) string {
if value, exists := os.LookupEnv(key); exists {
return value
}
return defaultValue
}
// ValidateNumericCode 验证数字验证码格式
func ValidateNumericCode(code string) (string, error) {
// 移除空格
code = strings.ReplaceAll(code, " ", "")
if len(code) != 6 {
return "", fmt.Errorf("验证码必须是6位数字")
}
// 检查是否为纯数字
if _, err := strconv.Atoi(code); err != nil {
return "", fmt.Errorf("验证码只能包含数字")
}
return code, nil
}
// GenerateQRCodeData 生成二维码数据
func GenerateQRCodeData(secret, username string) string {
issuer := Get2FAIssuer()
accountName := fmt.Sprintf("%s (%s)", username, issuer)
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
issuer, accountName, secret, issuer)
}
-1
View File
@@ -11,7 +11,6 @@ const (
ContextKeyTokenKey ContextKey = "token_key"
ContextKeyTokenId ContextKey = "token_id"
ContextKeyTokenGroup ContextKey = "token_group"
ContextKeyTokenAllowIps ContextKey = "allow_ips"
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
+2 -2
View File
@@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
logInfo.ApiKey = ""
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
if err != nil {
return testResult{
context: c,
@@ -275,7 +275,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
Quota: quota,
Content: "模型测试",
UseTimeSeconds: int(consumedTime),
IsStream: false,
IsStream: info.IsStream,
Group: info.UsingGroup,
Other: other,
})
+429
View File
@@ -71,6 +71,13 @@ func parseStatusFilter(statusParam string) int {
}
}
func clearChannelInfo(channel *model.Channel) {
if channel.ChannelInfo.IsMultiKey {
channel.ChannelInfo.MultiKeyDisabledReason = nil
channel.ChannelInfo.MultiKeyDisabledTime = nil
}
}
func GetAllChannels(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
@@ -145,6 +152,10 @@ func GetAllChannels(c *gin.Context) {
}
}
for _, datum := range channelData {
clearChannelInfo(datum)
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
@@ -371,6 +382,10 @@ func SearchChannels(c *gin.Context) {
pagedData := channelData[startIdx:endIdx]
for _, datum := range pagedData {
clearChannelInfo(datum)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -394,6 +409,9 @@ func GetChannel(c *gin.Context) {
common.ApiError(c, err)
return
}
if channel != nil {
clearChannelInfo(channel)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -827,6 +845,7 @@ func UpdateChannel(c *gin.Context) {
}
model.InitChannelCache()
channel.Key = ""
clearChannelInfo(&channel.Channel)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -1030,3 +1049,413 @@ func CopyChannel(c *gin.Context) {
// success
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
}
// MultiKeyManageRequest represents the request for multi-key management operations
type MultiKeyManageRequest struct {
ChannelId int `json:"channel_id"`
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
Page int `json:"page,omitempty"` // for get_key_status pagination
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
}
// MultiKeyStatusResponse represents the response for key status query
type MultiKeyStatusResponse struct {
Keys []KeyStatus `json:"keys"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
// Statistics
EnabledCount int `json:"enabled_count"`
ManualDisabledCount int `json:"manual_disabled_count"`
AutoDisabledCount int `json:"auto_disabled_count"`
}
type KeyStatus struct {
Index int `json:"index"`
Status int `json:"status"` // 1: enabled, 2: disabled
DisabledTime int64 `json:"disabled_time,omitempty"`
Reason string `json:"reason,omitempty"`
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
}
// ManageMultiKeys handles multi-key management operations
func ManageMultiKeys(c *gin.Context) {
request := MultiKeyManageRequest{}
err := c.ShouldBindJSON(&request)
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(request.ChannelId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "渠道不存在",
})
return
}
if !channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该渠道不是多密钥模式",
})
return
}
lock := model.GetChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()
switch request.Action {
case "get_key_status":
keys := channel.GetKeys()
// Default pagination parameters
page := request.Page
pageSize := request.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 50 // Default page size
}
// Statistics for all keys (unchanged by filtering)
var enabledCount, manualDisabledCount, autoDisabledCount int
// Build all key status data first
var allKeyStatusList []KeyStatus
for i, key := range keys {
status := 1 // default enabled
var disabledTime int64
var reason string
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// Count for statistics (all keys)
switch status {
case 1:
enabledCount++
case 2:
manualDisabledCount++
case 3:
autoDisabledCount++
}
if status != 1 {
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
}
}
// Create key preview (first 10 chars)
keyPreview := key
if len(key) > 10 {
keyPreview = key[:10] + "..."
}
allKeyStatusList = append(allKeyStatusList, KeyStatus{
Index: i,
Status: status,
DisabledTime: disabledTime,
Reason: reason,
KeyPreview: keyPreview,
})
}
// Apply status filter if specified
var filteredKeyStatusList []KeyStatus
if request.Status != nil {
for _, keyStatus := range allKeyStatusList {
if keyStatus.Status == *request.Status {
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
}
}
} else {
filteredKeyStatusList = allKeyStatusList
}
// Calculate pagination based on filtered results
filteredTotal := len(filteredKeyStatusList)
totalPages := (filteredTotal + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
if page > totalPages {
page = totalPages
}
// Calculate range for current page
start := (page - 1) * pageSize
end := start + pageSize
if end > filteredTotal {
end = filteredTotal
}
// Get the page data
var pageKeyStatusList []KeyStatus
if start < filteredTotal {
pageKeyStatusList = filteredKeyStatusList[start:end]
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": MultiKeyStatusResponse{
Keys: pageKeyStatusList,
Total: filteredTotal, // Total of filtered results
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
EnabledCount: enabledCount, // Overall statistics
ManualDisabledCount: manualDisabledCount, // Overall statistics
AutoDisabledCount: autoDisabledCount, // Overall statistics
},
})
return
case "disable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要禁用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已禁用",
})
return
case "enable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要启用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
if channel.ChannelInfo.MultiKeyStatusList != nil {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已启用",
})
return
case "enable_all_keys":
// 清空所有禁用状态,使所有密钥回到默认启用状态
var enabledCount int
if channel.ChannelInfo.MultiKeyStatusList != nil {
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
}
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
})
return
case "disable_all_keys":
// 禁用所有启用的密钥
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
var disabledCount int
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
status := 1 // default enabled
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
// 只禁用当前启用的密钥
if status == 1 {
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
disabledCount++
}
}
if disabledCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有可禁用的密钥",
})
return
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
})
return
case "delete_disabled_keys":
keys := channel.GetKeys()
var remainingKeys []string
var deletedCount int
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
status := 1 // default enabled
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
if status == 3 {
deletedCount++
} else {
remainingKeys = append(remainingKeys, key)
// 保留非自动禁用密钥的状态信息,重新索引
if status != 1 {
newStatusList[newIndex] = status
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
}
newIndex++
}
}
if deletedCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有需要删除的自动禁用密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
"data": deletedCount,
})
return
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的操作",
})
return
}
}
+553
View File
@@ -0,0 +1,553 @@
package controller
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// Setup2FARequest 设置2FA请求结构
type Setup2FARequest struct {
Code string `json:"code" binding:"required"`
}
// Verify2FARequest 验证2FA请求结构
type Verify2FARequest struct {
Code string `json:"code" binding:"required"`
}
// Setup2FAResponse 设置2FA响应结构
type Setup2FAResponse struct {
Secret string `json:"secret"`
QRCodeData string `json:"qr_code_data"`
BackupCodes []string `json:"backup_codes"`
}
// Setup2FA 初始化2FA设置
func Setup2FA(c *gin.Context) {
userId := c.GetInt("id")
// 检查用户是否已经启用2FA
existing, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if existing != nil && existing.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已启用2FA,请先禁用后重新设置",
})
return
}
// 如果存在已禁用的2FA记录,先删除它
if existing != nil && !existing.IsEnabled {
if err := existing.Delete(); err != nil {
common.ApiError(c, err)
return
}
existing = nil // 重置为nil,后续将创建新记录
}
// 获取用户信息
user, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
// 生成TOTP密钥
key, err := common.GenerateTOTPSecret(user.Username)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成2FA密钥失败",
})
common.SysError("生成TOTP密钥失败: " + err.Error())
return
}
// 生成备用码
backupCodes, err := common.GenerateBackupCodes()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成备用码失败",
})
common.SysError("生成备用码失败: " + err.Error())
return
}
// 生成二维码数据
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
// 创建或更新2FA记录(暂未启用)
twoFA := &model.TwoFA{
UserId: userId,
Secret: key.Secret(),
IsEnabled: false,
}
if existing != nil {
// 更新现有记录
twoFA.Id = existing.Id
err = twoFA.Update()
} else {
// 创建新记录
err = twoFA.Create()
}
if err != nil {
common.ApiError(c, err)
return
}
// 创建备用码记录
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "保存备用码失败",
})
common.SysError("保存备用码失败: " + err.Error())
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
"data": Setup2FAResponse{
Secret: key.Secret(),
QRCodeData: qrCodeData,
BackupCodes: backupCodes,
},
})
}
// Enable2FA 启用2FA
func Enable2FA(c *gin.Context) {
var req Setup2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "请先完成2FA初始化设置",
})
return
}
if twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "2FA已经启用",
})
return
}
// 验证TOTP验证码
cleanCode, err := common.ValidateNumericCode(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 启用2FA
if err := twoFA.Enable(); err != nil {
common.ApiError(c, err)
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "两步验证启用成功",
})
}
// Disable2FA 禁用2FA
func Disable2FA(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码或备用码
cleanCode, err := common.ValidateNumericCode(req.Code)
isValidTOTP := false
isValidBackup := false
if err == nil {
// 尝试验证TOTP
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
}
if !isValidTOTP {
// 尝试验证备用码
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
if !isValidTOTP && !isValidBackup {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 禁用2FA
if err := model.DisableTwoFA(userId); err != nil {
common.ApiError(c, err)
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "两步验证已禁用",
})
}
// Get2FAStatus 获取用户2FA状态
func Get2FAStatus(c *gin.Context) {
userId := c.GetInt("id")
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
status := map[string]interface{}{
"enabled": false,
"locked": false,
}
if twoFA != nil {
status["enabled"] = twoFA.IsEnabled
status["locked"] = twoFA.IsLocked()
if twoFA.IsEnabled {
// 获取剩余备用码数量
backupCount, err := model.GetUnusedBackupCodeCount(userId)
if err != nil {
common.SysError("获取备用码数量失败: " + err.Error())
} else {
status["backup_codes_remaining"] = backupCount
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": status,
})
}
// RegenerateBackupCodes 重新生成备用码
func RegenerateBackupCodes(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码
cleanCode, err := common.ValidateNumericCode(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if !valid {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 生成新的备用码
backupCodes, err := common.GenerateBackupCodes()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成备用码失败",
})
common.SysError("生成备用码失败: " + err.Error())
return
}
// 保存新的备用码
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "保存备用码失败",
})
common.SysError("保存备用码失败: " + err.Error())
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "备用码重新生成成功",
"data": map[string]interface{}{
"backup_codes": backupCodes,
},
})
}
// Verify2FALogin 登录时验证2FA
func Verify2FALogin(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
// 从会话中获取pending用户信息
session := sessions.Default(c)
pendingUserId := session.Get("pending_user_id")
if pendingUserId == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "会话已过期,请重新登录",
})
return
}
userId, ok := pendingUserId.(int)
if !ok {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "会话数据无效,请重新登录",
})
return
}
// 获取用户信息
user, err := model.GetUserById(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户不存在",
})
return
}
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(user.Id)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码或备用码
cleanCode, err := common.ValidateNumericCode(req.Code)
isValidTOTP := false
isValidBackup := false
if err == nil {
// 尝试验证TOTP
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
}
if !isValidTOTP {
// 尝试验证备用码
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
if !isValidTOTP && !isValidBackup {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 2FA验证成功,清理pending会话信息并完成登录
session.Delete("pending_username")
session.Delete("pending_user_id")
session.Save()
setupLogin(user, c)
}
// Admin2FAStats 管理员获取2FA统计信息
func Admin2FAStats(c *gin.Context) {
stats, err := model.GetTwoFAStats()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": stats,
})
}
// AdminDisable2FA 管理员强制禁用用户2FA
func AdminDisable2FA(c *gin.Context) {
userIdStr := c.Param("id")
userId, err := strconv.Atoi(userIdStr)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户ID格式错误",
})
return
}
// 检查目标用户权限
targetUser, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
myRole := c.GetInt("role")
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权操作同级或更高级用户的2FA设置",
})
return
}
// 禁用2FA
if err := model.DisableTwoFA(userId); err != nil {
if errors.Is(err, model.ErrTwoFANotEnabled) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
common.ApiError(c, err)
return
}
// 记录操作日志
adminId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeManage,
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "用户2FA已被强制禁用",
})
}
+26
View File
@@ -62,6 +62,32 @@ func Login(c *gin.Context) {
})
return
}
// 检查是否启用2FA
if model.IsTwoFAEnabled(user.Id) {
// 设置pending session,等待2FA验证
session := sessions.Default(c)
session.Set("pending_username", user.Username)
session.Set("pending_user_id", user.Id)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"message": "无法保存会话信息,请重试",
"success": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "请输入两步验证码",
"success": true,
"data": map[string]interface{}{
"require_2fa": true,
},
})
return
}
setupLogin(&user, c)
}
+1 -1
View File
@@ -361,7 +361,7 @@ type ClaudeUsage struct {
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
}
type ClaudeServerToolUse struct {
+5 -2
View File
@@ -99,8 +99,11 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
if r.MaxCompletionTokens != 0 {
return r.MaxCompletionTokens
}
return r.MaxTokens
}
func (r *GeneralOpenAIRequest) ParseInput() []string {
+8 -6
View File
@@ -7,9 +7,10 @@ require (
github.com/Calcium-Ion/go-epay v0.0.4
github.com/andybalholm/brotli v1.1.1
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2 v1.37.2
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
github.com/aws/smithy-go v1.22.5
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
@@ -24,6 +25,7 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.5.0
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
@@ -41,10 +43,10 @@ require (
require (
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
github.com/boombuler/barcode v1.1.0 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
+17 -12
View File
@@ -6,20 +6,23 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
@@ -169,6 +172,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+32 -1
View File
@@ -4,7 +4,10 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
@@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
allowIpsMap := token.GetIpLimitsMap()
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
@@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) {
userCache.WriteContext(c)
userGroup := userCache.Group
tokenGroup := token.Group
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
err = SetupContextForToken(c, token, parts...)
if err != nil {
return
@@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
+12 -39
View File
@@ -10,7 +10,6 @@ import (
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"one-api/types"
"strconv"
@@ -27,14 +26,6 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
var channel *model.Channel
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -81,22 +54,21 @@ func Distribute() func(c *gin.Context) {
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
if !ok {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
var tokenModelLimit map[string]bool
tokenModelLimit, ok = s.(map[string]bool)
if !ok {
tokenModelLimit = map[string]bool{}
}
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
if _, ok := tokenModelLimit[matchName]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
}
if shouldSelectChannel {
@@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) {
return
}
var selectGroup string
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
showGroup := userGroup
+25 -14
View File
@@ -41,6 +41,7 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
Settings string `json:"settings"`
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"`
@@ -52,11 +53,13 @@ type Channel struct {
}
type ChannelInfo struct {
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
}
// Value implements driver.Valuer interface
@@ -70,7 +73,7 @@ func (c *ChannelInfo) Scan(value interface{}) error {
return common.Unmarshal(bytesValue, c)
}
func (channel *Channel) getKeys() []string {
func (channel *Channel) GetKeys() []string {
if channel.Key == "" {
return []string{}
}
@@ -101,7 +104,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
}
// Obtain all keys (split by \n)
keys := channel.getKeys()
keys := channel.GetKeys()
if len(keys) == 0 {
// No keys available, return error, should disable the channel
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
@@ -138,7 +141,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
return keys[selectedIdx], selectedIdx, nil
case constant.MultiKeyModePolling:
// Use channel-specific lock to ensure thread-safe polling
lock := getChannelPollingLock(channel.Id)
lock := GetChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()
@@ -497,8 +500,8 @@ var channelStatusLock sync.Mutex
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
var channelPollingLocks sync.Map
// getChannelPollingLock returns or creates a mutex for the given channel ID
func getChannelPollingLock(channelId int) *sync.Mutex {
// GetChannelPollingLock returns or creates a mutex for the given channel ID
func GetChannelPollingLock(channelId int) *sync.Mutex {
if lock, exists := channelPollingLocks.Load(channelId); exists {
return lock.(*sync.Mutex)
}
@@ -528,8 +531,8 @@ func CleanupChannelPollingLocks() {
})
}
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
keys := channel.getKeys()
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
keys := channel.GetKeys()
if len(keys) == 0 {
channel.Status = status
} else {
@@ -547,6 +550,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
} else {
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
}
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
channel.Status = common.ChannelStatusAutoDisabled
@@ -569,7 +580,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
}
if channelCache.ChannelInfo.IsMultiKey {
// 如果是多Key模式,更新缓存中的状态
handlerMultiKeyUpdate(channelCache, usingKey, status)
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
//CacheUpdateChannel(channelCache)
//return true
} else {
@@ -600,7 +611,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if channel.ChannelInfo.IsMultiKey {
beforeStatus := channel.Status
handlerMultiKeyUpdate(channel, usingKey, status)
handlerMultiKeyUpdate(channel, usingKey, status, reason)
if beforeStatus != channel.Status {
shouldUpdateAbilities = true
}
+3 -7
View File
@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/setting"
"one-api/setting/ratio_setting"
"sort"
"strings"
"sync"
@@ -70,7 +71,7 @@ func InitChannelCache() {
//channelsIDM = newChannelId2channel
for i, channel := range newChannelId2channel {
if channel.ChannelInfo.IsMultiKey {
channel.Keys = channel.getKeys()
channel.Keys = channel.GetKeys()
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
if oldChannel, ok := channelsIDM[i]; ok {
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
@@ -128,12 +129,7 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
if strings.HasPrefix(model, "gpt-4o-gizmo") {
model = "gpt-4o-gizmo-*"
}
model = ratio_setting.FormatMatchingModelName(model)
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
+4
View File
@@ -251,6 +251,8 @@ func migrateDB() error {
&QuotaData{},
&Task{},
&Setup{},
&TwoFA{},
&TwoFABackupCode{},
)
if err != nil {
return err
@@ -277,6 +279,8 @@ func migrateDBFast() error {
{&QuotaData{}, "QuotaData"},
{&Task{}, "Task"},
{&Setup{}, "Setup"},
{&TwoFA{}, "TwoFA"},
{&TwoFABackupCode{}, "TwoFABackupCode"},
}
// 动态计算migration数量,确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))
+322
View File
@@ -0,0 +1,322 @@
package model
import (
"errors"
"fmt"
"one-api/common"
"time"
"gorm.io/gorm"
)
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
// TwoFA 用户2FA设置表
type TwoFA struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"unique;not null;index"`
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
LockedUntil *time.Time `json:"locked_until,omitempty"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// TwoFABackupCode 备用码使用记录表
type TwoFABackupCode struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"not null;index"`
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
IsUsed bool `json:"is_used" gorm:"default:false"`
UsedAt *time.Time `json:"used_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// GetTwoFAByUserId 根据用户ID获取2FA设置
func GetTwoFAByUserId(userId int) (*TwoFA, error) {
if userId == 0 {
return nil, errors.New("用户ID不能为空")
}
var twoFA TwoFA
err := DB.Where("user_id = ?", userId).First(&twoFA).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // 返回nil表示未设置2FA
}
return nil, err
}
return &twoFA, nil
}
// IsTwoFAEnabled 检查用户是否启用了2FA
func IsTwoFAEnabled(userId int) bool {
twoFA, err := GetTwoFAByUserId(userId)
if err != nil || twoFA == nil {
return false
}
return twoFA.IsEnabled
}
// CreateTwoFA 创建2FA设置
func (t *TwoFA) Create() error {
// 检查用户是否已存在2FA设置
existing, err := GetTwoFAByUserId(t.UserId)
if err != nil {
return err
}
if existing != nil {
return errors.New("用户已存在2FA设置")
}
// 验证用户存在
var user User
if err := DB.First(&user, t.UserId).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("用户不存在")
}
return err
}
return DB.Create(t).Error
}
// Update 更新2FA设置
func (t *TwoFA) Update() error {
if t.Id == 0 {
return errors.New("2FA记录ID不能为空")
}
return DB.Save(t).Error
}
// Delete 删除2FA设置
func (t *TwoFA) Delete() error {
if t.Id == 0 {
return errors.New("2FA记录ID不能为空")
}
// 使用事务确保原子性
return DB.Transaction(func(tx *gorm.DB) error {
// 同时删除相关的备用码记录(硬删除)
if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
return err
}
// 硬删除2FA记录
return tx.Unscoped().Delete(t).Error
})
}
// ResetFailedAttempts 重置失败尝试次数
func (t *TwoFA) ResetFailedAttempts() error {
t.FailedAttempts = 0
t.LockedUntil = nil
return t.Update()
}
// IncrementFailedAttempts 增加失败尝试次数
func (t *TwoFA) IncrementFailedAttempts() error {
t.FailedAttempts++
// 检查是否需要锁定
if t.FailedAttempts >= common.MaxFailAttempts {
lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
t.LockedUntil = &lockUntil
}
return t.Update()
}
// IsLocked 检查账户是否被锁定
func (t *TwoFA) IsLocked() bool {
if t.LockedUntil == nil {
return false
}
return time.Now().Before(*t.LockedUntil)
}
// CreateBackupCodes 创建备用码
func CreateBackupCodes(userId int, codes []string) error {
return DB.Transaction(func(tx *gorm.DB) error {
// 先删除现有的备用码
if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
return err
}
// 创建新的备用码记录
for _, code := range codes {
hashedCode, err := common.HashBackupCode(code)
if err != nil {
return err
}
backupCode := TwoFABackupCode{
UserId: userId,
CodeHash: hashedCode,
IsUsed: false,
}
if err := tx.Create(&backupCode).Error; err != nil {
return err
}
}
return nil
})
}
// ValidateBackupCode 验证并使用备用码
func ValidateBackupCode(userId int, code string) (bool, error) {
if !common.ValidateBackupCode(code) {
return false, errors.New("验证码或备用码不正确")
}
normalizedCode := common.NormalizeBackupCode(code)
// 查找未使用的备用码
var backupCodes []TwoFABackupCode
if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
return false, err
}
// 验证备用码
for _, bc := range backupCodes {
if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
// 标记为已使用
now := time.Now()
bc.IsUsed = true
bc.UsedAt = &now
if err := DB.Save(&bc).Error; err != nil {
return false, err
}
return true, nil
}
}
return false, nil
}
// GetUnusedBackupCodeCount 获取未使用的备用码数量
func GetUnusedBackupCodeCount(userId int) (int, error) {
var count int64
err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
return int(count), err
}
// DisableTwoFA 禁用用户的2FA
func DisableTwoFA(userId int) error {
twoFA, err := GetTwoFAByUserId(userId)
if err != nil {
return err
}
if twoFA == nil {
return ErrTwoFANotEnabled
}
// 删除2FA设置和备用码
return twoFA.Delete()
}
// EnableTwoFA 启用2FA
func (t *TwoFA) Enable() error {
t.IsEnabled = true
t.FailedAttempts = 0
t.LockedUntil = nil
return t.Update()
}
// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
// 检查是否被锁定
if t.IsLocked() {
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
}
// 验证TOTP码
if !common.ValidateTOTPCode(t.Secret, code) {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
// 验证成功,重置失败次数并更新最后使用时间
now := time.Now()
t.FailedAttempts = 0
t.LockedUntil = nil
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
}
// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
// 检查是否被锁定
if t.IsLocked() {
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
}
// 验证备用码
valid, err := ValidateBackupCode(t.UserId, code)
if err != nil {
return false, err
}
if !valid {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
// 验证成功,重置失败次数并更新最后使用时间
now := time.Now()
t.FailedAttempts = 0
t.LockedUntil = nil
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
}
// GetTwoFAStats 获取2FA统计信息(管理员使用)
func GetTwoFAStats() (map[string]interface{}, error) {
var totalUsers, enabledUsers int64
// 总用户数
if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
return nil, err
}
// 启用2FA的用户数
if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
return nil, err
}
enabledRate := float64(0)
if totalUsers > 0 {
enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
}
return map[string]interface{}{
"total_users": totalUsers,
"enabled_users": enabledUsers,
"enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
}, nil
}
+47 -27
View File
@@ -3,16 +3,17 @@ package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
"strings"
)
type Adaptor struct {
@@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
return req, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -34,18 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
switch info.RelayFormat {
case relaycommon.RelayFormatClaude:
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
}
}
return fullRequestURL, nil
}
@@ -65,7 +70,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
// docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
// fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
if strings.Contains(request.Model, "thinking") {
request.EnableThinking = true
request.Stream = true
info.IsStream = true
}
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
if !info.IsStream {
request.EnableThinking = false
@@ -106,18 +117,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info)
default:
switch info.RelayFormat {
case relaycommon.RelayFormatClaude:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
default:
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info)
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
}
return
+4
View File
@@ -13,6 +13,7 @@ var awsModelIDMap = map[string]string{
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -54,6 +55,9 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-opus-4-20250514-v1:0": {
"us": true,
},
"anthropic.claude-opus-4-1-20250805-v1:0": {
"us": true,
},
}
var awsRegionCrossModelPrefixMap = map[string]string{
+19 -8
View File
@@ -19,20 +19,31 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/aws/smithy-go/auth/bearer"
)
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
awsSecret := strings.Split(info.ApiKey, "|")
if len(awsSecret) != 3 {
var client *bedrockruntime.Client
switch len(awsSecret) {
case 2:
apiKey := awsSecret[0]
region := awsSecret[1]
client = bedrockruntime.New(bedrockruntime.Options{
Region: region,
BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
})
case 3:
ak := awsSecret[0]
sk := awsSecret[1]
region := awsSecret[2]
client = bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
default:
return nil, errors.New("invalid aws secret key")
}
ak := awsSecret[0]
sk := awsSecret[1]
region := awsSecret[2]
client := bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
return client, nil
}
+3 -3
View File
@@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
if request.GetMaxTokens() != 0 {
maxTokens := int(request.GetMaxTokens())
if request.GetMaxTokens() == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
+1 -1
View File
@@ -104,7 +104,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
err, usage = ClaudeHandler(c, resp, info, a.RequestMode)
}
return
}
+2
View File
@@ -17,6 +17,8 @@ var ModelList = []string{
"claude-sonnet-4-20250514-thinking",
"claude-opus-4-20250514",
"claude-opus-4-20250514-thinking",
"claude-opus-4-1-20250805",
"claude-opus-4-1-20250805-thinking",
}
var ChannelName = "claude"
+2 -2
View File
@@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
MaxTokens: textRequest.GetMaxTokens(),
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
+1 -1
View File
@@ -5,7 +5,7 @@ import "one-api/dto"
type CfRequest struct {
Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
+1 -1
View File
@@ -7,7 +7,7 @@ type CohereRequest struct {
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
MaxTokens uint `json:"max_tokens"`
SafetyMode string `json:"safety_mode,omitempty"`
}
+3 -4
View File
@@ -24,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := openai.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+78 -8
View File
@@ -49,12 +49,20 @@ const (
flash25LiteMaxBudget = 24576
)
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
func clampThinkingBudget(modelName string, budget int) int {
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
func isNew25ProModel(modelName string) bool {
return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
}
func is25FlashLiteModel(modelName string) bool {
return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
}
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
func clampThinkingBudget(modelName string, budget int) int {
isNew25Pro := isNew25ProModel(modelName)
is25FlashLite := is25FlashLiteModel(modelName)
if is25FlashLite {
if budget < flash25LiteMinBudget {
@@ -81,7 +89,34 @@ func clampThinkingBudget(modelName string, budget int) int {
return budget
}
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) {
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
func clampThinkingBudgetByEffort(modelName string, effort string) int {
isNew25Pro := isNew25ProModel(modelName)
is25FlashLite := is25FlashLiteModel(modelName)
maxBudget := 0
if is25FlashLite {
maxBudget = flash25LiteMaxBudget
}
if isNew25Pro {
maxBudget = pro25MaxBudget
} else {
maxBudget = flash25MaxBudget
}
switch effort {
case "high":
maxBudget = maxBudget * 80 / 100
case "medium":
maxBudget = maxBudget * 50 / 100
case "low":
maxBudget = maxBudget * 20 / 100
}
return clampThinkingBudget(modelName, maxBudget)
}
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
@@ -124,6 +159,11 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
} else {
if len(oaiRequest) > 0 {
// 如果有reasoningEffort参数,则根据其值设置思考预算
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
}
}
}
} else if strings.HasSuffix(modelName, "-nothinking") {
@@ -144,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
GenerationConfig: dto.GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
MaxOutputTokens: textRequest.GetMaxTokens(),
Seed: int64(textRequest.Seed),
},
}
@@ -156,7 +196,37 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
}
ThinkingAdaptor(&geminiRequest, info)
adaptorWithExtraBody := false
if len(textRequest.ExtraBody) > 0 {
if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
var extraBody map[string]interface{}
if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
return nil, fmt.Errorf("invalid extra body: %w", err)
}
// eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
adaptorWithExtraBody = true
if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
budgetInt := int(budget)
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(budgetInt),
IncludeThoughts: true,
}
} else {
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
}
}
}
}
}
}
if !adaptorWithExtraBody {
ThinkingAdaptor(&geminiRequest, info, textRequest)
}
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
@@ -814,7 +884,7 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
if err != nil {
return fmt.Errorf("failed to marshal stream response: %w", err)
}
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, info.ShouldIncludeUsage)
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
return nil
}
+1 -1
View File
@@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
MaxTokens: request.GetMaxTokens(),
Tools: request.Tools,
ToolChoice: request.ToolChoice,
}
+1 -1
View File
@@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.MaxTokens,
MaxTokens: request.GetMaxTokens(),
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
+21 -3
View File
@@ -9,6 +9,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
@@ -73,9 +74,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == relayconstant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
@@ -122,6 +120,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
default:
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
@@ -172,6 +173,23 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if len(request.Usage) == 0 {
request.Usage = json.RawMessage(`{"include":true}`)
}
if strings.HasSuffix(info.UpstreamModelName, "-thinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
request.Model = info.UpstreamModelName
if len(request.Reasoning) == 0 {
reasoning := map[string]any{
"enabled": true,
}
if request.ReasoningEffort != "" {
reasoning["effort"] = request.ReasoningEffort
}
marshal, err := common.Marshal(reasoning)
if err != nil {
return nil, fmt.Errorf("error marshalling reasoning: %w", err)
}
request.Reasoning = marshal
}
}
}
if strings.HasPrefix(request.Model, "o") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
+16 -6
View File
@@ -37,9 +37,14 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
// compute usage
usage := dto.Usage{}
usage.PromptTokens = responsesResponse.Usage.InputTokens
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
usage.TotalTokens = responsesResponse.Usage.TotalTokens
if responsesResponse.Usage != nil {
usage.PromptTokens = responsesResponse.Usage.InputTokens
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
usage.TotalTokens = responsesResponse.Usage.TotalTokens
if responsesResponse.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
}
}
// 解析 Tools 用量
for _, tool := range responsesResponse.Tools {
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
@@ -64,9 +69,14 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
if streamResponse.Response.Usage != nil {
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
if streamResponse.Response.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
}
}
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
-24
View File
@@ -18,30 +18,6 @@ import (
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
TopP: textRequest.TopP,
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
} else {
palmMessage.Author = "1"
}
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
}
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
+1 -1
View File
@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
MaxTokens: request.GetMaxTokens(),
}
}
+2 -1
View File
@@ -35,6 +35,7 @@ var claudeModelMap = map[string]string{
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
"claude-opus-4-20250514": "claude-opus-4@20250514",
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -237,7 +238,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
+1 -1
View File
@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
xunfeiRequest.Parameter.Chat.Domain = domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest
}
+1 -1
View File
@@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
MaxTokens: request.GetMaxTokens(),
Stop: Stop,
Tools: request.Tools,
ToolChoice: request.ToolChoice,
+3
View File
@@ -225,6 +225,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
if startTime.IsZero() {
startTime = time.Now()
}
// firstResponseTime = time.Now() - 1 second
apiType, _ := common.ChannelType2APIType(channelType)
+13
View File
@@ -44,6 +44,7 @@ func SetApiRouter(router *gin.Engine) {
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
@@ -66,6 +67,13 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
// 2FA routes
selfRoute.GET("/2fa/status", controller.Get2FAStatus)
selfRoute.POST("/2fa/setup", controller.Setup2FA)
selfRoute.POST("/2fa/enable", controller.Enable2FA)
selfRoute.POST("/2fa/disable", controller.Disable2FA)
selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes)
}
adminRoute := userRoute.Group("/")
@@ -78,6 +86,10 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.DeleteUser)
// Admin 2FA routes
adminRoute.GET("/2fa/stats", controller.Admin2FAStats)
adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA)
}
}
optionRoute := apiRouter.Group("/option")
@@ -120,6 +132,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
channelRoute.GET("/tag/models", controller.GetTagModels)
channelRoute.POST("/copy/:id", controller.CopyChannel)
channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())
+5 -1
View File
@@ -283,7 +283,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done
info.FinishReason = *chosenChoice.FinishReason
return claudeResponses
if !info.Done {
return claudeResponses
}
}
if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
@@ -432,6 +434,8 @@ func stopReasonOpenAI2Claude(reason string) string {
return "end_turn"
case "stop_sequence":
return "stop_sequence"
case "length":
fallthrough
case "max_tokens":
return "max_tokens"
case "tool_calls":
+3
View File
@@ -93,6 +93,9 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
if showBodyWhenFail {
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
} else {
if common.DebugEnabled {
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
}
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
}
return
+4
View File
@@ -40,6 +40,8 @@ var defaultCacheRatio = map[string]float64{
"claude-sonnet-4-20250514-thinking": 0.1,
"claude-opus-4-20250514": 0.1,
"claude-opus-4-20250514-thinking": 0.1,
"claude-opus-4-1-20250805": 0.1,
"claude-opus-4-1-20250805-thinking": 0.1,
}
var defaultCreateCacheRatio = map[string]float64{
@@ -55,6 +57,8 @@ var defaultCreateCacheRatio = map[string]float64{
"claude-sonnet-4-20250514-thinking": 1.25,
"claude-opus-4-20250514": 1.25,
"claude-opus-4-20250514-thinking": 1.25,
"claude-opus-4-1-20250805": 1.25,
"claude-opus-4-1-20250805-thinking": 1.25,
}
//var defaultCreateCacheRatio = map[string]float64{}
+21 -17
View File
@@ -118,6 +118,7 @@ var defaultModelRatio = map[string]float64{
"claude-sonnet-4-20250514": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"claude-opus-4-20250514": 7.5,
"claude-opus-4-1-20250805": 7.5,
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
@@ -334,12 +335,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
name = FormatMatchingModelName(name)
price, ok := modelPriceMap[name]
if !ok {
if printErr {
@@ -373,11 +370,8 @@ func GetModelRatio(name string) (float64, bool, string) {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
name = FormatMatchingModelName(name)
ratio, ok := modelRatioMap[name]
if !ok {
return 37.5, operation_setting.SelfUseModeEnabled, name
@@ -428,12 +422,9 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
func GetCompletionRatio(name string) float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
name = FormatMatchingModelName(name)
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -663,3 +654,16 @@ func GetCompletionRatioCopy() map[string]float64 {
}
return copyMap
}
// 转换模型名,减少渠道必须配置各种带参数模型
func FormatMatchingModelName(name string) string {
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
return name
}
+10 -1
View File
@@ -189,9 +189,13 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
}
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
if errorCode == ErrorCodeDoRequestFailed {
err = errors.New("upstream error: do request failed")
}
openaiError := OpenAIError{
Message: err.Error(),
Type: string(errorCode),
Code: errorCode,
}
return WithOpenAIError(openaiError, statusCode, ops...)
}
@@ -199,6 +203,7 @@ func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAP
func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
openaiError := OpenAIError{
Type: string(errorCode),
Code: errorCode,
}
return WithOpenAIError(openaiError, statusCode, ops...)
}
@@ -224,7 +229,11 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops
func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
code, ok := openAIError.Code.(string)
if !ok {
code = fmt.Sprintf("%v", openAIError.Code)
if openAIError.Code == nil {
code = fmt.Sprintf("%v", openAIError.Code)
} else {
code = "unknown_error"
}
}
if openAIError.Type == "" {
openAIError.Type = "upstream_error"
+6 -3
View File
@@ -21,6 +21,7 @@
"lucide-react": "^0.511.0",
"marked": "^4.1.1",
"mermaid": "^11.6.0",
"qrcode.react": "^4.2.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
@@ -1492,6 +1493,8 @@
"punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="],
"qrcode.react": ["qrcode.react@4.2.0", "", { "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA=="],
"quansync": ["quansync@0.2.10", "", {}, "sha512-t41VRkMYbkHyCYmOvx/6URnN80H7k4X0lLdBMGsz+maAwrJQYB1djpV6vHrQIBE0WBSGqhtEHrK9U3DWWH8v7A=="],
"query-string": ["query-string@9.2.0", "", { "dependencies": { "decode-uri-component": "^0.4.1", "filter-obj": "^5.1.0", "split-on-first": "^3.0.0" } }, "sha512-YIRhrHujoQxhexwRLxfy3VSjOXmvZRd2nyw1PwL1UUqZ/ys1dEZd1+NSgXkne2l/4X/7OXkigEAuhTX0g/ivJQ=="],
@@ -1502,7 +1505,7 @@
"rc-checkbox": ["rc-checkbox@3.5.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "^2.3.2", "rc-util": "^5.25.2" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-aOAQc3E98HteIIsSqm6Xk2FPKIER6+5vyEFMZfo73TqM+VVAIqOkHoPjgKLqSNtVLWScoaM7vY2ZrGEheI79yg=="],
"rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="],
"rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="],
"rc-dialog": ["rc-dialog@9.6.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "@rc-component/portal": "^1.0.0-8", "classnames": "^2.2.6", "rc-motion": "^2.3.0", "rc-util": "^5.21.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-ApoVi9Z8PaCQg6FsUzS8yvBEQy0ZL2PkuvAgrmohPkN3okps5WZ5WQWPc1RNuiOKaAYv8B97ACdsFU5LizzCqg=="],
@@ -1946,8 +1949,6 @@
"@lobehub/ui/lucide-react": ["lucide-react@0.484.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-oZy8coK9kZzvqhSgfbGkPtTgyjpBvs3ukLgDPv14dSOZtBtboryWF5o8i3qen7QbGg7JhiJBz5mK1p8YoMZTLQ=="],
"@lobehub/ui/rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="],
"@radix-ui/react-dismissable-layer/@radix-ui/react-compose-refs": ["@radix-ui/react-compose-refs@1.0.0", "", { "dependencies": { "@babel/runtime": "^7.13.10" }, "peerDependencies": { "react": "^16.8 || ^17.0 || ^18.0" } }, "sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA=="],
"@radix-ui/react-popper/@floating-ui/react-dom": ["@floating-ui/react-dom@0.7.2", "", { "dependencies": { "@floating-ui/dom": "^0.5.3", "use-isomorphic-layout-effect": "^1.1.1" }, "peerDependencies": { "react": ">=16.8.0", "react-dom": ">=16.8.0" } }, "sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg=="],
@@ -1964,6 +1965,8 @@
"@visactor/vrender-kits/roughjs": ["roughjs@4.5.2", "", { "dependencies": { "path-data-parser": "^0.1.0", "points-on-curve": "^0.2.0", "points-on-path": "^0.2.1" } }, "sha512-2xSlLDKdsWyFxrveYWk9YQ/Y9UfK38EAMRNkYkMqYBJvPX8abCa9PN0x3w02H8Oa6/0bcZICJU+U95VumPqseg=="],
"antd/rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="],
"antd/scroll-into-view-if-needed": ["scroll-into-view-if-needed@3.1.0", "", { "dependencies": { "compute-scroll-into-view": "^3.0.2" } }, "sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ=="],
"chokidar/glob-parent": ["glob-parent@5.1.2", "", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="],
+1
View File
@@ -21,6 +21,7 @@
"lucide-react": "^0.511.0",
"marked": "^4.1.1",
"mermaid": "^11.6.0",
"qrcode.react": "^4.2.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
+54
View File
@@ -50,6 +50,7 @@ import { IconGithubLogo, IconMail, IconLock } from '@douyinfe/semi-icons';
import OIDCIcon from '../common/logo/OIDCIcon.js';
import WeChatIcon from '../common/logo/WeChatIcon.js';
import LinuxDoIcon from '../common/logo/LinuxDoIcon.js';
import TwoFAVerification from './TwoFAVerification.js';
import { useTranslation } from 'react-i18next';
const LoginForm = () => {
@@ -78,6 +79,7 @@ const LoginForm = () => {
const [resetPasswordLoading, setResetPasswordLoading] = useState(false);
const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false);
const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
const [showTwoFA, setShowTwoFA] = useState(false);
const logo = getLogo();
const systemName = getSystemName();
@@ -162,6 +164,13 @@ const LoginForm = () => {
);
const { success, message, data } = res.data;
if (success) {
// 检查是否需要2FA验证
if (data && data.require_2fa) {
setShowTwoFA(true);
setLoginLoading(false);
return;
}
userDispatch({ type: 'login', payload: data });
setUserData(data);
updateAPI();
@@ -280,6 +289,21 @@ const LoginForm = () => {
setOtherLoginOptionsLoading(false);
};
// 2FA验证成功处理
const handle2FASuccess = (data) => {
userDispatch({ type: 'login', payload: data });
setUserData(data);
updateAPI();
showSuccess('登录成功!');
navigate('/console');
};
// 返回登录页面
const handleBackToLogin = () => {
setShowTwoFA(false);
setInputs({ username: '', password: '', wechat_verification_code: '' });
};
const renderOAuthOptions = () => {
return (
<div className="flex flex-col items-center">
@@ -537,6 +561,35 @@ const LoginForm = () => {
);
};
// 2FA验证弹窗
const render2FAModal = () => {
return (
<Modal
title={
<div className="flex items-center">
<div className="w-8 h-8 rounded-full bg-green-100 dark:bg-green-900 flex items-center justify-center mr-3">
<svg className="w-4 h-4 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
<path fillRule="evenodd" d="M6 8a2 2 0 11-4 0 2 2 0 014 0zM8 7a1 1 0 100 2h8a1 1 0 100-2H8zM6 14a2 2 0 11-4 0 2 2 0 014 0zM8 13a1 1 0 100 2h8a1 1 0 100-2H8z" clipRule="evenodd" />
</svg>
</div>
两步验证
</div>
}
visible={showTwoFA}
onCancel={handleBackToLogin}
footer={null}
width={450}
centered
>
<TwoFAVerification
onSuccess={handle2FASuccess}
onBack={handleBackToLogin}
isModal={true}
/>
</Modal>
);
};
return (
<div className="relative overflow-hidden bg-gray-100 flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
{/* 背景模糊晕染球 */}
@@ -547,6 +600,7 @@ const LoginForm = () => {
? renderEmailLoginForm()
: renderOAuthOptions()}
{renderWeChatLoginModal()}
{render2FAModal()}
{turnstileEnabled && (
<div className="flex justify-center mt-6">
@@ -0,0 +1,230 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import { API, showError, showSuccess } from '../../helpers';
import { Button, Card, Divider, Form, Input, Typography } from '@douyinfe/semi-ui';
import React, { useState } from 'react';
const { Title, Text, Paragraph } = Typography;
const TwoFAVerification = ({ onSuccess, onBack, isModal = false }) => {
const [loading, setLoading] = useState(false);
const [useBackupCode, setUseBackupCode] = useState(false);
const [verificationCode, setVerificationCode] = useState('');
const handleSubmit = async () => {
if (!verificationCode) {
showError('请输入验证码');
return;
}
// Validate code format
if (useBackupCode && verificationCode.length !== 8) {
showError('备用码必须是8位');
return;
} else if (!useBackupCode && !/^\d{6}$/.test(verificationCode)) {
showError('验证码必须是6位数字');
return;
}
setLoading(true);
try {
const res = await API.post('/api/user/login/2fa', {
code: verificationCode
});
if (res.data.success) {
showSuccess('登录成功');
// 保存用户信息到本地存储
localStorage.setItem('user', JSON.stringify(res.data.data));
if (onSuccess) {
onSuccess(res.data.data);
}
} else {
showError(res.data.message);
}
} catch (error) {
showError('验证失败,请重试');
} finally {
setLoading(false);
}
};
const handleKeyPress = (e) => {
if (e.key === 'Enter') {
handleSubmit();
}
};
if (isModal) {
return (
<div className="space-y-4">
<Paragraph className="text-gray-600 dark:text-gray-300">
请输入认证器应用显示的验证码完成登录
</Paragraph>
<Form onSubmit={handleSubmit}>
<Form.Input
field="code"
label={useBackupCode ? "备用码" : "验证码"}
placeholder={useBackupCode ? "请输入8位备用码" : "请输入6位验证码"}
value={verificationCode}
onChange={setVerificationCode}
onKeyPress={handleKeyPress}
size="large"
style={{ marginBottom: 16 }}
autoFocus
/>
<Button
htmlType="submit"
type="primary"
loading={loading}
block
size="large"
style={{ marginBottom: 16 }}
>
验证并登录
</Button>
</Form>
<Divider />
<div style={{ textAlign: 'center' }}>
<Button
theme="borderless"
type="tertiary"
onClick={() => {
setUseBackupCode(!useBackupCode);
setVerificationCode('');
}}
style={{ marginRight: 16, color: '#1890ff', padding: 0 }}
>
{useBackupCode ? '使用认证器验证码' : '使用备用码'}
</Button>
{onBack && (
<Button
theme="borderless"
type="tertiary"
onClick={onBack}
style={{ color: '#1890ff', padding: 0 }}
>
返回登录
</Button>
)}
</div>
<div className="bg-gray-50 dark:bg-gray-800 rounded-lg p-3">
<Text size="small" type="secondary">
<strong>提示</strong>
<br />
验证码每30秒更新一次
<br />
如果无法获取验证码请使用备用码
<br />
每个备用码只能使用一次
</Text>
</div>
</div>
);
}
return (
<div style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
minHeight: '60vh'
}}>
<Card style={{ width: 400, padding: 24 }}>
<div style={{ textAlign: 'center', marginBottom: 24 }}>
<Title heading={3}>两步验证</Title>
<Paragraph type="secondary">
请输入认证器应用显示的验证码完成登录
</Paragraph>
</div>
<Form onSubmit={handleSubmit}>
<Form.Input
field="code"
label={useBackupCode ? "备用码" : "验证码"}
placeholder={useBackupCode ? "请输入8位备用码" : "请输入6位验证码"}
value={verificationCode}
onChange={setVerificationCode}
onKeyPress={handleKeyPress}
size="large"
style={{ marginBottom: 16 }}
autoFocus
/>
<Button
htmlType="submit"
type="primary"
loading={loading}
block
size="large"
style={{ marginBottom: 16 }}
>
验证并登录
</Button>
</Form>
<Divider />
<div style={{ textAlign: 'center' }}>
<Button
theme="borderless"
type="tertiary"
onClick={() => {
setUseBackupCode(!useBackupCode);
setVerificationCode('');
}}
style={{ marginRight: 16, color: '#1890ff', padding: 0 }}
>
{useBackupCode ? '使用认证器验证码' : '使用备用码'}
</Button>
{onBack && (
<Button
theme="borderless"
type="tertiary"
onClick={onBack}
style={{ color: '#1890ff', padding: 0 }}
>
返回登录
</Button>
)}
</div>
<div style={{ marginTop: 24, padding: 16, background: '#f6f8fa', borderRadius: 6 }}>
<Text size="small" type="secondary">
<strong>提示</strong>
<br />
验证码每30秒更新一次
<br />
如果无法获取验证码请使用备用码
<br />
每个备用码只能使用一次
</Text>
</div>
</Card>
</div>
);
};
export default TwoFAVerification;
@@ -36,6 +36,7 @@ import {
renderModelTag,
getModelCategories
} from '../../helpers';
import TwoFASetting from './TwoFASetting';
import Turnstile from 'react-turnstile';
import { UserContext } from '../../context/User';
import { useTheme } from '../../context/Theme';
@@ -1041,6 +1042,9 @@ const PersonalSetting = () => {
</div>
</Card>
{/* 两步验证设置 */}
<TwoFASetting />
{/* 危险区域 */}
<Card
className="!rounded-xl border-red-200 w-full"
+524
View File
@@ -0,0 +1,524 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import { API, showError, showSuccess, showWarning } from '../../helpers';
import { Banner, Button, Card, Checkbox, Divider, Form, Input, Modal, Tag, Typography } from '@douyinfe/semi-ui';
import React, { useEffect, useState } from 'react';
import { QRCodeSVG } from 'qrcode.react';
const { Text, Paragraph } = Typography;
const TwoFASetting = () => {
const [loading, setLoading] = useState(false);
const [status, setStatus] = useState({
enabled: false,
locked: false,
backup_codes_remaining: 0
});
// 模态框状态
const [setupModalVisible, setSetupModalVisible] = useState(false);
const [enableModalVisible, setEnableModalVisible] = useState(false);
const [disableModalVisible, setDisableModalVisible] = useState(false);
const [backupModalVisible, setBackupModalVisible] = useState(false);
// 表单数据
const [setupData, setSetupData] = useState(null);
const [verificationCode, setVerificationCode] = useState('');
const [backupCodes, setBackupCodes] = useState([]);
const [confirmDisable, setConfirmDisable] = useState(false);
// 获取2FA状态
const fetchStatus = async () => {
try {
const res = await API.get('/api/user/2fa/status');
if (res.data.success) {
setStatus(res.data.data);
}
} catch (error) {
showError('获取2FA状态失败');
}
};
useEffect(() => {
fetchStatus();
}, []);
// 初始化2FA设置
const handleSetup2FA = async () => {
setLoading(true);
try {
const res = await API.post('/api/user/2fa/setup');
if (res.data.success) {
setSetupData(res.data.data);
setSetupModalVisible(true);
} else {
showError(res.data.message);
}
} catch (error) {
showError('设置2FA失败');
} finally {
setLoading(false);
}
};
// 启用2FA
const handleEnable2FA = async () => {
if (!verificationCode) {
showWarning('请输入验证码');
return;
}
setLoading(true);
try {
const res = await API.post('/api/user/2fa/enable', {
code: verificationCode
});
if (res.data.success) {
showSuccess('两步验证启用成功!');
setEnableModalVisible(false);
setSetupModalVisible(false);
setVerificationCode('');
fetchStatus();
} else {
showError(res.data.message);
}
} catch (error) {
showError('启用2FA失败');
} finally {
setLoading(false);
}
};
// 禁用2FA
const handleDisable2FA = async () => {
if (!verificationCode) {
showWarning('请输入验证码或备用码');
return;
}
if (!confirmDisable) {
showWarning('请确认您已了解禁用两步验证的后果');
return;
}
setLoading(true);
try {
const res = await API.post('/api/user/2fa/disable', {
code: verificationCode
});
if (res.data.success) {
showSuccess('两步验证已禁用');
setDisableModalVisible(false);
setVerificationCode('');
setConfirmDisable(false);
fetchStatus();
} else {
showError(res.data.message);
}
} catch (error) {
showError('禁用2FA失败');
} finally {
setLoading(false);
}
};
// 重新生成备用码
const handleRegenerateBackupCodes = async () => {
if (!verificationCode) {
showWarning('请输入验证码');
return;
}
setLoading(true);
try {
const res = await API.post('/api/user/2fa/backup_codes', {
code: verificationCode
});
if (res.data.success) {
setBackupCodes(res.data.data.backup_codes);
showSuccess('备用码重新生成成功');
setVerificationCode('');
fetchStatus();
} else {
showError(res.data.message);
}
} catch (error) {
showError('重新生成备用码失败');
} finally {
setLoading(false);
}
};
const copyBackupCodes = () => {
const codesText = backupCodes.join('\n');
navigator.clipboard.writeText(codesText).then(() => {
showSuccess('备用码已复制到剪贴板');
}).catch(() => {
showError('复制失败,请手动复制');
});
};
return (
<div>
<Card
className="!rounded-xl transition-shadow w-full"
bodyStyle={{ padding: '20px' }}
shadows='hover'
style={{ marginBottom: 16 }}
>
<div className="flex items-center justify-between">
<div className="flex items-center flex-1">
<div className="w-10 h-10 rounded-full bg-green-100 dark:bg-green-900 flex items-center justify-center mr-3">
<svg className="w-5 h-5 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
<path fillRule="evenodd" d="M6 8a2 2 0 11-4 0 2 2 0 014 0zM8 7a1 1 0 100 2h8a1 1 0 100-2H8zM6 14a2 2 0 11-4 0 2 2 0 014 0zM8 13a1 1 0 100 2h8a1 1 0 100-2H8z" clipRule="evenodd" />
</svg>
</div>
<div className="flex-1 min-w-0">
<div className="font-medium text-gray-900 dark:text-gray-100">两步验证设置</div>
<div className="text-sm text-gray-500 dark:text-gray-400 mt-1">
两步验证2FA为您的账户提供额外的安全保护启用后登录时需要输入密码和验证器应用生成的验证码
</div>
<div className="flex items-center mt-2 space-x-2">
<Text strong>当前状态</Text>
{status.enabled ? (
<Tag color="green" size="small">已启用</Tag>
) : (
<Tag color="red" size="small">未启用</Tag>
)}
{status.locked && (
<Tag color="orange" size="small">账户已锁定</Tag>
)}
</div>
{status.enabled && (
<div className="mt-1">
<Text size="small" type="secondary">剩余备用码{status.backup_codes_remaining || 0} </Text>
</div>
)}
</div>
</div>
<div className="flex flex-col space-y-2">
{!status.enabled ? (
<Button
type="primary"
size="default"
onClick={handleSetup2FA}
loading={loading}
>
启用两步验证
</Button>
) : (
<div className="flex flex-col space-y-2">
<Button
type="danger"
size="default"
onClick={() => setDisableModalVisible(true)}
>
禁用两步验证
</Button>
<Button
size="default"
onClick={() => setBackupModalVisible(true)}
>
重新生成备用码
</Button>
</div>
)}
</div>
</div>
</Card>
{/* 2FA设置模态框 */}
<Modal
title={
<div className="flex items-center">
<div className="w-8 h-8 rounded-full bg-green-100 dark:bg-green-900 flex items-center justify-center mr-3">
<svg className="w-4 h-4 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
<path fillRule="evenodd" d="M6 8a2 2 0 11-4 0 2 2 0 014 0zM8 7a1 1 0 100 2h8a1 1 0 100-2H8zM6 14a2 2 0 11-4 0 2 2 0 014 0zM8 13a1 1 0 100 2h8a1 1 0 100-2H8z" clipRule="evenodd" />
</svg>
</div>
设置两步验证
</div>
}
visible={setupModalVisible}
onCancel={() => {
setSetupModalVisible(false);
setSetupData(null);
}}
footer={null}
width={650}
style={{ maxWidth: '90vw' }}
>
{setupData && (
<div className="space-y-6">
{/* 步骤 1:扫描二维码 */}
<div className="bg-gray-50 dark:bg-gray-800 rounded-lg p-4">
<div className="flex items-center mb-3">
<div className="w-6 h-6 rounded-full bg-blue-500 text-white flex items-center justify-center text-sm font-medium mr-2">
1
</div>
<Text strong className="text-gray-900 dark:text-gray-100">扫描二维码</Text>
</div>
<Paragraph className="text-gray-600 dark:text-gray-300 mb-4">
使用认证器应用 Google AuthenticatorMicrosoft Authenticator扫描下方二维码
</Paragraph>
<div className="flex justify-center mb-4">
<div className="bg-white p-4 rounded-lg shadow-sm">
<QRCodeSVG value={setupData.qr_code_data} size={180} />
</div>
</div>
<div className="bg-blue-50 dark:bg-blue-900 rounded-lg p-3">
<Text className="text-blue-800 dark:text-blue-200 text-sm">
或手动输入密钥<Text code copyable className="ml-2">{setupData.secret}</Text>
</Text>
</div>
</div>
{/* 步骤 2:保存备用码 */}
<div className="bg-gray-50 dark:bg-gray-800 rounded-lg p-4">
<div className="flex items-center mb-3">
<div className="w-6 h-6 rounded-full bg-orange-500 text-white flex items-center justify-center text-sm font-medium mr-2">
2
</div>
<Text strong className="text-gray-900 dark:text-gray-100">保存备用码</Text>
</div>
<Paragraph className="text-gray-600 dark:text-gray-300 mb-4">
请将以下备用码保存在安全的地方如果丢失手机可以使用这些备用码登录
</Paragraph>
<div className="bg-yellow-50 dark:bg-yellow-900 border border-yellow-200 dark:border-yellow-700 rounded-lg p-4">
<div className="grid grid-cols-2 gap-2 mb-3">
{setupData.backup_codes.map((code, index) => (
<div key={index} className="bg-white dark:bg-gray-700 p-2 rounded text-center">
<Text code className="text-sm">{code}</Text>
</div>
))}
</div>
<Button
size="small"
type="primary"
onClick={() => {
const codesText = setupData.backup_codes.join('\n');
navigator.clipboard.writeText(codesText);
showSuccess('备用码已复制');
}}
className="w-full"
>
复制所有备用码
</Button>
</div>
</div>
{/* 步骤 3:验证设置 */}
<div className="bg-gray-50 dark:bg-gray-800 rounded-lg p-4">
<div className="flex items-center mb-3">
<div className="w-6 h-6 rounded-full bg-green-500 text-white flex items-center justify-center text-sm font-medium mr-2">
3
</div>
<Text strong className="text-gray-900 dark:text-gray-100">验证设置</Text>
</div>
<Paragraph className="text-gray-600 dark:text-gray-300 mb-4">
输入认证器应用显示的6位数字验证码
</Paragraph>
<Form onSubmit={handleEnable2FA}>
<Form.Input
field="code"
placeholder="请输入6位验证码"
value={verificationCode}
onChange={setVerificationCode}
size="large"
style={{ marginBottom: 16 }}
maxLength={6}
/>
<Button
htmlType="submit"
type="primary"
loading={loading}
size="large"
block
>
完成设置并启用两步验证
</Button>
</Form>
</div>
</div>
)}
</Modal>
{/* 禁用2FA模态框 */}
<Modal
title={
<div className="flex items-center">
<div className="w-8 h-8 rounded-full bg-red-100 dark:bg-red-900 flex items-center justify-center mr-3">
<svg className="w-4 h-4 text-red-600 dark:text-red-400" fill="currentColor" viewBox="0 0 20 20">
<path fillRule="evenodd" d="M8.257 3.099c.765-1.36 2.722-1.36 3.486 0l5.58 9.92c.75 1.334-.213 2.98-1.742 2.98H4.42c-1.53 0-2.493-1.646-1.743-2.98l5.58-9.92zM11 13a1 1 0 11-2 0 1 1 0 012 0zm-1-8a1 1 0 00-1 1v3a1 1 0 002 0V6a1 1 0 00-1-1z" clipRule="evenodd" />
</svg>
</div>
禁用两步验证
</div>
}
visible={disableModalVisible}
onCancel={() => {
setDisableModalVisible(false);
setVerificationCode('');
setConfirmDisable(false);
}}
footer={null}
width={550}
>
<div className="space-y-4">
<Banner
type="warning"
description={
<div className="space-y-2">
<div className="font-medium">警告禁用两步验证将会</div>
<ul className="list-disc list-inside space-y-1 text-sm">
<li>降低您账户的安全性</li>
<li>永久删除您的两步验证设置</li>
<li>永久删除所有备用码包括未使用的</li>
<li>需要重新完整设置才能再次启用</li>
</ul>
<div className="text-sm text-red-600 dark:text-red-400 font-medium mt-2">
此操作不可撤销请谨慎操作
</div>
</div>
}
className="rounded-lg"
/>
<Form onSubmit={handleDisable2FA}>
<Form.Input
field="code"
label="验证码"
placeholder="请输入认证器验证码或备用码"
value={verificationCode}
onChange={setVerificationCode}
size="large"
style={{ marginBottom: 16 }}
/>
<div className="mb-4">
<Checkbox
checked={confirmDisable}
onChange={(e) => setConfirmDisable(e.target.checked)}
className="text-sm"
>
我已了解禁用两步验证将永久删除所有相关设置和备用码此操作不可撤销
</Checkbox>
</div>
<Button
htmlType="submit"
type="danger"
loading={loading}
size="large"
block
disabled={!confirmDisable}
>
确认禁用两步验证
</Button>
</Form>
</div>
</Modal>
{/* 重新生成备用码模态框 */}
<Modal
title={
<div className="flex items-center">
<div className="w-8 h-8 rounded-full bg-blue-100 dark:bg-blue-900 flex items-center justify-center mr-3">
<svg className="w-4 h-4 text-blue-600 dark:text-blue-400" fill="currentColor" viewBox="0 0 20 20">
<path fillRule="evenodd" d="M4 2a1 1 0 011 1v2.101a7.002 7.002 0 0111.601 2.566 1 1 0 11-1.885.666A5.002 5.002 0 005.999 7H9a1 1 0 010 2H4a1 1 0 01-1-1V3a1 1 0 011-1zm.008 9.057a1 1 0 011.276.61A5.002 5.002 0 0014.001 13H11a1 1 0 110-2h5a1 1 0 011 1v5a1 1 0 11-2 0v-2.101a7.002 7.002 0 01-11.601-2.566 1 1 0 01.61-1.276z" clipRule="evenodd" />
</svg>
</div>
重新生成备用码
</div>
}
visible={backupModalVisible}
onCancel={() => {
setBackupModalVisible(false);
setVerificationCode('');
setBackupCodes([]);
}}
footer={null}
width={500}
>
<div className="space-y-4">
{backupCodes.length === 0 ? (
<>
<Banner
type="warning"
description="重新生成备用码将使现有的备用码失效,请确保您已保存了当前的备用码。"
className="rounded-lg"
/>
<Form onSubmit={handleRegenerateBackupCodes}>
<Form.Input
field="code"
label="验证码"
placeholder="请输入认证器验证码"
value={verificationCode}
onChange={setVerificationCode}
size="large"
style={{ marginBottom: 16 }}
/>
<Button
htmlType="submit"
type="primary"
loading={loading}
size="large"
block
>
生成新的备用码
</Button>
</Form>
</>
) : (
<>
<div className="text-center mb-4">
<div className="w-12 h-12 rounded-full bg-green-100 dark:bg-green-900 flex items-center justify-center mx-auto mb-2">
<svg className="w-6 h-6 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
<path fillRule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clipRule="evenodd" />
</svg>
</div>
<Text strong className="text-lg">新的备用码已生成</Text>
<Paragraph className="text-gray-600 dark:text-gray-300 mt-2">
请将以下备用码保存在安全的地方
</Paragraph>
</div>
<div className="bg-yellow-50 dark:bg-yellow-900 border border-yellow-200 dark:border-yellow-700 rounded-lg p-4">
<div className="grid grid-cols-2 gap-2 mb-3">
{backupCodes.map((code, index) => (
<div key={index} className="bg-white dark:bg-gray-700 p-2 rounded text-center">
<Text code className="text-sm">{code}</Text>
</div>
))}
</div>
<Button
onClick={copyBackupCodes}
type="primary"
size="large"
block
>
复制所有备用码
</Button>
</div>
</>
)}
</div>
</Modal>
</div>
);
};
export default TwoFASetting;
@@ -210,7 +210,9 @@ export const getChannelsColumns = ({
copySelectedChannel,
refresh,
activePage,
channels
channels,
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel
}) => {
return [
{
@@ -503,47 +505,7 @@ export const getChannelsColumns = ({
/>
</SplitButtonGroup>
{record.channel_info?.is_multi_key ? (
<SplitButtonGroup
aria-label={t('多密钥渠道操作项目组')}
>
{
record.status === 1 ? (
<Button
type='danger'
size="small"
onClick={() => manageChannel(record.id, 'disable', record)}
>
{t('禁用')}
</Button>
) : (
<Button
size="small"
onClick={() => manageChannel(record.id, 'enable', record)}
>
{t('启用')}
</Button>
)
}
<Dropdown
trigger='click'
position='bottomRight'
menu={[
{
node: 'item',
name: t('启用全部密钥'),
onClick: () => manageChannel(record.id, 'enable_all', record),
}
]}
>
<Button
type='tertiary'
size="small"
icon={<IconTreeTriangleDown />}
/>
</Dropdown>
</SplitButtonGroup>
) : (
{
record.status === 1 ? (
<Button
type='danger'
@@ -560,18 +522,55 @@ export const getChannelsColumns = ({
{t('启用')}
</Button>
)
)}
}
<Button
type='tertiary'
size="small"
onClick={() => {
setEditingChannel(record);
setShowEdit(true);
}}
>
{t('编辑')}
</Button>
{record.channel_info?.is_multi_key ? (
<SplitButtonGroup
aria-label={t('多密钥渠道操作项目组')}
>
<Button
type='tertiary'
size="small"
onClick={() => {
setEditingChannel(record);
setShowEdit(true);
}}
>
{t('编辑')}
</Button>
<Dropdown
trigger='click'
position='bottomRight'
menu={[
{
node: 'item',
name: t('多key管理'),
onClick: () => {
setCurrentMultiKeyChannel(record);
setShowMultiKeyManageModal(true);
},
}
]}
>
<Button
type='tertiary'
size="small"
icon={<IconTreeTriangleDown />}
/>
</Dropdown>
</SplitButtonGroup>
) : (
<Button
type='tertiary'
size="small"
onClick={() => {
setEditingChannel(record);
setShowEdit(true);
}}
>
{t('编辑')}
</Button>
)}
<Dropdown
trigger='click'
@@ -57,6 +57,9 @@ const ChannelsTable = (channelsData) => {
setEditingTag,
copySelectedChannel,
refresh,
// Multi-key management
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel,
} = channelsData;
// Get all columns
@@ -79,6 +82,8 @@ const ChannelsTable = (channelsData) => {
refresh,
activePage,
channels,
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel,
});
}, [
t,
@@ -98,6 +103,8 @@ const ChannelsTable = (channelsData) => {
refresh,
activePage,
channels,
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel,
]);
// Filter columns based on visibility settings
@@ -30,6 +30,7 @@ import ModelTestModal from './modals/ModelTestModal.jsx';
import ColumnSelectorModal from './modals/ColumnSelectorModal.jsx';
import EditChannelModal from './modals/EditChannelModal.jsx';
import EditTagModal from './modals/EditTagModal.jsx';
import MultiKeyManageModal from './modals/MultiKeyManageModal.jsx';
import { createCardProPagination } from '../../../helpers/utils';
const ChannelsPage = () => {
@@ -54,6 +55,12 @@ const ChannelsPage = () => {
/>
<BatchTagModal {...channelsData} />
<ModelTestModal {...channelsData} />
<MultiKeyManageModal
visible={channelsData.showMultiKeyManageModal}
onCancel={() => channelsData.setShowMultiKeyManageModal(false)}
channel={channelsData.currentMultiKeyChannel}
onRefresh={channelsData.refresh}
/>
{/* Main Content */}
<CardPro
@@ -0,0 +1,589 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import {
Modal,
Button,
Table,
Tag,
Typography,
Space,
Tooltip,
Popconfirm,
Empty,
Spin,
Banner,
Select,
Pagination
} from '@douyinfe/semi-ui';
import {
IconRefresh,
IconDelete,
IconClose,
IconSave,
IconSetting
} from '@douyinfe/semi-icons';
import { API, showError, showSuccess, timestamp2string } from '../../../../helpers/index.js';
const { Text, Title } = Typography;
const MultiKeyManageModal = ({
visible,
onCancel,
channel,
onRefresh
}) => {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [keyStatusList, setKeyStatusList] = useState([]);
const [operationLoading, setOperationLoading] = useState({});
// Pagination states
const [currentPage, setCurrentPage] = useState(1);
const [pageSize, setPageSize] = useState(50);
const [total, setTotal] = useState(0);
const [totalPages, setTotalPages] = useState(0);
// Statistics states
const [enabledCount, setEnabledCount] = useState(0);
const [manualDisabledCount, setManualDisabledCount] = useState(0);
const [autoDisabledCount, setAutoDisabledCount] = useState(0);
// Filter states
const [statusFilter, setStatusFilter] = useState(null); // null=all, 1=enabled, 2=manual_disabled, 3=auto_disabled
// Load key status data
const loadKeyStatus = async (page = currentPage, size = pageSize, status = statusFilter) => {
if (!channel?.id) return;
setLoading(true);
try {
const requestData = {
channel_id: channel.id,
action: 'get_key_status',
page: page,
page_size: size
};
// Add status filter if specified
if (status !== null) {
requestData.status = status;
}
const res = await API.post('/api/channel/multi_key/manage', requestData);
if (res.data.success) {
const data = res.data.data;
setKeyStatusList(data.keys || []);
setTotal(data.total || 0);
setCurrentPage(data.page || 1);
setPageSize(data.page_size || 50);
setTotalPages(data.total_pages || 0);
// Update statistics (these are always the overall statistics)
setEnabledCount(data.enabled_count || 0);
setManualDisabledCount(data.manual_disabled_count || 0);
setAutoDisabledCount(data.auto_disabled_count || 0);
} else {
showError(res.data.message);
}
} catch (error) {
console.error(error);
showError(t('获取密钥状态失败'));
} finally {
setLoading(false);
}
};
// Disable a specific key
const handleDisableKey = async (keyIndex) => {
const operationId = `disable_${keyIndex}`;
setOperationLoading(prev => ({ ...prev, [operationId]: true }));
try {
const res = await API.post('/api/channel/multi_key/manage', {
channel_id: channel.id,
action: 'disable_key',
key_index: keyIndex
});
if (res.data.success) {
showSuccess(t('密钥已禁用'));
await loadKeyStatus(currentPage, pageSize); // Reload current page
onRefresh && onRefresh(); // Refresh parent component
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('禁用密钥失败'));
} finally {
setOperationLoading(prev => ({ ...prev, [operationId]: false }));
}
};
// Enable a specific key
const handleEnableKey = async (keyIndex) => {
const operationId = `enable_${keyIndex}`;
setOperationLoading(prev => ({ ...prev, [operationId]: true }));
try {
const res = await API.post('/api/channel/multi_key/manage', {
channel_id: channel.id,
action: 'enable_key',
key_index: keyIndex
});
if (res.data.success) {
showSuccess(t('密钥已启用'));
await loadKeyStatus(currentPage, pageSize); // Reload current page
onRefresh && onRefresh(); // Refresh parent component
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('启用密钥失败'));
} finally {
setOperationLoading(prev => ({ ...prev, [operationId]: false }));
}
};
// Enable all disabled keys
const handleEnableAll = async () => {
setOperationLoading(prev => ({ ...prev, enable_all: true }));
try {
const res = await API.post('/api/channel/multi_key/manage', {
channel_id: channel.id,
action: 'enable_all_keys'
});
if (res.data.success) {
showSuccess(res.data.message || t('已启用所有密钥'));
// Reset to first page after bulk operation
setCurrentPage(1);
await loadKeyStatus(1, pageSize);
onRefresh && onRefresh(); // Refresh parent component
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('启用所有密钥失败'));
} finally {
setOperationLoading(prev => ({ ...prev, enable_all: false }));
}
};
// Disable all enabled keys
const handleDisableAll = async () => {
setOperationLoading(prev => ({ ...prev, disable_all: true }));
try {
const res = await API.post('/api/channel/multi_key/manage', {
channel_id: channel.id,
action: 'disable_all_keys'
});
if (res.data.success) {
showSuccess(res.data.message || t('已禁用所有密钥'));
// Reset to first page after bulk operation
setCurrentPage(1);
await loadKeyStatus(1, pageSize);
onRefresh && onRefresh(); // Refresh parent component
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('禁用所有密钥失败'));
} finally {
setOperationLoading(prev => ({ ...prev, disable_all: false }));
}
};
// Delete all disabled keys
const handleDeleteDisabledKeys = async () => {
setOperationLoading(prev => ({ ...prev, delete_disabled: true }));
try {
const res = await API.post('/api/channel/multi_key/manage', {
channel_id: channel.id,
action: 'delete_disabled_keys'
});
if (res.data.success) {
showSuccess(res.data.message);
// Reset to first page after deletion as data structure might change
setCurrentPage(1);
await loadKeyStatus(1, pageSize);
onRefresh && onRefresh(); // Refresh parent component
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('删除禁用密钥失败'));
} finally {
setOperationLoading(prev => ({ ...prev, delete_disabled: false }));
}
};
// Handle page change
const handlePageChange = (page) => {
setCurrentPage(page);
loadKeyStatus(page, pageSize);
};
// Handle page size change
const handlePageSizeChange = (size) => {
setPageSize(size);
setCurrentPage(1); // Reset to first page
loadKeyStatus(1, size);
};
// Handle status filter change
const handleStatusFilterChange = (status) => {
setStatusFilter(status);
setCurrentPage(1); // Reset to first page when filter changes
loadKeyStatus(1, pageSize, status);
};
// Effect to load data when modal opens
useEffect(() => {
if (visible && channel?.id) {
setCurrentPage(1); // Reset to first page when opening
loadKeyStatus(1, pageSize);
}
}, [visible, channel?.id]);
// Reset pagination when modal closes
useEffect(() => {
if (!visible) {
setCurrentPage(1);
setKeyStatusList([]);
setTotal(0);
setTotalPages(0);
setEnabledCount(0);
setManualDisabledCount(0);
setAutoDisabledCount(0);
setStatusFilter(null); // Reset filter
}
}, [visible]);
// Get status tag component
const renderStatusTag = (status) => {
switch (status) {
case 1:
return <Tag color='green' shape='circle'>{t('已启用')}</Tag>;
case 2:
return <Tag color='red' shape='circle'>{t('已禁用')}</Tag>;
case 3:
return <Tag color='orange' shape='circle'>{t('自动禁用')}</Tag>;
default:
return <Tag color='grey' shape='circle'>{t('未知状态')}</Tag>;
}
};
// Table columns definition
const columns = [
{
title: t('索引'),
dataIndex: 'index',
render: (text) => `#${text}`,
},
// {
// title: t(''),
// dataIndex: 'key_preview',
// render: (text) => (
// <Text code style={{ fontSize: '12px' }}>
// {text}
// </Text>
// ),
// },
{
title: t('状态'),
dataIndex: 'status',
width: 100,
render: (status) => renderStatusTag(status),
},
{
title: t('禁用原因'),
dataIndex: 'reason',
width: 220,
render: (reason, record) => {
if (record.status === 1 || !reason) {
return <Text type='quaternary'>-</Text>;
}
return (
<Tooltip content={reason}>
<Text style={{ maxWidth: '200px', display: 'block' }} ellipsis>
{reason}
</Text>
</Tooltip>
);
},
},
{
title: t('禁用时间'),
dataIndex: 'disabled_time',
width: 150,
render: (time, record) => {
if (record.status === 1 || !time) {
return <Text type='quaternary'>-</Text>;
}
return (
<Tooltip content={timestamp2string(time)}>
<Text style={{ fontSize: '12px' }}>
{timestamp2string(time)}
</Text>
</Tooltip>
);
},
},
{
title: t('操作'),
key: 'action',
width: 120,
render: (_, record) => (
<Space>
{record.status === 1 ? (
<Button
type='danger'
size='small'
loading={operationLoading[`disable_${record.index}`]}
onClick={() => handleDisableKey(record.index)}
>
{t('禁用')}
</Button>
) : (
<Button
type='primary'
size='small'
loading={operationLoading[`enable_${record.index}`]}
onClick={() => handleEnableKey(record.index)}
>
{t('启用')}
</Button>
)}
</Space>
),
},
];
return (
<Modal
title={
<Space>
<IconSetting />
<span>{t('多密钥管理')} - {channel?.name}</span>
</Space>
}
visible={visible}
onCancel={onCancel}
width={900}
footer={
<Space>
<Button onClick={onCancel}>{t('关闭')}</Button>
<Button
icon={<IconRefresh />}
onClick={() => loadKeyStatus(currentPage, pageSize)}
loading={loading}
>
{t('刷新')}
</Button>
<Popconfirm
title={t('确定要启用所有密钥吗?')}
onConfirm={handleEnableAll}
position={'topRight'}
>
<Button
type='primary'
loading={operationLoading.enable_all}
>
{t('启用全部')}
</Button>
</Popconfirm>
{enabledCount > 0 && (
<Popconfirm
title={t('确定要禁用所有的密钥吗?')}
onConfirm={handleDisableAll}
okType={'danger'}
position={'topRight'}
>
<Button
type='danger'
loading={operationLoading.disable_all}
>
{t('禁用全部')}
</Button>
</Popconfirm>
)}
<Popconfirm
title={t('确定要删除所有已自动禁用的密钥吗?')}
content={t('此操作不可撤销,将永久删除已自动禁用的密钥')}
onConfirm={handleDeleteDisabledKeys}
okType={'danger'}
position={'topRight'}
>
<Button
type='danger'
icon={<IconDelete />}
loading={operationLoading.delete_disabled}
>
{t('删除自动禁用密钥')}
</Button>
</Popconfirm>
</Space>
}
>
<div style={{ height: '100%', display: 'flex', flexDirection: 'column' }}>
{/* Statistics Banner */}
<Banner
type='info'
style={{ marginBottom: '16px', flexShrink: 0 }}
description={
<div>
<Text>
{t('总共 {{total}} 个密钥,{{enabled}} 个已启用,{{manual}} 个手动禁用,{{auto}} 个自动禁用', {
total: total,
enabled: enabledCount,
manual: manualDisabledCount,
auto: autoDisabledCount
})}
</Text>
{channel?.channel_info?.multi_key_mode && (
<div style={{ marginTop: '4px' }}>
<Text type='quaternary' style={{ fontSize: '12px' }}>
{t('多密钥模式')}: {channel.channel_info.multi_key_mode === 'random' ? t('随机') : t('轮询')}
</Text>
</div>
)}
</div>
}
/>
{/* Filter Controls */}
<div style={{ marginBottom: '16px', display: 'flex', alignItems: 'center', gap: '12px', flexShrink: 0 }}>
<Text style={{ fontSize: '14px', fontWeight: '500' }}>{t('状态筛选')}:</Text>
<Select
value={statusFilter}
onChange={handleStatusFilterChange}
style={{ width: '120px' }}
size='small'
placeholder={t('全部状态')}
>
<Select.Option value={null}>{t('全部状态')}</Select.Option>
<Select.Option value={1}>{t('已启用')}</Select.Option>
<Select.Option value={2}>{t('手动禁用')}</Select.Option>
<Select.Option value={3}>{t('自动禁用')}</Select.Option>
</Select>
{statusFilter !== null && (
<Text type='quaternary' style={{ fontSize: '12px' }}>
{t('当前显示 {{count}} 条筛选结果', { count: total })}
</Text>
)}
</div>
{/* Key Status Table */}
<div style={{ flex: 1, display: 'flex', flexDirection: 'column', minHeight: 0 }}>
<Spin spinning={loading}>
{keyStatusList.length > 0 ? (
<div style={{ height: '100%', display: 'flex', flexDirection: 'column' }}>
<div style={{ flex: 1, overflow: 'auto', marginBottom: '16px' }}>
<Table
columns={columns}
dataSource={keyStatusList}
pagination={false}
size='small'
bordered
rowKey='index'
scroll={{ y: 'calc(100vh - 400px)' }}
/>
</div>
{/* Pagination */}
{total > 0 && (
<div style={{
display: 'flex',
justifyContent: 'space-between',
alignItems: 'center',
flexShrink: 0,
padding: '12px 0',
borderTop: '1px solid var(--semi-color-border)',
backgroundColor: 'var(--semi-color-bg-1)'
}}>
<Text type='quaternary' style={{ fontSize: '12px' }}>
{t('显示第 {{start}}-{{end}} 条,共 {{total}} 条', {
start: (currentPage - 1) * pageSize + 1,
end: Math.min(currentPage * pageSize, total),
total: total
})}
</Text>
<div style={{ display: 'flex', alignItems: 'center', gap: '12px' }}>
<Text type='quaternary' style={{ fontSize: '12px' }}>
{t('每页显示')}:
</Text>
<Select
value={pageSize}
onChange={handlePageSizeChange}
size='small'
style={{ width: '80px' }}
>
<Select.Option value={50}>50</Select.Option>
<Select.Option value={100}>100</Select.Option>
<Select.Option value={500}>500</Select.Option>
<Select.Option value={1000}>1000</Select.Option>
</Select>
<Pagination
current={currentPage}
total={total}
pageSize={pageSize}
showSizeChanger={false}
showQuickJumper
size='small'
onChange={handlePageChange}
showTotal={(total, range) =>
t('第 {{current}} / {{total}} 页', {
current: currentPage,
total: totalPages
})
}
/>
</div>
</div>
)}
</div>
) : (
!loading && (
<Empty
image={Empty.PRESENTED_IMAGE_SIMPLE}
title={t('暂无密钥数据')}
description={t('请检查渠道配置或刷新重试')}
/>
)
)}
</Spin>
</div>
</div>
</Modal>
);
};
export default MultiKeyManageModal;
+7 -3
View File
@@ -1156,6 +1156,7 @@ export function renderLogContent(
modelPrice = -1,
groupRatio,
user_group_ratio,
cacheRatio = 1.0,
image = false,
imageRatio = 1.0,
webSearch = false,
@@ -1174,9 +1175,10 @@ export function renderLogContent(
} else {
if (image) {
return i18next.t(
'模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}}{{ratioType}} {{ratio}}',
'模型倍率 {{modelRatio}}缓存倍率 {{cacheRatio}}输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}}{{ratioType}} {{ratio}}',
{
modelRatio: modelRatio,
cacheRatio: cacheRatio,
completionRatio: completionRatio,
imageRatio: imageRatio,
ratioType: ratioLabel,
@@ -1185,9 +1187,10 @@ export function renderLogContent(
);
} else if (webSearch) {
return i18next.t(
'模型倍率 {{modelRatio}},输出倍率 {{completionRatio}}{{ratioType}} {{ratio}}Web 搜索调用 {{webSearchCallCount}} 次',
'模型倍率 {{modelRatio}}缓存倍率 {{cacheRatio}}输出倍率 {{completionRatio}}{{ratioType}} {{ratio}}Web 搜索调用 {{webSearchCallCount}} 次',
{
modelRatio: modelRatio,
cacheRatio: cacheRatio,
completionRatio: completionRatio,
ratioType: ratioLabel,
ratio,
@@ -1196,9 +1199,10 @@ export function renderLogContent(
);
} else {
return i18next.t(
'模型倍率 {{modelRatio}},输出倍率 {{completionRatio}}{{ratioType}} {{ratio}}',
'模型倍率 {{modelRatio}}缓存倍率 {{cacheRatio}}输出倍率 {{completionRatio}}{{ratioType}} {{ratio}}',
{
modelRatio: modelRatio,
cacheRatio: cacheRatio,
completionRatio: completionRatio,
ratioType: ratioLabel,
ratio,
+10
View File
@@ -83,6 +83,10 @@ export const useChannelsData = () => {
const [isProcessingQueue, setIsProcessingQueue] = useState(false);
const [modelTablePage, setModelTablePage] = useState(1);
// Multi-key management states
const [showMultiKeyManageModal, setShowMultiKeyManageModal] = useState(false);
const [currentMultiKeyChannel, setCurrentMultiKeyChannel] = useState(null);
// Refs
const requestCounter = useRef(0);
const allSelectingRef = useRef(false);
@@ -885,6 +889,12 @@ export const useChannelsData = () => {
setModelTablePage,
allSelectingRef,
// Multi-key management states
showMultiKeyManageModal,
setShowMultiKeyManageModal,
currentMultiKeyChannel,
setCurrentMultiKeyChannel,
// Form
formApi,
setFormApi,
@@ -366,6 +366,7 @@ export const useLogsData = () => {
other.model_price,
other.group_ratio,
other?.user_group_ratio,
other.cache_ratio || 1.0,
false,
1.0,
other.web_search || false,
@@ -44,6 +44,7 @@ export default function ModelSettingsVisualEditor(props) {
const { t } = useTranslation();
const [models, setModels] = useState([]);
const [visible, setVisible] = useState(false);
const [isEditMode, setIsEditMode] = useState(false);
const [currentModel, setCurrentModel] = useState(null);
const [searchText, setSearchText] = useState('');
const [currentPage, setCurrentPage] = useState(1);
@@ -386,9 +387,11 @@ export default function ModelSettingsVisualEditor(props) {
setCurrentModel(null);
setPricingMode('per-token');
setPricingSubMode('ratio');
setIsEditMode(false);
};
const editModel = (record) => {
setIsEditMode(true);
// Determine which pricing mode to use based on the model's current configuration
let initialPricingMode = 'per-token';
let initialPricingSubMode = 'ratio';
@@ -500,13 +503,7 @@ export default function ModelSettingsVisualEditor(props) {
</Space>
<Modal
title={
currentModel &&
currentModel.name &&
models.some((model) => model.name === currentModel.name)
? t('编辑模型')
: t('添加模型')
}
title={isEditMode ? t('编辑模型') : t('添加模型')}
visible={visible}
onCancel={() => {
resetModalState();
@@ -562,11 +559,7 @@ export default function ModelSettingsVisualEditor(props) {
label={t('模型名称')}
placeholder='strawberry'
required
disabled={
currentModel &&
currentModel.name &&
models.some((model) => model.name === currentModel.name)
}
disabled={isEditMode}
onChange={(value) =>
setCurrentModel((prev) => ({ ...prev, name: value }))
}