Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9717af7dcd | |||
| 219ed0d1cf | |||
| 9625ee2d55 | |||
| 23d4100fe3 | |||
| 5b8382a6ab | |||
| 9ec326714e | |||
| a423ee3dd1 | |||
| f10d469b72 | |||
| fd6b6175bb | |||
| c59b331170 | |||
| 729593afae | |||
| 782ad01496 | |||
| 4e2e6e6e86 | |||
| 5478d2fb59 | |||
| 32978f9cd9 | |||
| ce74e94fc7 | |||
| 7d3bf7c5bb | |||
| 472377bf8c | |||
| 6c7a10ac7c | |||
| 375e25221d | |||
| 2b6e3e8010 | |||
| a80ecb8896 | |||
| 37f0383941 | |||
| c9b276c604 | |||
| 6193c547d9 | |||
| be77f3d763 | |||
| 337bb41588 | |||
| 8df5a45805 | |||
| ec006d3a67 | |||
| 067f3ce9ed | |||
| 649fa84231 | |||
| 46cd3f4071 | |||
| 8d54f86261 | |||
| 0bcd7388f4 | |||
| 3b6d0d0291 | |||
| 51f7f89190 | |||
| ff2e0dbc26 | |||
| e1d51fd169 | |||
| 85ecfd7062 | |||
| 7ffbff88f1 | |||
| 8217e694f8 | |||
| 73609f50d0 | |||
| 4cf1ffa801 |
+1
-1
@@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
var SQLitePath = "one-api.db?_busy_timeout=5000"
|
||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||
|
||||
+49
-7
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -102,7 +103,7 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
@@ -171,10 +172,41 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
})
|
||||
switch modelType {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||
for i, model := range userOpenAiModels {
|
||||
useranthropicModels[i] = dto.AnthropicModel{
|
||||
ID: model.Id,
|
||||
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
|
||||
DisplayName: model.Id,
|
||||
Type: "model",
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"data": useranthropicModels,
|
||||
"first_id": useranthropicModels[0].ID,
|
||||
"has_more": false,
|
||||
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
|
||||
})
|
||||
case constant.ChannelTypeGemini:
|
||||
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
|
||||
for i, model := range userOpenAiModels {
|
||||
userGeminiModels[i] = dto.GeminiModel{
|
||||
Name: model.Id,
|
||||
DisplayName: model.Id,
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"models": userGeminiModels,
|
||||
"nextPageToken": nil,
|
||||
})
|
||||
default:
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ChannelListModels(c *gin.Context) {
|
||||
@@ -198,10 +230,20 @@ func EnabledListModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func RetrieveModel(c *gin.Context) {
|
||||
func RetrieveModel(c *gin.Context, modelType int) {
|
||||
modelId := c.Param("model")
|
||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||||
c.JSON(200, aiModel)
|
||||
switch modelType {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
c.JSON(200, dto.AnthropicModel{
|
||||
ID: aiModel.Id,
|
||||
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
|
||||
DisplayName: aiModel.Id,
|
||||
Type: "model",
|
||||
})
|
||||
default:
|
||||
c.JSON(200, aiModel)
|
||||
}
|
||||
} else {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||
|
||||
+172
-20
@@ -2,9 +2,12 @@ package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -19,10 +22,8 @@ func GetAllModelsMeta(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// 填充附加字段
|
||||
for _, m := range modelsMeta {
|
||||
fillModelExtra(m)
|
||||
}
|
||||
// 批量填充附加字段,提升列表接口性能
|
||||
enrichModels(modelsMeta)
|
||||
var total int64
|
||||
model.DB.Model(&model.Model{}).Count(&total)
|
||||
|
||||
@@ -52,9 +53,8 @@ func SearchModelsMeta(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
for _, m := range modelsMeta {
|
||||
fillModelExtra(m)
|
||||
}
|
||||
// 批量填充附加字段,提升列表接口性能
|
||||
enrichModels(modelsMeta)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(modelsMeta)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
@@ -73,7 +73,7 @@ func GetModelMeta(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
fillModelExtra(&m)
|
||||
enrichModels([]*model.Model{&m})
|
||||
common.ApiSuccess(c, &m)
|
||||
}
|
||||
|
||||
@@ -160,19 +160,171 @@ func DeleteModelMeta(c *gin.Context) {
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
|
||||
func fillModelExtra(m *model.Model) {
|
||||
if m.Endpoints == "" {
|
||||
eps := model.GetModelSupportEndpointTypes(m.ModelName)
|
||||
if b, err := json.Marshal(eps); err == nil {
|
||||
m.Endpoints = string(b)
|
||||
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
|
||||
func enrichModels(models []*model.Model) {
|
||||
if len(models) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 1) 拆分精确与规则匹配
|
||||
exactNames := make([]string, 0)
|
||||
exactIdx := make(map[string][]int) // modelName -> indices in models
|
||||
ruleIndices := make([]int, 0)
|
||||
for i, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.NameRule == model.NameRuleExact {
|
||||
exactNames = append(exactNames, m.ModelName)
|
||||
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
|
||||
} else {
|
||||
ruleIndices = append(ruleIndices, i)
|
||||
}
|
||||
}
|
||||
if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
|
||||
m.BoundChannels = channels
|
||||
|
||||
// 2) 批量查询精确模型的绑定渠道
|
||||
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
|
||||
|
||||
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
|
||||
for name, indices := range exactIdx {
|
||||
chs := channelsByModel[name]
|
||||
for _, idx := range indices {
|
||||
mm := models[idx]
|
||||
if mm.Endpoints == "" {
|
||||
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
|
||||
if b, err := json.Marshal(eps); err == nil {
|
||||
mm.Endpoints = string(b)
|
||||
}
|
||||
}
|
||||
mm.BoundChannels = chs
|
||||
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
|
||||
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ruleIndices) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 4) 一次性读取定价缓存,内存匹配所有规则模型
|
||||
pricings := model.GetPricing()
|
||||
|
||||
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
|
||||
matchedNamesByIdx := make(map[int][]string)
|
||||
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
|
||||
groupSetByIdx := make(map[int]map[string]struct{})
|
||||
quotaSetByIdx := make(map[int]map[int]struct{})
|
||||
|
||||
for _, p := range pricings {
|
||||
for _, idx := range ruleIndices {
|
||||
mm := models[idx]
|
||||
var matched bool
|
||||
switch mm.NameRule {
|
||||
case model.NameRulePrefix:
|
||||
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
|
||||
case model.NameRuleSuffix:
|
||||
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
|
||||
case model.NameRuleContains:
|
||||
matched = strings.Contains(p.ModelName, mm.ModelName)
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
|
||||
|
||||
es := endpointSetByIdx[idx]
|
||||
if es == nil {
|
||||
es = make(map[constant.EndpointType]struct{})
|
||||
endpointSetByIdx[idx] = es
|
||||
}
|
||||
for _, et := range p.SupportedEndpointTypes {
|
||||
es[et] = struct{}{}
|
||||
}
|
||||
|
||||
gs := groupSetByIdx[idx]
|
||||
if gs == nil {
|
||||
gs = make(map[string]struct{})
|
||||
groupSetByIdx[idx] = gs
|
||||
}
|
||||
for _, g := range p.EnableGroup {
|
||||
gs[g] = struct{}{}
|
||||
}
|
||||
|
||||
qs := quotaSetByIdx[idx]
|
||||
if qs == nil {
|
||||
qs = make(map[int]struct{})
|
||||
quotaSetByIdx[idx] = qs
|
||||
}
|
||||
qs[p.QuotaType] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
|
||||
allMatchedSet := make(map[string]struct{})
|
||||
for _, names := range matchedNamesByIdx {
|
||||
for _, n := range names {
|
||||
allMatchedSet[n] = struct{}{}
|
||||
}
|
||||
}
|
||||
allMatched := make([]string, 0, len(allMatchedSet))
|
||||
for n := range allMatchedSet {
|
||||
allMatched = append(allMatched, n)
|
||||
}
|
||||
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
|
||||
|
||||
// 6) 回填每个规则模型的并集信息
|
||||
for _, idx := range ruleIndices {
|
||||
mm := models[idx]
|
||||
|
||||
// 端点并集 -> 序列化
|
||||
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
|
||||
eps := make([]constant.EndpointType, 0, len(es))
|
||||
for et := range es {
|
||||
eps = append(eps, et)
|
||||
}
|
||||
if b, err := json.Marshal(eps); err == nil {
|
||||
mm.Endpoints = string(b)
|
||||
}
|
||||
}
|
||||
|
||||
// 分组并集
|
||||
if gs, ok := groupSetByIdx[idx]; ok {
|
||||
groups := make([]string, 0, len(gs))
|
||||
for g := range gs {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
mm.EnableGroups = groups
|
||||
}
|
||||
|
||||
// 配额类型集合(保持去重并排序)
|
||||
if qs, ok := quotaSetByIdx[idx]; ok {
|
||||
arr := make([]int, 0, len(qs))
|
||||
for k := range qs {
|
||||
arr = append(arr, k)
|
||||
}
|
||||
sort.Ints(arr)
|
||||
mm.QuotaTypes = arr
|
||||
}
|
||||
|
||||
// 渠道并集
|
||||
names := matchedNamesByIdx[idx]
|
||||
channelSet := make(map[string]model.BoundChannel)
|
||||
for _, n := range names {
|
||||
for _, ch := range matchedChannelsByModel[n] {
|
||||
key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
||||
channelSet[key] = ch
|
||||
}
|
||||
}
|
||||
if len(channelSet) > 0 {
|
||||
chs := make([]model.BoundChannel, 0, len(channelSet))
|
||||
for _, ch := range channelSet {
|
||||
chs = append(chs, ch)
|
||||
}
|
||||
mm.BoundChannels = chs
|
||||
}
|
||||
|
||||
// 匹配信息
|
||||
mm.MatchedModels = names
|
||||
mm.MatchedCount = len(names)
|
||||
}
|
||||
// 填充启用分组
|
||||
m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
|
||||
// 填充计费类型
|
||||
m.QuotaType = model.GetModelQuotaType(m.ModelName)
|
||||
}
|
||||
|
||||
@@ -312,10 +312,6 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType == constant.ChannelTypeAnthropic {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == 408 {
|
||||
|
||||
+5
-3
@@ -1,6 +1,6 @@
|
||||
# One API – Web 界面后端接口文档
|
||||
# New API – Web 界面后端接口文档
|
||||
|
||||
> 本文档汇总了 **One API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
|
||||
> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
|
||||
>
|
||||
> 接口前缀统一为 `https://<your-domain>`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
|
||||
>
|
||||
@@ -62,6 +62,8 @@
|
||||
| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
|
||||
|
||||
### 5.2 用户自身操作 (需登录)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
|
||||
| GET | /api/user/self | 用户 | 获取个人资料 |
|
||||
| GET | /api/user/models | 用户 | 获取模型可见性 |
|
||||
@@ -192,4 +194,4 @@
|
||||
|
||||
---
|
||||
|
||||
> **更新日期**:2025.07.17
|
||||
> **更新日期**:2025.07.17
|
||||
|
||||
+37
-1
@@ -3,16 +3,52 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||
var tools []GeminiChatTool
|
||||
if strings.HasSuffix(string(r.Tools), "[") {
|
||||
// is array
|
||||
if err := common.Unmarshal(r.Tools, &tools); err != nil {
|
||||
common.LogError(nil, "error_unmarshalling_tools: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
} else if strings.HasPrefix(string(r.Tools), "{") {
|
||||
// is object
|
||||
singleTool := GeminiChatTool{}
|
||||
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
|
||||
common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
tools = []GeminiChatTool{singleTool}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
|
||||
if len(tools) == 0 {
|
||||
r.Tools = json.RawMessage("[]")
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal the tools to JSON
|
||||
data, err := common.Marshal(tools)
|
||||
if err != nil {
|
||||
common.LogError(nil, "error_marshalling_tools: "+err.Error())
|
||||
return
|
||||
}
|
||||
r.Tools = data
|
||||
}
|
||||
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
|
||||
@@ -54,7 +54,7 @@ type GeneralOpenAIRequest struct {
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao,zhipu_v4
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
SearchParameters any `json:"search_parameters,omitempty"` //xai
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
|
||||
@@ -2,6 +2,7 @@ package dto
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
// 这里不好动就不动了,本来想独立出来的(
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
@@ -9,3 +10,26 @@ type OpenAIModels struct {
|
||||
OwnedBy string `json:"owned_by"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
type AnthropicModel struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type GeminiModel struct {
|
||||
Name interface{} `json:"name"`
|
||||
BaseModelId interface{} `json:"baseModelId"`
|
||||
Version interface{} `json:"version"`
|
||||
DisplayName interface{} `json:"displayName"`
|
||||
Description interface{} `json:"description"`
|
||||
InputTokenLimit interface{} `json:"inputTokenLimit"`
|
||||
OutputTokenLimit interface{} `json:"outputTokenLimit"`
|
||||
SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
|
||||
Thinking interface{} `json:"thinking"`
|
||||
Temperature interface{} `json:"temperature"`
|
||||
MaxTemperature interface{} `json:"maxTemperature"`
|
||||
TopP interface{} `json:"topP"`
|
||||
TopK interface{} `json:"topK"`
|
||||
}
|
||||
|
||||
+9
-7
@@ -192,16 +192,18 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
anthropicKey := c.Request.Header.Get("x-api-key")
|
||||
// 检查path包含/v1/messages
|
||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
||||
// 从x-api-key中获取key
|
||||
key := c.Request.Header.Get("x-api-key")
|
||||
if key != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
// 或者是否 x-api-key 不为空且存在anthropic-version
|
||||
// 谁知道有多少不符合规范没写anthropic-version的
|
||||
// 所以就这样随它去吧(
|
||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") || (anthropicKey != "" && c.Request.Header.Get("anthropic-version") != "") {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
|
||||
}
|
||||
// gemini api 从query中获取key
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
skKey := c.Query("key")
|
||||
if skKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||
|
||||
@@ -174,7 +174,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
if _, ok := c.Get("relay_mode"); !ok {
|
||||
c.Set("relay_mode", relayMode)
|
||||
}
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
relayMode := relayconstant.RelayModeGemini
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
EmailVerificationRateLimitMark = "EV"
|
||||
EmailVerificationMaxRequests = 2 // 30秒内最多2次
|
||||
EmailVerificationDuration = 30 // 30秒时间窗口
|
||||
)
|
||||
|
||||
func redisEmailVerificationRateLimiter(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
rdb := common.RDB
|
||||
key := "emailVerification:" + EmailVerificationRateLimitMark + ":" + c.ClientIP()
|
||||
|
||||
count, err := rdb.Incr(ctx, key).Result()
|
||||
if err != nil {
|
||||
// fallback
|
||||
memoryEmailVerificationRateLimiter(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 第一次设置键时设置过期时间
|
||||
if count == 1 {
|
||||
_ = rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second).Err()
|
||||
}
|
||||
|
||||
// 检查是否超出限制
|
||||
if count <= int64(EmailVerificationMaxRequests) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 获取剩余等待时间
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
waitSeconds := int64(EmailVerificationDuration)
|
||||
if err == nil && ttl > 0 {
|
||||
waitSeconds = int64(ttl.Seconds())
|
||||
}
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", waitSeconds),
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func memoryEmailVerificationRateLimiter(c *gin.Context) {
|
||||
key := EmailVerificationRateLimitMark + ":" + c.ClientIP()
|
||||
|
||||
if !inMemoryRateLimiter.Request(key, EmailVerificationMaxRequests, EmailVerificationDuration) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"success": false,
|
||||
"message": "发送过于频繁,请稍后再试",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func EmailVerificationRateLimit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if common.RedisEnabled {
|
||||
redisEmailVerificationRateLimiter(c)
|
||||
} else {
|
||||
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
||||
memoryEmailVerificationRateLimiter(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
)
|
||||
|
||||
func JimengRequestConvert() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
action := c.Query("Action")
|
||||
if action == "" {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Action query parameter is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Jimeng official API request
|
||||
var originalReq map[string]interface{}
|
||||
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
model, _ := originalReq["req_key"].(string)
|
||||
prompt, _ := originalReq["prompt"].(string)
|
||||
|
||||
unifiedReq := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"metadata": originalReq,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(unifiedReq)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "Failed to marshal request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Update request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||
c.Set(common.KeyRequestBody, jsonData)
|
||||
|
||||
if image, ok := originalReq["image"]; !ok || image == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
c.Request.URL.Path = "/v1/video/generations"
|
||||
|
||||
if action == "CVSync2AsyncGetResult" {
|
||||
taskId, ok := originalReq["task_id"].(string)
|
||||
if !ok || taskId == "" {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "task_id is required for CVSync2AsyncGetResult")
|
||||
return
|
||||
}
|
||||
c.Request.URL.Path = "/v1/video/generations/" + taskId
|
||||
c.Request.Method = http.MethodGet
|
||||
c.Set("task_id", taskId)
|
||||
c.Set("relay_mode", relayconstant.RelayModeVideoFetchByID)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
+130
-18
@@ -66,18 +66,18 @@ var LOG_DB *gorm.DB
|
||||
|
||||
// dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors
|
||||
func dropIndexIfExists(tableName string, indexName string) {
|
||||
if !common.UsingMySQL {
|
||||
return
|
||||
}
|
||||
var count int64
|
||||
// Check index existence via information_schema
|
||||
err := DB.Raw(
|
||||
"SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
|
||||
tableName, indexName,
|
||||
).Scan(&count).Error
|
||||
if err == nil && count > 0 {
|
||||
_ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
|
||||
}
|
||||
if !common.UsingMySQL {
|
||||
return
|
||||
}
|
||||
var count int64
|
||||
// Check index existence via information_schema
|
||||
err := DB.Raw(
|
||||
"SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
|
||||
tableName, indexName,
|
||||
).Scan(&count).Error
|
||||
if err == nil && count > 0 {
|
||||
_ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
|
||||
}
|
||||
}
|
||||
|
||||
func createRootAccountIfNeed() error {
|
||||
@@ -196,6 +196,12 @@ func InitDB() (err error) {
|
||||
db = db.Debug()
|
||||
}
|
||||
DB = db
|
||||
// MySQL charset/collation startup check: ensure Chinese-capable charset
|
||||
if common.UsingMySQL {
|
||||
if err := checkMySQLChineseSupport(DB); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
sqlDB, err := DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -230,6 +236,12 @@ func InitLogDB() (err error) {
|
||||
db = db.Debug()
|
||||
}
|
||||
LOG_DB = db
|
||||
// If log DB is MySQL, also ensure Chinese-capable charset
|
||||
if common.LogSqlType == common.DatabaseTypeMySQL {
|
||||
if err := checkMySQLChineseSupport(LOG_DB); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
sqlDB, err := LOG_DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -252,11 +264,15 @@ func InitLogDB() (err error) {
|
||||
|
||||
func migrateDB() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
dropIndexIfExists("models", "uk_model_name")
|
||||
dropIndexIfExists("vendors", "uk_vendor_name")
|
||||
if !common.UsingPostgreSQL {
|
||||
return migrateDBFast()
|
||||
}
|
||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引 (model_name, deleted_at) 冲突
|
||||
dropIndexIfExists("models", "uk_model_name") // 新版复合索引名称(若已存在)
|
||||
dropIndexIfExists("models", "model_name") // 旧版列级唯一索引名称
|
||||
|
||||
dropIndexIfExists("vendors", "uk_vendor_name") // 新版复合索引名称(若已存在)
|
||||
dropIndexIfExists("vendors", "name") // 旧版列级唯一索引名称
|
||||
//if !common.UsingPostgreSQL {
|
||||
// return migrateDBFast()
|
||||
//}
|
||||
err := DB.AutoMigrate(
|
||||
&Channel{},
|
||||
&Token{},
|
||||
@@ -284,8 +300,12 @@ func migrateDB() error {
|
||||
|
||||
func migrateDBFast() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引冲突
|
||||
dropIndexIfExists("models", "uk_model_name")
|
||||
dropIndexIfExists("models", "model_name")
|
||||
|
||||
dropIndexIfExists("vendors", "uk_vendor_name")
|
||||
dropIndexIfExists("vendors", "name")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@@ -305,7 +325,7 @@ func migrateDBFast() error {
|
||||
{&QuotaData{}, "QuotaData"},
|
||||
{&Task{}, "Task"},
|
||||
{&Model{}, "Model"},
|
||||
{&Vendor{}, "Vendor"},
|
||||
{&Vendor{}, "Vendor"},
|
||||
{&PrefillGroup{}, "PrefillGroup"},
|
||||
{&Setup{}, "Setup"},
|
||||
{&TwoFA{}, "TwoFA"},
|
||||
@@ -365,6 +385,98 @@ func CloseDB() error {
|
||||
return closeDB(DB)
|
||||
}
|
||||
|
||||
// checkMySQLChineseSupport ensures the MySQL connection and current schema
|
||||
// default charset/collation can store Chinese characters. It allows common
|
||||
// Chinese-capable charsets (utf8mb4, utf8, gbk, big5, gb18030) and panics otherwise.
|
||||
func checkMySQLChineseSupport(db *gorm.DB) error {
|
||||
// 仅检测:当前库默认字符集/排序规则 + 各表的排序规则(隐含字符集)
|
||||
|
||||
// Read current schema defaults
|
||||
var schemaCharset, schemaCollation string
|
||||
err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err)
|
||||
}
|
||||
|
||||
toLower := func(s string) string { return strings.ToLower(s) }
|
||||
// Allowed charsets that can store Chinese text
|
||||
allowedCharsets := map[string]string{
|
||||
"utf8mb4": "utf8mb4_",
|
||||
"utf8": "utf8_",
|
||||
"gbk": "gbk_",
|
||||
"big5": "big5_",
|
||||
"gb18030": "gb18030_",
|
||||
}
|
||||
isChineseCapable := func(cs, cl string) bool {
|
||||
csLower := toLower(cs)
|
||||
clLower := toLower(cl)
|
||||
if prefix, ok := allowedCharsets[csLower]; ok {
|
||||
if clLower == "" {
|
||||
return true
|
||||
}
|
||||
return strings.HasPrefix(clLower, prefix)
|
||||
}
|
||||
// 如果仅提供了排序规则,尝试按排序规则前缀判断
|
||||
for _, prefix := range allowedCharsets {
|
||||
if strings.HasPrefix(clLower, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 1) 当前库默认值必须支持中文
|
||||
if !isChineseCapable(schemaCharset, schemaCollation) {
|
||||
return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030",
|
||||
schemaCharset, schemaCollation, schemaCharset, schemaCollation)
|
||||
}
|
||||
|
||||
// 2) 所有物理表的排序规则(隐含字符集)必须支持中文
|
||||
type tableInfo struct {
|
||||
Name string
|
||||
Collation *string
|
||||
}
|
||||
var tables []tableInfo
|
||||
if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil {
|
||||
return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err)
|
||||
}
|
||||
|
||||
var badTables []string
|
||||
for _, t := range tables {
|
||||
// NULL 或空表示继承库默认设置,已在上面校验库默认,视为通过
|
||||
if t.Collation == nil || *t.Collation == "" {
|
||||
continue
|
||||
}
|
||||
cl := *t.Collation
|
||||
// 仅凭排序规则判断是否中文可用
|
||||
ok := false
|
||||
lower := strings.ToLower(cl)
|
||||
for _, prefix := range allowedCharsets {
|
||||
if strings.HasPrefix(lower, prefix) {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl))
|
||||
}
|
||||
}
|
||||
|
||||
if len(badTables) > 0 {
|
||||
// 限制输出数量以避免日志过长
|
||||
maxShow := 20
|
||||
shown := badTables
|
||||
if len(shown) > maxShow {
|
||||
shown = shown[:maxShow]
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v",
|
||||
maxShow, shown, maxShow, shown,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
lastPingTime time.Time
|
||||
pingMutex sync.Mutex
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package model
|
||||
|
||||
// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。
|
||||
// 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。
|
||||
func GetModelEnableGroups(modelName string) []string {
|
||||
// 确保缓存最新
|
||||
GetPricing()
|
||||
@@ -19,16 +17,15 @@ func GetModelEnableGroups(modelName string) []string {
|
||||
return groups
|
||||
}
|
||||
|
||||
// GetModelQuotaType 返回指定模型的计费类型(quota_type)。
|
||||
// 同样使用缓存映射,避免每次遍历定价切片。
|
||||
func GetModelQuotaType(modelName string) int {
|
||||
// GetModelQuotaTypes 返回指定模型的计费类型集合(来自缓存)
|
||||
func GetModelQuotaTypes(modelName string) []int {
|
||||
GetPricing()
|
||||
|
||||
modelEnableGroupsLock.RLock()
|
||||
quota, ok := modelQuotaTypeMap[modelName]
|
||||
modelEnableGroupsLock.RUnlock()
|
||||
if !ok {
|
||||
return 0
|
||||
return []int{}
|
||||
}
|
||||
return quota
|
||||
return []int{quota}
|
||||
}
|
||||
|
||||
+34
-96
@@ -3,30 +3,15 @@ package model
|
||||
import (
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Model 用于存储模型的元数据,例如描述、标签等
|
||||
// ModelName 字段具有唯一性约束,确保每个模型只会出现一次
|
||||
// Tags 字段使用逗号分隔的字符串保存标签集合,后期可根据需要扩展为 JSON 类型
|
||||
// Status: 1 表示启用,0 表示禁用,保留以便后续功能扩展
|
||||
// CreatedTime 和 UpdatedTime 使用 Unix 时间戳(秒)保存方便跨数据库移植
|
||||
// DeletedAt 采用 GORM 的软删除特性,便于后续数据恢复
|
||||
//
|
||||
// 该表设计遵循第三范式(3NF):
|
||||
// 1. 每一列都与主键(Id 或 ModelName)直接相关
|
||||
// 2. 不存在部分依赖(ModelName 是唯一键)
|
||||
// 3. 不存在传递依赖(描述、标签等都依赖于 ModelName,而非依赖于其他非主键列)
|
||||
// 这样既保证了数据一致性,也方便后期扩展
|
||||
|
||||
// 模型名称匹配规则
|
||||
const (
|
||||
NameRuleExact = iota // 0 精确匹配
|
||||
NameRulePrefix // 1 前缀匹配
|
||||
NameRuleContains // 2 包含匹配
|
||||
NameRuleSuffix // 3 后缀匹配
|
||||
NameRuleExact = iota
|
||||
NameRulePrefix
|
||||
NameRuleContains
|
||||
NameRuleSuffix
|
||||
)
|
||||
|
||||
type BoundChannel struct {
|
||||
@@ -49,11 +34,13 @@ type Model struct {
|
||||
|
||||
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||||
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||||
QuotaType int `json:"quota_type" gorm:"-"`
|
||||
QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
|
||||
NameRule int `json:"name_rule" gorm:"default:0"`
|
||||
|
||||
MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
|
||||
MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
|
||||
}
|
||||
|
||||
// Insert 创建新的模型元数据记录
|
||||
func (mi *Model) Insert() error {
|
||||
now := common.GetTimestamp()
|
||||
mi.CreatedTime = now
|
||||
@@ -61,7 +48,6 @@ func (mi *Model) Insert() error {
|
||||
return DB.Create(mi).Error
|
||||
}
|
||||
|
||||
// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID)
|
||||
func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||||
if name == "" {
|
||||
return false, nil
|
||||
@@ -71,10 +57,8 @@ func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||||
return cnt > 0, err
|
||||
}
|
||||
|
||||
// Update 更新现有模型记录
|
||||
func (mi *Model) Update() error {
|
||||
mi.UpdatedTime = common.GetTimestamp()
|
||||
// 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新
|
||||
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
|
||||
Model(&Model{}).
|
||||
Where("id = ?", mi.Id).
|
||||
@@ -83,22 +67,10 @@ func (mi *Model) Update() error {
|
||||
Updates(mi).Error
|
||||
}
|
||||
|
||||
// Delete 软删除模型记录
|
||||
func (mi *Model) Delete() error {
|
||||
return DB.Delete(mi).Error
|
||||
}
|
||||
|
||||
// GetModelByName 根据模型名称查询元数据
|
||||
func GetModelByName(name string) (*Model, error) {
|
||||
var mi Model
|
||||
err := DB.Where("model_name = ?", name).First(&mi).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mi, nil
|
||||
}
|
||||
|
||||
// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响)
|
||||
func GetVendorModelCounts() (map[int64]int64, error) {
|
||||
var stats []struct {
|
||||
VendorID int64
|
||||
@@ -117,72 +89,38 @@ func GetVendorModelCounts() (map[int64]int64, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetAllModels 分页获取所有模型元数据
|
||||
func GetAllModels(offset int, limit int) ([]*Model, error) {
|
||||
var models []*Model
|
||||
err := DB.Offset(offset).Limit(limit).Find(&models).Error
|
||||
err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
|
||||
return models, err
|
||||
}
|
||||
|
||||
// GetBoundChannels 查询支持该模型的渠道(名称+类型)
|
||||
func GetBoundChannels(modelName string) ([]BoundChannel, error) {
|
||||
var channels []BoundChannel
|
||||
err := DB.Table("channels").
|
||||
Select("channels.name, channels.type").
|
||||
Joins("join abilities on abilities.channel_id = channels.id").
|
||||
Where("abilities.model = ? AND abilities.enabled = ?", modelName, true).
|
||||
Group("channels.id").
|
||||
Scan(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
|
||||
func FindModelByNameWithRule(name string) (*Model, error) {
|
||||
// 1. 精确匹配
|
||||
if m, err := GetModelByName(name); err == nil {
|
||||
return m, nil
|
||||
func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
|
||||
result := make(map[string][]BoundChannel)
|
||||
if len(modelNames) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
// 2. 规则匹配
|
||||
var models []*Model
|
||||
if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil {
|
||||
type row struct {
|
||||
Model string
|
||||
Name string
|
||||
Type int
|
||||
}
|
||||
var rows []row
|
||||
err := DB.Table("channels").
|
||||
Select("abilities.model as model, channels.name as name, channels.type as type").
|
||||
Joins("JOIN abilities ON abilities.channel_id = channels.id").
|
||||
Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
|
||||
Distinct().
|
||||
Scan(&rows).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var prefixMatch, suffixMatch, containsMatch *Model
|
||||
for _, m := range models {
|
||||
switch m.NameRule {
|
||||
case NameRulePrefix:
|
||||
if strings.HasPrefix(name, m.ModelName) {
|
||||
if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) {
|
||||
prefixMatch = m
|
||||
}
|
||||
}
|
||||
case NameRuleSuffix:
|
||||
if strings.HasSuffix(name, m.ModelName) {
|
||||
if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) {
|
||||
suffixMatch = m
|
||||
}
|
||||
}
|
||||
case NameRuleContains:
|
||||
if strings.Contains(name, m.ModelName) {
|
||||
if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) {
|
||||
containsMatch = m
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, r := range rows {
|
||||
result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
|
||||
}
|
||||
if prefixMatch != nil {
|
||||
return prefixMatch, nil
|
||||
}
|
||||
if suffixMatch != nil {
|
||||
return suffixMatch, nil
|
||||
}
|
||||
if containsMatch != nil {
|
||||
return containsMatch, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SearchModels 根据关键词和供应商搜索模型,支持分页
|
||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||
var models []*Model
|
||||
db := DB.Model(&Model{})
|
||||
@@ -191,7 +129,6 @@ func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Mode
|
||||
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
|
||||
}
|
||||
if vendor != "" {
|
||||
// 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配
|
||||
if vid, err := strconv.Atoi(vendor); err == nil {
|
||||
db = db.Where("models.vendor_id = ?", vid)
|
||||
} else {
|
||||
@@ -199,10 +136,11 @@ func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Mode
|
||||
}
|
||||
}
|
||||
var total int64
|
||||
err := db.Count(&total).Error
|
||||
if err != nil {
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error
|
||||
return models, total, err
|
||||
if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return models, total, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
@@ -38,6 +39,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
||||
case constant.RelayModeResponses:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.BaseUrl, info.ApiVersion), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||
}
|
||||
@@ -62,8 +65,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
@@ -110,6 +112,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
} else {
|
||||
err, usage = cfHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeResponses:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OaiResponsesHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
|
||||
@@ -53,13 +53,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
}
|
||||
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
@@ -267,24 +267,23 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
tool.Function.Parameters = cleanedParams
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
geminiTools := geminiRequest.GetTools()
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
|
||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
||||
// common.SysLog("tools_json: " + string(json_data))
|
||||
geminiRequest.SetTools(geminiTools)
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
|
||||
@@ -126,13 +126,26 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
task = strings.TrimPrefix(task, "messages")
|
||||
task = "chat/completions" + task
|
||||
}
|
||||
|
||||
// 特殊处理 responses API
|
||||
if info.RelayMode == relayconstant.RelayModeResponses {
|
||||
responsesApiVersion := "preview"
|
||||
|
||||
subUrl := "/openai/v1/responses"
|
||||
if strings.Contains(info.BaseUrl, "cognitiveservices.azure.com") {
|
||||
subUrl = "/openai/responses"
|
||||
responsesApiVersion = apiVersion
|
||||
}
|
||||
|
||||
if info.ChannelOtherSettings.AzureResponsesVersion != "" {
|
||||
responsesApiVersion = info.ChannelOtherSettings.AzureResponsesVersion
|
||||
}
|
||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=%s", responsesApiVersion)
|
||||
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
|
||||
@@ -243,34 +256,34 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o") || strings.HasPrefix(request.Model, "gpt-5") {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
|
||||
if strings.HasPrefix(request.Model, "o") {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") {
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(request.Model, "gpt-5") {
|
||||
if request.Model != "gpt-5-chat-latest" {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if info.UpstreamModelName != "gpt-5-chat-latest" {
|
||||
request.Temperature = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName)
|
||||
if effort != "" {
|
||||
request.ReasoningEffort = effort
|
||||
info.UpstreamModelName = originModel
|
||||
request.Model = originModel
|
||||
}
|
||||
|
||||
info.ReasoningEffort = request.ReasoningEffort
|
||||
info.UpstreamModelName = request.Model
|
||||
|
||||
// o系列模型developer适配(o1-mini除外)
|
||||
if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") {
|
||||
if !strings.HasPrefix(info.UpstreamModelName, "o1-mini") && !strings.HasPrefix(info.UpstreamModelName, "o1-preview") {
|
||||
//修改第一个Message的内容,将system改为developer
|
||||
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
|
||||
request.Messages[0].Role = "developer"
|
||||
|
||||
@@ -50,5 +50,6 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
ToolChoice: request.ToolChoice,
|
||||
THINKING: request.THINKING,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,6 +258,9 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
||||
|
||||
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||
taskId := c.Param("task_id")
|
||||
if taskId == "" {
|
||||
taskId = c.GetString("task_id")
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
|
||||
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||
|
||||
+15
-15
@@ -24,7 +24,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
|
||||
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
|
||||
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
@@ -67,7 +67,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
selfRoute.PUT("/setting", controller.UpdateUserSetting)
|
||||
|
||||
|
||||
// 2FA routes
|
||||
selfRoute.GET("/2fa/status", controller.Get2FAStatus)
|
||||
selfRoute.POST("/2fa/setup", controller.Setup2FA)
|
||||
@@ -86,7 +86,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
adminRoute.POST("/manage", controller.ManageUser)
|
||||
adminRoute.PUT("/", controller.UpdateUser)
|
||||
adminRoute.DELETE("/:id", controller.DeleteUser)
|
||||
|
||||
|
||||
// Admin 2FA routes
|
||||
adminRoute.GET("/2fa/stats", controller.Admin2FAStats)
|
||||
adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA)
|
||||
@@ -200,22 +200,22 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
vendorRoute := apiRouter.Group("/vendors")
|
||||
vendorRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
vendorRoute.GET("/", controller.GetAllVendors)
|
||||
vendorRoute.GET("/search", controller.SearchVendors)
|
||||
vendorRoute.GET("/:id", controller.GetVendorMeta)
|
||||
vendorRoute.POST("/", controller.CreateVendorMeta)
|
||||
vendorRoute.PUT("/", controller.UpdateVendorMeta)
|
||||
vendorRoute.DELETE("/:id", controller.DeleteVendorMeta)
|
||||
}
|
||||
vendorRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
vendorRoute.GET("/", controller.GetAllVendors)
|
||||
vendorRoute.GET("/search", controller.SearchVendors)
|
||||
vendorRoute.GET("/:id", controller.GetVendorMeta)
|
||||
vendorRoute.POST("/", controller.CreateVendorMeta)
|
||||
vendorRoute.PUT("/", controller.UpdateVendorMeta)
|
||||
vendorRoute.DELETE("/:id", controller.DeleteVendorMeta)
|
||||
}
|
||||
|
||||
modelsRoute := apiRouter.Group("/models")
|
||||
modelsRoute := apiRouter.Group("/models")
|
||||
modelsRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
modelsRoute.GET("/missing", controller.GetMissingModels)
|
||||
modelsRoute.GET("/", controller.GetAllModelsMeta)
|
||||
modelsRoute.GET("/search", controller.SearchModelsMeta)
|
||||
modelsRoute.GET("/", controller.GetAllModelsMeta)
|
||||
modelsRoute.GET("/search", controller.SearchModelsMeta)
|
||||
modelsRoute.GET("/:id", controller.GetModelMeta)
|
||||
modelsRoute.POST("/", controller.CreateModelMeta)
|
||||
modelsRoute.PUT("/", controller.UpdateModelMeta)
|
||||
|
||||
+38
-4
@@ -1,11 +1,11 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/constant"
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
"one-api/relay"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetRelayRouter(router *gin.Engine) {
|
||||
@@ -16,9 +16,43 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.TokenAuth())
|
||||
{
|
||||
modelsRouter.GET("", controller.ListModels)
|
||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||
modelsRouter.GET("", func(c *gin.Context) {
|
||||
switch {
|
||||
case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
|
||||
controller.ListModels(c, constant.ChannelTypeAnthropic)
|
||||
case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配
|
||||
controller.RetrieveModel(c, constant.ChannelTypeGemini)
|
||||
default:
|
||||
controller.ListModels(c, constant.ChannelTypeOpenAI)
|
||||
}
|
||||
})
|
||||
|
||||
modelsRouter.GET("/:model", func(c *gin.Context) {
|
||||
switch {
|
||||
case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
|
||||
controller.RetrieveModel(c, constant.ChannelTypeAnthropic)
|
||||
default:
|
||||
controller.RetrieveModel(c, constant.ChannelTypeOpenAI)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
geminiRouter := router.Group("/v1beta/models")
|
||||
geminiRouter.Use(middleware.TokenAuth())
|
||||
{
|
||||
geminiRouter.GET("", func(c *gin.Context) {
|
||||
controller.ListModels(c, constant.ChannelTypeGemini)
|
||||
})
|
||||
}
|
||||
|
||||
geminiCompatibleRouter := router.Group("/v1beta/openai/models")
|
||||
geminiCompatibleRouter.Use(middleware.TokenAuth())
|
||||
{
|
||||
geminiCompatibleRouter.GET("", func(c *gin.Context) {
|
||||
controller.ListModels(c, constant.ChannelTypeOpenAI)
|
||||
})
|
||||
}
|
||||
|
||||
playgroundRouter := router.Group("/pg")
|
||||
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
|
||||
{
|
||||
|
||||
@@ -23,4 +23,12 @@ func SetVideoRouter(router *gin.Engine) {
|
||||
klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
|
||||
klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
|
||||
}
|
||||
|
||||
// Jimeng official API routes - direct mapping to official API format
|
||||
jimengOfficialGroup := router.Group("jimeng")
|
||||
jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
// Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31
|
||||
jimengOfficialGroup.POST("/", controller.RelayTask)
|
||||
}
|
||||
}
|
||||
|
||||
+2
-2
@@ -569,9 +569,9 @@ func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycomm
|
||||
}
|
||||
|
||||
// 转换工具调用
|
||||
if len(geminiRequest.Tools) > 0 {
|
||||
if len(geminiRequest.GetTools()) > 0 {
|
||||
var tools []dto.ToolCallRequest
|
||||
for _, tool := range geminiRequest.Tools {
|
||||
for _, tool := range geminiRequest.GetTools() {
|
||||
if tool.FunctionDeclarations != nil {
|
||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||
|
||||
+16
-11
@@ -21,10 +21,11 @@ import React, { lazy, Suspense } from 'react';
|
||||
import { Route, Routes, useLocation } from 'react-router-dom';
|
||||
import Loading from './components/common/ui/Loading.js';
|
||||
import User from './pages/User';
|
||||
import { AuthRedirect, PrivateRoute } from './helpers';
|
||||
import { AuthRedirect, PrivateRoute, AdminRoute } from './helpers';
|
||||
import RegisterForm from './components/auth/RegisterForm.js';
|
||||
import LoginForm from './components/auth/LoginForm.js';
|
||||
import NotFound from './pages/NotFound';
|
||||
import Forbidden from './pages/Forbidden';
|
||||
import Setting from './pages/Setting';
|
||||
|
||||
import PasswordResetForm from './components/auth/PasswordResetForm.js';
|
||||
@@ -72,20 +73,24 @@ function App() {
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/forbidden'
|
||||
element={<Forbidden />}
|
||||
/>
|
||||
<Route
|
||||
path='/console/models'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<AdminRoute>
|
||||
<ModelPage />
|
||||
</PrivateRoute>
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/console/channel'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<AdminRoute>
|
||||
<Channel />
|
||||
</PrivateRoute>
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
@@ -107,17 +112,17 @@ function App() {
|
||||
<Route
|
||||
path='/console/redemption'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<AdminRoute>
|
||||
<Redemption />
|
||||
</PrivateRoute>
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/console/user'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<AdminRoute>
|
||||
<User />
|
||||
</PrivateRoute>
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
@@ -183,11 +188,11 @@ function App() {
|
||||
<Route
|
||||
path='/console/setting'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<AdminRoute>
|
||||
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
||||
<Setting />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
|
||||
@@ -80,6 +80,8 @@ const RegisterForm = () => {
|
||||
const [verificationCodeLoading, setVerificationCodeLoading] = useState(false);
|
||||
const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = useState(false);
|
||||
const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
|
||||
const [disableButton, setDisableButton] = useState(false);
|
||||
const [countdown, setCountdown] = useState(30);
|
||||
|
||||
const logo = getLogo();
|
||||
const systemName = getSystemName();
|
||||
@@ -106,6 +108,19 @@ const RegisterForm = () => {
|
||||
}
|
||||
}, [status]);
|
||||
|
||||
useEffect(() => {
|
||||
let countdownInterval = null;
|
||||
if (disableButton && countdown > 0) {
|
||||
countdownInterval = setInterval(() => {
|
||||
setCountdown(countdown - 1);
|
||||
}, 1000);
|
||||
} else if (countdown === 0) {
|
||||
setDisableButton(false);
|
||||
setCountdown(30);
|
||||
}
|
||||
return () => clearInterval(countdownInterval); // Clean up on unmount
|
||||
}, [disableButton, countdown]);
|
||||
|
||||
const onWeChatLoginClicked = () => {
|
||||
setWechatLoading(true);
|
||||
setShowWeChatLoginModal(true);
|
||||
@@ -198,6 +213,7 @@ const RegisterForm = () => {
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess('验证码发送成功,请检查你的邮箱!');
|
||||
setDisableButton(true); // 发送成功后禁用按钮,开始倒计时
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -454,9 +470,10 @@ const RegisterForm = () => {
|
||||
<Button
|
||||
onClick={sendVerificationCode}
|
||||
loading={verificationCodeLoading}
|
||||
disabled={disableButton || verificationCodeLoading}
|
||||
size="small"
|
||||
>
|
||||
{t('获取验证码')}
|
||||
{disableButton ? `${t('重新发送')} (${countdown})` : t('获取验证码')}
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
|
||||
@@ -41,7 +41,7 @@ const CardTable = ({
|
||||
}) => {
|
||||
const isMobile = useIsMobile();
|
||||
const { t } = useTranslation();
|
||||
|
||||
|
||||
const showSkeleton = useMinimumLoadingTime(loading);
|
||||
|
||||
const getRowKey = (record, index) => {
|
||||
|
||||
@@ -17,10 +17,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useRef } from 'react';
|
||||
import React, { useState } from 'react';
|
||||
import { useIsMobile } from '../../../hooks/common/useIsMobile';
|
||||
import { useMinimumLoadingTime } from '../../../hooks/common/useMinimumLoadingTime';
|
||||
import { Divider, Button, Tag, Row, Col, Collapsible, Checkbox, Skeleton } from '@douyinfe/semi-ui';
|
||||
import { Divider, Button, Tag, Row, Col, Collapsible, Checkbox, Skeleton, Tooltip } from '@douyinfe/semi-ui';
|
||||
import { IconChevronDown, IconChevronUp } from '@douyinfe/semi-icons';
|
||||
|
||||
/**
|
||||
@@ -57,8 +57,6 @@ const SelectableButtonGroup = ({
|
||||
const needCollapse = collapsible && items.length > perRow * maxVisibleRows;
|
||||
const showSkeleton = useMinimumLoadingTime(loading);
|
||||
|
||||
const contentRef = useRef(null);
|
||||
|
||||
const maskStyle = isOpen
|
||||
? {}
|
||||
: {
|
||||
@@ -131,7 +129,7 @@ const SelectableButtonGroup = ({
|
||||
};
|
||||
|
||||
const contentElement = showSkeleton ? renderSkeletonButtons() : (
|
||||
<Row gutter={[8, 8]} style={{ lineHeight: '32px', ...style }} ref={contentRef}>
|
||||
<Row gutter={[8, 8]} style={{ lineHeight: '32px', ...style }}>
|
||||
{items.map((item) => {
|
||||
const isDisabled = item.disabled || (typeof item.tagCount === 'number' && item.tagCount === 0);
|
||||
const isActive = Array.isArray(activeValue)
|
||||
@@ -152,6 +150,7 @@ const SelectableButtonGroup = ({
|
||||
theme={isActive ? 'light' : 'outline'}
|
||||
type={isActive ? 'primary' : 'tertiary'}
|
||||
disabled={isDisabled}
|
||||
className="sbg-button"
|
||||
icon={
|
||||
<Checkbox
|
||||
checked={isActive}
|
||||
@@ -162,19 +161,15 @@ const SelectableButtonGroup = ({
|
||||
}
|
||||
style={{ width: '100%', cursor: 'default' }}
|
||||
>
|
||||
{item.icon && (
|
||||
<span style={{ marginRight: 4 }}>{item.icon}</span>
|
||||
)}
|
||||
<span style={{ marginRight: item.tagCount !== undefined ? 4 : 0 }}>{item.label}</span>
|
||||
{item.tagCount !== undefined && (
|
||||
<Tag
|
||||
color='white'
|
||||
shape="circle"
|
||||
size="small"
|
||||
>
|
||||
{item.tagCount}
|
||||
</Tag>
|
||||
)}
|
||||
<div className="sbg-content">
|
||||
{item.icon && (<span className="sbg-icon">{item.icon}</span>)}
|
||||
<Tooltip content={item.label}>
|
||||
<span className="sbg-ellipsis">{item.label}</span>
|
||||
</Tooltip>
|
||||
{item.tagCount !== undefined && (
|
||||
<Tag className="sbg-tag" color='white' shape="circle" size="small">{item.tagCount}</Tag>
|
||||
)}
|
||||
</div>
|
||||
</Button>
|
||||
</Col>
|
||||
);
|
||||
@@ -192,20 +187,19 @@ const SelectableButtonGroup = ({
|
||||
onClick={() => onChange(item.value)}
|
||||
theme={isActive ? 'light' : 'outline'}
|
||||
type={isActive ? 'primary' : 'tertiary'}
|
||||
icon={item.icon}
|
||||
disabled={isDisabled}
|
||||
className="sbg-button"
|
||||
style={{ width: '100%' }}
|
||||
>
|
||||
<span style={{ marginRight: item.tagCount !== undefined ? 4 : 0 }}>{item.label}</span>
|
||||
{item.tagCount !== undefined && (
|
||||
<Tag
|
||||
color='white'
|
||||
shape="circle"
|
||||
size="small"
|
||||
>
|
||||
{item.tagCount}
|
||||
</Tag>
|
||||
)}
|
||||
<div className="sbg-content">
|
||||
{item.icon && (<span className="sbg-icon">{item.icon}</span>)}
|
||||
<Tooltip content={item.label}>
|
||||
<span className="sbg-ellipsis">{item.label}</span>
|
||||
</Tooltip>
|
||||
{item.tagCount !== undefined && (
|
||||
<Tag className="sbg-tag" color='white' shape="circle" size="small">{item.tagCount}</Tag>
|
||||
)}
|
||||
</div>
|
||||
</Button>
|
||||
</Col>
|
||||
);
|
||||
|
||||
@@ -135,7 +135,7 @@ const ModelSelectModal = ({ visible, models = [], selected = [], onConfirm, onCa
|
||||
const allActiveKeys = categoryEntries.map((_, index) => `${categoryKeyPrefix}_${index}`);
|
||||
|
||||
return (
|
||||
<Collapse activeKey={allActiveKeys}>
|
||||
<Collapse key={`${categoryKeyPrefix}_${categoryEntries.length}`} defaultActiveKey={[]}>
|
||||
{categoryEntries.map(([key, categoryData], index) => (
|
||||
<Collapse.Panel
|
||||
key={`${categoryKeyPrefix}_${index}`}
|
||||
|
||||
@@ -63,7 +63,7 @@ const ModelPricingTable = ({
|
||||
key: group,
|
||||
group: group,
|
||||
ratio: groupRatioValue,
|
||||
billingType: modelData?.quota_type === 0 ? t('按量计费') : t('按次计费'),
|
||||
billingType: modelData?.quota_type === 0 ? t('按量计费') : (modelData?.quota_type === 1 ? t('按次计费') : '-'),
|
||||
inputPrice: modelData?.quota_type === 0 ? priceData.inputPrice : '-',
|
||||
outputPrice: modelData?.quota_type === 0 ? (priceData.completionPrice || priceData.outputPrice) : '-',
|
||||
fixedPrice: modelData?.quota_type === 1 ? priceData.price : '-',
|
||||
@@ -100,11 +100,16 @@ const ModelPricingTable = ({
|
||||
columns.push({
|
||||
title: t('计费类型'),
|
||||
dataIndex: 'billingType',
|
||||
render: (text) => (
|
||||
<Tag color={text === t('按量计费') ? 'violet' : 'teal'} size="small" shape="circle">
|
||||
{text}
|
||||
</Tag>
|
||||
),
|
||||
render: (text) => {
|
||||
let color = 'white';
|
||||
if (text === t('按量计费')) color = 'violet';
|
||||
else if (text === t('按次计费')) color = 'teal';
|
||||
return (
|
||||
<Tag color={color} size="small" shape="circle">
|
||||
{text || '-'}
|
||||
</Tag>
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
// 根据计费类型添加价格列
|
||||
|
||||
@@ -26,7 +26,7 @@ const PricingCardSkeleton = ({
|
||||
showRatio = false
|
||||
}) => {
|
||||
const placeholder = (
|
||||
<div className="p-4">
|
||||
<div className="px-4">
|
||||
<div className="grid grid-cols-1 xl:grid-cols-2 2xl:grid-cols-3 gap-4">
|
||||
{Array.from({ length: skeletonCount }).map((_, index) => (
|
||||
<Card
|
||||
@@ -123,7 +123,7 @@ const PricingCardSkeleton = ({
|
||||
</div>
|
||||
|
||||
{/* 分页骨架 */}
|
||||
<div className="flex justify-center mt-6 pt-4 border-t pricing-pagination-divider">
|
||||
<div className="flex justify-center mt-6 py-4 border-t pricing-pagination-divider">
|
||||
<Skeleton.Button style={{ width: 300, height: 32 }} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -144,13 +144,24 @@ const PricingCardView = ({
|
||||
// 渲染标签
|
||||
const renderTags = (record) => {
|
||||
// 计费类型标签(左边)
|
||||
const billingType = record.quota_type === 1 ? 'teal' : 'violet';
|
||||
const billingText = record.quota_type === 1 ? t('按次计费') : t('按量计费');
|
||||
const billingTag = (
|
||||
<Tag key="billing" shape='circle' color={billingType} size='small'>
|
||||
{billingText}
|
||||
let billingTag = (
|
||||
<Tag key="billing" shape='circle' color='white' size='small'>
|
||||
-
|
||||
</Tag>
|
||||
);
|
||||
if (record.quota_type === 1) {
|
||||
billingTag = (
|
||||
<Tag key="billing" shape='circle' color='teal' size='small'>
|
||||
{t('按次计费')}
|
||||
</Tag>
|
||||
);
|
||||
} else if (record.quota_type === 0) {
|
||||
billingTag = (
|
||||
<Tag key="billing" shape='circle' color='violet' size='small'>
|
||||
{t('按量计费')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
// 自定义标签(右边)
|
||||
const customTags = [];
|
||||
@@ -204,7 +215,7 @@ const PricingCardView = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-4">
|
||||
<div className="px-4">
|
||||
<div className="grid grid-cols-1 xl:grid-cols-2 2xl:grid-cols-3 gap-4">
|
||||
{paginatedModels.map((model, index) => {
|
||||
const modelKey = getModelKey(model);
|
||||
@@ -316,7 +327,7 @@ const PricingCardView = ({
|
||||
|
||||
{/* 分页 */}
|
||||
{filteredModels.length > 0 && (
|
||||
<div className="flex justify-center mt-6 pt-4 border-t pricing-pagination-divider">
|
||||
<div className="flex justify-center mt-6 py-4 border-t pricing-pagination-divider">
|
||||
<Pagination
|
||||
currentPage={currentPage}
|
||||
pageSize={pageSize}
|
||||
|
||||
@@ -18,7 +18,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Button, Space, Tag, Typography, Modal } from '@douyinfe/semi-ui';
|
||||
import { Button, Space, Tag, Typography, Modal, Tooltip } from '@douyinfe/semi-ui';
|
||||
import {
|
||||
timestamp2string,
|
||||
getLobeHubIcon,
|
||||
@@ -121,23 +121,34 @@ const renderEndpoints = (value) => {
|
||||
}
|
||||
};
|
||||
|
||||
// Render quota type
|
||||
const renderQuotaType = (qt, t) => {
|
||||
if (qt === 1) {
|
||||
return (
|
||||
<Tag color='teal' size='small' shape='circle'>
|
||||
{t('按次计费')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
if (qt === 0) {
|
||||
return (
|
||||
<Tag color='violet' size='small' shape='circle'>
|
||||
{t('按量计费')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
return qt ?? '-';
|
||||
// Render quota types (array) using common limited items renderer
|
||||
const renderQuotaTypes = (arr, t) => {
|
||||
if (!Array.isArray(arr) || arr.length === 0) return '-';
|
||||
return renderLimitedItems({
|
||||
items: arr,
|
||||
renderItem: (qt, idx) => {
|
||||
if (qt === 1) {
|
||||
return (
|
||||
<Tag key={`${qt}-${idx}`} color='teal' size='small' shape='circle'>
|
||||
{t('按次计费')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
if (qt === 0) {
|
||||
return (
|
||||
<Tag key={`${qt}-${idx}`} color='violet' size='small' shape='circle'>
|
||||
{t('按量计费')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Tag key={`${qt}-${idx}`} color='white' size='small' shape='circle'>
|
||||
{qt}
|
||||
</Tag>
|
||||
);
|
||||
},
|
||||
maxDisplay: 3,
|
||||
});
|
||||
};
|
||||
|
||||
// Render bound channels
|
||||
@@ -207,8 +218,8 @@ const renderOperations = (text, record, setEditingModel, setShowEdit, manageMode
|
||||
);
|
||||
};
|
||||
|
||||
// 名称匹配类型渲染
|
||||
const renderNameRule = (rule, t) => {
|
||||
// 名称匹配类型渲染(带匹配数量 Tooltip)
|
||||
const renderNameRule = (rule, record, t) => {
|
||||
const map = {
|
||||
0: { color: 'green', label: t('精确') },
|
||||
1: { color: 'blue', label: t('前缀') },
|
||||
@@ -217,11 +228,27 @@ const renderNameRule = (rule, t) => {
|
||||
};
|
||||
const cfg = map[rule];
|
||||
if (!cfg) return '-';
|
||||
return (
|
||||
|
||||
let label = cfg.label;
|
||||
if (rule !== 0 && record.matched_count) {
|
||||
label = `${cfg.label} ${record.matched_count}${t('个模型')}`;
|
||||
}
|
||||
|
||||
const tagElement = (
|
||||
<Tag color={cfg.color} size="small" shape='circle'>
|
||||
{cfg.label}
|
||||
{label}
|
||||
</Tag>
|
||||
);
|
||||
|
||||
if (rule === 0 || !record.matched_models || record.matched_models.length === 0) {
|
||||
return tagElement;
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip content={record.matched_models.join(', ')} showArrow>
|
||||
{tagElement}
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export const getModelsColumns = ({
|
||||
@@ -252,7 +279,7 @@ export const getModelsColumns = ({
|
||||
{
|
||||
title: t('匹配类型'),
|
||||
dataIndex: 'name_rule',
|
||||
render: (val) => renderNameRule(val, t),
|
||||
render: (val, record) => renderNameRule(val, record, t),
|
||||
},
|
||||
{
|
||||
title: t('描述'),
|
||||
@@ -286,8 +313,8 @@ export const getModelsColumns = ({
|
||||
},
|
||||
{
|
||||
title: t('计费类型'),
|
||||
dataIndex: 'quota_type',
|
||||
render: (qt) => renderQuotaType(qt, t),
|
||||
dataIndex: 'quota_types',
|
||||
render: (qts) => renderQuotaTypes(qts, t),
|
||||
},
|
||||
{
|
||||
title: t('创建时间'),
|
||||
|
||||
@@ -49,4 +49,20 @@ function PrivateRoute({ children }) {
|
||||
return children;
|
||||
}
|
||||
|
||||
export function AdminRoute({ children }) {
|
||||
const raw = localStorage.getItem('user');
|
||||
if (!raw) {
|
||||
return <Navigate to='/login' state={{ from: history.location }} />;
|
||||
}
|
||||
try {
|
||||
const user = JSON.parse(raw);
|
||||
if (user && typeof user.role === 'number' && user.role >= 10) {
|
||||
return children;
|
||||
}
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
return <Navigate to='/forbidden' replace />;
|
||||
}
|
||||
|
||||
export { PrivateRoute };
|
||||
|
||||
@@ -632,12 +632,22 @@ export const calculateModelPrice = ({
|
||||
};
|
||||
}
|
||||
|
||||
// 按次计费
|
||||
const priceUSD = parseFloat(record.model_price) * usedGroupRatio;
|
||||
const displayVal = displayPrice(priceUSD);
|
||||
if (record.quota_type === 1) {
|
||||
// 按次计费
|
||||
const priceUSD = parseFloat(record.model_price) * usedGroupRatio;
|
||||
const displayVal = displayPrice(priceUSD);
|
||||
|
||||
return {
|
||||
price: displayVal,
|
||||
isPerToken: false,
|
||||
usedGroup,
|
||||
usedGroupRatio,
|
||||
};
|
||||
}
|
||||
|
||||
// 未知计费类型,返回占位信息
|
||||
return {
|
||||
price: displayVal,
|
||||
price: '-',
|
||||
isPerToken: false,
|
||||
usedGroup,
|
||||
usedGroupRatio,
|
||||
|
||||
@@ -1459,6 +1459,7 @@
|
||||
"设计与开发由": "Designed & Developed with love by",
|
||||
"演示站点": "Demo Site",
|
||||
"页面未找到,请检查您的浏览器地址是否正确": "Page not found, please check if your browser address is correct",
|
||||
"您无权访问此页面,请联系管理员": "You do not have permission to access this page. Please contact the administrator.",
|
||||
"New API项目仓库地址:": "New API project repository address: ",
|
||||
"© {{currentYear}}": "© {{currentYear}}",
|
||||
"| 基于": " | Based on ",
|
||||
|
||||
+21
-2
@@ -289,6 +289,27 @@ code {
|
||||
}
|
||||
|
||||
/* ==================== 组件特定样式 ==================== */
|
||||
/* SelectableButtonGroup */
|
||||
.sbg-button .semi-button-content {
|
||||
min-width: 0 !important;
|
||||
}
|
||||
|
||||
.sbg-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.sbg-ellipsis {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
/* Tabs组件样式 */
|
||||
.semi-tabs-content {
|
||||
padding: 0 !important;
|
||||
@@ -686,7 +707,6 @@ html.dark .with-pastel-balls::before {
|
||||
max-width: 460px;
|
||||
height: calc(100vh - 60px);
|
||||
background-color: var(--semi-color-bg-0);
|
||||
border-right: 1px solid var(--semi-color-border);
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
@@ -710,7 +730,6 @@ html.dark .with-pastel-balls::before {
|
||||
|
||||
.pricing-search-header {
|
||||
padding: 1rem;
|
||||
border-bottom: 1px solid var(--semi-color-border);
|
||||
background-color: var(--semi-color-bg-0);
|
||||
flex-shrink: 0;
|
||||
position: sticky;
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Empty } from '@douyinfe/semi-ui';
|
||||
import { IllustrationNoAccess, IllustrationNoAccessDark } from '@douyinfe/semi-illustrations';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const Forbidden = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<div className="flex justify-center items-center h-screen p-8">
|
||||
<Empty
|
||||
image={<IllustrationNoAccess style={{ width: 250, height: 250 }} />}
|
||||
darkModeImage={<IllustrationNoAccessDark style={{ width: 250, height: 250 }} />}
|
||||
description={t('您无权访问此页面,请联系管理员')}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Forbidden;
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ import { useTranslation } from 'react-i18next';
|
||||
const NotFound = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<div className="flex justify-center items-center h-screen p-8 mt-[60px]">
|
||||
<div className="flex justify-center items-center h-screen p-8">
|
||||
<Empty
|
||||
image={<IllustrationNotFound style={{ width: 250, height: 250 }} />}
|
||||
darkModeImage={<IllustrationNotFoundDark style={{ width: 250, height: 250 }} />}
|
||||
|
||||
Reference in New Issue
Block a user