Compare commits

...

11 Commits

Author SHA1 Message Date
1808837298@qq.com 6cb9001ff3 fix: claude to openai tools use 2025-03-12 19:29:15 +08:00
1808837298@qq.com d9a6a2db87 fix: claude to openai tools use 2025-03-12 18:53:38 +08:00
1808837298@qq.com 03de7e2ea1 Merge remote-tracking branch 'origin/main' 2025-03-12 17:53:52 +08:00
1808837298@qq.com 1800e0ae9e feat(relay): Add Xinference channel support 2025-03-12 17:53:46 +08:00
Calcium-Ion f4cf7c8d43 Merge pull request #848 from wzxjohn/feature/oidc
feat: add oidc support
2025-03-11 23:20:55 +08:00
1808837298@qq.com a280feeae0 fix: Add error logging for OIDC configuration retrieval 2025-03-11 23:20:27 +08:00
1808837298@qq.com 30d9f433f1 refactor: Update OIDC status check to use oidc_enabled flag 2025-03-11 22:36:31 +08:00
1808837298@qq.com 3ede51a9a7 refactor: Remove OIDC configuration from option initialization 2025-03-11 22:03:20 +08:00
1808837298@qq.com 9f3cc03508 refactor: Migrate OIDC configuration to system settings 2025-03-11 22:00:31 +08:00
1808837298@qq.com 215e768caf feat(ui): Improve model testing button layout and styling 2025-03-11 21:22:10 +08:00
wzxjohn bdb1a2fcb9 feat: add oidc support 2025-03-11 15:52:03 +08:00
26 changed files with 716 additions and 159 deletions
+2 -1
View File
@@ -77,7 +77,6 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LinuxDOClientId = ""
var LinuxDOClientSecret = ""
@@ -235,6 +234,7 @@ const (
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -287,4 +287,5 @@ var ChannelBaseURLs = []string{
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
}
+38 -34
View File
@@ -8,6 +8,7 @@ import (
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -34,40 +35,43 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
},
})
return
+240
View File
@@ -0,0 +1,240 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/system_setting"
"strconv"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
if oidcResponse.AccessToken == "" {
common.SysError("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysError("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysError("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}
+8
View File
@@ -6,6 +6,7 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/system_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -51,6 +52,13 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "oidc.enabled":
if option.Value == "true" && system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret",
})
}
case "LinuxDOOAuthEnabled":
if option.Value == "true" && common.LinuxDOClientId == "" {
c.JSON(http.StatusOK, gin.H{
+5 -5
View File
@@ -28,9 +28,9 @@ require (
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
golang.org/x/crypto v0.27.0
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.28.0
golang.org/x/net v0.35.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -84,9 +84,9 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect
+10 -10
View File
@@ -217,18 +217,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -239,14 +239,14 @@ golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+13 -1
View File
@@ -9,7 +9,6 @@ import (
"strings"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -24,6 +23,7 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
@@ -442,6 +442,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -473,6 +481,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
}
+21 -24
View File
@@ -144,11 +144,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
var id string
var model string
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &relaymodel.Usage{},
}
isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
@@ -161,33 +164,19 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
isFirst = false
info.FirstResponseTime = time.Now()
}
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
claudeResponse := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
if claudeUsage != nil {
usage.PromptTokens += claudeUsage.InputTokens
usage.CompletionTokens += claudeUsage.OutputTokens
}
response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse)
if response == nil {
if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) {
return true
}
if response.Id != "" {
id = response.Id
}
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
@@ -203,8 +192,16 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return false
}
})
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
@@ -217,5 +214,5 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
}
return nil, &usage
return nil, claudeInfo.Usage
}
+62 -50
View File
@@ -1,6 +1,7 @@
package claude
import (
"bytes"
"encoding/json"
"fmt"
"io"
@@ -290,9 +291,8 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
return &claudeRequest, nil
}
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
@@ -308,7 +308,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
//claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
@@ -325,7 +325,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
})
}
} else {
return nil, nil
return nil
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
@@ -352,23 +352,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil, nil
return nil
} else {
return nil, nil
return nil
}
}
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
return &response
}
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
@@ -437,49 +434,65 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse
}
type ClaudeResponseInfo struct {
ResponseId string
Created int64
Model string
ResponseText strings.Builder
Usage *dto.Usage
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if oaiResponse == nil {
return false
}
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text)
} else if claudeResponse.Type == "message_delta" {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
}
}
oaiResponse.Id = claudeInfo.ResponseId
oaiResponse.Created = claudeInfo.Created
oaiResponse.Model = claudeInfo.Model
return true
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
usage = &dto.Usage{}
responseText := ""
createdTime := common.GetTimestamp()
claudeInfo := &ClaudeResponseInfo{
ResponseId: responseId,
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
return true
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
return true
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
err = helper.ObjectData(c, response)
if err != nil {
@@ -489,25 +502,24 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
})
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if usage.PromptTokens == 0 {
usage.PromptTokens = info.PromptTokens
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
helper.Done(c)
//resp.Body.Close()
return nil, usage
return nil, claudeInfo.Usage
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+5
View File
@@ -18,6 +18,7 @@ import (
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
@@ -251,6 +252,8 @@ func (a *Adaptor) GetModelList() []string {
return lingyiwanwu.ModelList
case common.ChannelTypeMiniMax:
return minimax.ModelList
case common.ChannelTypeXinference:
return xinference.ModelList
default:
return ModelList
}
@@ -266,6 +269,8 @@ func (a *Adaptor) GetChannelName() string {
return lingyiwanwu.ChannelName
case common.ChannelTypeMiniMax:
return minimax.ChannelName
case common.ChannelTypeXinference:
return xinference.ChannelName
default:
return ChannelName
}
+7
View File
@@ -0,0 +1,7 @@
package xinference
var ModelList = []string{
"bge-reranker-v2-m3",
}
var ChannelName = "xinference"
+3
View File
@@ -31,6 +31,7 @@ const (
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -89,6 +90,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
case common.ChannelTypeXinference:
apiType = APITypeXinference
}
if apiType == -1 {
return APITypeOpenAI, false
+2 -2
View File
@@ -34,8 +34,6 @@ import (
func GetAdaptor(apiType int) channel.Adaptor {
switch apiType {
//case constant.APITypeAIProxyLibrary:
// return &aiproxy.Adaptor{}
case constant.APITypeAli:
return &ali.Adaptor{}
case constant.APITypeAnthropic:
@@ -86,6 +84,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &baidu_v2.Adaptor{}
case constant.APITypeOpenRouter:
return &openrouter.Adaptor{}
case constant.APITypeXinference:
return &openai.Adaptor{}
}
return nil
}
+1
View File
@@ -25,6 +25,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
+25
View File
@@ -0,0 +1,25 @@
package system_setting
import "one-api/setting/config"
type OIDCSettings struct {
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"`
WellKnown string `json:"well_known"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
}
// 默认配置
var defaultOIDCSettings = OIDCSettings{}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("oidc", &defaultOIDCSettings)
}
func GetOIDCSettings() *OIDCSettings {
return &defaultOIDCSettings
}
+8
View File
@@ -156,6 +156,14 @@ function App() {
</Suspense>
}
/>
<Route
path='/oauth/oidc'
element={
<Suspense fallback={<Loading></Loading>}>
<OAuth2Callback type='oidc'></OAuth2Callback>
</Suspense>
}
/>
<Route
path='/oauth/linuxdo'
element={
+22 -24
View File
@@ -450,13 +450,6 @@ const ChannelsTable = () => {
dataIndex: 'operate',
render: (text, record, index) => {
if (record.children === undefined) {
// 构建模型测试菜单
const modelMenuItems = record.models.split(',').map(model => ({
node: 'item',
name: model,
onClick: () => testChannel(record, model)
}));
return (
<div>
<SplitButtonGroup
@@ -1566,8 +1559,9 @@ const ChannelsTable = () => {
<div style={{
display: 'grid',
gridTemplateColumns: 'repeat(auto-fill, minmax(180px, 1fr))',
gap: '10px'
gridTemplateColumns: 'repeat(auto-fill, minmax(200px, 1fr))',
gap: '12px',
marginBottom: '16px'
}}>
{currentTestChannel.models.split(',')
.filter(model => model.toLowerCase().includes(modelSearchKeyword.toLowerCase()))
@@ -1575,27 +1569,31 @@ const ChannelsTable = () => {
return (
<Button
key={index}
theme="light"
type="tertiary"
style={{
height: 'auto',
padding: '8px 12px',
textAlign: 'center',
}}
onClick={() => {
testChannel(currentTestChannel, model);
}}
>
{model}
</Button>
theme="light"
type="tertiary"
style={{
height: 'auto',
padding: '10px 12px',
textAlign: 'center',
whiteSpace: 'nowrap',
overflow: 'hidden',
textOverflow: 'ellipsis',
width: '100%',
borderRadius: '6px'
}}
onClick={() => {
testChannel(currentTestChannel, model);
}}
>
{model}
</Button>
);
})}
</div>
{/* 显示搜索结果数量 */}
{modelSearchKeyword && (
<Typography.Text type="secondary" style={{ marginTop: '16px', display: 'block' }}>
<Typography.Text type="secondary" style={{ display: 'block' }}>
{t('找到')} {currentTestChannel.models.split(',').filter(model =>
model.toLowerCase().includes(modelSearchKeyword.toLowerCase())
).length} {t('个模型')}
+14 -1
View File
@@ -9,7 +9,7 @@ import {
showSuccess,
updateAPI,
} from '../helpers';
import { onGitHubOAuthClicked, onLinuxDOOAuthClicked } from './utils';
import {onGitHubOAuthClicked, onOIDCClicked, onLinuxDOOAuthClicked} from './utils';
import Turnstile from 'react-turnstile';
import {
Button,
@@ -25,6 +25,7 @@ import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import TelegramLoginButton from 'react-telegram-login';
import { IconGithubLogo, IconAlarm } from '@douyinfe/semi-icons';
import OIDCIcon from './OIDCIcon.js';
import WeChatIcon from './WeChatIcon';
import { setUserData } from '../helpers/data.js';
import LinuxDoIcon from './LinuxDoIcon.js';
@@ -229,6 +230,7 @@ const LoginForm = () => {
</Text>
</div>
{status.github_oauth ||
status.oidc_enabled ||
status.wechat_login ||
status.telegram_oauth ||
status.linuxdo_oauth ? (
@@ -254,6 +256,17 @@ const LoginForm = () => {
) : (
<></>
)}
{status.oidc_enabled ? (
<Button
type='primary'
icon={<OIDCIcon />}
onClick={() =>
onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id)
}
/>
) : (
<></>
)}
{status.linuxdo_oauth ? (
<Button
icon={<LinuxDoIcon />}
+22
View File
@@ -0,0 +1,22 @@
import React from 'react';
import { Icon } from '@douyinfe/semi-ui';
const OIDCIcon = (props) => {
function CustomIcon() {
return (
<svg t="1723135116886" className="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"
p-id="10969" width="1em" height="1em">
<path
d="M512 960C265 960 64 759 64 512S265 64 512 64s448 201 448 448-201 448-448 448z m0-882.6c-239.7 0-434.6 195-434.6 434.6s195 434.6 434.6 434.6 434.6-195 434.6-434.6S751.7 77.4 512 77.4z"
p-id="10970" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="60"></path>
<path
d="M197.7 512c0-78.3 31.6-98.8 87.2-98.8 56.2 0 87.2 20.5 87.2 98.8s-31 98.8-87.2 98.8c-55.7 0-87.2-20.5-87.2-98.8z m130.4 0c0-46.8-7.8-64.5-43.2-64.5-35.2 0-42.9 17.7-42.9 64.5 0 47.1 7.8 63.7 42.9 63.7 35.4 0 43.2-16.6 43.2-63.7zM409.7 415.9h42.1V608h-42.1V415.9zM653.9 512c0 74.2-37.1 96.1-93.6 96.1h-65.9V415.9h65.9c56.5 0 93.6 16.1 93.6 96.1z m-43.5 0c0-49.3-17.7-60.6-52.3-60.6h-21.6v120.7h21.6c35.4 0 52.3-13.3 52.3-60.1zM686.5 512c0-74.2 36.3-98.8 92.7-98.8 18.3 0 33.2 2.2 44.8 6.4v36.3c-11.9-4.2-26-6.6-42.1-6.6-34.6 0-49.8 15.5-49.8 62.6 0 50.1 15.2 62.6 49.3 62.6 15.8 0 30.2-2.2 44.8-7.5v36c-11.3 4.7-28.5 8-46.8 8-56.1-0.2-92.9-18.7-92.9-99z"
p-id="10971" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="20"></path>
</svg>
);
}
return <Icon svg={<CustomIcon />} />;
};
export default OIDCIcon;
+31 -1
View File
@@ -10,7 +10,7 @@ import {
} from '../helpers';
import Turnstile from 'react-turnstile';
import {UserContext} from '../context/User';
import {onGitHubOAuthClicked, onLinuxDOOAuthClicked} from './utils';
import {onGitHubOAuthClicked, onOIDCClicked, onLinuxDOOAuthClicked} from './utils';
import {
Avatar,
Banner,
@@ -640,6 +640,36 @@ const PersonalSetting = () => {
</div>
</div>
</div>
<div style={{marginTop: 10}}>
<Typography.Text strong>{t('OIDC')}</Typography.Text>
<div
style={{display: 'flex', justifyContent: 'space-between'}}
>
<div>
<Input
value={
userState.user && userState.user.oidc_id !== ''
? userState.user.oidc_id
: t('未绑定')
}
readonly={true}
></Input>
</div>
<div>
<Button
onClick={() => {
onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id);
}}
disabled={
(userState.user && userState.user.oidc_id !== '') ||
!status.oidc_enabled
}
>
{status.oidc_enabled ? t('绑定') : t('未启用')}
</Button>
</div>
</div>
</div>
<div style={{marginTop: 10}}>
<Typography.Text strong>{t('Telegram')}</Typography.Text>
<div
+14 -1
View File
@@ -6,7 +6,8 @@ import { Button, Card, Divider, Form, Icon, Layout, Modal } from '@douyinfe/semi
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import { IconGithubLogo } from '@douyinfe/semi-icons';
import { onGitHubOAuthClicked, onLinuxDOOAuthClicked } from './utils.js';
import {onGitHubOAuthClicked, onLinuxDOOAuthClicked, onOIDCClicked} from './utils.js';
import OIDCIcon from "./OIDCIcon.js";
import LinuxDoIcon from './LinuxDoIcon.js';
import WeChatIcon from './WeChatIcon.js';
import TelegramLoginButton from 'react-telegram-login/src';
@@ -262,6 +263,7 @@ const RegisterForm = () => {
</Text>
</div>
{status.github_oauth ||
status.oidc_enabled ||
status.wechat_login ||
status.telegram_oauth ||
status.linuxdo_oauth ? (
@@ -287,6 +289,17 @@ const RegisterForm = () => {
) : (
<></>
)}
{status.oidc_enabled ? (
<Button
type='primary'
icon={<OIDCIcon />}
onClick={() =>
onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id)
}
/>
) : (
<></>
)}
{status.linuxdo_oauth ? (
<Button
icon={<LinuxDoIcon />}
+122 -2
View File
@@ -8,7 +8,7 @@ import {
Message,
Modal,
} from 'semantic-ui-react';
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
import { API, removeTrailingSlash, showError, showSuccess, verifyJSON } from '../helpers';
import { useTheme } from '../context/Theme';
@@ -20,6 +20,13 @@ const SystemSetting = () => {
GitHubOAuthEnabled: '',
GitHubClientId: '',
GitHubClientSecret: '',
'oidc.enabled': '',
'oidc.client_id': '',
'oidc.client_secret': '',
'oidc.well_known': '',
'oidc.authorization_endpoint': '',
'oidc.token_endpoint': '',
'oidc.user_info_endpoint': '',
Notice: '',
SMTPServer: '',
SMTPPort: '',
@@ -106,6 +113,7 @@ const SystemSetting = () => {
case 'PasswordRegisterEnabled':
case 'EmailVerificationEnabled':
case 'GitHubOAuthEnabled':
case 'oidc.enabled':
case 'LinuxDOOAuthEnabled':
case 'WeChatAuthEnabled':
case 'TelegramOAuthEnabled':
@@ -159,6 +167,12 @@ const SystemSetting = () => {
name === 'PayAddress' ||
name === 'GitHubClientId' ||
name === 'GitHubClientSecret' ||
name === 'oidc.well_known' ||
name === 'oidc.client_id' ||
name === 'oidc.client_secret' ||
name === 'oidc.authorization_endpoint' ||
name === 'oidc.token_endpoint' ||
name === 'oidc.user_info_endpoint' ||
name === 'WeChatServerAddress' ||
name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' ||
@@ -286,6 +300,44 @@ const SystemSetting = () => {
}
};
const submitOIDCSettings = async () => {
if (inputs['oidc.well_known'] !== '') {
if (!inputs['oidc.well_known'].startsWith('http://') && !inputs['oidc.well_known'].startsWith('https://')) {
showError('Well-Known URL 必须以 http:// 或 https:// 开头');
return;
}
try {
const res = await API.get(inputs['oidc.well_known']);
inputs['oidc.authorization_endpoint'] = res.data['authorization_endpoint'];
inputs['oidc.token_endpoint'] = res.data['token_endpoint'];
inputs['oidc.user_info_endpoint'] = res.data['userinfo_endpoint'];
showSuccess('获取 OIDC 配置成功!');
} catch (err) {
console.error(err);
showError("获取 OIDC 配置失败,请检查网络状况和 Well-Known URL 是否正确");
}
}
if (originInputs['oidc.well_known'] !== inputs['oidc.well_known']) {
await updateOption('oidc.well_known', inputs['oidc.well_known']);
}
if (originInputs['oidc.client_id'] !== inputs['oidc.client_id']) {
await updateOption('oidc.client_id', inputs['oidc.client_id']);
}
if (originInputs['oidc.client_secret'] !== inputs['oidc.client_secret'] && inputs['oidc.client_secret'] !== '') {
await updateOption('oidc.client_secret', inputs['oidc.client_secret']);
}
if (originInputs['oidc.authorization_endpoint'] !== inputs['oidc.authorization_endpoint']) {
await updateOption('oidc.authorization_endpoint', inputs['oidc.authorization_endpoint']);
}
if (originInputs['oidc.token_endpoint'] !== inputs['oidc.token_endpoint']) {
await updateOption('oidc.token_endpoint', inputs['oidc.token_endpoint']);
}
if (originInputs['oidc.user_info_endpoint'] !== inputs['oidc.user_info_endpoint']) {
await updateOption('oidc.user_info_endpoint', inputs['oidc.user_info_endpoint']);
}
}
const submitTelegramSettings = async () => {
// await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled);
await updateOption('TelegramBotToken', inputs.TelegramBotToken);
@@ -370,7 +422,7 @@ const SystemSetting = () => {
</Header>
<Message info>
注意代理功能仅对图片请求和 Webhook 请求生效不会影响其他 API 请求如需配置 API 请求代理请参考
<a
<a
href='https://github.com/Calcium-Ion/new-api/blob/main/docs/channel/other_setting.md'
target='_blank'
rel='noreferrer'
@@ -518,6 +570,12 @@ const SystemSetting = () => {
name='GitHubOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs['oidc.enabled'] === 'true'}
label='允许通过 OIDC 登录 & 注册'
name='oidc.enabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.LinuxDOOAuthEnabled === 'true'}
label='允许通过 LinuxDO 账户登录 & 注册'
@@ -864,6 +922,68 @@ const SystemSetting = () => {
<Form.Button onClick={submitLinuxDOOAuth}>
保存 LinuxDO OAuth 设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 OIDC
<Header.Subheader>
用以支持通过 OIDC 登录例如 OktaAuth0 等兼容 OIDC 协议的 IdP
</Header.Subheader>
</Header>
<Message>
主页链接填 <code>{ inputs.ServerAddress }</code>
重定向 URL <code>{ `${ inputs.ServerAddress }/oauth/oidc` }</code>
</Message>
<Message>
若你的 OIDC Provider 支持 Discovery Endpoint你可以仅填写 OIDC Well-Known URL系统会自动获取 OIDC 配置
</Message>
<Form.Group widths={3}>
<Form.Input
label='Client ID'
name='oidc.client_id'
onChange={handleInputChange}
value={inputs['oidc.client_id']}
placeholder='输入 OIDC 的 Client ID'
/>
<Form.Input
label='Client Secret'
name='oidc.client_secret'
onChange={handleInputChange}
type='password'
value={inputs['oidc.client_secret']}
placeholder='敏感信息不会发送到前端显示'
/>
<Form.Input
label='Well-Known URL'
name='oidc.well_known'
onChange={handleInputChange}
value={inputs['oidc.well_known']}
placeholder='请输入 OIDC 的 Well-Known URL'
/>
<Form.Input
label='Authorization Endpoint'
name='oidc.authorization_endpoint'
onChange={handleInputChange}
value={inputs['oidc.authorization_endpoint']}
placeholder='输入 OIDC 的 Authorization Endpoint'
/>
<Form.Input
label='Token Endpoint'
name='oidc.token_endpoint'
onChange={handleInputChange}
value={inputs['oidc.token_endpoint']}
placeholder='输入 OIDC 的 Token Endpoint'
/>
<Form.Input
label='Userinfo Endpoint'
name='oidc.user_info_endpoint'
onChange={handleInputChange}
value={inputs['oidc.user_info_endpoint']}
placeholder='输入 OIDC 的 Userinfo Endpoint'
/>
</Form.Group>
<Form.Button onClick={submitOIDCSettings}>
保存 OIDC 设置
</Form.Button>
</Form>
</Grid.Column>
</Grid>
+15
View File
@@ -16,6 +16,21 @@ export async function getOAuthState() {
}
}
export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
const state = await getOAuthState();
if (!state) return;
const redirect_uri = `${window.location.origin}/oauth/oidc`;
const response_type = "code";
const scope = "openid profile email";
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
if (openInNewTab) {
window.open(url);
} else
{
window.location.href = url;
}
}
export async function onGitHubOAuthClicked(github_client_id) {
const state = await getOAuthState();
if (!state) return;
+9 -3
View File
@@ -80,11 +80,12 @@ export const CHANNEL_OPTIONS = [
label: 'Google PaLM2'
},
{
value: 45,
value: 47,
color: 'blue',
label: '字节火山方舟、豆包、DeepSeek通用'
label: 'Xinference'
},
{ value: 25, color: 'green', label: 'Moonshot' },
{ value: 20, color: 'green', label: 'OpenRouter' },
{ value: 19, color: 'blue', label: '360 智脑' },
{ value: 23, color: 'teal', label: '腾讯混元' },
{ value: 31, color: 'green', label: '零一万物' },
@@ -108,5 +109,10 @@ export const CHANNEL_OPTIONS = [
value: 44,
color: 'purple',
label: '嵌入模型:MokaAI M3E'
}
},
{
value: 45,
color: 'blue',
label: '字节火山方舟、豆包、DeepSeek通用'
},
];
+6
View File
@@ -151,6 +151,12 @@ const Home = () => {
? t('已启用')
: t('未启用')}
</p>
<p>
{t('OIDC 身份验证')}
{statusState?.status?.oidc === true
? t('已启用')
: t('未启用')}
</p>
<p>
{t('微信身份验证')}
{statusState?.status?.wechat_login === true
+11
View File
@@ -26,6 +26,7 @@ const EditUser = (props) => {
display_name: '',
password: '',
github_id: '',
oidc_id: '',
wechat_id: '',
email: '',
quota: 0,
@@ -37,6 +38,7 @@ const EditUser = (props) => {
display_name,
password,
github_id,
oidc_id,
wechat_id,
telegram_id,
email,
@@ -232,6 +234,15 @@ const EditUser = (props) => {
placeholder={t('此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改')}
readonly
/>
<div style={{ marginTop: 20 }}>
<Typography.Text>{t('`已绑定的 OIDC 账户')}</Typography.Text>
</div>
<Input
name='oidc_id'
value={oidc_id}
placeholder={t('此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改')}
readonly
/>
<div style={{ marginTop: 20 }}>
<Typography.Text>{t('已绑定的微信账户')}</Typography.Text>
</div>