Compare commits

...

14 Commits

Author SHA1 Message Date
CaIon ac06e9a46f feat: Add CheckSetup function call in main to ensure proper initialization #942 2025-04-08 18:14:36 +08:00
Calcium-Ion 97be874cf5 Merge pull request #930 from Yiffyi/main
fix: save OIDC settings
2025-04-08 17:39:42 +08:00
CaIon c40404a8c1 Update MaxTokens for gemini model to 300 in test request 2025-04-08 17:37:25 +08:00
Calcium-Ion ba9cd9314b Merge pull request #936 from lamcodes/main
fix: gemini test MaxTokens
2025-04-08 17:33:31 +08:00
Calcium-Ion 950572b02e Set MaxTokens to 50 for gemini 2025-04-08 17:33:10 +08:00
CaIon d85e25fc46 feat: Integrate SetupCheck component for improved setup validation in routing 2025-04-08 17:31:46 +08:00
CaIon 9dc851d6ef feat: Initialize model settings and improve concurrency control in operation settings 2025-04-07 22:20:47 +08:00
CaIon 9be721ee29 feat: Add concurrency control to group ratio management with mutexes 2025-04-07 21:55:54 +08:00
zkp 93d9f141fe fix: gemini test MaxTokens 2025-04-06 23:24:47 +08:00
Yiffyi Jia 3d5e93d6a2 fix: cannot save OIDC settings 2025-04-05 04:24:38 +00:00
CaIon 69d790b47a Update model-ratio.go 2025-04-04 23:43:14 +08:00
CaIon 1e95160293 Update model-ratio.go 2025-04-04 23:41:41 +08:00
CaIon eca18ce25c fix: Improve setup check logic and logging for system initialization 2025-04-04 21:27:24 +08:00
CaIon ad7a64e585 Update model-ratio.go 2025-04-04 00:31:24 +08:00
11 changed files with 166 additions and 93 deletions
+2
View File
@@ -192,6 +192,8 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 300
} else {
testRequest.MaxTokens = 10
}
+7
View File
@@ -12,6 +12,7 @@ import (
"one-api/model"
"one-api/router"
"one-api/service"
"one-api/setting/operation_setting"
"os"
"strconv"
@@ -51,6 +52,9 @@ func main() {
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
model.CheckSetup()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
@@ -73,6 +77,9 @@ func main() {
constant.InitEnv()
// Initialize options
model.InitOptionMap()
// Initialize model settings
operation_setting.InitModelSettings()
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
+11 -5
View File
@@ -56,23 +56,30 @@ func createRootAccountIfNeed() error {
return nil
}
func checkSetup() {
if GetSetup() == nil {
func CheckSetup() {
setup := GetSetup()
if setup == nil {
// No setup record exists, check if we have a root user
if RootUserExists() {
common.SysLog("system is not initialized, but root user exists")
// Create setup record
setup := Setup{
newSetup := Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err := DB.Create(&setup).Error
err := DB.Create(&newSetup).Error
if err != nil {
common.SysLog("failed to create setup record: " + err.Error())
}
constant.Setup = true
} else {
common.SysLog("system is not initialized and no root user exists")
constant.Setup = false
}
} else {
// Setup record exists, system is initialized
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
constant.Setup = true
}
}
@@ -237,7 +244,6 @@ func migrateDB() error {
}
err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
checkSetup()
//err = createRootAccountIfNeed()
return err
}
+2
View File
@@ -16,6 +16,8 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-03-25",
// imagen models
"imagen-3.0-generate-002",
// embedding models
+17
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"one-api/common"
"sync"
)
var groupRatio = map[string]float64{
@@ -11,8 +12,12 @@ var groupRatio = map[string]float64{
"vip": 1,
"svip": 1,
}
var groupRatioMutex sync.RWMutex
func GetGroupRatioCopy() map[string]float64 {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
groupRatioCopy := make(map[string]float64)
for k, v := range groupRatio {
groupRatioCopy[k] = v
@@ -21,11 +26,17 @@ func GetGroupRatioCopy() map[string]float64 {
}
func ContainsGroupRatio(name string) bool {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
_, ok := groupRatio[name]
return ok
}
func GroupRatio2JSONString() string {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(groupRatio)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
@@ -34,11 +45,17 @@ func GroupRatio2JSONString() string {
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
groupRatioMutex.Lock()
defer groupRatioMutex.Unlock()
groupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &groupRatio)
}
func GetGroupRatio(name string) float64 {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
ratio, ok := groupRatio[name]
if !ok {
common.SysError("group ratio not found: " + name)
+7 -8
View File
@@ -56,17 +56,15 @@ var cacheRatioMapMutex sync.RWMutex
// GetCacheRatioMap returns the cache ratio map
func GetCacheRatioMap() map[string]float64 {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
if cacheRatioMap == nil {
cacheRatioMap = defaultCacheRatio
}
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
return cacheRatioMap
}
// CacheRatio2JSONString converts the cache ratio map to a JSON string
func CacheRatio2JSONString() string {
GetCacheRatioMap()
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(cacheRatioMap)
if err != nil {
common.SysError("error marshalling cache ratio: " + err.Error())
@@ -84,10 +82,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
// GetCacheRatio returns the cache ratio for a model
func GetCacheRatio(name string) (float64, bool) {
GetCacheRatioMap()
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
ratio, ok := cacheRatioMap[name]
if !ok {
return 1, false // Default to 0.5 if not found
return 1, false // Default to 1 if not found
}
return ratio, true
}
+90 -66
View File
@@ -131,17 +131,12 @@ var defaultModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"gemini-1.5-pro-latest": 1.25, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 0.075,
"gemini-2.0-flash": 0.05,
"gemini-2.5-pro-exp-03-25": 1.25,
"gemini-2.5-pro-preview-03-25": 1.25,
"text-embedding-004": 0.001,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@@ -207,26 +202,27 @@ var defaultModelRatio = map[string]float64{
}
var defaultModelPrice = map[string]float64{
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-3": 0.04,
"imagen-3.0-generate-002": 0.03,
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
}
var (
@@ -249,17 +245,41 @@ var defaultCompletionRatio = map[string]float64{
"gpt-4-all": 2,
}
func GetModelPriceMap() map[string]float64 {
// InitModelSettings initializes all model related settings maps
func InitModelSettings() {
// Initialize modelPriceMap
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
}
modelPriceMap = defaultModelPrice
modelPriceMapMutex.Unlock()
// Initialize modelRatioMap
modelRatioMapMutex.Lock()
modelRatioMap = defaultModelRatio
modelRatioMapMutex.Unlock()
// Initialize CompletionRatio
CompletionRatioMutex.Lock()
CompletionRatio = defaultCompletionRatio
CompletionRatioMutex.Unlock()
// Initialize cacheRatioMap
cacheRatioMapMutex.Lock()
cacheRatioMap = defaultCacheRatio
cacheRatioMapMutex.Unlock()
common.SysLog("model settings initialized")
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
return modelPriceMap
}
func ModelPrice2JSONString() string {
GetModelPriceMap()
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
jsonBytes, err := json.Marshal(modelPriceMap)
if err != nil {
common.SysError("error marshalling model price: " + err.Error())
@@ -276,7 +296,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
func GetModelPrice(name string, printErr bool) (float64, bool) {
GetModelPriceMap()
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
@@ -293,24 +315,6 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true
}
func GetModelRatioMap() map[string]float64 {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
}
return modelRatioMap
}
func ModelRatio2JSONString() string {
GetModelRatioMap()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
@@ -319,7 +323,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}
func GetModelRatio(name string) (float64, bool) {
GetModelRatioMap()
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
@@ -343,16 +349,15 @@ func GetDefaultModelRatioMap() map[string]float64 {
}
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
return CompletionRatio
}
func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
common.SysError("error marshalling completion ratio: " + err.Error())
@@ -368,7 +373,8 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
}
func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap()
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
@@ -438,7 +444,14 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
return 3, true
}
if strings.HasPrefix(name, "gemini-") {
return 4, true
if strings.HasPrefix(name, "gemini-1.5") {
return 4, true
} else if strings.HasPrefix(name, "gemini-2.0") {
return 4, true
} else if strings.HasPrefix(name, "gemini-2.5-pro-preview") {
return 6, true
}
return 4, false
}
if strings.HasPrefix(name, "command") {
switch name {
@@ -451,7 +464,7 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
case "command-r-plus-08-2024":
return 4, true
default:
return 4, true
return 4, false
}
}
// hint 只给官方上4倍率,由于开源模型供应商自行定价,不对其进行补全倍率进行强制对齐
@@ -508,3 +521,14 @@ func GetAudioCompletionRatio(name string) float64 {
}
return 2
}
func ModelRatio2JSONString() string {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
+3 -2
View File
@@ -26,6 +26,7 @@ import Playground from './pages/Playground/Playground.js';
import OAuth2Callback from "./components/OAuth2Callback.js";
import PersonalSetting from './components/PersonalSetting.js';
import Setup from './pages/Setup/index.js';
import SetupCheck from './components/SetupCheck';
const Home = lazy(() => import('./pages/Home'));
const Detail = lazy(() => import('./pages/Detail'));
@@ -35,7 +36,7 @@ function App() {
const location = useLocation();
return (
<>
<SetupCheck>
<Routes>
<Route
path='/'
@@ -286,7 +287,7 @@ function App() {
/>
<Route path='*' element={<NotFound />} />
</Routes>
</>
</SetupCheck>
);
}
+18
View File
@@ -0,0 +1,18 @@
import React, { useContext, useEffect } from 'react';
import { Navigate, useLocation } from 'react-router-dom';
import { StatusContext } from '../context/Status';
const SetupCheck = ({ children }) => {
const [statusState] = useContext(StatusContext);
const location = useLocation();
useEffect(() => {
if (statusState?.status?.setup === false && location.pathname !== '/setup') {
window.location.href = '/setup';
}
}, [statusState?.status?.setup, location.pathname]);
return children;
};
export default SetupCheck;
+7 -7
View File
@@ -619,7 +619,7 @@ const SystemSetting = () => {
允许通过 Telegram 进行登录
</Form.Checkbox>
<Form.Checkbox
field='oidc.enabled'
field="['oidc.enabled']"
noLabel
onChange={(e) => handleCheckboxChange('oidc.enabled', e)}
>
@@ -721,14 +721,14 @@ const SystemSetting = () => {
<Row gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field='oidc.well_known'
field="['oidc.well_known']"
label='Well-Known URL'
placeholder='请输入 OIDC 的 Well-Known URL'
/>
</Col>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field='oidc.client_id'
field="['oidc.client_id']"
label='Client ID'
placeholder='输入 OIDC 的 Client ID'
/>
@@ -737,7 +737,7 @@ const SystemSetting = () => {
<Row gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field='oidc.client_secret'
field="['oidc.client_secret']"
label='Client Secret'
type='password'
placeholder='敏感信息不会发送到前端显示'
@@ -745,7 +745,7 @@ const SystemSetting = () => {
</Col>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field='oidc.authorization_endpoint'
field="['oidc.authorization_endpoint']"
label='Authorization Endpoint'
placeholder='输入 OIDC 的 Authorization Endpoint'
/>
@@ -754,14 +754,14 @@ const SystemSetting = () => {
<Row gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field='oidc.token_endpoint'
field="['oidc.token_endpoint']"
label='Token Endpoint'
placeholder='输入 OIDC 的 Token Endpoint'
/>
</Col>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field='oidc.user_info_endpoint'
field="['oidc.user_info_endpoint']"
label='User Info Endpoint'
placeholder='输入 OIDC 的 Userinfo Endpoint'
/>
+2 -5
View File
@@ -66,13 +66,9 @@ const Home = () => {
};
useEffect(() => {
if (statusState.status?.setup === false) {
window.location.href = '/setup';
return;
}
displayNotice().then();
displayHomePageContent().then();
});
}, []);
return (
<>
@@ -116,6 +112,7 @@ const Home = () => {
https://github.com/Calcium-Ion/new-api
</a>
</p>
<p>
{t('协议')}
<a