Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c17439a14f | |||
| 95775a5a52 | |||
| 8bb1522de6 | |||
| 3d8a72d15c | |||
| 83dc498268 | |||
| d03c1f4bea | |||
| 14cca53ec8 | |||
| f45caf29af | |||
| 56dce5a745 | |||
| dd9a1534e0 | |||
| 5cb4c4e644 | |||
| 12977d9a52 | |||
| 13eb2ec1f1 | |||
| 62e82256ef | |||
| bb00bd0ffe | |||
| f14f4eaf1a | |||
| 448184a250 | |||
| b4e3147124 | |||
| 0359dfb8c3 | |||
| c7b3b8d0b1 | |||
| 400a7e6e5b | |||
| 111dc18d26 |
@@ -0,0 +1,19 @@
|
||||
### PR 类型
|
||||
|
||||
- [ ] Bug 修复
|
||||
- [ ] 新功能
|
||||
- [ ] 文档更新
|
||||
- [ ] 其他
|
||||
|
||||
### PR 是否包含破坏性更新?
|
||||
|
||||
- [ ] 是
|
||||
- [ ] 否
|
||||
|
||||
### PR 描述
|
||||
|
||||
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||
|
||||
### **重要提示**
|
||||
|
||||
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
|
||||
@@ -26,6 +26,7 @@ jobs:
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
name: Check PR Branching Strategy
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
|
||||
jobs:
|
||||
check-branching-strategy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Enforce branching strategy
|
||||
run: |
|
||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "Branching strategy check passed."
|
||||
@@ -27,9 +27,6 @@
|
||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||
</a>
|
||||
<a href="https://coderabbit.ai">
|
||||
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case constant.ChannelTypeOpenAI:
|
||||
apiType = constant.APITypeOpenAI
|
||||
case constant.ChannelTypeAnthropic:
|
||||
apiType = constant.APITypeAnthropic
|
||||
case constant.ChannelTypeBaidu:
|
||||
apiType = constant.APITypeBaidu
|
||||
case constant.ChannelTypePaLM:
|
||||
apiType = constant.APITypePaLM
|
||||
case constant.ChannelTypeZhipu:
|
||||
apiType = constant.APITypeZhipu
|
||||
case constant.ChannelTypeAli:
|
||||
apiType = constant.APITypeAli
|
||||
case constant.ChannelTypeXunfei:
|
||||
apiType = constant.APITypeXunfei
|
||||
case constant.ChannelTypeAIProxyLibrary:
|
||||
apiType = constant.APITypeAIProxyLibrary
|
||||
case constant.ChannelTypeTencent:
|
||||
apiType = constant.APITypeTencent
|
||||
case constant.ChannelTypeGemini:
|
||||
apiType = constant.APITypeGemini
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
apiType = constant.APITypeZhipuV4
|
||||
case constant.ChannelTypeOllama:
|
||||
apiType = constant.APITypeOllama
|
||||
case constant.ChannelTypePerplexity:
|
||||
apiType = constant.APITypePerplexity
|
||||
case constant.ChannelTypeAws:
|
||||
apiType = constant.APITypeAws
|
||||
case constant.ChannelTypeCohere:
|
||||
apiType = constant.APITypeCohere
|
||||
case constant.ChannelTypeDify:
|
||||
apiType = constant.APITypeDify
|
||||
case constant.ChannelTypeJina:
|
||||
apiType = constant.APITypeJina
|
||||
case constant.ChannelCloudflare:
|
||||
apiType = constant.APITypeCloudflare
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
apiType = constant.APITypeSiliconFlow
|
||||
case constant.ChannelTypeVertexAi:
|
||||
apiType = constant.APITypeVertexAi
|
||||
case constant.ChannelTypeMistral:
|
||||
apiType = constant.APITypeMistral
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
apiType = constant.APITypeDeepSeek
|
||||
case constant.ChannelTypeMokaAI:
|
||||
apiType = constant.APITypeMokaAI
|
||||
case constant.ChannelTypeVolcEngine:
|
||||
apiType = constant.APITypeVolcEngine
|
||||
case constant.ChannelTypeBaiduV2:
|
||||
apiType = constant.APITypeBaiduV2
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
apiType = constant.APITypeOpenRouter
|
||||
case constant.ChannelTypeXinference:
|
||||
apiType = constant.APITypeXinference
|
||||
case constant.ChannelTypeXai:
|
||||
apiType = constant.APITypeXai
|
||||
case constant.ChannelTypeCoze:
|
||||
apiType = constant.APITypeCoze
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
@@ -193,111 +193,3 @@ const (
|
||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||
var endpointTypes []constant.EndpointType
|
||||
switch channelType {
|
||||
case constant.ChannelTypeJina:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
|
||||
//case constant.ChannelTypeSunoAPI:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
|
||||
//case constant.ChannelTypeKling:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
|
||||
//case constant.ChannelTypeJimeng:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
|
||||
case constant.ChannelTypeAws:
|
||||
fallthrough
|
||||
case constant.ChannelTypeAnthropic:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
fallthrough
|
||||
case constant.ChannelTypeGemini:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
default:
|
||||
if IsOpenAIResponseOnlyModel(modelName) {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||
} else {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
}
|
||||
}
|
||||
if IsImageGenerationModel(modelName) {
|
||||
// add to first
|
||||
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
|
||||
}
|
||||
return endpointTypes
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"bytes"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
@@ -42,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||
c.Set(string(key), value)
|
||||
}
|
||||
|
||||
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||
return c.Get(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||
return c.GetString(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||
return c.GetInt(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||
return c.GetBool(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||
return c.GetStringSlice(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||
return c.GetStringMap(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||
return c.GetTime(string(key))
|
||||
}
|
||||
|
||||
+23
-1
@@ -4,6 +4,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -24,7 +25,7 @@ func printHelp() {
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
}
|
||||
|
||||
func InitCommonEnv() {
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
@@ -95,4 +96,25 @@ func InitCommonEnv() {
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
initConstantEnv()
|
||||
}
|
||||
|
||||
func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package common
|
||||
|
||||
import "strings"
|
||||
|
||||
var (
|
||||
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||
OpenAIResponseOnlyModels = []string{
|
||||
"o3-pro",
|
||||
"o3-deep-research",
|
||||
"o4-mini-deep-research",
|
||||
}
|
||||
ImageGenerationModels = []string{
|
||||
"dall-e-3",
|
||||
"dall-e-2",
|
||||
"gpt-image-1",
|
||||
"prefix:imagen-",
|
||||
"flux-",
|
||||
"flux.1-",
|
||||
}
|
||||
)
|
||||
|
||||
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||
for _, m := range OpenAIResponseOnlyModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsImageGenerationModel(modelName string) bool {
|
||||
modelName = strings.ToLower(modelName)
|
||||
for _, m := range ImageGenerationModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -16,6 +16,10 @@ import (
|
||||
var RDB *redis.Client
|
||||
var RedisEnabled = true
|
||||
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return SyncFrequency
|
||||
}
|
||||
|
||||
// InitRedisClient This function is called after init()
|
||||
func InitRedisClient() (err error) {
|
||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
# constant 包 (`/constant`)
|
||||
|
||||
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||
|
||||
## 当前文件
|
||||
|
||||
| 文件 | 说明 |
|
||||
|----------------------|---------------------------------------------------------------------|
|
||||
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||
|
||||
## 使用约定
|
||||
|
||||
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||
|
||||
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||
@@ -0,0 +1,34 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipuV4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -1,12 +1,5 @@
|
||||
package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
// 使用函数来避免初始化顺序带来的赋值问题
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return common.SyncFrequency
|
||||
}
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
UserGroupKeyFmt = "user_group:%d"
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
+31
-7
@@ -1,11 +1,35 @@
|
||||
package constant
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
ContextKeyUserSetting = "user_setting"
|
||||
ContextKeyUserQuota = "user_quota"
|
||||
ContextKeyUserStatus = "user_status"
|
||||
ContextKeyUserEmail = "user_email"
|
||||
ContextKeyUserGroup = "user_group"
|
||||
ContextKeyUsingGroup = "group"
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
|
||||
/* token related keys */
|
||||
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||
ContextKeyTokenKey ContextKey = "token_key"
|
||||
ContextKeyTokenId ContextKey = "token_id"
|
||||
ContextKeyTokenGroup ContextKey = "token_group"
|
||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyParamOverride ContextKey = "param_override"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
ContextKeyUserSetting ContextKey = "user_setting"
|
||||
ContextKeyUserQuota ContextKey = "user_quota"
|
||||
ContextKeyUserStatus ContextKey = "user_status"
|
||||
ContextKeyUserEmail ContextKey = "user_email"
|
||||
ContextKeyUserGroup ContextKey = "user_group"
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
type EndpointType string
|
||||
|
||||
const (
|
||||
EndpointTypeOpenAI EndpointType = "openai"
|
||||
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||
EndpointTypeGemini EndpointType = "gemini"
|
||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||
//EndpointTypeKling EndpointType = "kling"
|
||||
//EndpointTypeJimeng EndpointType = "jimeng"
|
||||
)
|
||||
@@ -1,9 +1,5 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
@@ -17,39 +13,3 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
|
||||
//var GeminiModelMap = map[string]string{
|
||||
// "gemini-1.0-pro": "v1",
|
||||
//}
|
||||
|
||||
func InitEnv() {
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
|
||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
//if modelVersionMapStr == "" {
|
||||
// return
|
||||
//}
|
||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||
// parts := strings.Split(pair, ":")
|
||||
// if len(parts) == 2 {
|
||||
// GeminiModelMap[parts[0]] = parts[1]
|
||||
// } else {
|
||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,9 @@ const (
|
||||
const (
|
||||
SunoActionMusic = "MUSIC"
|
||||
SunoActionLyrics = "LYRICS"
|
||||
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
@@ -341,34 +342,34 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeOpenAI:
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
return 0, errors.New("尚未实现")
|
||||
case common.ChannelTypeCustom:
|
||||
case constant.ChannelTypeCustom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
//case common.ChannelTypeOpenAISB:
|
||||
// return updateChannelOpenAISBBalance(channel)
|
||||
case common.ChannelTypeAIProxy:
|
||||
case constant.ChannelTypeAIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case common.ChannelTypeAPI2GPT:
|
||||
case constant.ChannelTypeAPI2GPT:
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case common.ChannelTypeAIGC2D:
|
||||
case constant.ChannelTypeAIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
case common.ChannelTypeSiliconFlow:
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
return updateChannelSiliconFlowBalance(channel)
|
||||
case common.ChannelTypeDeepSeek:
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case common.ChannelTypeMoonshot:
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
|
||||
+10
-10
@@ -11,12 +11,12 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
@@ -31,19 +31,19 @@ import (
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
tik := time.Now()
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
if channel.Type == constant.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported!!!"), nil
|
||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeSunoAPI {
|
||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeKling {
|
||||
if channel.Type == constant.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeJimeng {
|
||||
if channel.Type == constant.ChannelTypeJimeng {
|
||||
return errors.New("jimeng channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
@@ -56,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
@@ -202,7 +202,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
testRequest.MaxTokens = 50
|
||||
}
|
||||
} else if strings.Contains(model, "gemini") {
|
||||
testRequest.MaxTokens = 300
|
||||
testRequest.MaxTokens = 3000
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -125,7 +126,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p-1)*pageSize).Omit("key").Find(&channelData).Error
|
||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
@@ -181,15 +182,15 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
@@ -213,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
var ids []string
|
||||
for _, model := range result.Data {
|
||||
id := model.ID
|
||||
if channel.Type == common.ChannelTypeGemini {
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
id = strings.TrimPrefix(id, "models/")
|
||||
}
|
||||
ids = append(ids, id)
|
||||
@@ -388,7 +389,7 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
channel.CreatedTime = common.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -613,7 +614,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -668,7 +669,7 @@ func FetchModels(c *gin.Context) {
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = common.ChannelBaseURLs[req.Type]
|
||||
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
+52
-95
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -14,10 +15,7 @@ import (
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -26,30 +24,10 @@ var openAIModels []dto.OpenAIModels
|
||||
var openAIModelsMap map[string]dto.OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
func getPermission() []dto.OpenAIModelPermission {
|
||||
var permission []dto.OpenAIModelPermission
|
||||
permission = append(permission, dto.OpenAIModelPermission{
|
||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
||||
Object: "model_permission",
|
||||
Created: 1626777600,
|
||||
AllowCreateEngine: true,
|
||||
AllowSampling: true,
|
||||
AllowLogprobs: true,
|
||||
AllowSearchIndices: false,
|
||||
AllowView: true,
|
||||
AllowFineTuning: false,
|
||||
Organization: "*",
|
||||
Group: nil,
|
||||
IsBlocking: false,
|
||||
})
|
||||
return permission
|
||||
}
|
||||
|
||||
func init() {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
permission := getPermission()
|
||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||
if i == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 0; i < constant.APITypeDummy; i++ {
|
||||
if i == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
@@ -57,69 +35,51 @@ func init() {
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, modelName := range ai360.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range moonshot.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range lingyiwanwu.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range minimax.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||
@@ -127,9 +87,9 @@ func init() {
|
||||
openAIModelsMap[aiModel.Id] = aiModel
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||
apiType, success := common.ChannelType2APIType(i)
|
||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
@@ -144,11 +104,10 @@ func init() {
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
permission := getPermission()
|
||||
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -156,17 +115,16 @@ func ListModels(c *gin.Context) {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: allowModel,
|
||||
Parent: nil,
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -181,14 +139,14 @@ func ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := c.GetString("token_group")
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupModels(autoGroup)
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
@@ -196,20 +154,19 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupModels(group)
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
}
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
for _, modelName := range models {
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: s,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: s,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func Playground(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userId := c.GetInt("id")
|
||||
|
||||
+4
-4
@@ -8,12 +8,12 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
@@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
@@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType == common.ChannelTypeAnthropic {
|
||||
if channelType == constant.ChannelTypeAnthropic {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -51,7 +51,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
+1
-1
@@ -487,7 +487,7 @@ func GetUserModels(c *gin.Context) {
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
var models []string
|
||||
for group := range groups {
|
||||
for _, g := range model.GetGroupModels(group) {
|
||||
for _, g := range model.GetGroupEnabledModels(group) {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
|
||||
+6
-21
@@ -1,26 +1,11 @@
|
||||
package dto
|
||||
|
||||
type OpenAIModelPermission struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group *string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
}
|
||||
import "one-api/constant"
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
TopN int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||
|
||||
@@ -169,10 +169,8 @@ func InitResources() error {
|
||||
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||
}
|
||||
|
||||
// 加载旧的(common)环境变量
|
||||
common.InitCommonEnv()
|
||||
// 加载constants的环境变量
|
||||
constant.InitEnv()
|
||||
// 加载环境变量
|
||||
common.InitEnv()
|
||||
|
||||
// Initialize model settings
|
||||
ratio_setting.InitRatioSettings()
|
||||
@@ -193,6 +191,9 @@ func InitResources() error {
|
||||
// Initialize options, should after model.InitDB()
|
||||
model.InitOptionMap()
|
||||
|
||||
// 初始化模型
|
||||
model.GetPricing()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
|
||||
+16
-16
@@ -25,7 +25,7 @@ type ModelRequest struct {
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
allowIpsMap := c.GetStringMap("allow_ips")
|
||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
||||
if len(allowIpsMap) != 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
@@ -34,14 +34,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
@@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
c.Set(constant.ContextKeyUsingGroup, userGroup)
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -76,9 +76,9 @@ func Distribute() func(c *gin.Context) {
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -121,7 +121,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
@@ -261,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeVertexAi:
|
||||
case constant.ChannelTypeVertexAi:
|
||||
c.Set("region", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
case constant.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
case constant.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeMokaAI:
|
||||
case constant.ChannelTypeMokaAI:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeCoze:
|
||||
case constant.ChannelTypeCoze:
|
||||
c.Set("bot_id", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func KlingRequestConvert() func(c *gin.Context) {
|
||||
@@ -35,7 +37,7 @@ func KlingRequestConvert() func(c *gin.Context) {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||
c.Request.URL.Path = "/v1/video/generations"
|
||||
if image := originalReq["image"]; image == "" {
|
||||
c.Set("action", "textGenerate")
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// We have to reset the request body for the next handlers
|
||||
|
||||
@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
|
||||
// 获取分组
|
||||
group := c.GetString("token_group")
|
||||
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if group == "" {
|
||||
group = c.GetString(constant.ContextKeyUserGroup)
|
||||
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
}
|
||||
|
||||
//获取分组的限流配置
|
||||
|
||||
+20
-5
@@ -21,7 +21,22 @@ type Ability struct {
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
}
|
||||
|
||||
func GetGroupModels(group string) []string {
|
||||
type AbilityWithChannel struct {
|
||||
Ability
|
||||
ChannelType int `json:"channel_type"`
|
||||
}
|
||||
|
||||
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
||||
var abilities []AbilityWithChannel
|
||||
err := DB.Table("abilities").
|
||||
Select("abilities.*, channels.type as channel_type").
|
||||
Joins("left join channels on abilities.channel_id = channels.id").
|
||||
Where("abilities.enabled = ?", true).
|
||||
Scan(&abilities).Error
|
||||
return abilities, err
|
||||
}
|
||||
|
||||
func GetGroupEnabledModels(group string) []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+78
-32
@@ -1,20 +1,24 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Pricing struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups,omitempty"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -23,47 +27,89 @@ var (
|
||||
updatePricingLock sync.Mutex
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
var (
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
modelSupportEndpointsLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
updatePricing()
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
// Double check after acquiring the lock
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
modelSupportEndpointsLock.Lock()
|
||||
defer modelSupportEndpointsLock.Unlock()
|
||||
updatePricing()
|
||||
}
|
||||
}
|
||||
//if group != "" {
|
||||
// userPricingMap := make([]Pricing, 0)
|
||||
// models := GetGroupModels(group)
|
||||
// for _, pricing := range pricingMap {
|
||||
// if !common.StringsContains(models, pricing.ModelName) {
|
||||
// pricing.Available = false
|
||||
// }
|
||||
// userPricingMap = append(userPricingMap, pricing)
|
||||
// }
|
||||
// return userPricingMap
|
||||
//}
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||
if model == "" {
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
modelSupportEndpointsLock.RLock()
|
||||
defer modelSupportEndpointsLock.RUnlock()
|
||||
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||||
return endpoints
|
||||
}
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
|
||||
func updatePricing() {
|
||||
//modelRatios := common.GetModelRatios()
|
||||
enableAbilities := GetAllEnableAbilities()
|
||||
modelGroupsMap := make(map[string][]string)
|
||||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||
return
|
||||
}
|
||||
modelGroupsMap := make(map[string]*types.Set[string])
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
groups := modelGroupsMap[ability.Model]
|
||||
if groups == nil {
|
||||
groups = make([]string, 0)
|
||||
groups, ok := modelGroupsMap[ability.Model]
|
||||
if !ok {
|
||||
groups = types.NewSet[string]()
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
}
|
||||
if !common.StringsContains(groups, ability.Group) {
|
||||
groups = append(groups, ability.Group)
|
||||
groups.Add(ability.Group)
|
||||
}
|
||||
|
||||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||
modelSupportEndpointsStr := make(map[string][]string)
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||||
if !ok {
|
||||
endpoints = make([]string, 0)
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||
for _, channelType := range channelTypes {
|
||||
if !common.StringsContains(endpoints, string(channelType)) {
|
||||
endpoints = append(endpoints, string(channelType))
|
||||
}
|
||||
}
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
for model, endpoints := range modelSupportEndpointsStr {
|
||||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||
for _, endpointStr := range endpoints {
|
||||
endpointType := constant.EndpointType(endpointStr)
|
||||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||||
}
|
||||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||
}
|
||||
|
||||
pricingMap = make([]Pricing, 0)
|
||||
for model, groups := range modelGroupsMap {
|
||||
pricing := Pricing{
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
ModelName: model,
|
||||
EnableGroup: groups.Items(),
|
||||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||
}
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+7
-7
@@ -24,12 +24,12 @@ type UserBase struct {
|
||||
}
|
||||
|
||||
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
c.Set(constant.ContextKeyUserGroup, user.Group)
|
||||
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
||||
c.Set(constant.ContextKeyUserStatus, user.Status)
|
||||
c.Set(constant.ContextKeyUserEmail, user.Email)
|
||||
c.Set("username", user.Username)
|
||||
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
|
||||
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
|
||||
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
|
||||
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
|
||||
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
|
||||
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
|
||||
}
|
||||
|
||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
|
||||
return common.RedisHSetObj(
|
||||
getUserCacheKey(user.Id),
|
||||
user.ToBaseUser(),
|
||||
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
|
||||
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
@@ -82,7 +82,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return embeddingRequestOpenAI2Ali(request), nil
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
|
||||
@@ -39,31 +39,18 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var aliResponse AliEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
var fullTextResponse dto.OpenAIEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
model := c.GetString("model")
|
||||
if model == "" {
|
||||
model = "text-embedding-v4"
|
||||
}
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package jina
|
||||
var ModelList = []string{
|
||||
"jina-clip-v1",
|
||||
"jina-reranker-v2-base-multilingual",
|
||||
"jina-reranker-m0",
|
||||
}
|
||||
|
||||
var ChannelName = "jina"
|
||||
|
||||
@@ -9,8 +9,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
@@ -21,7 +20,7 @@ import (
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/common_handler"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
|
||||
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
||||
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
@@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
baseUrl = "wss://" + baseUrl
|
||||
@@ -79,10 +78,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
switch info.ChannelType {
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
apiVersion := info.ApiVersion
|
||||
if apiVersion == "" {
|
||||
apiVersion = constant2.AzureDefaultAPIVersion
|
||||
apiVersion = constant.AzureDefaultAPIVersion
|
||||
}
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||
@@ -90,25 +89,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
|
||||
// 特殊处理 responses API
|
||||
if info.RelayMode == constant.RelayModeResponses {
|
||||
if info.RelayMode == relayconstant.RelayModeResponses {
|
||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
|
||||
model_ := info.UpstreamModelName
|
||||
// 2025年5月10日后创建的渠道不移除.
|
||||
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
|
||||
if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
}
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
case common.ChannelTypeMiniMax:
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.GetRequestURL(info)
|
||||
case common.ChannelTypeCustom:
|
||||
case constant.ChannelTypeCustom:
|
||||
url := info.BaseUrl
|
||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||
return url, nil
|
||||
@@ -119,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, header)
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
if info.ChannelType == constant.ChannelTypeAzure {
|
||||
header.Set("api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||||
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
|
||||
header.Set("OpenAI-Organization", info.Organization)
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||
if swp != "" {
|
||||
items := []string{
|
||||
@@ -145,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
||||
} else {
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||
header.Set("X-Title", "New API")
|
||||
}
|
||||
@@ -156,10 +155,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
||||
if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
if len(request.Usage) == 0 {
|
||||
request.Usage = json.RawMessage(`{"include":true}`)
|
||||
}
|
||||
@@ -205,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
a.ResponseFormat = request.ResponseFormat
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
if info.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling object: %w", err)
|
||||
@@ -254,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesEdits:
|
||||
case relayconstant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
@@ -411,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription ||
|
||||
info.RelayMode == constant.RelayModeAudioTranslation ||
|
||||
info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
|
||||
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
|
||||
info.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
return channel.DoFormRequest(a, c, info, requestBody)
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
} else if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
return channel.DoWssRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
@@ -424,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRealtime:
|
||||
case relayconstant.RelayModeRealtime:
|
||||
err, usage = OpenaiRealtimeHandler(c, info)
|
||||
case constant.RelayModeAudioSpeech:
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeAudioTranslation:
|
||||
case relayconstant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
case relayconstant.RelayModeRerank:
|
||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||
case constant.RelayModeResponses:
|
||||
case relayconstant.RelayModeResponses:
|
||||
if info.IsStream {
|
||||
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
||||
} else {
|
||||
@@ -454,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
switch a.ChannelType {
|
||||
case common.ChannelType360:
|
||||
case constant.ChannelType360:
|
||||
return ai360.ModelList
|
||||
case common.ChannelTypeMoonshot:
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ModelList
|
||||
case common.ChannelTypeLingYiWanWu:
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ModelList
|
||||
case common.ChannelTypeMiniMax:
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.ModelList
|
||||
case common.ChannelTypeXinference:
|
||||
case constant.ChannelTypeXinference:
|
||||
return xinference.ModelList
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return openrouter.ModelList
|
||||
default:
|
||||
return ModelList
|
||||
@@ -473,17 +472,17 @@ func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
switch a.ChannelType {
|
||||
case common.ChannelType360:
|
||||
case constant.ChannelType360:
|
||||
return ai360.ChannelName
|
||||
case common.ChannelTypeMoonshot:
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ChannelName
|
||||
case common.ChannelTypeLingYiWanWu:
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ChannelName
|
||||
case common.ChannelTypeMiniMax:
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.ChannelName
|
||||
case common.ChannelTypeXinference:
|
||||
case constant.ChannelTypeXinference:
|
||||
return xinference.ChannelName
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return openrouter.ChannelName
|
||||
default:
|
||||
return ChannelName
|
||||
|
||||
@@ -168,7 +168,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||
if info.ChannelType == constant.ChannelTypeDeepSeek {
|
||||
if usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -88,7 +89,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -92,7 +93,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
@@ -112,7 +113,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
@@ -198,7 +199,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid action")
|
||||
}
|
||||
path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
|
||||
+32
-32
@@ -113,17 +113,17 @@ type RelayInfo struct {
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
var streamSupportedChannels = map[int]bool{
|
||||
common.ChannelTypeOpenAI: true,
|
||||
common.ChannelTypeAnthropic: true,
|
||||
common.ChannelTypeAws: true,
|
||||
common.ChannelTypeGemini: true,
|
||||
common.ChannelCloudflare: true,
|
||||
common.ChannelTypeAzure: true,
|
||||
common.ChannelTypeVolcEngine: true,
|
||||
common.ChannelTypeOllama: true,
|
||||
common.ChannelTypeXai: true,
|
||||
common.ChannelTypeDeepSeek: true,
|
||||
common.ChannelTypeBaiduV2: true,
|
||||
constant.ChannelTypeOpenAI: true,
|
||||
constant.ChannelTypeAnthropic: true,
|
||||
constant.ChannelTypeAws: true,
|
||||
constant.ChannelTypeGemini: true,
|
||||
constant.ChannelCloudflare: true,
|
||||
constant.ChannelTypeAzure: true,
|
||||
constant.ChannelTypeVolcEngine: true,
|
||||
constant.ChannelTypeOllama: true,
|
||||
constant.ChannelTypeXai: true,
|
||||
constant.ChannelTypeDeepSeek: true,
|
||||
constant.ChannelTypeBaiduV2: true,
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
@@ -211,40 +211,40 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelSetting := c.GetStringMap("channel_setting")
|
||||
paramOverride := c.GetStringMap("param_override")
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
tokenKey := c.GetString("token_key")
|
||||
userId := c.GetInt("id")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
||||
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
||||
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
||||
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
||||
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
||||
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||
UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
UsingGroup: c.GetString(constant.ContextKeyUsingGroup),
|
||||
UserGroup: c.GetString(constant.ContextKeyUserGroup),
|
||||
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
OriginModelName: c.GetString("original_model"),
|
||||
UpstreamModelName: c.GetString("original_model"),
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
//RecodeModelName: c.GetString("original_model"),
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
@@ -266,12 +266,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
info.RequestURLPath = "/v1" + info.RequestURLPath
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
info.BaseUrl = constant.ChannelBaseURLs[channelType]
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
if info.ChannelType == constant.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeVertexAi {
|
||||
if info.ChannelType == constant.ChannelTypeVertexAi {
|
||||
info.ApiVersion = c.GetString("region")
|
||||
}
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
case constant.ChannelTypeOpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -21,7 +22,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
println("reranker response body: ", string(responseBody))
|
||||
}
|
||||
var jinaResp dto.RerankResponse
|
||||
if info.ChannelType == common.ChannelTypeXinference {
|
||||
if info.ChannelType == constant.ChannelTypeXinference {
|
||||
var xinRerankResponse xinference.XinRerankResponse
|
||||
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipuV4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
apiType = APITypeOpenAI
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeAnthropic
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
case common.ChannelTypeZhipu_v4:
|
||||
apiType = APITypeZhipuV4
|
||||
case common.ChannelTypeOllama:
|
||||
apiType = APITypeOllama
|
||||
case common.ChannelTypePerplexity:
|
||||
apiType = APITypePerplexity
|
||||
case common.ChannelTypeAws:
|
||||
apiType = APITypeAws
|
||||
case common.ChannelTypeCohere:
|
||||
apiType = APITypeCohere
|
||||
case common.ChannelTypeDify:
|
||||
apiType = APITypeDify
|
||||
case common.ChannelTypeJina:
|
||||
apiType = APITypeJina
|
||||
case common.ChannelCloudflare:
|
||||
apiType = APITypeCloudflare
|
||||
case common.ChannelTypeSiliconFlow:
|
||||
apiType = APITypeSiliconFlow
|
||||
case common.ChannelTypeVertexAi:
|
||||
apiType = APITypeVertexAi
|
||||
case common.ChannelTypeMistral:
|
||||
apiType = APITypeMistral
|
||||
case common.ChannelTypeDeepSeek:
|
||||
apiType = APITypeDeepSeek
|
||||
case common.ChannelTypeMokaAI:
|
||||
apiType = APITypeMokaAI
|
||||
case common.ChannelTypeVolcEngine:
|
||||
apiType = APITypeVolcEngine
|
||||
case common.ChannelTypeBaiduV2:
|
||||
apiType = APITypeBaiduV2
|
||||
case common.ChannelTypeOpenRouter:
|
||||
apiType = APITypeOpenRouter
|
||||
case common.ChannelTypeXinference:
|
||||
apiType = APITypeXinference
|
||||
case common.ChannelTypeXai:
|
||||
apiType = APITypeXai
|
||||
case common.ChannelTypeCoze:
|
||||
apiType = APITypeCoze
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -17,8 +18,6 @@ import (
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"one-api/relay/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"one-api/constant"
|
||||
commonconstant "one-api/constant"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
@@ -32,7 +33,6 @@ import (
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
|
||||
@@ -78,12 +78,15 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
|
||||
}
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
|
||||
+2
-1
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/operation_setting"
|
||||
@@ -48,7 +49,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
|
||||
}
|
||||
if err.StatusCode == http.StatusForbidden {
|
||||
switch channelType {
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
+2
-1
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -19,7 +20,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
|
||||
Stream: claudeRequest.Stream,
|
||||
}
|
||||
|
||||
isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
|
||||
isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
|
||||
|
||||
if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
|
||||
if isOpenRouter {
|
||||
|
||||
+3
-3
@@ -6,7 +6,7 @@ import (
|
||||
"log"
|
||||
"math"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -232,7 +232,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
cacheCreationRatio := priceData.CacheCreationRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
if relayInfo.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
|
||||
@@ -447,7 +447,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
||||
gopool.Go(func() {
|
||||
userSetting := relayInfo.UserSetting
|
||||
threshold := common.QuotaRemindThreshold
|
||||
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
|
||||
if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
|
||||
threshold = int(userCustomThreshold.(float64))
|
||||
}
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
|
||||
if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
var config image.Config
|
||||
@@ -172,9 +172,6 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
||||
}
|
||||
}
|
||||
toolTokens := CountTokenInput(countStr, request.Model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += 8
|
||||
tkm += toolTokens
|
||||
}
|
||||
@@ -195,9 +192,6 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += systemTokens
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package types
|
||||
|
||||
type Set[T comparable] struct {
|
||||
items map[T]struct{}
|
||||
}
|
||||
|
||||
// NewSet 创建并返回一个新的 Set
|
||||
func NewSet[T comparable]() *Set[T] {
|
||||
return &Set[T]{
|
||||
items: make(map[T]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set[T]) Add(item T) {
|
||||
s.items[item] = struct{}{}
|
||||
}
|
||||
|
||||
// Remove 从 Set 中移除一个元素
|
||||
func (s *Set[T]) Remove(item T) {
|
||||
delete(s.items, item)
|
||||
}
|
||||
|
||||
// Contains 检查 Set 是否包含某个元素
|
||||
func (s *Set[T]) Contains(item T) bool {
|
||||
_, exists := s.items[item]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Len 返回 Set 中元素的数量
|
||||
func (s *Set[T]) Len() int {
|
||||
return len(s.items)
|
||||
}
|
||||
|
||||
// Items 返回 Set 中所有元素组成的切片
|
||||
// 注意:由于 map 的无序性,返回的切片元素顺序是随机的
|
||||
func (s *Set[T]) Items() []T {
|
||||
items := make([]T, 0, s.Len())
|
||||
for item := range s.items {
|
||||
items = append(items, item)
|
||||
}
|
||||
return items
|
||||
}
|
||||
@@ -197,7 +197,6 @@ const ChannelSelectorModal = forwardRef(({
|
||||
value={searchText}
|
||||
onChange={setSearchText}
|
||||
showClear
|
||||
className="!rounded-full"
|
||||
/>
|
||||
|
||||
<Table
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useContext, useEffect, useRef, useMemo, useState } from 'react';
|
||||
import { API, copy, showError, showInfo, showSuccess, getModelCategories, renderModelTag } from '../../helpers';
|
||||
import { API, copy, showError, showInfo, showSuccess, getModelCategories, renderModelTag, stringToColor } from '../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import {
|
||||
@@ -106,6 +106,26 @@ const ModelPricing = () => {
|
||||
) : null;
|
||||
}
|
||||
|
||||
function renderSupportedEndpoints(endpoints) {
|
||||
if (!endpoints || endpoints.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<Space wrap>
|
||||
{endpoints.map((endpoint, idx) => (
|
||||
<Tag
|
||||
key={endpoint}
|
||||
color={stringToColor(endpoint)}
|
||||
size='large'
|
||||
shape='circle'
|
||||
>
|
||||
{endpoint}
|
||||
</Tag>
|
||||
))}
|
||||
</Space>
|
||||
);
|
||||
}
|
||||
|
||||
const columns = [
|
||||
{
|
||||
title: t('可用性'),
|
||||
@@ -120,6 +140,13 @@ const ModelPricing = () => {
|
||||
},
|
||||
defaultSortOrder: 'descend',
|
||||
},
|
||||
{
|
||||
title: t('可用端点类型'),
|
||||
dataIndex: 'supported_endpoint_types',
|
||||
render: (text, record, index) => {
|
||||
return renderSupportedEndpoints(text);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('模型名称'),
|
||||
dataIndex: 'model_name',
|
||||
@@ -499,7 +526,7 @@ const ModelPricing = () => {
|
||||
<div className="flex items-center">
|
||||
<AlertCircle size={14} className="mr-1.5 flex-shrink-0" />
|
||||
<span className="truncate">
|
||||
{t('未登录,使用默认分组倍率')}: {groupRatio['default']}
|
||||
{t('未登录,使用默认分组倍率:')}{groupRatio['default']}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -49,6 +49,7 @@ import {
|
||||
IconSearch,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { useTableCompactMode } from '../../hooks/useTableCompactMode';
|
||||
import { TASK_ACTION_GENERATE, TASK_ACTION_TEXT_GENERATE } from '../../constants/common.constant';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
@@ -207,13 +208,13 @@ const LogsTable = () => {
|
||||
{t('生成歌词')}
|
||||
</Tag>
|
||||
);
|
||||
case 'generate':
|
||||
case TASK_ACTION_GENERATE:
|
||||
return (
|
||||
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
|
||||
{t('图生视频')}
|
||||
</Tag>
|
||||
);
|
||||
case 'textGenerate':
|
||||
case TASK_ACTION_TEXT_GENERATE:
|
||||
return (
|
||||
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
|
||||
{t('文生视频')}
|
||||
@@ -444,7 +445,7 @@ const LogsTable = () => {
|
||||
fixed: 'right',
|
||||
render: (text, record, index) => {
|
||||
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
|
||||
const isVideoTask = record.action === 'generate' || record.action === 'textGenerate';
|
||||
const isVideoTask = record.action === TASK_ACTION_GENERATE || record.action === TASK_ACTION_TEXT_GENERATE;
|
||||
const isSuccess = record.status === 'SUCCESS';
|
||||
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
|
||||
if (isSuccess && isVideoTask && isUrl) {
|
||||
|
||||
@@ -119,7 +119,7 @@ const UsersTable = () => {
|
||||
<Tooltip content={remark} position="top" showArrow>
|
||||
<Tag color='white' size='large' shape='circle' className="!text-xs">
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 flex-shrink-0" style={{ backgroundColor: '#10b981' }} />
|
||||
<div className="w-2 h-2 flex-shrink-0 rounded-full" style={{ backgroundColor: '#10b981' }} />
|
||||
{displayRemark}
|
||||
</div>
|
||||
</Tag>
|
||||
|
||||
@@ -17,4 +17,7 @@ export const API_ENDPOINTS = [
|
||||
'/v1/audio/speech',
|
||||
'/v1/audio/transcriptions',
|
||||
'/v1/audio/translations'
|
||||
];
|
||||
];
|
||||
|
||||
export const TASK_ACTION_GENERATE = 'generate';
|
||||
export const TASK_ACTION_TEXT_GENERATE = 'textGenerate';
|
||||
@@ -876,7 +876,7 @@
|
||||
"加载token失败": "Failed to load token",
|
||||
"配置聊天": "Configure chat",
|
||||
"模型消耗分布": "Model consumption distribution",
|
||||
"模型调用次数占比": "Proportion of model calls",
|
||||
"模型调用次数占比": "Model call ratio",
|
||||
"用户消耗分布": "User consumption distribution",
|
||||
"时间粒度": "Time granularity",
|
||||
"天": "day",
|
||||
@@ -1119,6 +1119,10 @@
|
||||
"平均TPM": "Average TPM",
|
||||
"消耗分布": "Consumption distribution",
|
||||
"调用次数分布": "Models call distribution",
|
||||
"消耗趋势": "Consumption trend",
|
||||
"模型消耗趋势": "Model consumption trend",
|
||||
"调用次数排行": "Models call ranking",
|
||||
"模型调用次数排行": "Model call ranking",
|
||||
"添加渠道": "Add channel",
|
||||
"测试所有通道": "Test all channels",
|
||||
"删除禁用通道": "Delete disabled channels",
|
||||
@@ -1143,8 +1147,8 @@
|
||||
"默认测试模型": "Default Test Model",
|
||||
"不填则为模型列表第一个": "First model in list if empty",
|
||||
"是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道": "Auto-disable (only effective when auto-disable is enabled). When turned off, this channel will not be automatically disabled",
|
||||
"状态码复写(仅影响本地判断,不修改返回到上游的状态码)": "Status Code Override (only affects local judgment, does not modify status code returned upstream)",
|
||||
"此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "Optional, used to override returned status codes, e.g. rewriting Claude channel's 400 error to 500 (for retry). Do not abuse this feature. Example:",
|
||||
"状态码复写": "Status Code Override",
|
||||
"此项可选,用于复写返回的状态码,仅影响本地判断,不修改返回到上游的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "Optional, used to override returned status codes, only affects local judgment, does not modify status code returned upstream, e.g. rewriting Claude channel's 400 error to 500 (for retry). Do not abuse this feature. Example:",
|
||||
"渠道标签": "Channel Tag",
|
||||
"渠道优先级": "Channel Priority",
|
||||
"渠道权重": "Channel Weight",
|
||||
@@ -1199,7 +1203,7 @@
|
||||
"添加用户": "Add user",
|
||||
"角色": "Role",
|
||||
"已绑定的 Telegram 账户": "Bound Telegram account",
|
||||
"新额度": "New quota",
|
||||
"新额度:": "New quota: ",
|
||||
"需要添加的额度(支持负数)": "Need to add quota (supports negative numbers)",
|
||||
"此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "Read-only, user's personal settings, and cannot be modified directly",
|
||||
"请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characterss",
|
||||
@@ -1750,5 +1754,7 @@
|
||||
"批量创建时会在名称后自动添加随机后缀": "When creating in batches, a random suffix will be automatically added to the name",
|
||||
"额度必须大于0": "Quota must be greater than 0",
|
||||
"生成数量必须大于0": "Generation quantity must be greater than 0",
|
||||
"创建后可在编辑渠道时获取上游模型列表": "After creation, you can get the upstream model list when editing the channel"
|
||||
"创建后可在编辑渠道时获取上游模型列表": "After creation, you can get the upstream model list when editing the channel",
|
||||
"可用端点类型": "Supported endpoint types",
|
||||
"未登录,使用默认分组倍率:": "Not logged in, using default group ratio: "
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import {
|
||||
API,
|
||||
showError,
|
||||
@@ -11,15 +11,13 @@ import {
|
||||
SideSheet,
|
||||
Space,
|
||||
Button,
|
||||
Input,
|
||||
Typography,
|
||||
Spin,
|
||||
Select,
|
||||
Banner,
|
||||
TextArea,
|
||||
Card,
|
||||
Tag,
|
||||
Avatar,
|
||||
Form,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconSave,
|
||||
@@ -53,9 +51,14 @@ const EditTagModal = (props) => {
|
||||
models: [],
|
||||
};
|
||||
const [inputs, setInputs] = useState(originInputs);
|
||||
const formApiRef = useRef(null);
|
||||
const getInitValues = () => ({ ...originInputs });
|
||||
|
||||
const handleInputChange = (name, value) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue(name, value);
|
||||
}
|
||||
if (name === 'type') {
|
||||
let localModels = [];
|
||||
switch (value) {
|
||||
@@ -133,27 +136,25 @@ const EditTagModal = (props) => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
const handleSave = async (values) => {
|
||||
setLoading(true);
|
||||
let data = {
|
||||
tag: tag,
|
||||
};
|
||||
if (inputs.model_mapping !== null && inputs.model_mapping !== '') {
|
||||
if (inputs.model_mapping !== '' && !verifyJSON(inputs.model_mapping)) {
|
||||
const formVals = values || formApiRef.current?.getValues() || {};
|
||||
let data = { tag };
|
||||
if (formVals.model_mapping) {
|
||||
if (!verifyJSON(formVals.model_mapping)) {
|
||||
showInfo('模型映射必须是合法的 JSON 格式!');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
data.model_mapping = inputs.model_mapping;
|
||||
data.model_mapping = formVals.model_mapping;
|
||||
}
|
||||
if (inputs.groups.length > 0) {
|
||||
data.groups = inputs.groups.join(',');
|
||||
if (formVals.groups && formVals.groups.length > 0) {
|
||||
data.groups = formVals.groups.join(',');
|
||||
}
|
||||
if (inputs.models.length > 0) {
|
||||
data.models = inputs.models.join(',');
|
||||
if (formVals.models && formVals.models.length > 0) {
|
||||
data.models = formVals.models.join(',');
|
||||
}
|
||||
data.new_tag = inputs.new_tag;
|
||||
// check have any change
|
||||
data.new_tag = formVals.new_tag;
|
||||
if (
|
||||
data.model_mapping === undefined &&
|
||||
data.groups === undefined &&
|
||||
@@ -202,7 +203,7 @@ const EditTagModal = (props) => {
|
||||
const res = await API.get(`/api/channel/tag/models?tag=${tag}`);
|
||||
if (res?.data?.success) {
|
||||
const models = res.data.data ? res.data.data.split(',') : [];
|
||||
setInputs((inputs) => ({ ...inputs, models: models }));
|
||||
handleInputChange('models', models);
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
@@ -213,19 +214,32 @@ const EditTagModal = (props) => {
|
||||
}
|
||||
};
|
||||
|
||||
fetchModels().then();
|
||||
fetchGroups().then();
|
||||
fetchTagModels().then();
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues({
|
||||
...getInitValues(),
|
||||
tag: tag,
|
||||
new_tag: tag,
|
||||
});
|
||||
}
|
||||
|
||||
setInputs({
|
||||
...originInputs,
|
||||
tag: tag,
|
||||
new_tag: tag,
|
||||
});
|
||||
fetchModels().then();
|
||||
fetchGroups().then();
|
||||
fetchTagModels().then(); // Call the new function
|
||||
}, [visible, tag]); // Add tag to dependency array
|
||||
}, [visible, tag]);
|
||||
|
||||
useEffect(() => {
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues(inputs);
|
||||
}
|
||||
}, [inputs]);
|
||||
|
||||
const addCustomModels = () => {
|
||||
if (customModel.trim() === '') return;
|
||||
// 使用逗号分隔字符串,然后去除每个模型名称前后的空格
|
||||
const modelArray = customModel.split(',').map((model) => model.trim());
|
||||
|
||||
let localModels = [...inputs.models];
|
||||
@@ -233,11 +247,9 @@ const EditTagModal = (props) => {
|
||||
const addedModels = [];
|
||||
|
||||
modelArray.forEach((model) => {
|
||||
// 检查模型是否已存在,且模型名称非空
|
||||
if (model && !localModels.includes(model)) {
|
||||
localModels.push(model); // 添加到模型列表
|
||||
localModels.push(model);
|
||||
localModelOptions.push({
|
||||
// 添加到下拉选项
|
||||
key: model,
|
||||
text: model,
|
||||
value: model,
|
||||
@@ -246,7 +258,6 @@ const EditTagModal = (props) => {
|
||||
}
|
||||
});
|
||||
|
||||
// 更新状态值
|
||||
setModelOptions(localModelOptions);
|
||||
setCustomModel('');
|
||||
handleInputChange('models', localModels);
|
||||
@@ -283,7 +294,7 @@ const EditTagModal = (props) => {
|
||||
<Space>
|
||||
<Button
|
||||
theme="solid"
|
||||
onClick={handleSave}
|
||||
onClick={() => formApiRef.current?.submitForm()}
|
||||
loading={loading}
|
||||
icon={<IconSave />}
|
||||
>
|
||||
@@ -302,146 +313,128 @@ const EditTagModal = (props) => {
|
||||
}
|
||||
closeIcon={null}
|
||||
>
|
||||
<Spin spinning={loading}>
|
||||
<div className="p-2">
|
||||
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
|
||||
{/* Header: Tag Info */}
|
||||
<div className="flex items-center mb-2">
|
||||
<Avatar size="small" color="blue" className="mr-2 shadow-md">
|
||||
<IconBookmark size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className="text-lg font-medium">{t('标签信息')}</Text>
|
||||
<div className="text-xs text-gray-600">{t('标签的基本配置')}</div>
|
||||
</div>
|
||||
</div>
|
||||
<Form
|
||||
key={tag || 'edit'}
|
||||
initValues={getInitValues()}
|
||||
getFormApi={(api) => (formApiRef.current = api)}
|
||||
onSubmit={handleSave}
|
||||
>
|
||||
{() => (
|
||||
<Spin spinning={loading}>
|
||||
<div className="p-2">
|
||||
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
|
||||
{/* Header: Tag Info */}
|
||||
<div className="flex items-center mb-2">
|
||||
<Avatar size="small" color="blue" className="mr-2 shadow-md">
|
||||
<IconBookmark size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className="text-lg font-medium">{t('标签信息')}</Text>
|
||||
<div className="text-xs text-gray-600">{t('标签的基本配置')}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Banner
|
||||
type="warning"
|
||||
description={t('所有编辑均为覆盖操作,留空则不更改')}
|
||||
className="!rounded-lg mb-4"
|
||||
/>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Text strong className="block mb-2">{t('标签名称')}</Text>
|
||||
<Input
|
||||
value={inputs.new_tag}
|
||||
onChange={(value) => setInputs({ ...inputs, new_tag: value })}
|
||||
placeholder={t('请输入新标签,留空则解散标签')}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
|
||||
{/* Header: Model Config */}
|
||||
<div className="flex items-center mb-2">
|
||||
<Avatar size="small" color="purple" className="mr-2 shadow-md">
|
||||
<IconCode size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className="text-lg font-medium">{t('模型配置')}</Text>
|
||||
<div className="text-xs text-gray-600">{t('模型选择和映射设置')}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Text strong className="block mb-2">{t('模型')}</Text>
|
||||
<Banner
|
||||
type="info"
|
||||
description={t('当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。')}
|
||||
type="warning"
|
||||
description={t('所有编辑均为覆盖操作,留空则不更改')}
|
||||
className="!rounded-lg mb-4"
|
||||
/>
|
||||
<Select
|
||||
placeholder={t('请选择该渠道所支持的模型,留空则不更改')}
|
||||
name='models'
|
||||
multiple
|
||||
filter
|
||||
searchPosition='dropdown'
|
||||
onChange={(value) => handleInputChange('models', value)}
|
||||
value={inputs.models}
|
||||
optionList={modelOptions}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Input
|
||||
addonAfter={
|
||||
<Button type='primary' onClick={addCustomModels} className="!rounded-r-lg">
|
||||
{t('填入')}
|
||||
</Button>
|
||||
}
|
||||
placeholder={t('输入自定义模型名称')}
|
||||
value={customModel}
|
||||
onChange={(value) => setCustomModel(value.trim())}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
<Form.Input
|
||||
field='new_tag'
|
||||
label={t('标签名称')}
|
||||
placeholder={t('请输入新标签,留空则解散标签')}
|
||||
onChange={(value) => handleInputChange('new_tag', value)}
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<div>
|
||||
<Text strong className="block mb-2">{t('模型重定向')}</Text>
|
||||
<TextArea
|
||||
placeholder={t('此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改')}
|
||||
name='model_mapping'
|
||||
onChange={(value) => handleInputChange('model_mapping', value)}
|
||||
autosize
|
||||
value={inputs.model_mapping}
|
||||
/>
|
||||
<Space className="mt-2">
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => handleInputChange('model_mapping', JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2))}
|
||||
>
|
||||
{t('填入模板')}
|
||||
</Text>
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => handleInputChange('model_mapping', JSON.stringify({}, null, 2))}
|
||||
>
|
||||
{t('清空重定向')}
|
||||
</Text>
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => handleInputChange('model_mapping', '')}
|
||||
>
|
||||
{t('不更改')}
|
||||
</Text>
|
||||
</Space>
|
||||
</div>
|
||||
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
|
||||
{/* Header: Model Config */}
|
||||
<div className="flex items-center mb-2">
|
||||
<Avatar size="small" color="purple" className="mr-2 shadow-md">
|
||||
<IconCode size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className="text-lg font-medium">{t('模型配置')}</Text>
|
||||
<div className="text-xs text-gray-600">{t('模型选择和映射设置')}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<Banner
|
||||
type="info"
|
||||
description={t('当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。')}
|
||||
className="!rounded-lg mb-4"
|
||||
/>
|
||||
<Form.Select
|
||||
field='models'
|
||||
label={t('模型')}
|
||||
placeholder={t('请选择该渠道所支持的模型,留空则不更改')}
|
||||
multiple
|
||||
filter
|
||||
searchPosition='dropdown'
|
||||
optionList={modelOptions}
|
||||
style={{ width: '100%' }}
|
||||
onChange={(value) => handleInputChange('models', value)}
|
||||
/>
|
||||
|
||||
<Form.Input
|
||||
field='custom_model'
|
||||
label={t('自定义模型名称')}
|
||||
placeholder={t('输入自定义模型名称')}
|
||||
onChange={(value) => setCustomModel(value.trim())}
|
||||
suffix={<Button size='small' type='primary' onClick={addCustomModels}>{t('填入')}</Button>}
|
||||
/>
|
||||
|
||||
<Form.TextArea
|
||||
field='model_mapping'
|
||||
label={t('模型重定向')}
|
||||
placeholder={t('此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改')}
|
||||
autosize
|
||||
onChange={(value) => handleInputChange('model_mapping', value)}
|
||||
extraText={(
|
||||
<Space>
|
||||
<Text className="!text-semi-color-primary cursor-pointer" onClick={() => handleInputChange('model_mapping', JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2))}>{t('填入模板')}</Text>
|
||||
<Text className="!text-semi-color-primary cursor-pointer" onClick={() => handleInputChange('model_mapping', JSON.stringify({}, null, 2))}>{t('清空重定向')}</Text>
|
||||
<Text className="!text-semi-color-primary cursor-pointer" onClick={() => handleInputChange('model_mapping', '')}>{t('不更改')}</Text>
|
||||
</Space>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="!rounded-2xl shadow-sm border-0">
|
||||
{/* Header: Group Settings */}
|
||||
<div className="flex items-center mb-2">
|
||||
<Avatar size="small" color="green" className="mr-2 shadow-md">
|
||||
<IconUser size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className="text-lg font-medium">{t('分组设置')}</Text>
|
||||
<div className="text-xs text-gray-600">{t('用户分组配置')}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<Form.Select
|
||||
field='groups'
|
||||
label={t('分组')}
|
||||
placeholder={t('请选择可以使用该渠道的分组,留空则不更改')}
|
||||
multiple
|
||||
allowAdditions
|
||||
additionLabel={t('请在系统设置页面编辑分组倍率以添加新的分组:')}
|
||||
optionList={groupOptions}
|
||||
style={{ width: '100%' }}
|
||||
onChange={(value) => handleInputChange('groups', value)}
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="!rounded-2xl shadow-sm border-0">
|
||||
{/* Header: Group Settings */}
|
||||
<div className="flex items-center mb-2">
|
||||
<Avatar size="small" color="green" className="mr-2 shadow-md">
|
||||
<IconUser size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className="text-lg font-medium">{t('分组设置')}</Text>
|
||||
<div className="text-xs text-gray-600">{t('用户分组配置')}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Text strong className="block mb-2">{t('分组')}</Text>
|
||||
<Select
|
||||
placeholder={t('请选择可以使用该渠道的分组,留空则不更改')}
|
||||
name='groups'
|
||||
multiple
|
||||
allowAdditions
|
||||
additionLabel={t('请在系统设置页面编辑分组倍率以添加新的分组:')}
|
||||
onChange={(value) => handleInputChange('groups', value)}
|
||||
value={inputs.groups}
|
||||
optionList={groupOptions}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</Spin>
|
||||
</Spin>
|
||||
)}
|
||||
</Form>
|
||||
</SideSheet>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -366,6 +366,86 @@ const Detail = (props) => {
|
||||
},
|
||||
});
|
||||
|
||||
// 模型消耗趋势折线图
|
||||
const [spec_model_line, setSpecModelLine] = useState({
|
||||
type: 'line',
|
||||
data: [
|
||||
{
|
||||
id: 'lineData',
|
||||
values: [],
|
||||
},
|
||||
],
|
||||
xField: 'Time',
|
||||
yField: 'Count',
|
||||
seriesField: 'Model',
|
||||
legends: {
|
||||
visible: true,
|
||||
selectMode: 'single',
|
||||
},
|
||||
title: {
|
||||
visible: true,
|
||||
text: t('模型消耗趋势'),
|
||||
subtext: '',
|
||||
},
|
||||
tooltip: {
|
||||
mark: {
|
||||
content: [
|
||||
{
|
||||
key: (datum) => datum['Model'],
|
||||
value: (datum) => renderNumber(datum['Count']),
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
color: {
|
||||
specified: modelColorMap,
|
||||
},
|
||||
});
|
||||
|
||||
// 模型调用次数排行柱状图
|
||||
const [spec_rank_bar, setSpecRankBar] = useState({
|
||||
type: 'bar',
|
||||
data: [
|
||||
{
|
||||
id: 'rankData',
|
||||
values: [],
|
||||
},
|
||||
],
|
||||
xField: 'Model',
|
||||
yField: 'Count',
|
||||
seriesField: 'Model',
|
||||
legends: {
|
||||
visible: true,
|
||||
selectMode: 'single',
|
||||
},
|
||||
title: {
|
||||
visible: true,
|
||||
text: t('模型调用次数排行'),
|
||||
subtext: '',
|
||||
},
|
||||
bar: {
|
||||
state: {
|
||||
hover: {
|
||||
stroke: '#000',
|
||||
lineWidth: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
tooltip: {
|
||||
mark: {
|
||||
content: [
|
||||
{
|
||||
key: (datum) => datum['Model'],
|
||||
value: (datum) => renderNumber(datum['Count']),
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
color: {
|
||||
specified: modelColorMap,
|
||||
},
|
||||
});
|
||||
|
||||
// ========== Hooks - Memoized Values ==========
|
||||
const performanceMetrics = useMemo(() => {
|
||||
const timeDiff = (Date.parse(end_timestamp) - Date.parse(start_timestamp)) / 60000;
|
||||
@@ -853,6 +933,46 @@ const Detail = (props) => {
|
||||
'barData'
|
||||
);
|
||||
|
||||
// ===== 模型调用次数折线图 =====
|
||||
let modelLineData = [];
|
||||
chartTimePoints.forEach((time) => {
|
||||
const timeData = Array.from(uniqueModels).map((model) => {
|
||||
const key = `${time}-${model}`;
|
||||
const aggregated = aggregatedData.get(key);
|
||||
return {
|
||||
Time: time,
|
||||
Model: model,
|
||||
Count: aggregated?.count || 0,
|
||||
};
|
||||
});
|
||||
modelLineData.push(...timeData);
|
||||
});
|
||||
modelLineData.sort((a, b) => a.Time.localeCompare(b.Time));
|
||||
|
||||
// ===== 模型调用次数排行柱状图 =====
|
||||
const rankData = Array.from(modelTotals)
|
||||
.map(([model, count]) => ({
|
||||
Model: model,
|
||||
Count: count,
|
||||
}))
|
||||
.sort((a, b) => b.Count - a.Count);
|
||||
|
||||
updateChartSpec(
|
||||
setSpecModelLine,
|
||||
modelLineData,
|
||||
`${t('总计')}:${renderNumber(totalTimes)}`,
|
||||
newModelColors,
|
||||
'lineData'
|
||||
);
|
||||
|
||||
updateChartSpec(
|
||||
setSpecRankBar,
|
||||
rankData,
|
||||
`${t('总计')}:${renderNumber(totalTimes)}`,
|
||||
newModelColors,
|
||||
'rankData'
|
||||
);
|
||||
|
||||
setPieData(newPieData);
|
||||
setLineData(newLineData);
|
||||
setConsumeQuota(totalQuota);
|
||||
@@ -1122,28 +1242,53 @@ const Detail = (props) => {
|
||||
{t('消耗分布')}
|
||||
</span>
|
||||
} itemKey="1" />
|
||||
<TabPane tab={
|
||||
<span>
|
||||
<IconPulse />
|
||||
{t('消耗趋势')}
|
||||
</span>
|
||||
} itemKey="2" />
|
||||
<TabPane tab={
|
||||
<span>
|
||||
<IconPieChart2Stroked />
|
||||
{t('调用次数分布')}
|
||||
</span>
|
||||
} itemKey="2" />
|
||||
} itemKey="3" />
|
||||
<TabPane tab={
|
||||
<span>
|
||||
<IconHistogram />
|
||||
{t('调用次数排行')}
|
||||
</span>
|
||||
} itemKey="4" />
|
||||
</Tabs>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div style={{ height: 400 }}>
|
||||
{activeChartTab === '1' ? (
|
||||
{activeChartTab === '1' && (
|
||||
<VChart
|
||||
spec={spec_line}
|
||||
option={CHART_CONFIG}
|
||||
/>
|
||||
) : (
|
||||
)}
|
||||
{activeChartTab === '2' && (
|
||||
<VChart
|
||||
spec={spec_model_line}
|
||||
option={CHART_CONFIG}
|
||||
/>
|
||||
)}
|
||||
{activeChartTab === '3' && (
|
||||
<VChart
|
||||
spec={spec_pie}
|
||||
option={CHART_CONFIG}
|
||||
/>
|
||||
)}
|
||||
{activeChartTab === '4' && (
|
||||
<VChart
|
||||
spec={spec_rank_bar}
|
||||
option={CHART_CONFIG}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
|
||||
@@ -272,10 +272,7 @@ const Home = () => {
|
||||
className="w-full h-screen border-none"
|
||||
/>
|
||||
) : (
|
||||
<div
|
||||
className="text-base md:text-lg p-4 md:p-6 lg:p-8 overflow-x-hidden max-w-6xl mx-auto"
|
||||
dangerouslySetInnerHTML={{ __html: homePageContent }}
|
||||
></div>
|
||||
<div className="mt-[64px]" dangerouslySetInnerHTML={{ __html: homePageContent }} />
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -78,8 +78,7 @@ const EditRedemption = (props) => {
|
||||
|
||||
const submit = async (values) => {
|
||||
let name = values.name;
|
||||
if (!isEdit && values.name === '') {
|
||||
|
||||
if (!isEdit && (!name || name === '')) {
|
||||
name = renderQuota(values.quota);
|
||||
}
|
||||
setLoading(true);
|
||||
@@ -209,7 +208,7 @@ const EditRedemption = (props) => {
|
||||
label={t('名称')}
|
||||
placeholder={t('请输入名称')}
|
||||
style={{ width: '100%' }}
|
||||
rules={isEdit ? [] : [{ required: true, message: t('请输入名称') }]}
|
||||
rules={!isEdit ? [] : [{ required: true, message: t('请输入名称') }]}
|
||||
showClear
|
||||
/>
|
||||
</Col>
|
||||
|
||||
@@ -373,7 +373,7 @@ export default function UpstreamRatioSync(props) {
|
||||
<div className="flex flex-col md:flex-row gap-2 w-full md:w-auto order-2 md:order-1">
|
||||
<Button
|
||||
icon={<RefreshCcw size={14} />}
|
||||
className="!rounded-full w-full md:w-auto mt-2"
|
||||
className="w-full md:w-auto mt-2"
|
||||
onClick={() => {
|
||||
setModalVisible(true);
|
||||
if (allChannels.length === 0) {
|
||||
@@ -393,7 +393,7 @@ export default function UpstreamRatioSync(props) {
|
||||
type='secondary'
|
||||
onClick={applySync}
|
||||
disabled={!hasSelections}
|
||||
className="!rounded-full w-full md:w-auto mt-2"
|
||||
className="w-full md:w-auto mt-2"
|
||||
>
|
||||
{t('应用同步')}
|
||||
</Button>
|
||||
@@ -406,7 +406,7 @@ export default function UpstreamRatioSync(props) {
|
||||
placeholder={t('搜索模型名称')}
|
||||
value={searchKeyword}
|
||||
onChange={setSearchKeyword}
|
||||
className="!rounded-full w-full sm:w-64"
|
||||
className="w-full sm:w-64"
|
||||
showClear
|
||||
/>
|
||||
|
||||
@@ -414,7 +414,7 @@ export default function UpstreamRatioSync(props) {
|
||||
placeholder={t('按倍率类型筛选')}
|
||||
value={ratioTypeFilter}
|
||||
onChange={setRatioTypeFilter}
|
||||
className="!rounded-full w-full sm:w-48"
|
||||
className="w-full sm:w-48"
|
||||
showClear
|
||||
onClear={() => setRatioTypeFilter('')}
|
||||
>
|
||||
@@ -704,7 +704,6 @@ export default function UpstreamRatioSync(props) {
|
||||
scroll={{ x: 'max-content' }}
|
||||
size='middle'
|
||||
loading={loading || syncLoading}
|
||||
className="rounded-xl overflow-hidden"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -139,14 +139,24 @@ const EditToken = (props) => {
|
||||
if (formApiRef.current) {
|
||||
if (!isEdit) {
|
||||
formApiRef.current.setValues(getInitValues());
|
||||
} else {
|
||||
loadToken();
|
||||
}
|
||||
}
|
||||
loadModels();
|
||||
loadGroups();
|
||||
}, [props.editingToken.id]);
|
||||
|
||||
useEffect(() => {
|
||||
if (props.visiable) {
|
||||
if (isEdit) {
|
||||
loadToken();
|
||||
} else {
|
||||
formApiRef.current?.setValues(getInitValues());
|
||||
}
|
||||
} else {
|
||||
formApiRef.current?.reset();
|
||||
}
|
||||
}, [props.visiable, props.editingToken.id]);
|
||||
|
||||
const generateRandomSuffix = () => {
|
||||
const characters =
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
Row,
|
||||
Col,
|
||||
Input,
|
||||
InputNumber,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconUser,
|
||||
@@ -39,7 +40,7 @@ const EditUser = (props) => {
|
||||
const userId = props.editingUser.id;
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [addQuotaModalOpen, setIsModalOpen] = useState(false);
|
||||
const [addQuotaLocal, setAddQuotaLocal] = useState('0');
|
||||
const [addQuotaLocal, setAddQuotaLocal] = useState('');
|
||||
const [groupOptions, setGroupOptions] = useState([]);
|
||||
const formApiRef = useRef(null);
|
||||
|
||||
@@ -254,7 +255,6 @@ const EditUser = (props) => {
|
||||
field='quota'
|
||||
label={t('剩余额度')}
|
||||
placeholder={t('请输入新的剩余额度')}
|
||||
min={0}
|
||||
step={500000}
|
||||
extraText={renderQuotaWithPrompt(values.quota || 0)}
|
||||
rules={[{ required: true, message: t('请输入额度') }]}
|
||||
@@ -328,18 +328,19 @@ const EditUser = (props) => {
|
||||
const current = formApiRef.current?.getValue('quota') || 0;
|
||||
return (
|
||||
<Text type='secondary' className='block mb-2'>
|
||||
{`${t('新额度')}${renderQuota(current)} + ${renderQuota(addQuotaLocal)} = ${renderQuota(current + parseInt(addQuotaLocal || 0))}`}
|
||||
{`${t('新额度:')}${renderQuota(current)} + ${renderQuota(addQuotaLocal)} = ${renderQuota(current + parseInt(addQuotaLocal || 0))}`}
|
||||
</Text>
|
||||
);
|
||||
})()
|
||||
}
|
||||
</div>
|
||||
<Input
|
||||
<InputNumber
|
||||
placeholder={t('需要添加的额度(支持负数)')}
|
||||
type='number'
|
||||
value={addQuotaLocal}
|
||||
onChange={setAddQuotaLocal}
|
||||
style={{ width: '100%' }}
|
||||
showClear
|
||||
step={500000}
|
||||
/>
|
||||
</Modal>
|
||||
</>
|
||||
|
||||
Reference in New Issue
Block a user