Files
new-api/middleware/auth.go
T

499 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
// toInt safely converts session values to int.
// Session stores may return int, int64, float64, etc. depending on the backend.
func toInt(v interface{}) (int, bool) {
if v == nil {
return 0, false
}
switch n := v.(type) {
case int:
return n, true
case int64:
return int(n), true
case int32:
return int(n), true
case float64:
return int(n), true
case float32:
return int(n), true
default:
return 0, false
}
}
func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c)
username := session.Get("username")
role := session.Get("role")
id := session.Get("id")
status := session.Get("status")
useAccessToken := false
if username == nil {
// Check access token
accessToken := c.Request.Header.Get("Authorization")
if accessToken == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthNotLoggedIn),
})
c.Abort()
return
}
user, authErr := model.ValidateAccessToken(accessToken)
if authErr != nil {
if errors.Is(authErr, model.ErrDatabase) {
common.SysLog("ValidateAccessToken database error: " + authErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgDatabaseError),
})
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthAccessTokenInvalid),
})
}
c.Abort()
return
}
if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserInfoInvalid),
})
c.Abort()
return
}
// Token is valid
username = user.Username
role = user.Role
id = user.Id
status = user.Status
useAccessToken = true
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthAccessTokenInvalid),
})
c.Abort()
return
}
}
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserIdNotProvided),
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserIdFormatError),
})
c.Abort()
return
}
sessionId, ok := toInt(id)
if !ok || sessionId != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserIdMismatch),
})
c.Abort()
return
}
statusInt, ok := toInt(status)
if !ok || statusInt == common.UserStatusDisabled {
if !ok {
common.SysLog(fmt.Sprintf("authHelper: invalid status type %T: %v", status, status))
}
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserBanned),
})
c.Abort()
return
}
roleInt, ok := toInt(role)
if !ok || roleInt < minRole {
if !ok {
common.SysLog(fmt.Sprintf("authHelper: invalid role type %T: %v", role, role))
}
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthInsufficientPrivilege),
})
c.Abort()
return
}
usernameStr, ok := username.(string)
if !ok {
common.SysLog(fmt.Sprintf("authHelper: invalid username type %T: %v", username, username))
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserInfoInvalid),
})
c.Abort()
return
}
if !validUserInfo(usernameStr, roleInt) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserInfoInvalid),
})
c.Abort()
return
}
// 防止不同newapi版本冲突,导致数据不通用
c.Header("Auth-Version", "864b7076dbcd0a3c01b5520316720ebf")
c.Set("username", usernameStr)
c.Set("role", roleInt)
c.Set("id", sessionId)
c.Set("group", session.Get("group"))
c.Set("user_group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
// 管理/root 写操作审计兜底:内聚在鉴权链路里,保证任何经过 AdminAuth/RootAuth
// 的写接口都会自动留痕(无需在路由上单独挂审计中间件,避免漏挂)。
// handler 内手动埋点者会设置 ContextKeyAuditLoggedfinishAdminAudit 据此跳过。
var auditWriter *auditResponseWriter
if minRole >= common.RoleAdminUser {
auditWriter = beginAdminAudit(c)
}
c.Next()
finishAdminAudit(c, auditWriter)
}
func TryUserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
session := sessions.Default(c)
id := session.Get("id")
if id != nil {
if idInt, ok := toInt(id); ok {
c.Set("id", idInt)
}
}
role := session.Get("role")
if role != nil {
if roleInt, ok := toInt(role); ok {
c.Set("role", roleInt)
}
}
c.Next()
}
}
func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser)
}
}
func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleAdminUser)
}
}
func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleRootUser)
}
}
func WssAuth(c *gin.Context) {
}
// TokenOrUserAuth allows either session-based user auth or API token auth.
// Used for endpoints that need to be accessible from both the dashboard and API clients.
func TokenOrUserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
// Try session auth first (dashboard users)
session := sessions.Default(c)
if id := session.Get("id"); id != nil {
if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled {
c.Set("id", id)
c.Next()
return
}
}
// Fall back to token auth (API clients)
TokenAuth()(c)
}
}
// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
// 即使令牌已过期、已耗尽或已禁用,也允许访问。
// 仍然检查用户是否被封禁。
func TokenAuthReadOnly() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("Authorization")
if key == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgTokenNotProvided),
})
c.Abort()
return
}
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
key = strings.TrimSpace(key[7:])
}
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.GetTokenByKey(key, false)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgTokenInvalid),
})
} else {
common.SysLog("TokenAuthReadOnly GetTokenByKey database error: " + err.Error())
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgDatabaseError),
})
}
c.Abort()
return
}
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
common.SysLog(fmt.Sprintf("TokenAuthReadOnly GetUserCache error for user %d: %v", token.UserId, err))
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgDatabaseError),
})
c.Abort()
return
}
if userCache.Status != common.UserStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": common.TranslateMessage(c, i18n.MsgAuthUserBanned),
})
c.Abort()
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
c.Next()
}
}
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
// 先检测是否为ws
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
// read sk from Sec-WebSocket-Protocol
key := c.Request.Header.Get("Sec-WebSocket-Protocol")
parts := strings.Split(key, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "openai-insecure-api-key") {
key = strings.TrimPrefix(part, "openai-insecure-api-key.")
break
}
}
c.Request.Header.Set("Authorization", "Bearer "+key)
}
// 检查path包含/v1/messages 或 /v1/models
if strings.Contains(c.Request.URL.Path, "/v1/messages") || strings.Contains(c.Request.URL.Path, "/v1/models") {
anthropicKey := c.Request.Header.Get("x-api-key")
if anthropicKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
}
}
// gemini api 从query中获取key
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
skKey := c.Query("key")
if skKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+skKey)
}
// 从x-goog-api-key header中获取key
xGoogKey := c.Request.Header.Get("x-goog-api-key")
if xGoogKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
}
}
key := c.Request.Header.Get("Authorization")
parts := make([]string, 0)
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
key = strings.TrimSpace(key[7:])
}
if key == "" || key == "midjourney-proxy" {
key = c.Request.Header.Get("mj-api-secret")
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
key = strings.TrimSpace(key[7:])
}
key = strings.TrimPrefix(key, "sk-")
parts = strings.Split(key, "-")
key = parts[0]
} else {
key = strings.TrimPrefix(key, "sk-")
parts = strings.Split(key, "-")
key = parts[0]
}
token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.UserId)
}
}
if err != nil {
if errors.Is(err, model.ErrDatabase) {
common.SysLog("TokenAuth ValidateUserToken database error: " + err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError,
common.TranslateMessage(c, i18n.MsgDatabaseError))
} else {
abortWithOpenAiMessage(c, http.StatusUnauthorized,
common.TranslateMessage(c, i18n.MsgTokenInvalid))
}
return
}
allowIps := token.GetIpLimits()
if len(allowIps) > 0 {
clientIp := c.ClientIP()
logger.LogDebug(c, "Token has IP restrictions, checking client IP %s", clientIp)
ip := net.ParseIP(clientIp)
if ip == nil {
abortWithOpenAiMessage(c, http.StatusForbidden, "无法解析客户端 IP 地址")
return
}
if common.IsIpInCIDRList(ip, allowIps) == false {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中", types.ErrorCodeAccessDenied)
return
}
logger.LogDebug(c, "Client IP %s passed the token IP restrictions check", clientIp)
}
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
common.SysLog(fmt.Sprintf("TokenAuth GetUserCache error for user %d: %v", token.UserId, err))
abortWithOpenAiMessage(c, http.StatusInternalServerError,
common.TranslateMessage(c, i18n.MsgDatabaseError))
return
}
userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, common.TranslateMessage(c, i18n.MsgAuthUserBanned))
return
}
userCache.WriteContext(c)
userGroup := userCache.Group
tokenGroup := token.Group
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := service.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("无权访问 %s 分组", tokenGroup))
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
err = SetupContextForToken(c, token, parts...)
if err != nil {
return
}
c.Next()
}
}
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
if token == nil {
return fmt.Errorf("token is nil")
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
c.Set("token_name", token.Name)
c.Set("token_unlimited_quota", token.UnlimitedQuota)
if !token.UnlimitedQuota {
c.Set("token_quota", token.RemainQuota)
}
if token.ModelLimitsEnabled {
c.Set("token_model_limit_enabled", true)
c.Set("token_model_limit", token.GetModelLimitsMap())
} else {
c.Set("token_model_limit_enabled", false)
}
common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group)
common.SetContextKey(c, constant.ContextKeyTokenCrossGroupRetry, token.CrossGroupRetry)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])
} else {
c.Header("specific_channel_version", "701e3ae1dc3f7975556d354e0675168d004891c8")
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return fmt.Errorf("普通用户不支持指定渠道")
}
}
return nil
}