813 lines
23 KiB
Go
813 lines
23 KiB
Go
package controller
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/model"
|
||
"github.com/QuantumNous/new-api/service"
|
||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||
"github.com/QuantumNous/new-api/types"
|
||
|
||
"github.com/shopspring/decimal"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// https://github.com/songquanpeng/one-api/issues/79
|
||
|
||
type OpenAISubscriptionResponse struct {
|
||
Object string `json:"object"`
|
||
HasPaymentMethod bool `json:"has_payment_method"`
|
||
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
||
HardLimitUSD float64 `json:"hard_limit_usd"`
|
||
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
||
AccessUntil int64 `json:"access_until"`
|
||
}
|
||
|
||
type OpenAIUsageDailyCost struct {
|
||
Timestamp float64 `json:"timestamp"`
|
||
LineItems []struct {
|
||
Name string `json:"name"`
|
||
Cost float64 `json:"cost"`
|
||
}
|
||
}
|
||
|
||
type OpenAICreditGrants struct {
|
||
Object string `json:"object"`
|
||
TotalGranted float64 `json:"total_granted"`
|
||
TotalUsed float64 `json:"total_used"`
|
||
TotalAvailable float64 `json:"total_available"`
|
||
}
|
||
|
||
type OpenAIUsageResponse struct {
|
||
Object string `json:"object"`
|
||
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
||
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||
}
|
||
|
||
type OpenAISBUsageResponse struct {
|
||
Msg string `json:"msg"`
|
||
Data *struct {
|
||
Credit string `json:"credit"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
type AIProxyUserOverviewResponse struct {
|
||
Success bool `json:"success"`
|
||
Message string `json:"message"`
|
||
ErrorCode int `json:"error_code"`
|
||
Data struct {
|
||
TotalPoints float64 `json:"totalPoints"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
type API2GPTUsageResponse struct {
|
||
Object string `json:"object"`
|
||
TotalGranted float64 `json:"total_granted"`
|
||
TotalUsed float64 `json:"total_used"`
|
||
TotalRemaining float64 `json:"total_remaining"`
|
||
}
|
||
|
||
type APGC2DGPTUsageResponse struct {
|
||
//Grants interface{} `json:"grants"`
|
||
Object string `json:"object"`
|
||
TotalAvailable float64 `json:"total_available"`
|
||
TotalGranted float64 `json:"total_granted"`
|
||
TotalUsed float64 `json:"total_used"`
|
||
}
|
||
|
||
type SiliconFlowUsageResponse struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Status bool `json:"status"`
|
||
Data struct {
|
||
ID string `json:"id"`
|
||
Name string `json:"name"`
|
||
Image string `json:"image"`
|
||
Email string `json:"email"`
|
||
IsAdmin bool `json:"isAdmin"`
|
||
Balance string `json:"balance"`
|
||
Status string `json:"status"`
|
||
Introduction string `json:"introduction"`
|
||
Role string `json:"role"`
|
||
ChargeBalance string `json:"chargeBalance"`
|
||
TotalBalance string `json:"totalBalance"`
|
||
Category string `json:"category"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
type DeepSeekUsageResponse struct {
|
||
IsAvailable bool `json:"is_available"`
|
||
BalanceInfos []struct {
|
||
Currency string `json:"currency"`
|
||
TotalBalance string `json:"total_balance"`
|
||
GrantedBalance string `json:"granted_balance"`
|
||
ToppedUpBalance string `json:"topped_up_balance"`
|
||
} `json:"balance_infos"`
|
||
}
|
||
|
||
type OpenRouterCreditResponse struct {
|
||
Data struct {
|
||
TotalCredits float64 `json:"total_credits"`
|
||
TotalUsage float64 `json:"total_usage"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
// GetAuthHeader get auth header
|
||
func GetAuthHeader(token string) http.Header {
|
||
h := http.Header{}
|
||
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||
return h
|
||
}
|
||
|
||
// GetClaudeAuthHeader get claude auth header
|
||
func GetClaudeAuthHeader(token string) http.Header {
|
||
h := http.Header{}
|
||
h.Add("x-api-key", token)
|
||
h.Add("anthropic-version", "2023-06-01")
|
||
return h
|
||
}
|
||
|
||
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
||
req, err := http.NewRequest(method, url, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for k := range headers {
|
||
req.Header.Add(k, headers.Get(k))
|
||
}
|
||
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
res, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if res.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("status code: %d", res.StatusCode)
|
||
}
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = res.Body.Close()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return body, nil
|
||
}
|
||
|
||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
||
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := OpenAICreditGrants{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
channel.UpdateBalance(response.TotalAvailable)
|
||
return response.TotalAvailable, nil
|
||
}
|
||
|
||
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
||
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := OpenAISBUsageResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
if response.Data == nil {
|
||
return 0, errors.New(response.Msg)
|
||
}
|
||
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
channel.UpdateBalance(balance)
|
||
return balance, nil
|
||
}
|
||
|
||
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
|
||
url := "https://aiproxy.io/api/report/getUserOverview"
|
||
headers := http.Header{}
|
||
headers.Add("Api-Key", channel.Key)
|
||
body, err := GetResponseBody("GET", url, channel, headers)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := AIProxyUserOverviewResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
if !response.Success {
|
||
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
|
||
}
|
||
channel.UpdateBalance(response.Data.TotalPoints)
|
||
return response.Data.TotalPoints, nil
|
||
}
|
||
|
||
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
|
||
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := API2GPTUsageResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
channel.UpdateBalance(response.TotalRemaining)
|
||
return response.TotalRemaining, nil
|
||
}
|
||
|
||
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
|
||
url := "https://api.siliconflow.cn/v1/user/info"
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := SiliconFlowUsageResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
if response.Code != 20000 {
|
||
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
||
}
|
||
balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
channel.UpdateBalance(balance)
|
||
return balance, nil
|
||
}
|
||
|
||
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
|
||
url := "https://api.deepseek.com/user/balance"
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := DeepSeekUsageResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
index := -1
|
||
for i, balanceInfo := range response.BalanceInfos {
|
||
if balanceInfo.Currency == "CNY" {
|
||
index = i
|
||
break
|
||
}
|
||
}
|
||
if index == -1 {
|
||
return 0, errors.New("currency CNY not found")
|
||
}
|
||
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
channel.UpdateBalance(balance)
|
||
return balance, nil
|
||
}
|
||
|
||
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
||
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := APGC2DGPTUsageResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
channel.UpdateBalance(response.TotalAvailable)
|
||
return response.TotalAvailable, nil
|
||
}
|
||
|
||
func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||
url := "https://openrouter.ai/api/v1/credits"
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
response := OpenRouterCreditResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
balance := response.Data.TotalCredits - response.Data.TotalUsage
|
||
channel.UpdateBalance(balance)
|
||
return balance, nil
|
||
}
|
||
|
||
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
type MoonshotBalanceData struct {
|
||
AvailableBalance float64 `json:"available_balance"`
|
||
VoucherBalance float64 `json:"voucher_balance"`
|
||
CashBalance float64 `json:"cash_balance"`
|
||
}
|
||
|
||
type MoonshotBalanceResponse struct {
|
||
Code int `json:"code"`
|
||
Data MoonshotBalanceData `json:"data"`
|
||
Scode string `json:"scode"`
|
||
Status bool `json:"status"`
|
||
}
|
||
|
||
response := MoonshotBalanceResponse{}
|
||
err = json.Unmarshal(body, &response)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
if !response.Status || response.Code != 0 {
|
||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||
}
|
||
availableBalanceCny := response.Data.AvailableBalance
|
||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
|
||
channel.UpdateBalance(availableBalanceUsd)
|
||
return availableBalanceUsd, nil
|
||
}
|
||
|
||
// updateChannelBalanceCustom 使用自定义余额查询配置获取上游账户余额
|
||
// 流程:(可选)获取CSRF token → 登录获取token/cookie → 用token/cookie查询余额 → 提取余额值
|
||
func updateChannelBalanceCustom(channel *model.Channel, config *dto.BalanceConfig) (float64, error) {
|
||
if config == nil || config.Template == dto.BalanceTemplateNone {
|
||
return 0, fmt.Errorf("未配置余额查询模板")
|
||
}
|
||
|
||
// 应用预置模板
|
||
applyBalanceTemplate(channel.GetBaseURL(), config)
|
||
|
||
if config.LoginURL == "" || config.BalanceURL == "" {
|
||
return 0, fmt.Errorf("余额查询配置不完整")
|
||
}
|
||
if config.Username == "" || config.Password == "" {
|
||
return 0, fmt.Errorf("请填写登录账号和密码")
|
||
}
|
||
|
||
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("创建HTTP客户端失败: %w", err)
|
||
}
|
||
|
||
// 第零步(可选):获取CSRF token
|
||
var csrfToken string
|
||
var csrfCookies []*http.Cookie
|
||
if config.CsrfURL != "" {
|
||
csrfReq, err := http.NewRequest("GET", config.CsrfURL, nil)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("创建CSRF请求失败: %w", err)
|
||
}
|
||
csrfResp, err := client.Do(csrfReq)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("获取CSRF页面失败: %w", err)
|
||
}
|
||
csrfRespBody, err := io.ReadAll(csrfResp.Body)
|
||
csrfResp.Body.Close()
|
||
if err != nil {
|
||
return 0, fmt.Errorf("读取CSRF页面失败: %w", err)
|
||
}
|
||
csrfCookies = csrfResp.Cookies()
|
||
|
||
if config.CsrfRegex != "" {
|
||
re, err := regexp.Compile(config.CsrfRegex)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("CSRF正则表达式无效: %w", err)
|
||
}
|
||
matches := re.FindSubmatch(csrfRespBody)
|
||
if len(matches) < 2 {
|
||
return 0, fmt.Errorf("未能从页面中提取CSRF token")
|
||
}
|
||
csrfToken = string(matches[1])
|
||
}
|
||
}
|
||
|
||
// 第一步:登录获取token/cookie
|
||
loginBody := strings.ReplaceAll(config.LoginBody, "{{username}}", config.Username)
|
||
loginBody = strings.ReplaceAll(loginBody, "{{password}}", config.Password)
|
||
loginBody = strings.ReplaceAll(loginBody, "{{csrf_token}}", csrfToken)
|
||
|
||
loginMethod := config.LoginMethod
|
||
if loginMethod == "" {
|
||
loginMethod = "POST"
|
||
}
|
||
|
||
contentType := config.LoginContentType
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
|
||
var loginReq *http.Request
|
||
if loginMethod == "POST" && loginBody != "" {
|
||
loginReq, err = http.NewRequest(loginMethod, config.LoginURL, strings.NewReader(loginBody))
|
||
if err != nil {
|
||
return 0, fmt.Errorf("创建登录请求失败: %w", err)
|
||
}
|
||
loginReq.Header.Set("Content-Type", contentType)
|
||
} else {
|
||
loginReq, err = http.NewRequest(loginMethod, config.LoginURL, nil)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("创建登录请求失败: %w", err)
|
||
}
|
||
}
|
||
|
||
// 携带CSRF页面的cookie
|
||
for _, cookie := range csrfCookies {
|
||
loginReq.AddCookie(cookie)
|
||
}
|
||
|
||
loginResp, err := client.Do(loginReq)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("登录请求失败: %w", err)
|
||
}
|
||
defer loginResp.Body.Close()
|
||
|
||
loginRespBody, err := io.ReadAll(loginResp.Body)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("读取登录响应失败: %w", err)
|
||
}
|
||
|
||
if loginResp.StatusCode != http.StatusOK {
|
||
if loginResp.StatusCode == http.StatusNotFound {
|
||
return 0, fmt.Errorf("登录失败 (404): 上游未转发API路径,请确认上游反向代理是否转发了 /api/ 路径,或手动填写完整的登录URL")
|
||
}
|
||
return 0, fmt.Errorf("登录失败 (status: %d): %s", loginResp.StatusCode, string(loginRespBody))
|
||
}
|
||
|
||
// 提取token(如果有配置)
|
||
var token string
|
||
if config.TokenPath != "" {
|
||
token = extractJSONValue(loginRespBody, config.TokenPath)
|
||
}
|
||
|
||
// 收集所有cookie(CSRF页面 + 登录响应)
|
||
var allCookies []*http.Cookie
|
||
allCookies = append(allCookies, csrfCookies...)
|
||
allCookies = append(allCookies, loginResp.Cookies()...)
|
||
|
||
// 第二步:查询余额
|
||
return queryBalance(channel, client, config, token, allCookies)
|
||
}
|
||
|
||
// queryBalance 查询余额并解析结果
|
||
func queryBalance(channel *model.Channel, client *http.Client, config *dto.BalanceConfig, token string, cookies []*http.Cookie) (float64, error) {
|
||
balanceReq, err := http.NewRequest("GET", config.BalanceURL, nil)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("创建余额查询请求失败: %w", err)
|
||
}
|
||
|
||
// 设置认证方式
|
||
if token != "" && config.TokenHeader != "" {
|
||
prefix := config.TokenPrefix
|
||
if prefix != "" {
|
||
prefix += " "
|
||
}
|
||
balanceReq.Header.Set(config.TokenHeader, prefix+token)
|
||
}
|
||
|
||
// 传递cookie
|
||
for _, cookie := range cookies {
|
||
balanceReq.AddCookie(cookie)
|
||
}
|
||
|
||
balanceResp, err := client.Do(balanceReq)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("余额查询请求失败: %w", err)
|
||
}
|
||
defer balanceResp.Body.Close()
|
||
|
||
balanceRespBody, err := io.ReadAll(balanceResp.Body)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("读取余额响应失败: %w", err)
|
||
}
|
||
|
||
if balanceResp.StatusCode != http.StatusOK {
|
||
return 0, fmt.Errorf("余额查询失败 (status: %d): %s", balanceResp.StatusCode, string(balanceRespBody))
|
||
}
|
||
|
||
// 提取余额值
|
||
if config.BalancePath == "" {
|
||
return 0, fmt.Errorf("未配置余额提取路径")
|
||
}
|
||
|
||
balanceStr := extractJSONValue(balanceRespBody, config.BalancePath)
|
||
if balanceStr == "" {
|
||
return 0, fmt.Errorf("未能从响应中提取余额值 (path: %s)", config.BalancePath)
|
||
}
|
||
|
||
balanceValue, err := strconv.ParseFloat(balanceStr, 64)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("余额值解析失败: %s (raw: %s)", err.Error(), balanceStr)
|
||
}
|
||
|
||
// 转换余额单位
|
||
var balanceUSD float64
|
||
switch config.BalanceUnit {
|
||
case "usd":
|
||
balanceUSD = balanceValue
|
||
case "cny":
|
||
balanceUSD = balanceValue / operation_setting.USDExchangeRate
|
||
default:
|
||
// quota: 内部额度单位,转换为USD
|
||
balanceUSD = balanceValue / common.QuotaPerUnit
|
||
}
|
||
|
||
channel.UpdateBalance(balanceUSD)
|
||
return balanceUSD, nil
|
||
}
|
||
|
||
// applyBalanceTemplate 应用预置模板的默认值
|
||
func applyBalanceTemplate(baseURL string, config *dto.BalanceConfig) {
|
||
// 去掉 base URL 中常见的 API 路径前缀(如 /v1),因为登录和余额查询走的是管理 API 路径
|
||
cleanBaseURL := strings.TrimRight(baseURL, "/")
|
||
cleanBaseURL = regexp.MustCompile(`(/v\d+)+/?$`).ReplaceAllString(cleanBaseURL, "")
|
||
|
||
switch config.Template {
|
||
case dto.BalanceTemplateNewAPI:
|
||
if config.LoginURL == "" {
|
||
config.LoginURL = fmt.Sprintf("%s/api/user/login", cleanBaseURL)
|
||
}
|
||
if config.LoginBody == "" {
|
||
config.LoginBody = `{"username":"{{username}}","password":"{{password}}"}`
|
||
}
|
||
if config.BalanceURL == "" {
|
||
config.BalanceURL = fmt.Sprintf("%s/api/user/self", cleanBaseURL)
|
||
}
|
||
if config.BalancePath == "" {
|
||
config.BalancePath = "data.quota"
|
||
}
|
||
if config.BalanceUnit == "" {
|
||
config.BalanceUnit = "quota"
|
||
}
|
||
// new-api使用cookie认证,不需要token
|
||
case dto.BalanceTemplateSub2API:
|
||
if config.LoginURL == "" {
|
||
config.LoginURL = fmt.Sprintf("%s/api/v1/auth/login", cleanBaseURL)
|
||
}
|
||
if config.LoginBody == "" {
|
||
config.LoginBody = `{"email":"{{username}}","password":"{{password}}"}`
|
||
}
|
||
if config.TokenPath == "" {
|
||
config.TokenPath = "data.token"
|
||
}
|
||
if config.TokenHeader == "" {
|
||
config.TokenHeader = "Authorization"
|
||
}
|
||
if config.TokenPrefix == "" {
|
||
config.TokenPrefix = "Bearer"
|
||
}
|
||
if config.BalanceURL == "" {
|
||
config.BalanceURL = fmt.Sprintf("%s/api/v1/user/profile", cleanBaseURL)
|
||
}
|
||
if config.BalancePath == "" {
|
||
config.BalancePath = "data.balance"
|
||
}
|
||
if config.BalanceUnit == "" {
|
||
config.BalanceUnit = "usd"
|
||
}
|
||
case dto.BalanceTemplateAuthGateway:
|
||
if config.CsrfURL == "" {
|
||
config.CsrfURL = fmt.Sprintf("%s/login", cleanBaseURL)
|
||
}
|
||
if config.CsrfRegex == "" {
|
||
config.CsrfRegex = `name="csrf_token"\s+value="([^"]+)"`
|
||
}
|
||
if config.LoginURL == "" {
|
||
config.LoginURL = fmt.Sprintf("%s/login", cleanBaseURL)
|
||
}
|
||
if config.LoginBody == "" {
|
||
config.LoginBody = `csrf_token={{csrf_token}}&next=%2F&username={{username}}&password={{password}}`
|
||
}
|
||
if config.LoginContentType == "" {
|
||
config.LoginContentType = "application/x-www-form-urlencoded"
|
||
}
|
||
if config.BalanceURL == "" {
|
||
config.BalanceURL = fmt.Sprintf("%s/console/events/state", cleanBaseURL)
|
||
}
|
||
if config.BalancePath == "" {
|
||
config.BalancePath = "balance"
|
||
}
|
||
if config.BalanceUnit == "" {
|
||
config.BalanceUnit = "usd"
|
||
}
|
||
// auth_gateway使用cookie认证,不需要token
|
||
}
|
||
}
|
||
|
||
// extractJSONValue 从JSON响应中按路径提取值,路径格式如 "data.quota"
|
||
func extractJSONValue(data []byte, path string) string {
|
||
parts := strings.Split(path, ".")
|
||
var current any
|
||
if err := json.Unmarshal(data, ¤t); err != nil {
|
||
return ""
|
||
}
|
||
for _, part := range parts {
|
||
m, ok := current.(map[string]any)
|
||
if !ok {
|
||
return ""
|
||
}
|
||
current, ok = m[part]
|
||
if !ok {
|
||
return ""
|
||
}
|
||
}
|
||
switch v := current.(type) {
|
||
case string:
|
||
return v
|
||
case float64:
|
||
return strconv.FormatFloat(v, 'f', -1, 64)
|
||
case int:
|
||
return strconv.Itoa(v)
|
||
case int64:
|
||
return strconv.FormatInt(v, 10)
|
||
case bool:
|
||
return strconv.FormatBool(v)
|
||
default:
|
||
return fmt.Sprintf("%v", v)
|
||
}
|
||
}
|
||
|
||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||
// 优先使用自定义余额查询
|
||
if balanceConfig := channel.GetOtherSettings().BalanceConfig; balanceConfig != nil && balanceConfig.Template != dto.BalanceTemplateNone {
|
||
return updateChannelBalanceCustom(channel, balanceConfig)
|
||
}
|
||
|
||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||
if channel.GetBaseURL() == "" {
|
||
channel.BaseURL = &baseURL
|
||
}
|
||
switch channel.Type {
|
||
case constant.ChannelTypeOpenAI:
|
||
if channel.GetBaseURL() != "" {
|
||
baseURL = channel.GetBaseURL()
|
||
}
|
||
case constant.ChannelTypeAzure:
|
||
return 0, errors.New("尚未实现")
|
||
case constant.ChannelTypeCustom:
|
||
baseURL = channel.GetBaseURL()
|
||
//case common.ChannelTypeOpenAISB:
|
||
// return updateChannelOpenAISBBalance(channel)
|
||
case constant.ChannelTypeAIProxy:
|
||
return updateChannelAIProxyBalance(channel)
|
||
case constant.ChannelTypeAPI2GPT:
|
||
return updateChannelAPI2GPTBalance(channel)
|
||
case constant.ChannelTypeAIGC2D:
|
||
return updateChannelAIGC2DBalance(channel)
|
||
case constant.ChannelTypeSiliconFlow:
|
||
return updateChannelSiliconFlowBalance(channel)
|
||
case constant.ChannelTypeDeepSeek:
|
||
return updateChannelDeepSeekBalance(channel)
|
||
case constant.ChannelTypeOpenRouter:
|
||
return updateChannelOpenRouterBalance(channel)
|
||
case constant.ChannelTypeMoonshot:
|
||
return updateChannelMoonshotBalance(channel)
|
||
default:
|
||
return 0, errors.New("尚未实现")
|
||
}
|
||
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
||
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
subscription := OpenAISubscriptionResponse{}
|
||
err = json.Unmarshal(body, &subscription)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
now := time.Now()
|
||
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
|
||
endDate := now.Format("2006-01-02")
|
||
if !subscription.HasPaymentMethod {
|
||
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
||
}
|
||
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
|
||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
usage := OpenAIUsageResponse{}
|
||
err = json.Unmarshal(body, &usage)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||
channel.UpdateBalance(balance)
|
||
return balance, nil
|
||
}
|
||
|
||
func UpdateChannelBalance(c *gin.Context) {
|
||
id, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
channel, err := model.CacheGetChannel(id)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "多密钥渠道不支持余额查询",
|
||
})
|
||
return
|
||
}
|
||
balance, err := updateChannelBalance(channel)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"balance": balance,
|
||
})
|
||
}
|
||
|
||
func updateAllChannelsBalance() error {
|
||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for _, channel := range channels {
|
||
if channel.Status != common.ChannelStatusEnabled {
|
||
continue
|
||
}
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
continue // skip multi-key channels
|
||
}
|
||
// TODO: support Azure
|
||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||
// continue
|
||
//}
|
||
balance, err := updateChannelBalance(channel)
|
||
if err != nil {
|
||
continue
|
||
} else {
|
||
// err is nil & balance <= 0 means quota is used up
|
||
if balance <= 0 {
|
||
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
|
||
}
|
||
}
|
||
time.Sleep(common.RequestInterval)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func UpdateAllChannelsBalance(c *gin.Context) {
|
||
// TODO: make it async
|
||
err := updateAllChannelsBalance()
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func AutomaticallyUpdateChannels(frequency int) {
|
||
for {
|
||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||
common.SysLog("updating all channels")
|
||
_ = updateAllChannelsBalance()
|
||
common.SysLog("channels update done")
|
||
}
|
||
}
|