Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bee339d279 | |||
| 4e93148d9e | |||
| e36d191c2e | |||
| 34afe9b426 | |||
| d604f48c06 | |||
| 86cfb3920e | |||
| 097a50ebdc | |||
| f424f906d8 | |||
| cc4ad6c39e | |||
| 4c21c4c43b | |||
| db89b57e1c | |||
| 62d4b63fc3 | |||
| 355307223a | |||
| f2f3410dcf | |||
| 02aacb38a2 | |||
| a7c38ec851 | |||
| 095e1920f1 | |||
| 8993386743 | |||
| 435d7ae0dd | |||
| 3a2138ba61 | |||
| e3d64cb76d |
@@ -43,3 +43,19 @@ func GetJsonType(data json.RawMessage) string {
|
||||
return "number"
|
||||
}
|
||||
}
|
||||
|
||||
// JsonRawMessageToString returns JSON strings as their decoded value and other JSON values as raw text.
|
||||
func JsonRawMessageToString(data json.RawMessage) string {
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) {
|
||||
return ""
|
||||
}
|
||||
if trimmed[0] != '"' {
|
||||
return string(trimmed)
|
||||
}
|
||||
var value string
|
||||
if err := Unmarshal(trimmed, &value); err != nil {
|
||||
return string(trimmed)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJsonRawMessageToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data json.RawMessage
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "object",
|
||||
data: json.RawMessage(`{"city":"Paris","days":0,"strict":false}`),
|
||||
want: `{"city":"Paris","days":0,"strict":false}`,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
data: json.RawMessage(`"{\"city\":\"Paris\",\"days\":0,\"strict\":false}"`),
|
||||
want: `{"city":"Paris","days":0,"strict":false}`,
|
||||
},
|
||||
{
|
||||
name: "null",
|
||||
data: json.RawMessage(`null`),
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
data: nil,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, JsonRawMessageToString(tt.data))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,26 @@ const (
|
||||
channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10
|
||||
)
|
||||
|
||||
var channelUpstreamModelUpdateSelectFields = []string{
|
||||
"id",
|
||||
"name",
|
||||
"type",
|
||||
"key",
|
||||
"status",
|
||||
"base_url",
|
||||
"models",
|
||||
"model_mapping",
|
||||
"settings",
|
||||
"setting",
|
||||
"other",
|
||||
"group",
|
||||
"priority",
|
||||
"weight",
|
||||
"tag",
|
||||
"channel_info",
|
||||
"header_override",
|
||||
}
|
||||
|
||||
var (
|
||||
channelUpstreamModelUpdateTaskOnce sync.Once
|
||||
channelUpstreamModelUpdateTaskRunning atomic.Bool
|
||||
@@ -521,7 +541,7 @@ func runChannelUpstreamModelUpdateTaskOnce() {
|
||||
for {
|
||||
var channels []*model.Channel
|
||||
query := model.DB.
|
||||
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
|
||||
Select(channelUpstreamModelUpdateSelectFields).
|
||||
Where("status = ?", common.ChannelStatusEnabled).
|
||||
Order("id asc").
|
||||
Limit(channelUpstreamModelUpdateTaskBatchSize)
|
||||
@@ -814,7 +834,7 @@ func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings)
|
||||
func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) {
|
||||
var channels []*model.Channel
|
||||
query := model.DB.
|
||||
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
|
||||
Select(channelUpstreamModelUpdateSelectFields).
|
||||
Where("status = ?", common.ChannelStatusEnabled).
|
||||
Order("id asc").
|
||||
Limit(batchSize)
|
||||
|
||||
@@ -81,6 +81,10 @@ func TestCollectPendingApplyUpstreamModelChanges(t *testing.T) {
|
||||
require.Equal(t, []string{"old-model"}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestChannelUpstreamModelUpdateSelectFieldsIncludeModelMapping(t *testing.T) {
|
||||
require.Contains(t, channelUpstreamModelUpdateSelectFields, "model_mapping")
|
||||
}
|
||||
|
||||
func TestNormalizeChannelModelMapping(t *testing.T) {
|
||||
modelMapping := `{
|
||||
" alias-model ": " upstream-model ",
|
||||
|
||||
+3
-5
@@ -15,9 +15,9 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||||
"github.com/QuantumNous/new-api/relay/channel/moonshot"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
@@ -134,8 +134,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if !acceptUnsetRatioModel {
|
||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
|
||||
if !exist {
|
||||
if !helper.HasModelBillingConfig(allowModel) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -182,8 +181,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if !acceptUnsetRatioModel {
|
||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
|
||||
if !exist {
|
||||
if !helper.HasModelBillingConfig(modelName) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/config"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type listModelsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data []dto.OpenAIModels `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
initModelListColumnNames(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
common.UsingSQLite = true
|
||||
common.UsingMySQL = false
|
||||
common.UsingPostgreSQL = false
|
||||
common.RedisEnabled = false
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlDB, err := db.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func initModelListColumnNames(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalIsMasterNode := common.IsMasterNode
|
||||
originalSQLitePath := common.SQLitePath
|
||||
originalUsingSQLite := common.UsingSQLite
|
||||
originalUsingMySQL := common.UsingMySQL
|
||||
originalUsingPostgreSQL := common.UsingPostgreSQL
|
||||
originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
|
||||
defer func() {
|
||||
common.IsMasterNode = originalIsMasterNode
|
||||
common.SQLitePath = originalSQLitePath
|
||||
common.UsingSQLite = originalUsingSQLite
|
||||
common.UsingMySQL = originalUsingMySQL
|
||||
common.UsingPostgreSQL = originalUsingPostgreSQL
|
||||
if hadSQLDSN {
|
||||
require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
|
||||
} else {
|
||||
require.NoError(t, os.Unsetenv("SQL_DSN"))
|
||||
}
|
||||
}()
|
||||
|
||||
common.IsMasterNode = false
|
||||
common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
common.UsingSQLite = false
|
||||
common.UsingMySQL = false
|
||||
common.UsingPostgreSQL = false
|
||||
require.NoError(t, os.Setenv("SQL_DSN", "local"))
|
||||
|
||||
require.NoError(t, model.InitDB())
|
||||
if model.DB != nil {
|
||||
sqlDB, err := model.DB.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
saved := map[string]string{}
|
||||
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
|
||||
if strings.HasPrefix(key, "billing_setting.") {
|
||||
saved[key] = value
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
|
||||
model.InvalidatePricingCache()
|
||||
})
|
||||
|
||||
modeBytes, err := common.Marshal(modes)
|
||||
require.NoError(t, err)
|
||||
exprBytes, err := common.Marshal(exprs)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
|
||||
"billing_setting.billing_mode": string(modeBytes),
|
||||
"billing_setting.billing_expr": string(exprBytes),
|
||||
}))
|
||||
model.InvalidatePricingCache()
|
||||
}
|
||||
|
||||
func withSelfUseModeDisabled(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
original := operation_setting.SelfUseModeEnabled
|
||||
operation_setting.SelfUseModeEnabled = false
|
||||
t.Cleanup(func() {
|
||||
operation_setting.SelfUseModeEnabled = original
|
||||
})
|
||||
}
|
||||
|
||||
func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
|
||||
t.Helper()
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
var payload listModelsResponse
|
||||
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
require.True(t, payload.Success)
|
||||
require.Equal(t, "list", payload.Object)
|
||||
|
||||
ids := make(map[string]struct{}, len(payload.Data))
|
||||
for _, item := range payload.Data {
|
||||
ids[item.Id] = struct{}{}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
|
||||
byName := make(map[string]model.Pricing, len(pricings))
|
||||
for _, pricing := range pricings {
|
||||
byName[pricing.ModelName] = pricing
|
||||
}
|
||||
return byName
|
||||
}
|
||||
|
||||
func TestListModelsIncludesTieredBillingModel(t *testing.T) {
|
||||
withSelfUseModeDisabled(t)
|
||||
withTieredBillingConfig(t, map[string]string{
|
||||
"zz-tiered-visible-model": "tiered_expr",
|
||||
"zz-tiered-empty-expr-model": "tiered_expr",
|
||||
"zz-tiered-missing-expr-model": "tiered_expr",
|
||||
}, map[string]string{
|
||||
"zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||
"zz-tiered-empty-expr-model": " ",
|
||||
})
|
||||
|
||||
db := setupModelListControllerTestDB(t)
|
||||
require.NoError(t, db.Create(&model.User{
|
||||
Id: 1001,
|
||||
Username: "model-list-user",
|
||||
Password: "password",
|
||||
Group: "default",
|
||||
Status: common.UserStatusEnabled,
|
||||
}).Error)
|
||||
require.NoError(t, db.Create(&[]model.Ability{
|
||||
{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
|
||||
}).Error)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
ctx.Set("id", 1001)
|
||||
|
||||
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||
|
||||
ids := decodeListModelsResponse(t, recorder)
|
||||
require.Contains(t, ids, "zz-tiered-visible-model")
|
||||
require.NotContains(t, ids, "zz-tiered-empty-expr-model")
|
||||
require.NotContains(t, ids, "zz-tiered-missing-expr-model")
|
||||
require.NotContains(t, ids, "zz-unpriced-model")
|
||||
|
||||
pricingByName := pricingByModelName(model.GetPricing())
|
||||
visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
|
||||
require.NotEmpty(t, visiblePricing.BillingExpr)
|
||||
|
||||
emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
|
||||
require.True(t, ok)
|
||||
require.Empty(t, emptyExprPricing.BillingMode)
|
||||
require.Empty(t, emptyExprPricing.BillingExpr)
|
||||
|
||||
missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
|
||||
require.True(t, ok)
|
||||
require.Empty(t, missingExprPricing.BillingMode)
|
||||
require.Empty(t, missingExprPricing.BillingExpr)
|
||||
}
|
||||
|
||||
func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
|
||||
withSelfUseModeDisabled(t)
|
||||
withTieredBillingConfig(t, map[string]string{
|
||||
"zz-token-tiered-visible-model": "tiered_expr",
|
||||
"zz-token-tiered-empty-expr-model": "tiered_expr",
|
||||
"zz-token-tiered-missing-expr-model": "tiered_expr",
|
||||
}, map[string]string{
|
||||
"zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||
"zz-token-tiered-empty-expr-model": "",
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
|
||||
"zz-token-tiered-visible-model": true,
|
||||
"zz-token-tiered-empty-expr-model": true,
|
||||
"zz-token-tiered-missing-expr-model": true,
|
||||
"zz-token-unpriced-model": true,
|
||||
})
|
||||
|
||||
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||
|
||||
ids := decodeListModelsResponse(t, recorder)
|
||||
require.Contains(t, ids, "zz-token-tiered-visible-model")
|
||||
require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
|
||||
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
|
||||
require.NotContains(t, ids, "zz-token-unpriced-model")
|
||||
}
|
||||
+161
-46
@@ -21,14 +21,16 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/billing_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
defaultEndpoint = "/api/pricing"
|
||||
maxConcurrentFetches = 8
|
||||
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||
floatEpsilon = 1e-9
|
||||
@@ -59,7 +61,29 @@ func valuesEqual(a, b interface{}) bool {
|
||||
return a == b
|
||||
}
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
var pricingSyncFields = []string{
|
||||
"model_ratio",
|
||||
"completion_ratio",
|
||||
"cache_ratio",
|
||||
"create_cache_ratio",
|
||||
"image_ratio",
|
||||
"audio_ratio",
|
||||
"audio_completion_ratio",
|
||||
"model_price",
|
||||
billing_setting.BillingModeField,
|
||||
billing_setting.BillingExprField,
|
||||
}
|
||||
|
||||
var numericPricingSyncFields = map[string]bool{
|
||||
"model_ratio": true,
|
||||
"completion_ratio": true,
|
||||
"cache_ratio": true,
|
||||
"create_cache_ratio": true,
|
||||
"image_ratio": true,
|
||||
"audio_ratio": true,
|
||||
"audio_completion_ratio": true,
|
||||
"model_price": true,
|
||||
}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
@@ -67,6 +91,54 @@ type upstreamResult struct {
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func valueMap(value any) map[string]any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
return typed
|
||||
case map[string]float64:
|
||||
return lo.MapValues(typed, func(value float64, _ string) any { return value })
|
||||
case map[string]string:
|
||||
return lo.MapValues(typed, func(value string, _ string) any { return value })
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat64(value any) (float64, bool) {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed, true
|
||||
case float32:
|
||||
return float64(typed), true
|
||||
case int:
|
||||
return float64(typed), true
|
||||
case int64:
|
||||
return float64(typed), true
|
||||
case json.Number:
|
||||
parsed, err := typed.Float64()
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSyncValue(field string, value any) any {
|
||||
if numericPricingSyncFields[field] {
|
||||
if parsed, ok := asFloat64(value); ok {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func getLocalPricingSyncData() map[string]any {
|
||||
data := billing_setting.GetPricingSyncData(map[string]any(ratio_setting.GetExposedData()))
|
||||
data["image_ratio"] = ratio_setting.GetImageRatioCopy()
|
||||
data["audio_ratio"] = ratio_setting.GetAudioRatioCopy()
|
||||
data["audio_completion_ratio"] = ratio_setting.GetAudioCompletionRatioCopy()
|
||||
return data
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -293,7 +365,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
if err := common.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
for _, rt := range pricingSyncFields {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
@@ -307,11 +379,18 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
CacheRatio *float64 `json:"cache_ratio"`
|
||||
CreateCacheRatio *float64 `json:"create_cache_ratio"`
|
||||
ImageRatio *float64 `json:"image_ratio"`
|
||||
AudioRatio *float64 `json:"audio_ratio"`
|
||||
AudioCompletionRatio *float64 `json:"audio_completion_ratio"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
BillingExpr string `json:"billing_expr"`
|
||||
}
|
||||
if err := common.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
@@ -321,9 +400,23 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
cacheRatioMap := make(map[string]float64)
|
||||
createCacheRatioMap := make(map[string]float64)
|
||||
imageRatioMap := make(map[string]float64)
|
||||
audioRatioMap := make(map[string]float64)
|
||||
audioCompletionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
billingModeMap := make(map[string]string)
|
||||
billingExprMap := make(map[string]string)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.ModelName == "" {
|
||||
continue
|
||||
}
|
||||
if item.BillingMode == billing_setting.BillingModeTieredExpr && strings.TrimSpace(item.BillingExpr) != "" {
|
||||
billingModeMap[item.ModelName] = billing_setting.BillingModeTieredExpr
|
||||
billingExprMap[item.ModelName] = item.BillingExpr
|
||||
}
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
@@ -331,6 +424,21 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
if item.CacheRatio != nil {
|
||||
cacheRatioMap[item.ModelName] = *item.CacheRatio
|
||||
}
|
||||
if item.CreateCacheRatio != nil {
|
||||
createCacheRatioMap[item.ModelName] = *item.CreateCacheRatio
|
||||
}
|
||||
if item.ImageRatio != nil {
|
||||
imageRatioMap[item.ModelName] = *item.ImageRatio
|
||||
}
|
||||
if item.AudioRatio != nil {
|
||||
audioRatioMap[item.ModelName] = *item.AudioRatio
|
||||
}
|
||||
if item.AudioCompletionRatio != nil {
|
||||
audioCompletionRatioMap[item.ModelName] = *item.AudioCompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
@@ -350,6 +458,21 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
if len(cacheRatioMap) > 0 {
|
||||
converted["cache_ratio"] = valueMap(cacheRatioMap)
|
||||
}
|
||||
if len(createCacheRatioMap) > 0 {
|
||||
converted["create_cache_ratio"] = valueMap(createCacheRatioMap)
|
||||
}
|
||||
if len(imageRatioMap) > 0 {
|
||||
converted["image_ratio"] = valueMap(imageRatioMap)
|
||||
}
|
||||
if len(audioRatioMap) > 0 {
|
||||
converted["audio_ratio"] = valueMap(audioRatioMap)
|
||||
}
|
||||
if len(audioCompletionRatioMap) > 0 {
|
||||
converted["audio_completion_ratio"] = valueMap(audioCompletionRatioMap)
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
@@ -358,6 +481,12 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
if len(billingModeMap) > 0 {
|
||||
converted[billing_setting.BillingModeField] = valueMap(billingModeMap)
|
||||
}
|
||||
if len(billingExprMap) > 0 {
|
||||
converted[billing_setting.BillingExprField] = valueMap(billingExprMap)
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
@@ -366,7 +495,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
localData := getLocalPricingSyncData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
@@ -412,22 +541,16 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, field := range pricingSyncFields {
|
||||
for modelName := range valueMap(localData[field]) {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
for _, field := range pricingSyncFields {
|
||||
for modelName := range valueMap(channel.data[field]) {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -438,10 +561,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
modelRatios := valueMap(channel.data["model_ratio"])
|
||||
completionRatios := valueMap(channel.data["completion_ratio"])
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
if len(modelRatios) > 0 && len(completionRatios) > 0 {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
@@ -451,12 +574,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
modelRatioFloat, modelRatioOK := asFloat64(modelRatioVal)
|
||||
completionRatioFloat, completionRatioOK := asFloat64(completionRatioVal)
|
||||
if modelRatioOK && completionRatioOK && nearlyEqual(modelRatioFloat, 37.5) && nearlyEqual(completionRatioFloat, 1.0) {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -470,14 +591,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
for _, ratioType := range pricingSyncFields {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
if val, exists := valueMap(localData[ratioType])[modelName]; exists {
|
||||
localValue = normalizeSyncValue(ratioType, val)
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
@@ -488,16 +605,14 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
if val, exists := valueMap(channel.data[ratioType])[modelName]; exists {
|
||||
upstreamValue = normalizeSyncValue(ratioType, val)
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && !valuesEqual(localValue, val) {
|
||||
hasDifference = true
|
||||
} else if valuesEqual(localValue, val) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
if localValue != nil && !valuesEqual(localValue, upstreamValue) {
|
||||
hasDifference = true
|
||||
} else if valuesEqual(localValue, upstreamValue) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
|
||||
@@ -83,13 +83,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
|
||||
|
||||
// create pending order first
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
|
||||
@@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
common.ApiErrorMsg(c, "创建订单失败")
|
||||
@@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
if err != nil {
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod)
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay)
|
||||
common.ApiErrorMsg(c, "拉起支付失败")
|
||||
return
|
||||
}
|
||||
@@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
_, _ = c.Writer.Write([]byte("fail"))
|
||||
return
|
||||
}
|
||||
@@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,13 +84,14 @@ func SubscriptionRequestStripePay(c *gin.Context) {
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
|
||||
+14
-24
@@ -123,17 +123,6 @@ type AmountRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
}
|
||||
|
||||
var nonEpayPaymentMethodsForCallback = []string{
|
||||
model.PaymentMethodStripe,
|
||||
model.PaymentMethodCreem,
|
||||
model.PaymentMethodWaffo,
|
||||
model.PaymentMethodWaffoPancake,
|
||||
}
|
||||
|
||||
func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool {
|
||||
return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod)
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||
return nil
|
||||
@@ -248,13 +237,14 @@ func RequestEpay(c *gin.Context) {
|
||||
amount = dAmount.Div(dQuotaPerUnit).IntPart()
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -379,15 +369,15 @@ func EpayNotify(c *gin.Context) {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
|
||||
return
|
||||
}
|
||||
if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
if topUp.PaymentMethod != verifyInfo.Type {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
if topUp.PaymentProvider != model.PaymentProviderEpay {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
if topUp.Status == common.TopUpStatusPending {
|
||||
if topUp.PaymentMethod != verifyInfo.Type {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
topUp.PaymentMethod = verifyInfo.Type
|
||||
}
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
|
||||
@@ -106,13 +106,14 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
|
||||
|
||||
// 先创建订单记录,使用产品配置的金额和充值额度
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -301,7 +302,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
// Try complete subscription order first
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
)
|
||||
|
||||
func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
paymentMethod string
|
||||
expectedBlocked bool
|
||||
}{
|
||||
{name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true},
|
||||
{name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true},
|
||||
{name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true},
|
||||
{name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true},
|
||||
{name: "alipay", paymentMethod: "alipay", expectedBlocked: false},
|
||||
{name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false},
|
||||
{name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked {
|
||||
t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+13
-12
@@ -101,13 +101,14 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -237,8 +238,8 @@ func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != model.PaymentMethodStripe {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp))
|
||||
if topUp.PaymentProvider != model.PaymentProviderStripe {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -270,7 +271,7 @@ func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, c
|
||||
"currency": strings.ToUpper(event.GetObjectValue("currency")),
|
||||
"event_type": string(event.Type),
|
||||
}
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
@@ -305,7 +306,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
|
||||
// Subscription order expiration
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil {
|
||||
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
@@ -313,7 +314,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
|
||||
return
|
||||
}
|
||||
|
||||
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired)
|
||||
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired)
|
||||
if errors.Is(err, model.ErrTopUpNotFound) {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
|
||||
return
|
||||
|
||||
@@ -208,13 +208,14 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
|
||||
// 创建本地订单
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: model.PaymentMethodWaffo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: model.PaymentMethodWaffo,
|
||||
PaymentProvider: model.PaymentProviderWaffo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
|
||||
@@ -379,7 +380,7 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
|
||||
// 终态失败订单标记为 failed,避免永远停在 pending
|
||||
if result.MerchantOrderID != "" {
|
||||
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil &&
|
||||
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil &&
|
||||
!errors.Is(err, model.ErrTopUpNotFound) &&
|
||||
!errors.Is(err, model.ErrTopUpStatusInvalid) {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
|
||||
|
||||
@@ -159,13 +159,14 @@ func RequestWaffoPancakePay(c *gin.Context) {
|
||||
|
||||
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
|
||||
|
||||
@@ -91,6 +91,7 @@ func Login(c *gin.Context) {
|
||||
|
||||
// setup session & cookies and then return user info
|
||||
func setupLogin(user *model.User, c *gin.Context) {
|
||||
model.UpdateUserLastLoginAt(user.Id)
|
||||
session := sessions.Default(c)
|
||||
session.Set("id", user.Id)
|
||||
session.Set("username", user.Username)
|
||||
|
||||
+15
-1
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
@@ -346,7 +347,20 @@ type ResponsesOutput struct {
|
||||
Size string `json:"size"`
|
||||
CallId string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Arguments json.RawMessage `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// ArgumentsString returns function call arguments in the string form expected by Chat Completions.
|
||||
func (r *ResponsesOutput) ArgumentsString() string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
return ResponsesArgumentsString(r.Arguments)
|
||||
}
|
||||
|
||||
// ResponsesArgumentsString returns function call arguments in the string form expected by Chat Completions.
|
||||
func ResponsesArgumentsString(arguments json.RawMessage) string {
|
||||
return common.JsonRawMessageToString(arguments)
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
|
||||
+13
-12
@@ -304,18 +304,19 @@ const (
|
||||
|
||||
// Distributor related messages
|
||||
const (
|
||||
MsgDistributorInvalidRequest = "distributor.invalid_request"
|
||||
MsgDistributorInvalidChannelId = "distributor.invalid_channel_id"
|
||||
MsgDistributorChannelDisabled = "distributor.channel_disabled"
|
||||
MsgDistributorTokenNoModelAccess = "distributor.token_no_model_access"
|
||||
MsgDistributorTokenModelForbidden = "distributor.token_model_forbidden"
|
||||
MsgDistributorModelNameRequired = "distributor.model_name_required"
|
||||
MsgDistributorInvalidPlayground = "distributor.invalid_playground_request"
|
||||
MsgDistributorGroupAccessDenied = "distributor.group_access_denied"
|
||||
MsgDistributorGetChannelFailed = "distributor.get_channel_failed"
|
||||
MsgDistributorNoAvailableChannel = "distributor.no_available_channel"
|
||||
MsgDistributorInvalidMidjourney = "distributor.invalid_midjourney_request"
|
||||
MsgDistributorInvalidParseModel = "distributor.invalid_request_parse_model"
|
||||
MsgDistributorInvalidRequest = "distributor.invalid_request"
|
||||
MsgDistributorInvalidChannelId = "distributor.invalid_channel_id"
|
||||
MsgDistributorChannelDisabled = "distributor.channel_disabled"
|
||||
MsgDistributorAffinityChannelDisabled = "distributor.affinity_channel_disabled"
|
||||
MsgDistributorTokenNoModelAccess = "distributor.token_no_model_access"
|
||||
MsgDistributorTokenModelForbidden = "distributor.token_model_forbidden"
|
||||
MsgDistributorModelNameRequired = "distributor.model_name_required"
|
||||
MsgDistributorInvalidPlayground = "distributor.invalid_playground_request"
|
||||
MsgDistributorGroupAccessDenied = "distributor.group_access_denied"
|
||||
MsgDistributorGetChannelFailed = "distributor.get_channel_failed"
|
||||
MsgDistributorNoAvailableChannel = "distributor.no_available_channel"
|
||||
MsgDistributorInvalidMidjourney = "distributor.invalid_midjourney_request"
|
||||
MsgDistributorInvalidParseModel = "distributor.invalid_request_parse_model"
|
||||
)
|
||||
|
||||
// Custom OAuth provider related messages
|
||||
|
||||
@@ -257,6 +257,7 @@ common.invalid_input: "Invalid input"
|
||||
distributor.invalid_request: "Invalid request: {{.Error}}"
|
||||
distributor.invalid_channel_id: "Invalid channel ID"
|
||||
distributor.channel_disabled: "This channel has been disabled"
|
||||
distributor.affinity_channel_disabled: "The channel selected by channel affinity has been disabled, and retry was stopped by rule. Please contact the administrator"
|
||||
distributor.token_no_model_access: "This token has no access to any models"
|
||||
distributor.token_model_forbidden: "This token has no access to model {{.Model}}"
|
||||
distributor.model_name_required: "Model name not specified, model name cannot be empty"
|
||||
|
||||
@@ -258,6 +258,7 @@ common.invalid_input: "输入不合法"
|
||||
distributor.invalid_request: "无效的请求,{{.Error}}"
|
||||
distributor.invalid_channel_id: "无效的渠道 Id"
|
||||
distributor.channel_disabled: "该渠道已被禁用"
|
||||
distributor.affinity_channel_disabled: "渠道亲和性命中的渠道已被禁用,已按规则停止重试,请联系管理员处理"
|
||||
distributor.token_no_model_access: "该令牌无权访问任何模型"
|
||||
distributor.token_model_forbidden: "该令牌无权访问模型 {{.Model}}"
|
||||
distributor.model_name_required: "未指定模型名称,模型名称不能为空"
|
||||
|
||||
@@ -258,6 +258,7 @@ common.invalid_input: "輸入不合法"
|
||||
distributor.invalid_request: "無效的請求,{{.Error}}"
|
||||
distributor.invalid_channel_id: "無效的管道 Id"
|
||||
distributor.channel_disabled: "該管道已被禁用"
|
||||
distributor.affinity_channel_disabled: "管道親和性命中的管道已被禁用,已按規則停止重試,請聯絡管理員處理"
|
||||
distributor.token_no_model_access: "該令牌無權存取任何模型"
|
||||
distributor.token_model_forbidden: "該令牌無權存取模型 {{.Model}}"
|
||||
distributor.model_name_required: "未指定模型名稱,模型名稱不能為空"
|
||||
|
||||
@@ -104,7 +104,7 @@ func Distribute() func(c *gin.Context) {
|
||||
if err == nil && preferred != nil {
|
||||
if preferred.Status != common.ChannelStatusEnabled {
|
||||
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorAffinityChannelDisabled))
|
||||
return
|
||||
}
|
||||
} else if usingGroup == "auto" {
|
||||
|
||||
@@ -578,6 +578,9 @@ func handleConfigUpdate(key, value string) bool {
|
||||
performance_setting.UpdateAndSync()
|
||||
} else if configName == "tool_price_setting" {
|
||||
operation_setting.RebuildToolPriceIndex()
|
||||
} else if configName == "billing_setting" {
|
||||
InvalidatePricingCache()
|
||||
ratio_setting.InvalidateExposedDataCache()
|
||||
}
|
||||
|
||||
return true // 已处理
|
||||
|
||||
@@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti
|
||||
return plan
|
||||
}
|
||||
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) {
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
order := &SubscriptionOrder{
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, order.Insert())
|
||||
}
|
||||
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) {
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
topUp := &TopUp{
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, topUp.Insert())
|
||||
}
|
||||
@@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 101, 0)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe)
|
||||
|
||||
err := RechargeWaffoPancake("waffo-pancake-guard")
|
||||
require.Error(t, err)
|
||||
@@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101))
|
||||
}
|
||||
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentMethod string
|
||||
expectedPaymentMethod string
|
||||
targetStatus string
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentProvider string
|
||||
expectedPaymentProvider string
|
||||
targetStatus string
|
||||
}{
|
||||
{
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentMethod: PaymentMethodCreem,
|
||||
expectedPaymentMethod: PaymentMethodStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentProvider: PaymentProviderCreem,
|
||||
expectedPaymentProvider: PaymentProviderStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
},
|
||||
{
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentMethod: PaymentMethodStripe,
|
||||
expectedPaymentMethod: PaymentMethodWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentProvider: PaymentProviderStripe,
|
||||
expectedPaymentProvider: PaymentProviderWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
truncateTables(t)
|
||||
insertUserForPaymentGuardTest(t, 150, 0)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider)
|
||||
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus)
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 202, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 301)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay")
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay")
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-guard-order")
|
||||
@@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T)
|
||||
assert.Nil(t, topUp)
|
||||
}
|
||||
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 303, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 401)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem)
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-expire-guard")
|
||||
|
||||
+10
-1
@@ -77,6 +77,15 @@ func GetPricing() []Pricing {
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func InvalidatePricingCache() {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
|
||||
pricingMap = nil
|
||||
vendorsList = nil
|
||||
lastGetPricingTime = time.Time{}
|
||||
}
|
||||
|
||||
// GetVendors 返回当前定价接口使用到的供应商信息
|
||||
func GetVendors() []PricingVendor {
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
@@ -323,7 +332,7 @@ func updatePricing() {
|
||||
pricing.AudioCompletionRatio = &audioCompletionRatio
|
||||
}
|
||||
if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" {
|
||||
if expr, ok := billing_setting.GetBillingExpr(model); ok && expr != "" {
|
||||
if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" {
|
||||
pricing.BillingMode = billingMode
|
||||
pricing.BillingExpr = expr
|
||||
}
|
||||
|
||||
+15
-9
@@ -198,11 +198,12 @@ type SubscriptionOrder struct {
|
||||
PlanId int `json:"plan_id" gorm:"index"`
|
||||
Money float64 `json:"money"`
|
||||
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
|
||||
ProviderPayload string `json:"provider_payload" gorm:"type:text"`
|
||||
}
|
||||
@@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
|
||||
}
|
||||
|
||||
// Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error {
|
||||
// expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check).
|
||||
// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update).
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status == common.TopUpStatusSuccess {
|
||||
@@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if providerPayload != "" {
|
||||
order.ProviderPayload = providerPayload
|
||||
}
|
||||
if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod {
|
||||
order.PaymentMethod = actualPaymentMethod
|
||||
}
|
||||
if err := tx.Save(&order).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
|
||||
return tx.Save(&topup).Error
|
||||
}
|
||||
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error {
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status != common.TopUpStatusPending {
|
||||
|
||||
+25
-16
@@ -12,15 +12,16 @@ import (
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -30,6 +31,14 @@ const (
|
||||
PaymentMethodWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentProviderEpay = "epay"
|
||||
PaymentProviderStripe = "stripe"
|
||||
PaymentProviderCreem = "creem"
|
||||
PaymentProviderWaffo = "waffo"
|
||||
PaymentProviderWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPaymentMethodMismatch = errors.New("payment method mismatch")
|
||||
ErrTopUpNotFound = errors.New("topup not found")
|
||||
@@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
return topUp
|
||||
}
|
||||
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error {
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
@@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil {
|
||||
return ErrTopUpNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
@@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodStripe {
|
||||
if topUp.PaymentProvider != PaymentProviderStripe {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error {
|
||||
// 计算应充值额度:
|
||||
// - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit
|
||||
// - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit
|
||||
if topUp.PaymentMethod == PaymentMethodStripe {
|
||||
if topUp.PaymentProvider == PaymentProviderStripe {
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart())
|
||||
} else {
|
||||
@@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodCreem {
|
||||
if topUp.PaymentProvider != PaymentProviderCreem {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffo {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffo {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffoPancake {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffoPancake {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,8 @@ type User struct {
|
||||
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||
StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"autoCreateTime;column:created_at"`
|
||||
LastLoginAt int64 `json:"last_login_at" gorm:"default:0;column:last_login_at"`
|
||||
}
|
||||
|
||||
func (user *User) ToBaseUser() *UserBase {
|
||||
@@ -951,6 +953,12 @@ func GetRootUser() (user *User) {
|
||||
return user
|
||||
}
|
||||
|
||||
func UpdateUserLastLoginAt(id int) {
|
||||
if err := DB.Model(&User{}).Where("id = ?", id).Update("last_login_at", common.GetTimestamp()).Error; err != nil {
|
||||
common.SysLog("failed to update user last_login_at: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||
|
||||
@@ -1000,11 +1000,82 @@ func TestImageAudioZero(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// len variable tests — tier conditions based on context length
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const lenTieredExpr = `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6)`
|
||||
|
||||
func TestLen_StandardTier(t *testing.T) {
|
||||
params := billingexpr.TokenParams{P: 80000, C: 5000, Len: 100000, CR: 20000}
|
||||
cost, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := 80000*3 + 5000*15 + 20000*0.3
|
||||
if math.Abs(cost-want) > 1e-6 {
|
||||
t.Errorf("cost = %f, want %f", cost, want)
|
||||
}
|
||||
if trace.MatchedTier != "standard" {
|
||||
t.Errorf("tier = %q, want standard", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_LongContextTier(t *testing.T) {
|
||||
// p is low (cache subtracted), but len is high (full context)
|
||||
params := billingexpr.TokenParams{P: 50000, C: 5000, Len: 300000, CR: 250000}
|
||||
cost, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := 50000*6 + 5000*22.5 + 250000*0.6
|
||||
if math.Abs(cost-want) > 1e-6 {
|
||||
t.Errorf("cost = %f, want %f", cost, want)
|
||||
}
|
||||
if trace.MatchedTier != "long_context" {
|
||||
t.Errorf("tier = %q, want long_context (len=300000 > 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_BoundaryExact(t *testing.T) {
|
||||
params := billingexpr.TokenParams{P: 100000, C: 1000, Len: 200000, CR: 100000}
|
||||
_, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if trace.MatchedTier != "standard" {
|
||||
t.Errorf("tier = %q, want standard (len=200000 <= 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_BoundaryPlusOne(t *testing.T) {
|
||||
params := billingexpr.TokenParams{P: 100000, C: 1000, Len: 200001, CR: 100001}
|
||||
_, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if trace.MatchedTier != "long_context" {
|
||||
t.Errorf("tier = %q, want long_context (len=200001 > 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_ZeroDefaultsToZero(t *testing.T) {
|
||||
// len defaults to 0 when not set
|
||||
params := billingexpr.TokenParams{P: 1000, C: 500}
|
||||
_, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if trace.MatchedTier != "standard" {
|
||||
t.Errorf("tier = %q, want standard (len=0 <= 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Benchmarks: compile vs cached execution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const benchComplexExpr = `p <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6 + img * 3 + img_o * 30 + ai * 10 + ao * 40) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12 + img * 6 + img_o * 60 + ai * 20 + ao * 80)`
|
||||
const benchComplexExpr = `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6 + img * 3 + img_o * 30 + ai * 10 + ao * 40) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12 + img * 6 + img_o * 60 + ai * 20 + ao * 80)`
|
||||
|
||||
func BenchmarkExprCompile(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -1015,7 +1086,7 @@ func BenchmarkExprCompile(b *testing.B) {
|
||||
|
||||
func BenchmarkExprRunCached(b *testing.B) {
|
||||
billingexpr.CompileFromCache(benchComplexExpr)
|
||||
params := billingexpr.TokenParams{P: 150000, C: 10000, CR: 30000, CC: 5000, Img: 2000, AI: 1000, AO: 500}
|
||||
params := billingexpr.TokenParams{P: 150000, C: 10000, Len: 188000, CR: 30000, CC: 5000, Img: 2000, AI: 1000, AO: 500}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
billingexpr.RunExpr(benchComplexExpr, params)
|
||||
|
||||
@@ -41,6 +41,7 @@ var (
|
||||
var compileEnvPrototypeV1 = map[string]interface{}{
|
||||
"p": float64(0),
|
||||
"c": float64(0),
|
||||
"len": float64(0),
|
||||
"cr": float64(0),
|
||||
"cc": float64(0),
|
||||
"cc1h": float64(0),
|
||||
|
||||
+16
-3
@@ -30,7 +30,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are
|
||||
|
||||
| 变量 | 含义 |
|
||||
|------|------|
|
||||
| `p` | 输入 token 数。**自动排除**表达式中单独计价的子类别(见下方说明) |
|
||||
| `p` | 输入 token 数(**计价用**)。**自动排除**表达式中单独计价的子类别(见下方说明) |
|
||||
| `len` | 输入上下文总长度(**条件判断用**)。不受自动排除影响,始终反映完整输入长度。非 Claude:等于原始 `prompt_tokens`;Claude:等于文本输入 + 缓存读取 + 缓存创建 |
|
||||
| `cr` | 缓存命中(读取)token 数 |
|
||||
| `cc` | 缓存创建 token 数(Claude 5分钟 TTL / 通用) |
|
||||
| `cc1h` | 缓存创建 token 数 — 1小时 TTL(Claude 专用) |
|
||||
@@ -51,6 +52,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are
|
||||
|
||||
**规则:如果表达式使用了某个子类别变量,对应的 token 就从 `p` 或 `c` 中扣除;如果没使用,那些 token 就留在 `p` 或 `c` 里按基础价格计费。**
|
||||
|
||||
> **重要:`len` 不受自动排除影响。** `len` 始终代表完整的输入上下文长度,不管表达式是否单独对缓存/图片/音频定价。因此**阶梯条件应使用 `len` 而非 `p`**,以避免缓存命中导致 `p` 降低而误判档位。
|
||||
|
||||
举例说明(假设上游返回的原始数据:prompt_tokens=1000,其中包含 200 cache read、100 image):
|
||||
|
||||
| 表达式 | `p` 的值 | 说明 |
|
||||
@@ -93,8 +96,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are
|
||||
# Simple flat pricing
|
||||
tier("base", p * 2.5 + c * 15 + cr * 0.25)
|
||||
|
||||
# Multi-tier (Claude Sonnet style)
|
||||
p <= 200000
|
||||
# Multi-tier (Claude Sonnet style) — use len for tier conditions
|
||||
len <= 200000
|
||||
? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6)
|
||||
: tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12)
|
||||
|
||||
@@ -199,6 +202,16 @@ Example: `p * 2.5 + c * 15 + cr * 0.25`
|
||||
- Expression uses `cr` → cache read tokens subtracted from `p`
|
||||
- Expression doesn't use `img` → image tokens stay in `p`, priced at $2.50
|
||||
|
||||
### `len` — Context Length Variable
|
||||
|
||||
`len` represents the total input context length, designed for **tier condition evaluation** (e.g. `len <= 200000 ? ...`). Unlike `p`, `len` is never reduced by sub-category exclusion.
|
||||
|
||||
**Computation rules:**
|
||||
- **Non-Claude (GPT/OpenAI format)**: `len = prompt_tokens` (the raw total from the upstream response)
|
||||
- **Claude format**: `len = input_tokens + cache_read_tokens + cache_creation_tokens` (since Claude's `input_tokens` is text-only, cache must be added back to reflect full context length)
|
||||
|
||||
This ensures that heavy cache usage doesn't cause the tier condition to incorrectly evaluate to a lower tier. For example, if a request has 300K total context but 250K is cached, `p` with cache subtracted would be only 50K (standard tier), while `len` correctly reports 300K (long-context tier).
|
||||
|
||||
### Quota Conversion
|
||||
|
||||
Expression coefficients are $/1M tokens. Conversion to internal quota:
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
|
||||
// RunExpr compiles (with cache) and executes an expression string.
|
||||
// The environment exposes:
|
||||
// - p, c — prompt / completion tokens
|
||||
// - p, c — prompt / completion tokens (auto-excluding separately-priced sub-categories)
|
||||
// - len — total input context length for tier conditions (never reduced by sub-category exclusion)
|
||||
// - cr, cc, cc1h — cache read / creation / creation-1h tokens
|
||||
// - tier(name, value) — trace callback that records which tier matched
|
||||
// - max, min, abs, ceil, floor — standard math helpers
|
||||
@@ -54,6 +55,7 @@ func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (flo
|
||||
env := map[string]interface{}{
|
||||
"p": params.P,
|
||||
"c": params.C,
|
||||
"len": params.Len,
|
||||
"cr": params.CR,
|
||||
"cc": params.CC,
|
||||
"cc1h": params.CC1h,
|
||||
|
||||
@@ -14,8 +14,9 @@ type RequestInput struct {
|
||||
// Fields beyond P and C are optional — when absent they default to 0,
|
||||
// which means cache-unaware expressions keep working unchanged.
|
||||
type TokenParams struct {
|
||||
P float64 // prompt tokens (text)
|
||||
C float64 // completion tokens (text)
|
||||
P float64 // prompt tokens (text) — auto-excludes sub-categories priced separately
|
||||
C float64 // completion tokens (text) — auto-excludes sub-categories priced separately
|
||||
Len float64 // total input context length for tier conditions (non-Claude: raw prompt_tokens; Claude: text + cache read + cache creation)
|
||||
CR float64 // cache read (hit) tokens
|
||||
CC float64 // cache creation tokens (5-min TTL for Claude, generic for others)
|
||||
CC1h float64 // cache creation tokens — 1-hour TTL (Claude only)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
@@ -18,12 +19,16 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
IsSyncImageModel bool
|
||||
}
|
||||
|
||||
const aliAnthropicMessagesModelsEnv = "ALI_ANTHROPIC_MESSAGES_MODELS"
|
||||
const defaultAliAnthropicMessagesModels = "qwen,deepseek-v4,kimi,glm,minimax-m"
|
||||
|
||||
/*
|
||||
var syncModels = []string{
|
||||
"z-image",
|
||||
@@ -32,8 +37,22 @@ type Adaptor struct {
|
||||
}
|
||||
*/
|
||||
func supportsAliAnthropicMessages(modelName string) bool {
|
||||
// Only models with the "qwen" designation can use the Claude-compatible interface; others require conversion.
|
||||
return strings.Contains(strings.ToLower(modelName), "qwen")
|
||||
normalizedModelName := strings.ToLower(strings.TrimSpace(modelName))
|
||||
if normalizedModelName == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return lo.SomeBy(aliAnthropicMessagesModelPatterns(), func(pattern string) bool {
|
||||
return strings.Contains(normalizedModelName, pattern)
|
||||
})
|
||||
}
|
||||
|
||||
func aliAnthropicMessagesModelPatterns() []string {
|
||||
configuredModels := common.GetEnvOrDefaultString(aliAnthropicMessagesModelsEnv, defaultAliAnthropicMessagesModels)
|
||||
return lo.FilterMap(strings.Split(configuredModels, ","), func(item string, _ int) (string, bool) {
|
||||
pattern := strings.ToLower(strings.TrimSpace(item))
|
||||
return pattern, pattern != ""
|
||||
})
|
||||
}
|
||||
|
||||
var syncModels = []string{
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -27,7 +29,18 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claudeRequest, ok := convertedRequest.(*dto.ClaudeRequest)
|
||||
if !ok {
|
||||
return convertedRequest, nil
|
||||
}
|
||||
if err := applyDeepSeekV4ClaudeThinkingSuffix(info, claudeRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claudeRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -71,9 +84,71 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if err := applyDeepSeekV4OpenAIThinkingSuffix(info, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func applyDeepSeekV4OpenAIThinkingSuffix(info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) error {
|
||||
modelName := request.Model
|
||||
if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" {
|
||||
modelName = info.UpstreamModelName
|
||||
}
|
||||
baseModel, thinkingType, effort, ok := reasoning.ParseDeepSeekV4ThinkingSuffix(modelName)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
thinking, err := common.Marshal(map[string]string{
|
||||
"type": thinkingType,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling thinking: %w", err)
|
||||
}
|
||||
request.Model = baseModel
|
||||
request.THINKING = thinking
|
||||
request.ReasoningEffort = effort
|
||||
if info != nil {
|
||||
if info.ChannelMeta != nil {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
info.ReasoningEffort = effort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyDeepSeekV4ClaudeThinkingSuffix(info *relaycommon.RelayInfo, request *dto.ClaudeRequest) error {
|
||||
modelName := request.Model
|
||||
if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" {
|
||||
modelName = info.UpstreamModelName
|
||||
}
|
||||
baseModel, thinkingType, effort, ok := reasoning.ParseDeepSeekV4ThinkingSuffix(modelName)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
request.Model = baseModel
|
||||
request.Thinking = &dto.Thinking{Type: thinkingType}
|
||||
if effort == "" {
|
||||
request.OutputConfig = nil
|
||||
} else {
|
||||
outputConfig, err := common.Marshal(map[string]string{
|
||||
"effort": effort,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling output_config: %w", err)
|
||||
}
|
||||
request.OutputConfig = outputConfig
|
||||
}
|
||||
if info != nil {
|
||||
if info.ChannelMeta != nil {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
info.ReasoningEffort = effort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package deepseek
|
||||
|
||||
var ModelList = []string{
|
||||
"deepseek-chat", "deepseek-reasoner",
|
||||
"deepseek-v4-flash", "deepseek-v4-flash-none", "deepseek-v4-flash-max",
|
||||
"deepseek-v4-pro", "deepseek-v4-pro-none", "deepseek-v4-pro-max",
|
||||
}
|
||||
|
||||
var ChannelName = "deepseek"
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/samber/lo"
|
||||
|
||||
@@ -39,21 +40,6 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别
|
||||
// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
|
||||
// minimal effort only available in gpt-5
|
||||
func parseReasoningEffortFromModelSuffix(model string) (string, string) {
|
||||
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium", "-none", "-xhigh"}
|
||||
for _, suffix := range effortSuffixes {
|
||||
if strings.HasSuffix(model, suffix) {
|
||||
effort := strings.TrimPrefix(suffix, "-")
|
||||
originModel := strings.TrimSuffix(model, suffix)
|
||||
return effort, originModel
|
||||
}
|
||||
}
|
||||
return "", model
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
@@ -342,7 +328,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName)
|
||||
effort, originModel := reasoning.ParseOpenAIReasoningEffortFromModelSuffix(info.UpstreamModelName)
|
||||
if effort != "" {
|
||||
request.ReasoningEffort = effort
|
||||
info.UpstreamModelName = originModel
|
||||
@@ -587,7 +573,7 @@ func detectImageMimeType(filename string) string {
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
|
||||
effort, originModel := reasoning.ParseOpenAIReasoningEffortFromModelSuffix(request.Model)
|
||||
if effort != "" {
|
||||
if request.Reasoning == nil {
|
||||
request.Reasoning = &dto.Reasoning{
|
||||
|
||||
@@ -408,7 +408,7 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
toolCallNameByID[callID] = name
|
||||
}
|
||||
|
||||
newArgs := streamResp.Item.Arguments
|
||||
newArgs := streamResp.Item.ArgumentsString()
|
||||
prevArgs := toolCallArgsByID[callID]
|
||||
argsDelta := ""
|
||||
if newArgs != "" {
|
||||
|
||||
@@ -77,7 +77,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if !strings.Contains(info.OriginModelName, "-nothinking") {
|
||||
// try to get no thinking model price
|
||||
noThinkingModelName := info.OriginModelName + "-nothinking"
|
||||
containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
|
||||
containPrice := helper.HasModelBillingConfig(noThinkingModelName)
|
||||
if containPrice {
|
||||
info.OriginModelName = noThinkingModelName
|
||||
info.UpstreamModelName = noThinkingModelName
|
||||
|
||||
+11
-11
@@ -2,6 +2,7 @@ package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
@@ -223,20 +224,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
func ContainPriceOrRatio(modelName string) bool {
|
||||
_, ok := ratio_setting.GetModelPrice(modelName, false)
|
||||
if ok {
|
||||
func HasModelBillingConfig(modelName string) bool {
|
||||
if _, ok := ratio_setting.GetModelPrice(modelName, false); ok {
|
||||
return true
|
||||
}
|
||||
_, ok, _ = ratio_setting.GetModelRatio(modelName)
|
||||
if ok {
|
||||
if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok {
|
||||
return true
|
||||
}
|
||||
if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr {
|
||||
_, ok = billing_setting.GetBillingExpr(modelName)
|
||||
return ok
|
||||
if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
expr, ok := billing_setting.GetBillingExpr(modelName)
|
||||
return ok && strings.TrimSpace(expr) != ""
|
||||
}
|
||||
|
||||
func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {
|
||||
@@ -256,8 +255,9 @@ func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptT
|
||||
}
|
||||
|
||||
rawCost, trace, err := billingexpr.RunExprWithRequest(exprStr, billingexpr.TokenParams{
|
||||
P: float64(promptTokens),
|
||||
C: float64(estimatedCompletionTokens),
|
||||
P: float64(promptTokens),
|
||||
C: float64(estimatedCompletionTokens),
|
||||
Len: float64(promptTokens),
|
||||
}, requestInput)
|
||||
if err != nil {
|
||||
return types.PriceData{}, fmt.Errorf("model %s tiered expr run failed: %w", info.OriginModelName, err)
|
||||
|
||||
@@ -60,7 +60,7 @@ func ResponsesResponseToChatCompletionsResponse(resp *dto.OpenAIResponsesRespons
|
||||
Type: "function",
|
||||
Function: dto.FunctionResponse{
|
||||
Name: name,
|
||||
Arguments: out.Arguments,
|
||||
Arguments: out.ArgumentsString(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
+3
-2
@@ -160,8 +160,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
|
||||
var tieredResult *billingexpr.TieredResult
|
||||
tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, billingexpr.TokenParams{
|
||||
P: float64(usage.InputTokens),
|
||||
C: float64(usage.OutputTokens),
|
||||
P: float64(usage.InputTokens),
|
||||
C: float64(usage.OutputTokens),
|
||||
Len: float64(usage.InputTokens),
|
||||
})
|
||||
if tieredOk {
|
||||
tieredResult = tieredRes
|
||||
|
||||
@@ -35,6 +35,14 @@ func BuildTieredTokenParams(usage *dto.Usage, isClaudeUsageSemantic bool, usedVa
|
||||
imgO := float64(usage.CompletionTokenDetails.ImageTokens)
|
||||
ao := float64(usage.CompletionTokenDetails.AudioTokens)
|
||||
|
||||
// len = total input context length for tier condition evaluation.
|
||||
// Non-Claude: prompt_tokens already includes everything.
|
||||
// Claude: input_tokens is text-only, so add cache read + cache creation.
|
||||
inputLen := p
|
||||
if isClaudeUsageSemantic {
|
||||
inputLen = p + cr + cc5m + cc1h
|
||||
}
|
||||
|
||||
if !isClaudeUsageSemantic {
|
||||
if usedVars["cr"] {
|
||||
p -= cr
|
||||
@@ -69,6 +77,7 @@ func BuildTieredTokenParams(usage *dto.Usage, isClaudeUsageSemantic bool, usedVa
|
||||
return billingexpr.TokenParams{
|
||||
P: p,
|
||||
C: c,
|
||||
Len: inputLen,
|
||||
CR: cr,
|
||||
CC: cc5m,
|
||||
CC1h: cc1h,
|
||||
|
||||
@@ -604,6 +604,97 @@ func TestBuildTieredTokenParams_ParityWithRatio_Image(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildTieredTokenParams: Len computation tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildTieredTokenParams_Len_GPT(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 10000,
|
||||
CompletionTokens: 2000,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 3000,
|
||||
TextTokens: 7000,
|
||||
},
|
||||
}
|
||||
expr := `tier("base", p * 2.5 + c * 15 + cr * 0.25)`
|
||||
usedVars := billingexpr.UsedVars(expr)
|
||||
params := BuildTieredTokenParams(usage, false, usedVars)
|
||||
|
||||
// Non-Claude: Len = raw PromptTokens
|
||||
if params.Len != 10000 {
|
||||
t.Fatalf("Len = %f, want 10000 (raw PromptTokens)", params.Len)
|
||||
}
|
||||
// P should be reduced by cache
|
||||
if params.P != 7000 {
|
||||
t.Fatalf("P = %f, want 7000 (PromptTokens - CachedTokens)", params.P)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTieredTokenParams_Len_Claude(t *testing.T) {
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 5000,
|
||||
CompletionTokens: 2000,
|
||||
UsageSemantic: "anthropic",
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 3000,
|
||||
TextTokens: 5000,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 1000,
|
||||
ClaudeCacheCreation1hTokens: 500,
|
||||
}
|
||||
expr := `tier("base", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6)`
|
||||
usedVars := billingexpr.UsedVars(expr)
|
||||
params := BuildTieredTokenParams(usage, true, usedVars)
|
||||
|
||||
// Claude: Len = PromptTokens + CachedTokens + CacheCreation5m + CacheCreation1h
|
||||
wantLen := float64(5000 + 3000 + 1000 + 500)
|
||||
if params.Len != wantLen {
|
||||
t.Fatalf("Len = %f, want %f (text + cache read + cache creation)", params.Len, wantLen)
|
||||
}
|
||||
// Claude: P is not reduced (isClaudeUsageSemantic = true)
|
||||
if params.P != 5000 {
|
||||
t.Fatalf("P = %f, want 5000 (no subtraction for Claude)", params.P)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTieredTokenParams_Len_TierCondition(t *testing.T) {
|
||||
// Test that len-based tier conditions work correctly when p is reduced by cache
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 300000,
|
||||
CompletionTokens: 5000,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 250000,
|
||||
TextTokens: 50000,
|
||||
},
|
||||
}
|
||||
expr := `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6)`
|
||||
usedVars := billingexpr.UsedVars(expr)
|
||||
params := BuildTieredTokenParams(usage, false, usedVars)
|
||||
|
||||
// Len = 300000 (raw prompt), P = 50000 (300000 - 250000 cache)
|
||||
if params.Len != 300000 {
|
||||
t.Fatalf("Len = %f, want 300000", params.Len)
|
||||
}
|
||||
if params.P != 50000 {
|
||||
t.Fatalf("P = %f, want 50000", params.P)
|
||||
}
|
||||
|
||||
// Run expression: len=300000 > 200000, so long_context tier
|
||||
cost, trace, err := billingexpr.RunExpr(expr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if trace.MatchedTier != "long_context" {
|
||||
t.Fatalf("tier = %s, want long_context (len=300000 but p=50000)", trace.MatchedTier)
|
||||
}
|
||||
// long_context: 50000*6 + 5000*22.5 + 250000*0.6
|
||||
wantCost := 50000.0*6 + 5000*22.5 + 250000*0.6
|
||||
if math.Abs(cost-wantCost) > 1e-6 {
|
||||
t.Fatalf("cost = %f, want %f", cost, wantCost)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stress test: 1000 concurrent goroutines, complex tiered expr vs ratio,
|
||||
// random token counts, verify correctness and measure performance
|
||||
|
||||
@@ -5,11 +5,14 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
"github.com/QuantumNous/new-api/setting/config"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
BillingModeRatio = "ratio"
|
||||
BillingModeTieredExpr = "tiered_expr"
|
||||
BillingModeField = "billing_mode"
|
||||
BillingExprField = "billing_expr"
|
||||
)
|
||||
|
||||
// BillingSetting is managed by config.GlobalConfig.Register.
|
||||
@@ -44,6 +47,25 @@ func GetBillingExpr(model string) (string, bool) {
|
||||
return expr, ok
|
||||
}
|
||||
|
||||
func GetBillingModeCopy() map[string]string {
|
||||
return lo.Assign(billingSetting.BillingMode)
|
||||
}
|
||||
|
||||
func GetBillingExprCopy() map[string]string {
|
||||
return lo.Assign(billingSetting.BillingExpr)
|
||||
}
|
||||
|
||||
func GetPricingSyncData(base map[string]any) map[string]any {
|
||||
extra := make(map[string]any, 2)
|
||||
if modes := GetBillingModeCopy(); len(modes) > 0 {
|
||||
extra[BillingModeField] = modes
|
||||
}
|
||||
if exprs := GetBillingExprCopy(); len(exprs) > 0 {
|
||||
extra[BillingExprField] = exprs
|
||||
}
|
||||
return lo.Assign(base, extra)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Smoke test (called externally for validation before save)
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -54,10 +76,10 @@ func SmokeTestExpr(exprStr string) error {
|
||||
|
||||
func smokeTestExpr(exprStr string) error {
|
||||
vectors := []billingexpr.TokenParams{
|
||||
{P: 0, C: 0},
|
||||
{P: 1000, C: 1000},
|
||||
{P: 100000, C: 100000},
|
||||
{P: 1000000, C: 1000000},
|
||||
{P: 0, C: 0, Len: 0},
|
||||
{P: 1000, C: 1000, Len: 1000},
|
||||
{P: 100000, C: 100000, Len: 100000},
|
||||
{P: 1000000, C: 1000000, Len: 1000000},
|
||||
}
|
||||
requests := []billingexpr.RequestInput{
|
||||
{},
|
||||
|
||||
@@ -252,8 +252,16 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
|
||||
continue
|
||||
}
|
||||
}
|
||||
case reflect.Map, reflect.Slice, reflect.Struct:
|
||||
// 复杂类型使用JSON反序列化
|
||||
case reflect.Map:
|
||||
// json.Unmarshal merges into existing maps (keeps old keys that are
|
||||
// absent from the new JSON). Allocate a fresh map so removed keys
|
||||
// are properly cleared.
|
||||
fresh := reflect.New(field.Type())
|
||||
if err := json.Unmarshal([]byte(strValue), fresh.Interface()); err != nil {
|
||||
continue
|
||||
}
|
||||
field.Set(fresh.Elem())
|
||||
case reflect.Slice, reflect.Struct:
|
||||
err := json.Unmarshal([]byte(strValue), field.Addr().Interface())
|
||||
if err != nil {
|
||||
continue
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testConfigWithMap struct {
|
||||
Modes map[string]string `json:"modes"`
|
||||
Exprs map[string]string `json:"exprs"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func TestUpdateConfigFromMap_MapReplacement(t *testing.T) {
|
||||
cfg := &testConfigWithMap{
|
||||
Modes: map[string]string{
|
||||
"model-a": "tiered_expr",
|
||||
"model-b": "tiered_expr",
|
||||
},
|
||||
Exprs: map[string]string{
|
||||
"model-a": "p * 5 + c * 25",
|
||||
"model-b": "p * 10 + c * 50",
|
||||
},
|
||||
Name: "billing",
|
||||
}
|
||||
|
||||
// Simulate removing model-a: new value only has model-b
|
||||
err := UpdateConfigFromMap(cfg, map[string]string{
|
||||
"modes": `{"model-b": "tiered_expr"}`,
|
||||
"exprs": `{"model-b": "p * 10 + c * 50"}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateConfigFromMap failed: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := cfg.Modes["model-a"]; ok {
|
||||
t.Errorf("Modes still contains model-a after it was removed from the update; got %v", cfg.Modes)
|
||||
}
|
||||
if _, ok := cfg.Exprs["model-a"]; ok {
|
||||
t.Errorf("Exprs still contains model-a after it was removed from the update; got %v", cfg.Exprs)
|
||||
}
|
||||
|
||||
if cfg.Modes["model-b"] != "tiered_expr" {
|
||||
t.Errorf("Modes[model-b] = %q, want %q", cfg.Modes["model-b"], "tiered_expr")
|
||||
}
|
||||
if cfg.Exprs["model-b"] != "p * 10 + c * 50" {
|
||||
t.Errorf("Exprs[model-b] = %q, want %q", cfg.Exprs["model-b"], "p * 10 + c * 50")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigFromMap_EmptyMapClearsAll(t *testing.T) {
|
||||
cfg := &testConfigWithMap{
|
||||
Modes: map[string]string{
|
||||
"model-a": "tiered_expr",
|
||||
},
|
||||
Exprs: map[string]string{
|
||||
"model-a": "p * 5 + c * 25",
|
||||
},
|
||||
}
|
||||
|
||||
err := UpdateConfigFromMap(cfg, map[string]string{
|
||||
"modes": `{}`,
|
||||
"exprs": `{}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateConfigFromMap failed: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Modes) != 0 {
|
||||
t.Errorf("Modes should be empty after updating with {}, got %v", cfg.Modes)
|
||||
}
|
||||
if len(cfg.Exprs) != 0 {
|
||||
t.Errorf("Exprs should be empty after updating with {}, got %v", cfg.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigFromMap_ScalarFieldsUnchanged(t *testing.T) {
|
||||
cfg := &testConfigWithMap{
|
||||
Modes: map[string]string{"m": "v"},
|
||||
Name: "old",
|
||||
}
|
||||
|
||||
err := UpdateConfigFromMap(cfg, map[string]string{
|
||||
"name": "new",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateConfigFromMap failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Name != "new" {
|
||||
t.Errorf("Name = %q, want %q", cfg.Name, "new")
|
||||
}
|
||||
// modes was not in configMap, should remain unchanged
|
||||
if cfg.Modes["m"] != "v" {
|
||||
t.Errorf("Modes should be unchanged, got %v", cfg.Modes)
|
||||
}
|
||||
}
|
||||
@@ -709,6 +709,18 @@ func GetCompletionRatioCopy() map[string]float64 {
|
||||
return completionRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
func GetImageRatioCopy() map[string]float64 {
|
||||
return imageRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
func GetAudioRatioCopy() map[string]float64 {
|
||||
return audioRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatioCopy() map[string]float64 {
|
||||
return audioCompletionRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
// 转换模型名,减少渠道必须配置各种带参数模型
|
||||
func FormatMatchingModelName(name string) string {
|
||||
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package ratio_setting
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetCompletionRatioInfoGPT55UsesOfficialOutputMultiplier(t *testing.T) {
|
||||
info := GetCompletionRatioInfo("gpt-5.5")
|
||||
|
||||
if info.Ratio != 6 {
|
||||
t.Fatalf("gpt-5.5 completion ratio = %v, want 6", info.Ratio)
|
||||
}
|
||||
if !info.Locked {
|
||||
t.Fatal("gpt-5.5 completion ratio should be locked to the official multiplier")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCompletionRatioGPT55DatedVariant(t *testing.T) {
|
||||
got := GetCompletionRatio("gpt-5.5-2026-04-24")
|
||||
|
||||
if got != 6 {
|
||||
t.Fatalf("gpt-5.5 dated variant completion ratio = %v, want 6", got)
|
||||
}
|
||||
}
|
||||
@@ -8,9 +8,17 @@ import (
|
||||
|
||||
var EffortSuffixes = []string{"-max", "-xhigh", "-high", "-medium", "-low", "-minimal"}
|
||||
|
||||
var OpenAIEffortSuffixes = []string{"-high", "-minimal", "-low", "-medium", "-none", "-xhigh"}
|
||||
|
||||
var DeepSeekV4EffortSuffixes = []string{"-none", "-max"}
|
||||
|
||||
// TrimEffortSuffix -> modelName level(low) exists
|
||||
func TrimEffortSuffix(modelName string) (string, string, bool) {
|
||||
suffix, found := lo.Find(EffortSuffixes, func(s string) bool {
|
||||
return TrimEffortSuffixWithSuffixes(modelName, EffortSuffixes)
|
||||
}
|
||||
|
||||
func TrimEffortSuffixWithSuffixes(modelName string, suffixes []string) (string, string, bool) {
|
||||
suffix, found := lo.Find(suffixes, func(s string) bool {
|
||||
return strings.HasSuffix(modelName, s)
|
||||
})
|
||||
if !found {
|
||||
@@ -18,3 +26,26 @@ func TrimEffortSuffix(modelName string) (string, string, bool) {
|
||||
}
|
||||
return strings.TrimSuffix(modelName, suffix), strings.TrimPrefix(suffix, "-"), true
|
||||
}
|
||||
|
||||
func ParseOpenAIReasoningEffortFromModelSuffix(modelName string) (string, string) {
|
||||
baseModel, effort, ok := TrimEffortSuffixWithSuffixes(modelName, OpenAIEffortSuffixes)
|
||||
if !ok {
|
||||
return "", modelName
|
||||
}
|
||||
return effort, baseModel
|
||||
}
|
||||
|
||||
func ParseDeepSeekV4ThinkingSuffix(modelName string) (baseModel string, thinkingType string, effort string, ok bool) {
|
||||
baseModel, suffix, ok := TrimEffortSuffixWithSuffixes(modelName, DeepSeekV4EffortSuffixes)
|
||||
if !ok || !strings.HasPrefix(baseModel, "deepseek-v4-") {
|
||||
return modelName, "", "", false
|
||||
}
|
||||
switch suffix {
|
||||
case "none":
|
||||
return baseModel, "disabled", "", true
|
||||
case "max":
|
||||
return baseModel, "enabled", "max", true
|
||||
default:
|
||||
return modelName, "", "", false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,8 +155,8 @@ const ChannelSelectorModal = forwardRef(
|
||||
onChange={handleTypeChange}
|
||||
style={{ width: 120 }}
|
||||
optionList={[
|
||||
{ label: 'ratio_config', value: 'ratio_config' },
|
||||
{ label: 'pricing', value: 'pricing' },
|
||||
{ label: 'ratio_config', value: 'ratio_config' },
|
||||
{ label: 'OpenRouter', value: 'openrouter' },
|
||||
{ label: 'custom', value: 'custom' },
|
||||
]}
|
||||
|
||||
@@ -106,7 +106,7 @@ const RatioSetting = () => {
|
||||
<Tabs.TabPane tab={t('未设置价格模型')} itemKey='unset_models'>
|
||||
<ModelRatioNotSetEditor options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('上游倍率同步')} itemKey='upstream_sync'>
|
||||
<Tabs.TabPane tab={t('上游价格同步')} itemKey='upstream_sync'>
|
||||
<UpstreamRatioSync options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('工具调用定价')} itemKey='tool_price'>
|
||||
|
||||
@@ -269,6 +269,24 @@ const EditChannelModal = (props) => {
|
||||
return [];
|
||||
}
|
||||
}, [inputs.model_mapping]);
|
||||
const redirectModelKeyList = useMemo(() => {
|
||||
const mapping = inputs.model_mapping;
|
||||
if (typeof mapping !== 'string') return [];
|
||||
const trimmed = mapping.trim();
|
||||
if (!trimmed) return [];
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) {
|
||||
return [];
|
||||
}
|
||||
const keys = Object.keys(parsed)
|
||||
.map((key) => key.trim())
|
||||
.filter((key) => key);
|
||||
return Array.from(new Set(keys));
|
||||
} catch (error) {
|
||||
return [];
|
||||
}
|
||||
}, [inputs.model_mapping]);
|
||||
const upstreamDetectedModels = useMemo(
|
||||
() =>
|
||||
Array.from(
|
||||
@@ -3842,6 +3860,7 @@ const EditChannelModal = (props) => {
|
||||
models={fetchedModels}
|
||||
selected={inputs.models}
|
||||
redirectModels={redirectModelList}
|
||||
redirectSourceModels={redirectModelKeyList}
|
||||
onConfirm={(selectedModels) => {
|
||||
handleInputChange('models', selectedModels);
|
||||
showSuccess(t('模型列表已更新'));
|
||||
|
||||
@@ -43,6 +43,7 @@ const ModelSelectModal = ({
|
||||
models = [],
|
||||
selected = [],
|
||||
redirectModels = [],
|
||||
redirectSourceModels = [],
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}) => {
|
||||
@@ -54,6 +55,14 @@ const ModelSelectModal = ({
|
||||
if (typeof model === 'object' && model.model_name) return model.model_name;
|
||||
return String(model ?? '');
|
||||
};
|
||||
const normalizeModelList = (modelList = []) =>
|
||||
Array.from(
|
||||
new Set(
|
||||
(modelList || [])
|
||||
.map((model) => getModelName(model).trim())
|
||||
.filter(Boolean),
|
||||
),
|
||||
);
|
||||
|
||||
const normalizedSelected = useMemo(
|
||||
() => (selected || []).map(getModelName),
|
||||
@@ -78,6 +87,10 @@ const ModelSelectModal = ({
|
||||
),
|
||||
[redirectModels],
|
||||
);
|
||||
const normalizedRedirectSourceSet = useMemo(
|
||||
() => new Set(normalizeModelList(redirectSourceModels)),
|
||||
[redirectSourceModels],
|
||||
);
|
||||
const normalizedSelectedSet = useMemo(() => {
|
||||
const set = new Set();
|
||||
(selected || []).forEach((model) => {
|
||||
@@ -116,6 +129,16 @@ const ModelSelectModal = ({
|
||||
const existingModels = filteredModels.filter((model) =>
|
||||
isExistingModel(model),
|
||||
);
|
||||
const fetchedModelSet = useMemo(
|
||||
() => new Set(normalizeModelList(models)),
|
||||
[models],
|
||||
);
|
||||
const removedModels = normalizeModelList(selected).filter(
|
||||
(model) =>
|
||||
!fetchedModelSet.has(model) &&
|
||||
!normalizedRedirectSourceSet.has(model) &&
|
||||
model.toLowerCase().includes(keyword.toLowerCase()),
|
||||
);
|
||||
|
||||
// 同步外部选中值
|
||||
useEffect(() => {
|
||||
@@ -127,11 +150,15 @@ const ModelSelectModal = ({
|
||||
// 当模型列表变化时,设置默认tab
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
// 默认显示新获取模型tab,如果没有新模型则显示已有模型
|
||||
const hasNewModels = newModels.length > 0;
|
||||
setActiveTab(hasNewModels ? 'new' : 'existing');
|
||||
if (newModels.length > 0) {
|
||||
setActiveTab('new');
|
||||
} else if (removedModels.length > 0) {
|
||||
setActiveTab('removed');
|
||||
} else {
|
||||
setActiveTab('existing');
|
||||
}
|
||||
}
|
||||
}, [visible, newModels.length, selected]);
|
||||
}, [visible, newModels.length, removedModels.length, selected]);
|
||||
|
||||
const handleOk = () => {
|
||||
onConfirm && onConfirm(checkedList);
|
||||
@@ -197,6 +224,14 @@ const ModelSelectModal = ({
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(removedModels.length > 0
|
||||
? [
|
||||
{
|
||||
tab: `${t('上游已删除的模型')} (${removedModels.length})`,
|
||||
itemKey: 'removed',
|
||||
},
|
||||
]
|
||||
: []),
|
||||
];
|
||||
|
||||
// 处理分类全选/取消全选
|
||||
@@ -343,9 +378,11 @@ const ModelSelectModal = ({
|
||||
showClear
|
||||
/>
|
||||
|
||||
<Spin spinning={!models || models.length === 0}>
|
||||
<Spin
|
||||
spinning={!models || (models.length === 0 && removedModels.length === 0)}
|
||||
>
|
||||
<div style={{ maxHeight: 400, overflowY: 'auto', paddingRight: 8 }}>
|
||||
{filteredModels.length === 0 ? (
|
||||
{filteredModels.length === 0 && removedModels.length === 0 ? (
|
||||
<Empty
|
||||
image={
|
||||
<IllustrationNoResult style={{ width: 150, height: 150 }} />
|
||||
@@ -369,6 +406,14 @@ const ModelSelectModal = ({
|
||||
{renderModelsByCategory(existingModelsByCategory, 'existing')}
|
||||
</div>
|
||||
)}
|
||||
{activeTab === 'removed' && removedModels.length > 0 && (
|
||||
<div>
|
||||
{renderModelsByCategory(
|
||||
categorizeModels(removedModels),
|
||||
'removed',
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</Checkbox.Group>
|
||||
)}
|
||||
</div>
|
||||
@@ -382,7 +427,11 @@ const ModelSelectModal = ({
|
||||
<div className='flex items-center justify-end gap-2'>
|
||||
{(() => {
|
||||
const currentModels =
|
||||
activeTab === 'new' ? newModels : existingModels;
|
||||
activeTab === 'new'
|
||||
? newModels
|
||||
: activeTab === 'removed'
|
||||
? removedModels
|
||||
: existingModels;
|
||||
const currentSelected = currentModels.filter((model) =>
|
||||
checkedList.includes(model),
|
||||
).length;
|
||||
|
||||
@@ -21,7 +21,7 @@ import React from 'react';
|
||||
import { Avatar, Tag, Table, Typography } from '@douyinfe/semi-ui';
|
||||
import { IconPriceTag } from '@douyinfe/semi-icons';
|
||||
import { parseTiersFromExpr, getCurrencyConfig } from '../../../../../helpers';
|
||||
import { BILLING_VARS } from '../../../../../constants';
|
||||
import { BILLING_PRICING_VARS } from '../../../../../constants';
|
||||
import {
|
||||
splitBillingExprAndRequestRules,
|
||||
tryParseRequestRuleExpr,
|
||||
@@ -113,7 +113,7 @@ export default function DynamicPricingBreakdown({ billingExpr, t }) {
|
||||
);
|
||||
}
|
||||
|
||||
const priceFields = BILLING_VARS.map((v) => [v.field, v.shortLabel]);
|
||||
const priceFields = BILLING_PRICING_VARS.map((v) => [v.field, v.shortLabel]);
|
||||
|
||||
const tierColumns = [
|
||||
{
|
||||
|
||||
@@ -29,7 +29,14 @@ import {
|
||||
Dropdown,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconMore } from '@douyinfe/semi-icons';
|
||||
import { renderGroup, renderNumber, renderQuota } from '../../../helpers';
|
||||
import {
|
||||
renderGroup,
|
||||
renderNumber,
|
||||
renderQuota,
|
||||
timestamp2string,
|
||||
} from '../../../helpers';
|
||||
|
||||
const renderTimestamp = (text) => (text ? timestamp2string(text) : '-');
|
||||
|
||||
/**
|
||||
* Render user role
|
||||
@@ -350,6 +357,16 @@ export const getUsersColumns = ({
|
||||
dataIndex: 'invite',
|
||||
render: (text, record, index) => renderInviteInfo(text, record, t),
|
||||
},
|
||||
{
|
||||
title: t('创建时间'),
|
||||
dataIndex: 'created_at',
|
||||
render: renderTimestamp,
|
||||
},
|
||||
{
|
||||
title: t('最后登录'),
|
||||
dataIndex: 'last_login_at',
|
||||
render: renderTimestamp,
|
||||
},
|
||||
{
|
||||
title: '',
|
||||
dataIndex: 'operate',
|
||||
|
||||
+12
-5
@@ -13,6 +13,7 @@
|
||||
export const BILLING_VARS = [
|
||||
{ key: 'p', field: 'inputPrice', tierField: 'input_unit_cost', label: '输入价格', shortLabel: '输入', side: 'input', isBase: true },
|
||||
{ key: 'c', field: 'outputPrice', tierField: 'output_unit_cost', label: '补全价格', shortLabel: '补全', side: 'output', isBase: true },
|
||||
{ key: 'len', field: null, tierField: null, label: '输入长度', shortLabel: '长度', side: 'condition', isConditionOnly: true },
|
||||
{ key: 'cr', field: 'cacheReadPrice', tierField: 'cache_read_unit_cost', label: '缓存读取价格', shortLabel: '缓存读', side: 'input', group: 'cache' },
|
||||
{ key: 'cc', field: 'cacheCreatePrice', tierField: 'cache_create_unit_cost', label: '缓存创建价格', shortLabel: '缓存创建', side: 'input', group: 'cache' },
|
||||
{ key: 'cc1h', field: 'cacheCreate1hPrice', tierField: 'cache_create_1h_unit_cost', label: '1h缓存创建价格', shortLabel: '1h缓存创建', side: 'input', group: 'cache' },
|
||||
@@ -24,18 +25,20 @@ export const BILLING_VARS = [
|
||||
|
||||
export const BILLING_VAR_KEYS = BILLING_VARS.map((v) => v.key);
|
||||
|
||||
export const BILLING_EXTRA_VARS = BILLING_VARS.filter((v) => !v.isBase);
|
||||
export const BILLING_PRICING_VARS = BILLING_VARS.filter((v) => !v.isConditionOnly);
|
||||
|
||||
export const BILLING_EXTRA_VARS = BILLING_VARS.filter((v) => !v.isBase && !v.isConditionOnly);
|
||||
|
||||
export const BILLING_VAR_KEY_TO_FIELD = Object.fromEntries(
|
||||
BILLING_VARS.map((v) => [v.key, v.field]),
|
||||
BILLING_PRICING_VARS.map((v) => [v.key, v.field]),
|
||||
);
|
||||
|
||||
export const BILLING_VAR_FIELD_TO_LABEL = Object.fromEntries(
|
||||
BILLING_VARS.map((v) => [v.field, v.label]),
|
||||
BILLING_PRICING_VARS.map((v) => [v.field, v.label]),
|
||||
);
|
||||
|
||||
export const BILLING_VAR_FIELD_TO_SHORT_LABEL = Object.fromEntries(
|
||||
BILLING_VARS.map((v) => [v.field, v.shortLabel]),
|
||||
BILLING_PRICING_VARS.map((v) => [v.field, v.shortLabel]),
|
||||
);
|
||||
|
||||
export const BILLING_CACHE_VAR_MAP = BILLING_EXTRA_VARS.map((v) => ({
|
||||
@@ -44,6 +47,10 @@ export const BILLING_CACHE_VAR_MAP = BILLING_EXTRA_VARS.map((v) => ({
|
||||
}));
|
||||
|
||||
export const BILLING_VAR_REGEX = new RegExp(
|
||||
`\\b(${BILLING_VAR_KEYS.join('|')})\\s*\\*\\s*([\\d.eE+-]+)`,
|
||||
`\\b(${BILLING_PRICING_VARS.map((v) => v.key).join('|')})\\s*\\*\\s*([\\d.eE+-]+)`,
|
||||
'g',
|
||||
);
|
||||
|
||||
export const BILLING_CONDITION_VARS = BILLING_VARS.filter(
|
||||
(v) => v.isBase || v.isConditionOnly,
|
||||
).map((v) => v.key);
|
||||
|
||||
Vendored
+1
-1
@@ -19,7 +19,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
|
||||
export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend!
|
||||
|
||||
export const DEFAULT_ENDPOINT = '/api/ratio_config';
|
||||
export const DEFAULT_ENDPOINT = '/api/pricing';
|
||||
|
||||
export const TABLE_COMPACT_MODES_KEY = 'table_compact_modes';
|
||||
|
||||
|
||||
Vendored
+5
-5
@@ -22,7 +22,7 @@ import { Modal, Tag, Typography, Avatar } from '@douyinfe/semi-ui';
|
||||
import { copy, showSuccess } from './utils';
|
||||
import { MOBILE_BREAKPOINT } from '../hooks/common/useIsMobile';
|
||||
import {
|
||||
BILLING_VARS,
|
||||
BILLING_PRICING_VARS,
|
||||
BILLING_VAR_KEY_TO_FIELD,
|
||||
BILLING_VAR_REGEX,
|
||||
} from '../constants';
|
||||
@@ -2246,7 +2246,7 @@ export function parseTiersFromExpr(exprStr) {
|
||||
if (!exprStr) return [];
|
||||
try {
|
||||
const { body } = stripExprVersion(exprStr);
|
||||
const condGroup = `((?:(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`;
|
||||
const condGroup = `((?:(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`;
|
||||
const tierRe = new RegExp(`(?:${condGroup}\\s*\\?\\s*)?tier\\("([^"]*)",\\s*([^)]+)\\)`, 'g');
|
||||
const tiers = [];
|
||||
let m;
|
||||
@@ -2255,7 +2255,7 @@ export function parseTiersFromExpr(exprStr) {
|
||||
const conditions = [];
|
||||
if (condStr) {
|
||||
for (const cp of condStr.split(/\s*&&\s*/)) {
|
||||
const cm = cp.trim().match(/^(p|c)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/);
|
||||
const cm = cp.trim().match(/^(p|c|len)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/);
|
||||
if (cm) conditions.push({ var: cm[1], op: cm[2], value: Number(cm[3]) });
|
||||
}
|
||||
}
|
||||
@@ -2293,7 +2293,7 @@ export function renderTieredModelPrice(opts) {
|
||||
const { symbol, rate } = getCurrencyConfig();
|
||||
const gr = groupRatio || 1;
|
||||
|
||||
const priceLines = BILLING_VARS.map((v) => [v.field, v.label]);
|
||||
const priceLines = BILLING_PRICING_VARS.map((v) => [v.field, v.label]);
|
||||
|
||||
const lines = [
|
||||
buildBillingText('命中档位:{{tier}}', { tier: matchedTier || tier.label }),
|
||||
@@ -2334,7 +2334,7 @@ export function renderTieredModelPriceSimple(opts) {
|
||||
];
|
||||
|
||||
if (tier && isPriceDisplayMode(displayMode)) {
|
||||
const priceSegments = BILLING_VARS.map((v) => [v.field, v.shortLabel]);
|
||||
const priceSegments = BILLING_PRICING_VARS.map((v) => [v.field, v.shortLabel]);
|
||||
for (const [field, label] of priceSegments) {
|
||||
if (tier[field] > 0) {
|
||||
segments.push({
|
||||
|
||||
Vendored
+2
-2
@@ -18,7 +18,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import { Toast, Pagination } from '@douyinfe/semi-ui';
|
||||
import { toastConstants, BILLING_VARS, BILLING_VAR_REGEX } from '../constants';
|
||||
import { toastConstants, BILLING_PRICING_VARS, BILLING_VAR_REGEX } from '../constants';
|
||||
import React from 'react';
|
||||
import { toast } from 'react-toastify';
|
||||
import {
|
||||
@@ -927,7 +927,7 @@ export const formatDynamicPriceSummary = (billingExpr, t, groupRatio = 1) => {
|
||||
}
|
||||
const hasCoeffs = 'p' in varCoeffs || 'c' in varCoeffs;
|
||||
|
||||
const varLabels = BILLING_VARS.map((v) => [v.key, v.label]);
|
||||
const varLabels = BILLING_PRICING_VARS.map((v) => [v.key, v.label]);
|
||||
|
||||
const hasTimeCondition = /\b(?:hour|minute|weekday|month|day)\(/.test(exprBody);
|
||||
const hasRequestCondition = /\b(?:param|header)\(/.test(exprBody);
|
||||
|
||||
@@ -29,17 +29,14 @@ import {
|
||||
Tooltip,
|
||||
Select,
|
||||
Modal,
|
||||
Spin,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import {
|
||||
RefreshCcw,
|
||||
CheckSquare,
|
||||
AlertTriangle,
|
||||
CheckCircle,
|
||||
} from 'lucide-react';
|
||||
import { RefreshCcw, CheckSquare, AlertTriangle } from 'lucide-react';
|
||||
import {
|
||||
API,
|
||||
showError,
|
||||
showInfo,
|
||||
showSuccess,
|
||||
showWarning,
|
||||
stringToColor,
|
||||
@@ -63,7 +60,7 @@ const MODELS_DEV_PRESET_NAME = 'models.dev 价格预设';
|
||||
const MODELS_DEV_PRESET_BASE_URL = 'https://models.dev';
|
||||
const MODELS_DEV_PRESET_ENDPOINT = 'https://models.dev/api.json';
|
||||
|
||||
function ConflictConfirmModal({ t, visible, items, onOk, onCancel }) {
|
||||
function ConflictConfirmModal({ t, visible, items, loading, onOk, onCancel }) {
|
||||
const isMobile = useIsMobile();
|
||||
const columns = [
|
||||
{ title: t('渠道'), dataIndex: 'channel' },
|
||||
@@ -84,7 +81,10 @@ function ConflictConfirmModal({ t, visible, items, onOk, onCancel }) {
|
||||
<Modal
|
||||
title={t('确认冲突项修改')}
|
||||
visible={visible}
|
||||
onCancel={onCancel}
|
||||
confirmLoading={loading}
|
||||
cancelButtonProps={{ disabled: loading }}
|
||||
maskClosable={!loading}
|
||||
onCancel={loading ? undefined : onCancel}
|
||||
onOk={onOk}
|
||||
size={isMobile ? 'full-width' : 'large'}
|
||||
>
|
||||
@@ -103,6 +103,7 @@ export default function UpstreamRatioSync(props) {
|
||||
const [modalVisible, setModalVisible] = useState(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [syncLoading, setSyncLoading] = useState(false);
|
||||
const [confirmLoading, setConfirmLoading] = useState(false);
|
||||
const isMobile = useIsMobile();
|
||||
|
||||
// 渠道选择相关
|
||||
@@ -251,7 +252,7 @@ export default function UpstreamRatioSync(props) {
|
||||
setHasSynced(true);
|
||||
|
||||
if (Object.keys(differences).length === 0) {
|
||||
showSuccess(t('未找到差异化倍率,无需同步'));
|
||||
showSuccess(t('未找到差异化价格,无需同步'));
|
||||
}
|
||||
} catch (e) {
|
||||
showError(t('请求后端接口失败:') + e.message);
|
||||
@@ -260,32 +261,165 @@ export default function UpstreamRatioSync(props) {
|
||||
}
|
||||
};
|
||||
|
||||
const ratioSyncFields = [
|
||||
'model_ratio',
|
||||
'completion_ratio',
|
||||
'cache_ratio',
|
||||
'create_cache_ratio',
|
||||
'image_ratio',
|
||||
'audio_ratio',
|
||||
'audio_completion_ratio',
|
||||
];
|
||||
|
||||
const numericSyncFields = new Set([...ratioSyncFields, 'model_price']);
|
||||
const syncFieldOrder = [
|
||||
...ratioSyncFields,
|
||||
'model_price',
|
||||
'billing_mode',
|
||||
'billing_expr',
|
||||
];
|
||||
|
||||
function getSyncFieldLabel(ratioType) {
|
||||
const typeMap = {
|
||||
model_ratio: t('模型倍率'),
|
||||
completion_ratio: t('补全倍率'),
|
||||
cache_ratio: t('缓存倍率'),
|
||||
create_cache_ratio: t('缓存创建倍率'),
|
||||
image_ratio: t('图片倍率'),
|
||||
audio_ratio: t('音频倍率'),
|
||||
audio_completion_ratio: t('音频补全倍率'),
|
||||
model_price: t('固定价格'),
|
||||
billing_mode: t('计费模式'),
|
||||
billing_expr: t('表达式计费'),
|
||||
};
|
||||
return typeMap[ratioType] || ratioType;
|
||||
}
|
||||
|
||||
function getOrderedRatioTypes(ratioTypes) {
|
||||
const keys = Object.keys(ratioTypes || {});
|
||||
const ordered = [
|
||||
...syncFieldOrder.filter((field) => keys.includes(field)),
|
||||
...keys.filter((field) => !syncFieldOrder.includes(field)),
|
||||
];
|
||||
return ratioTypeFilter
|
||||
? ordered.filter((field) => field === ratioTypeFilter)
|
||||
: ordered;
|
||||
}
|
||||
|
||||
function deleteResolutionField(newRes, model, ratioType) {
|
||||
if (!newRes[model]) return;
|
||||
delete newRes[model][ratioType];
|
||||
if (ratioType === 'billing_expr') {
|
||||
delete newRes[model].billing_mode;
|
||||
}
|
||||
if (ratioType === 'billing_mode') {
|
||||
delete newRes[model].billing_expr;
|
||||
}
|
||||
if (Object.keys(newRes[model]).length === 0) {
|
||||
delete newRes[model];
|
||||
}
|
||||
}
|
||||
|
||||
function getBillingCategory(ratioType) {
|
||||
return ratioType === 'model_price' ? 'price' : 'ratio';
|
||||
if (ratioType === 'model_price') return 'price';
|
||||
if (ratioType === 'billing_mode' || ratioType === 'billing_expr') {
|
||||
return 'tiered';
|
||||
}
|
||||
return 'ratio';
|
||||
}
|
||||
|
||||
function optionKeyBySyncField(ratioType) {
|
||||
const explicit = {
|
||||
billing_mode: 'billing_setting.billing_mode',
|
||||
billing_expr: 'billing_setting.billing_expr',
|
||||
};
|
||||
if (explicit[ratioType]) return explicit[ratioType];
|
||||
return ratioType
|
||||
.split('_')
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join('');
|
||||
}
|
||||
|
||||
function getUpstreamValue(model, ratioType, sourceName) {
|
||||
return differences[model]?.[ratioType]?.upstreams?.[sourceName];
|
||||
}
|
||||
|
||||
function isSelectableUpstreamValue(value) {
|
||||
return value !== null && value !== undefined && value !== 'same';
|
||||
}
|
||||
|
||||
function getPreferredSyncField(model, ratioType, sourceName) {
|
||||
const exprValue = getUpstreamValue(model, 'billing_expr', sourceName);
|
||||
if (ratioType !== 'billing_expr' && isSelectableUpstreamValue(exprValue)) {
|
||||
return 'billing_expr';
|
||||
}
|
||||
return ratioType;
|
||||
}
|
||||
|
||||
function shouldShowSyncField(model, ratioType, sourceName) {
|
||||
if (!sourceName) return true;
|
||||
return getPreferredSyncField(model, ratioType, sourceName) === ratioType;
|
||||
}
|
||||
|
||||
const selectValue = useCallback(
|
||||
(model, ratioType, value) => {
|
||||
(model, ratioType, value, sourceName) => {
|
||||
const preferredRatioType = sourceName
|
||||
? getPreferredSyncField(model, ratioType, sourceName)
|
||||
: ratioType;
|
||||
const preferredValue =
|
||||
preferredRatioType === ratioType
|
||||
? value
|
||||
: getUpstreamValue(model, preferredRatioType, sourceName);
|
||||
ratioType = preferredRatioType;
|
||||
value = preferredValue;
|
||||
|
||||
const category = getBillingCategory(ratioType);
|
||||
|
||||
setResolutions((prev) => {
|
||||
const newModelRes = { ...(prev[model] || {}) };
|
||||
|
||||
Object.keys(newModelRes).forEach((rt) => {
|
||||
if (getBillingCategory(rt) !== category) {
|
||||
if (
|
||||
category !== 'tiered' &&
|
||||
getBillingCategory(rt) !== 'tiered' &&
|
||||
getBillingCategory(rt) !== category
|
||||
) {
|
||||
delete newModelRes[rt];
|
||||
}
|
||||
});
|
||||
|
||||
newModelRes[ratioType] = value;
|
||||
|
||||
if (category === 'tiered' && sourceName) {
|
||||
const modeValue =
|
||||
differences[model]?.billing_mode?.upstreams?.[sourceName];
|
||||
const exprValue =
|
||||
differences[model]?.billing_expr?.upstreams?.[sourceName];
|
||||
if (
|
||||
modeValue !== undefined &&
|
||||
modeValue !== null &&
|
||||
modeValue !== 'same'
|
||||
) {
|
||||
newModelRes.billing_mode = modeValue;
|
||||
} else if (ratioType === 'billing_expr') {
|
||||
newModelRes.billing_mode = 'tiered_expr';
|
||||
}
|
||||
if (
|
||||
exprValue !== undefined &&
|
||||
exprValue !== null &&
|
||||
exprValue !== 'same'
|
||||
) {
|
||||
newModelRes.billing_expr = exprValue;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[model]: newModelRes,
|
||||
};
|
||||
});
|
||||
},
|
||||
[setResolutions],
|
||||
[setResolutions, differences],
|
||||
);
|
||||
|
||||
const applySync = async () => {
|
||||
@@ -293,7 +427,19 @@ export default function UpstreamRatioSync(props) {
|
||||
ModelRatio: JSON.parse(props.options.ModelRatio || '{}'),
|
||||
CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'),
|
||||
CacheRatio: JSON.parse(props.options.CacheRatio || '{}'),
|
||||
CreateCacheRatio: JSON.parse(props.options.CreateCacheRatio || '{}'),
|
||||
ImageRatio: JSON.parse(props.options.ImageRatio || '{}'),
|
||||
AudioRatio: JSON.parse(props.options.AudioRatio || '{}'),
|
||||
AudioCompletionRatio: JSON.parse(
|
||||
props.options.AudioCompletionRatio || '{}',
|
||||
),
|
||||
ModelPrice: JSON.parse(props.options.ModelPrice || '{}'),
|
||||
'billing_setting.billing_mode': JSON.parse(
|
||||
props.options['billing_setting.billing_mode'] || '{}',
|
||||
),
|
||||
'billing_setting.billing_expr': JSON.parse(
|
||||
props.options['billing_setting.billing_expr'] || '{}',
|
||||
),
|
||||
};
|
||||
|
||||
const conflicts = [];
|
||||
@@ -303,7 +449,11 @@ export default function UpstreamRatioSync(props) {
|
||||
if (
|
||||
currentRatios.ModelRatio[model] !== undefined ||
|
||||
currentRatios.CompletionRatio[model] !== undefined ||
|
||||
currentRatios.CacheRatio[model] !== undefined
|
||||
currentRatios.CacheRatio[model] !== undefined ||
|
||||
currentRatios.CreateCacheRatio[model] !== undefined ||
|
||||
currentRatios.ImageRatio[model] !== undefined ||
|
||||
currentRatios.AudioRatio[model] !== undefined ||
|
||||
currentRatios.AudioCompletionRatio[model] !== undefined
|
||||
)
|
||||
return 'ratio';
|
||||
return null;
|
||||
@@ -320,9 +470,14 @@ export default function UpstreamRatioSync(props) {
|
||||
|
||||
Object.entries(resolutions).forEach(([model, ratios]) => {
|
||||
const localCat = getLocalBillingCategory(model);
|
||||
const newCat = 'model_price' in ratios ? 'price' : 'ratio';
|
||||
const newCat =
|
||||
'model_price' in ratios
|
||||
? 'price'
|
||||
: ratioSyncFields.some((rt) => rt in ratios)
|
||||
? 'ratio'
|
||||
: 'tiered';
|
||||
|
||||
if (localCat && localCat !== newCat) {
|
||||
if (localCat && newCat !== 'tiered' && localCat !== newCat) {
|
||||
const currentDesc =
|
||||
localCat === 'price'
|
||||
? `${t('固定价格')} : ${currentRatios.ModelPrice[model]}`
|
||||
@@ -366,33 +521,50 @@ export default function UpstreamRatioSync(props) {
|
||||
ModelRatio: { ...currentRatios.ModelRatio },
|
||||
CompletionRatio: { ...currentRatios.CompletionRatio },
|
||||
CacheRatio: { ...currentRatios.CacheRatio },
|
||||
CreateCacheRatio: { ...currentRatios.CreateCacheRatio },
|
||||
ImageRatio: { ...currentRatios.ImageRatio },
|
||||
AudioRatio: { ...currentRatios.AudioRatio },
|
||||
AudioCompletionRatio: { ...currentRatios.AudioCompletionRatio },
|
||||
ModelPrice: { ...currentRatios.ModelPrice },
|
||||
'billing_setting.billing_mode': {
|
||||
...currentRatios['billing_setting.billing_mode'],
|
||||
},
|
||||
'billing_setting.billing_expr': {
|
||||
...currentRatios['billing_setting.billing_expr'],
|
||||
},
|
||||
};
|
||||
|
||||
Object.entries(resolutions).forEach(([model, ratios]) => {
|
||||
const selectedTypes = Object.keys(ratios);
|
||||
const hasPrice = selectedTypes.includes('model_price');
|
||||
const hasRatio = selectedTypes.some((rt) => rt !== 'model_price');
|
||||
const hasRatio = selectedTypes.some((rt) =>
|
||||
ratioSyncFields.includes(rt),
|
||||
);
|
||||
|
||||
if (hasPrice) {
|
||||
delete finalRatios.ModelRatio[model];
|
||||
delete finalRatios.CompletionRatio[model];
|
||||
delete finalRatios.CacheRatio[model];
|
||||
delete finalRatios.CreateCacheRatio[model];
|
||||
delete finalRatios.ImageRatio[model];
|
||||
delete finalRatios.AudioRatio[model];
|
||||
delete finalRatios.AudioCompletionRatio[model];
|
||||
}
|
||||
if (hasRatio) {
|
||||
delete finalRatios.ModelPrice[model];
|
||||
}
|
||||
|
||||
Object.entries(ratios).forEach(([ratioType, value]) => {
|
||||
const optionKey = ratioType
|
||||
.split('_')
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join('');
|
||||
finalRatios[optionKey][model] = parseFloat(value);
|
||||
const optionKey = optionKeyBySyncField(ratioType);
|
||||
finalRatios[optionKey][model] = numericSyncFields.has(ratioType)
|
||||
? parseFloat(value)
|
||||
: value;
|
||||
});
|
||||
});
|
||||
|
||||
setLoading(true);
|
||||
showInfo(t('正在同步价格,请稍候'));
|
||||
let success = false;
|
||||
try {
|
||||
const updates = Object.entries(finalRatios).map(([key, value]) =>
|
||||
API.put('/api/option/', {
|
||||
@@ -426,6 +598,7 @@ export default function UpstreamRatioSync(props) {
|
||||
});
|
||||
|
||||
setResolutions({});
|
||||
success = true;
|
||||
} else {
|
||||
showError(t('部分保存失败'));
|
||||
}
|
||||
@@ -434,6 +607,7 @@ export default function UpstreamRatioSync(props) {
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
return success;
|
||||
},
|
||||
[resolutions, props.options, props.refresh],
|
||||
);
|
||||
@@ -451,6 +625,7 @@ export default function UpstreamRatioSync(props) {
|
||||
<Button
|
||||
icon={<RefreshCcw size={14} />}
|
||||
className='w-full md:w-auto mt-2'
|
||||
disabled={loading || syncLoading || confirmLoading}
|
||||
onClick={() => {
|
||||
setModalVisible(true);
|
||||
if (allChannels.length === 0) {
|
||||
@@ -469,7 +644,10 @@ export default function UpstreamRatioSync(props) {
|
||||
icon={<CheckSquare size={14} />}
|
||||
type='secondary'
|
||||
onClick={applySync}
|
||||
disabled={!hasSelections}
|
||||
loading={loading || confirmLoading}
|
||||
disabled={
|
||||
!hasSelections || loading || syncLoading || confirmLoading
|
||||
}
|
||||
className='w-full md:w-auto mt-2'
|
||||
>
|
||||
{t('应用同步')}
|
||||
@@ -484,14 +662,16 @@ export default function UpstreamRatioSync(props) {
|
||||
value={searchKeyword}
|
||||
onChange={setSearchKeyword}
|
||||
className='w-full sm:w-64'
|
||||
disabled={loading || syncLoading || confirmLoading}
|
||||
showClear
|
||||
/>
|
||||
|
||||
<Select
|
||||
placeholder={t('按倍率类型筛选')}
|
||||
placeholder={t('按价格字段筛选')}
|
||||
value={ratioTypeFilter}
|
||||
onChange={setRatioTypeFilter}
|
||||
className='w-full sm:w-48'
|
||||
disabled={loading || syncLoading || confirmLoading}
|
||||
showClear
|
||||
onClear={() => setRatioTypeFilter('')}
|
||||
>
|
||||
@@ -500,7 +680,18 @@ export default function UpstreamRatioSync(props) {
|
||||
{t('补全倍率')}
|
||||
</Select.Option>
|
||||
<Select.Option value='cache_ratio'>{t('缓存倍率')}</Select.Option>
|
||||
<Select.Option value='create_cache_ratio'>
|
||||
{t('缓存创建倍率')}
|
||||
</Select.Option>
|
||||
<Select.Option value='image_ratio'>{t('图片倍率')}</Select.Option>
|
||||
<Select.Option value='audio_ratio'>{t('音频倍率')}</Select.Option>
|
||||
<Select.Option value='audio_completion_ratio'>
|
||||
{t('音频补全倍率')}
|
||||
</Select.Option>
|
||||
<Select.Option value='model_price'>{t('固定价格')}</Select.Option>
|
||||
<Select.Option value='billing_expr'>
|
||||
{t('表达式计费')}
|
||||
</Select.Option>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
@@ -510,31 +701,17 @@ export default function UpstreamRatioSync(props) {
|
||||
|
||||
const renderDifferenceTable = () => {
|
||||
const dataSource = useMemo(() => {
|
||||
const tmp = [];
|
||||
|
||||
Object.entries(differences).forEach(([model, ratioTypes]) => {
|
||||
return Object.entries(differences).map(([model, ratioTypes]) => {
|
||||
const hasPrice = 'model_price' in ratioTypes;
|
||||
const hasOtherRatio = [
|
||||
'model_ratio',
|
||||
'completion_ratio',
|
||||
'cache_ratio',
|
||||
].some((rt) => rt in ratioTypes);
|
||||
const billingConflict = hasPrice && hasOtherRatio;
|
||||
const hasOtherRatio = ratioSyncFields.some((rt) => rt in ratioTypes);
|
||||
|
||||
Object.entries(ratioTypes).forEach(([ratioType, diff]) => {
|
||||
tmp.push({
|
||||
key: `${model}_${ratioType}`,
|
||||
model,
|
||||
ratioType,
|
||||
current: diff.current,
|
||||
upstreams: diff.upstreams,
|
||||
confidence: diff.confidence || {},
|
||||
billingConflict,
|
||||
});
|
||||
});
|
||||
return {
|
||||
key: model,
|
||||
model,
|
||||
ratioTypes,
|
||||
billingConflict: hasPrice && hasOtherRatio,
|
||||
};
|
||||
});
|
||||
|
||||
return tmp;
|
||||
}, [differences]);
|
||||
|
||||
const filteredDataSource = useMemo(() => {
|
||||
@@ -548,7 +725,7 @@ export default function UpstreamRatioSync(props) {
|
||||
item.model.toLowerCase().includes(searchKeyword.toLowerCase().trim());
|
||||
|
||||
const matchesRatioType =
|
||||
!ratioTypeFilter || item.ratioType === ratioTypeFilter;
|
||||
!ratioTypeFilter || ratioTypeFilter in item.ratioTypes;
|
||||
|
||||
return matchesKeyword && matchesRatioType;
|
||||
});
|
||||
@@ -557,12 +734,162 @@ export default function UpstreamRatioSync(props) {
|
||||
const upstreamNames = useMemo(() => {
|
||||
const set = new Set();
|
||||
filteredDataSource.forEach((row) => {
|
||||
Object.keys(row.upstreams || {}).forEach((name) => set.add(name));
|
||||
getOrderedRatioTypes(row.ratioTypes).forEach((ratioType) => {
|
||||
Object.keys(row.ratioTypes[ratioType]?.upstreams || {}).forEach(
|
||||
(name) => set.add(name),
|
||||
);
|
||||
});
|
||||
});
|
||||
return Array.from(set);
|
||||
}, [filteredDataSource]);
|
||||
}, [filteredDataSource, ratioTypeFilter]);
|
||||
|
||||
const renderValueTag = (value, color = 'default') => {
|
||||
if (value === null || value === undefined) {
|
||||
return (
|
||||
<Tag color='default' shape='circle'>
|
||||
{t('未设置')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
const text = String(value);
|
||||
return (
|
||||
<Tooltip content={text}>
|
||||
<Tag color={color} shape='circle'>
|
||||
<span className='inline-block max-w-[360px] truncate align-bottom'>
|
||||
{text}
|
||||
</span>
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
const renderCurrentFields = (record) => {
|
||||
const fields = getOrderedRatioTypes(record.ratioTypes);
|
||||
return (
|
||||
<div className='flex min-w-[260px] flex-col gap-2'>
|
||||
{fields.map((ratioType) => (
|
||||
<div
|
||||
key={ratioType}
|
||||
className='flex min-w-0 flex-wrap items-center gap-2'
|
||||
>
|
||||
<Tag color={stringToColor(ratioType)} shape='circle'>
|
||||
{getSyncFieldLabel(ratioType)}
|
||||
</Tag>
|
||||
{renderValueTag(record.ratioTypes[ratioType]?.current, 'blue')}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderUpstreamField = (record, ratioType, upName) => {
|
||||
const diff = record.ratioTypes[ratioType] || {};
|
||||
const upstreamVal = diff.upstreams?.[upName];
|
||||
const isConfident = diff.confidence?.[upName] !== false;
|
||||
const isPreferredField =
|
||||
getPreferredSyncField(record.model, ratioType, upName) === ratioType;
|
||||
|
||||
if (upstreamVal === null || upstreamVal === undefined) {
|
||||
return renderValueTag(undefined);
|
||||
}
|
||||
|
||||
if (upstreamVal === 'same') {
|
||||
return (
|
||||
<Tag color='blue' shape='circle'>
|
||||
{t('与本地相同')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
const text = String(upstreamVal);
|
||||
const isSelected =
|
||||
isPreferredField &&
|
||||
resolutions[record.model]?.[ratioType] === upstreamVal;
|
||||
const valueNode = isPreferredField ? (
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
disabled={loading || syncLoading || confirmLoading}
|
||||
onChange={(e) => {
|
||||
const isChecked = e.target.checked;
|
||||
if (isChecked) {
|
||||
selectValue(record.model, ratioType, upstreamVal, upName);
|
||||
} else {
|
||||
setResolutions((prev) => {
|
||||
const newRes = { ...prev };
|
||||
deleteResolutionField(newRes, record.model, ratioType);
|
||||
return newRes;
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Tooltip content={text}>
|
||||
<span className='inline-block max-w-[360px] truncate align-bottom'>
|
||||
{text}
|
||||
</span>
|
||||
</Tooltip>
|
||||
</Checkbox>
|
||||
) : (
|
||||
<Tooltip content={text}>
|
||||
<Tag color='default' shape='circle' type='light'>
|
||||
<span className='inline-block max-w-[360px] truncate align-bottom'>
|
||||
{text}
|
||||
</span>
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className='flex min-w-0 items-center gap-2'>
|
||||
{valueNode}
|
||||
{!isConfident && (
|
||||
<Tooltip
|
||||
position='left'
|
||||
content={t('该数据可能不可信,请谨慎使用')}
|
||||
>
|
||||
<AlertTriangle size={16} className='shrink-0 text-yellow-500' />
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderUpstreamFields = (record, upName) => {
|
||||
const fields = getOrderedRatioTypes(record.ratioTypes).filter(
|
||||
(ratioType) => shouldShowSyncField(record.model, ratioType, upName),
|
||||
);
|
||||
return (
|
||||
<div className='flex min-w-[280px] flex-col gap-2'>
|
||||
{fields.map((ratioType) => (
|
||||
<div key={ratioType} className='flex min-w-0 items-start gap-2'>
|
||||
<Tag
|
||||
color={stringToColor(ratioType)}
|
||||
shape='circle'
|
||||
className='shrink-0'
|
||||
>
|
||||
{getSyncFieldLabel(ratioType)}
|
||||
</Tag>
|
||||
<div className='min-w-0 flex-1'>
|
||||
{renderUpstreamField(record, ratioType, upName)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
if (filteredDataSource.length === 0) {
|
||||
if (syncLoading) {
|
||||
return (
|
||||
<div className='flex min-h-[260px] flex-col items-center justify-center gap-3'>
|
||||
<Spin size='large' />
|
||||
<div className='text-sm text-gray-500'>
|
||||
{t('正在同步上游价格,请稍候')}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Empty
|
||||
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
|
||||
@@ -574,7 +901,7 @@ export default function UpstreamRatioSync(props) {
|
||||
? t('未找到匹配的模型')
|
||||
: Object.keys(differences).length === 0
|
||||
? hasSynced
|
||||
? t('暂无差异化倍率显示')
|
||||
? t('暂无差异化价格显示')
|
||||
: t('请先选择同步渠道')
|
||||
: t('请先选择同步渠道')
|
||||
}
|
||||
@@ -588,95 +915,24 @@ export default function UpstreamRatioSync(props) {
|
||||
title: t('模型'),
|
||||
dataIndex: 'model',
|
||||
fixed: 'left',
|
||||
},
|
||||
{
|
||||
title: t('倍率类型'),
|
||||
dataIndex: 'ratioType',
|
||||
render: (text, record) => {
|
||||
const typeMap = {
|
||||
model_ratio: t('模型倍率'),
|
||||
completion_ratio: t('补全倍率'),
|
||||
cache_ratio: t('缓存倍率'),
|
||||
model_price: t('固定价格'),
|
||||
};
|
||||
const baseTag = (
|
||||
<Tag color={stringToColor(text)} shape='circle'>
|
||||
{typeMap[text] || text}
|
||||
</Tag>
|
||||
);
|
||||
if (record?.billingConflict) {
|
||||
return (
|
||||
<div className='flex items-center gap-1'>
|
||||
{baseTag}
|
||||
<Tooltip
|
||||
position='top'
|
||||
content={t(
|
||||
'该模型存在固定价格与倍率计费方式冲突,请确认选择',
|
||||
)}
|
||||
>
|
||||
<AlertTriangle size={14} className='text-yellow-500' />
|
||||
</Tooltip>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return baseTag;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('置信度'),
|
||||
dataIndex: 'confidence',
|
||||
render: (_, record) => {
|
||||
const allConfident = Object.values(record.confidence || {}).every(
|
||||
(v) => v !== false,
|
||||
);
|
||||
|
||||
if (allConfident) {
|
||||
return (
|
||||
<Tooltip content={t('所有上游数据均可信')}>
|
||||
<Tag
|
||||
color='green'
|
||||
shape='circle'
|
||||
type='light'
|
||||
prefixIcon={<CheckCircle size={14} />}
|
||||
>
|
||||
{t('可信')}
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
);
|
||||
} else {
|
||||
const untrustedSources = Object.entries(record.confidence || {})
|
||||
.filter(([_, isConfident]) => isConfident === false)
|
||||
.map(([name]) => name)
|
||||
.join(', ');
|
||||
|
||||
return (
|
||||
render: (text, record) => (
|
||||
<div className='flex min-w-[180px] items-center gap-2'>
|
||||
<span className='font-medium'>{text}</span>
|
||||
{record.billingConflict && (
|
||||
<Tooltip
|
||||
content={t('以下上游数据可能不可信:') + untrustedSources}
|
||||
position='top'
|
||||
content={t('该模型存在固定价格与倍率计费方式冲突,请确认选择')}
|
||||
>
|
||||
<Tag
|
||||
color='yellow'
|
||||
shape='circle'
|
||||
type='light'
|
||||
prefixIcon={<AlertTriangle size={14} />}
|
||||
>
|
||||
{t('谨慎')}
|
||||
</Tag>
|
||||
<AlertTriangle size={14} className='shrink-0 text-yellow-500' />
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
},
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('当前值'),
|
||||
title: t('当前价格'),
|
||||
dataIndex: 'current',
|
||||
render: (text) => (
|
||||
<Tag
|
||||
color={text !== null && text !== undefined ? 'blue' : 'default'}
|
||||
shape='circle'
|
||||
>
|
||||
{text !== null && text !== undefined ? String(text) : t('未设置')}
|
||||
</Tag>
|
||||
),
|
||||
render: (_, record) => renderCurrentFields(record),
|
||||
},
|
||||
...upstreamNames.map((upName) => {
|
||||
const channelStats = (() => {
|
||||
@@ -684,19 +940,20 @@ export default function UpstreamRatioSync(props) {
|
||||
let selectedCount = 0;
|
||||
|
||||
filteredDataSource.forEach((row) => {
|
||||
const upstreamVal = row.upstreams?.[upName];
|
||||
if (
|
||||
upstreamVal !== null &&
|
||||
upstreamVal !== undefined &&
|
||||
upstreamVal !== 'same'
|
||||
) {
|
||||
selectableCount++;
|
||||
const isSelected =
|
||||
resolutions[row.model]?.[row.ratioType] === upstreamVal;
|
||||
if (isSelected) {
|
||||
selectedCount++;
|
||||
getOrderedRatioTypes(row.ratioTypes).forEach((ratioType) => {
|
||||
const upstreamVal =
|
||||
row.ratioTypes[ratioType]?.upstreams?.[upName];
|
||||
if (
|
||||
getPreferredSyncField(row.model, ratioType, upName) ===
|
||||
ratioType &&
|
||||
isSelectableUpstreamValue(upstreamVal)
|
||||
) {
|
||||
selectableCount++;
|
||||
if (resolutions[row.model]?.[ratioType] === upstreamVal) {
|
||||
selectedCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
@@ -713,25 +970,29 @@ export default function UpstreamRatioSync(props) {
|
||||
const handleBulkSelect = (checked) => {
|
||||
if (checked) {
|
||||
filteredDataSource.forEach((row) => {
|
||||
const upstreamVal = row.upstreams?.[upName];
|
||||
if (
|
||||
upstreamVal !== null &&
|
||||
upstreamVal !== undefined &&
|
||||
upstreamVal !== 'same'
|
||||
) {
|
||||
selectValue(row.model, row.ratioType, upstreamVal);
|
||||
}
|
||||
getOrderedRatioTypes(row.ratioTypes).forEach((ratioType) => {
|
||||
const upstreamVal =
|
||||
row.ratioTypes[ratioType]?.upstreams?.[upName];
|
||||
if (
|
||||
getPreferredSyncField(row.model, ratioType, upName) ===
|
||||
ratioType &&
|
||||
isSelectableUpstreamValue(upstreamVal)
|
||||
) {
|
||||
selectValue(row.model, ratioType, upstreamVal, upName);
|
||||
}
|
||||
});
|
||||
});
|
||||
} else {
|
||||
setResolutions((prev) => {
|
||||
const newRes = { ...prev };
|
||||
filteredDataSource.forEach((row) => {
|
||||
if (newRes[row.model]) {
|
||||
delete newRes[row.model][row.ratioType];
|
||||
if (Object.keys(newRes[row.model]).length === 0) {
|
||||
delete newRes[row.model];
|
||||
getOrderedRatioTypes(row.ratioTypes).forEach((ratioType) => {
|
||||
if (
|
||||
row.ratioTypes[ratioType]?.upstreams?.[upName] !== undefined
|
||||
) {
|
||||
deleteResolutionField(newRes, row.model, ratioType);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
return newRes;
|
||||
});
|
||||
@@ -743,6 +1004,7 @@ export default function UpstreamRatioSync(props) {
|
||||
<Checkbox
|
||||
checked={channelStats.allSelected}
|
||||
indeterminate={channelStats.partiallySelected}
|
||||
disabled={loading || syncLoading || confirmLoading}
|
||||
onChange={(e) => handleBulkSelect(e.target.checked)}
|
||||
>
|
||||
{upName}
|
||||
@@ -751,64 +1013,7 @@ export default function UpstreamRatioSync(props) {
|
||||
<span>{upName}</span>
|
||||
),
|
||||
dataIndex: upName,
|
||||
render: (_, record) => {
|
||||
const upstreamVal = record.upstreams?.[upName];
|
||||
const isConfident = record.confidence?.[upName] !== false;
|
||||
|
||||
if (upstreamVal === null || upstreamVal === undefined) {
|
||||
return (
|
||||
<Tag color='default' shape='circle'>
|
||||
{t('未设置')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
if (upstreamVal === 'same') {
|
||||
return (
|
||||
<Tag color='blue' shape='circle'>
|
||||
{t('与本地相同')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
const isSelected =
|
||||
resolutions[record.model]?.[record.ratioType] === upstreamVal;
|
||||
|
||||
return (
|
||||
<div className='flex items-center gap-2'>
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
onChange={(e) => {
|
||||
const isChecked = e.target.checked;
|
||||
if (isChecked) {
|
||||
selectValue(record.model, record.ratioType, upstreamVal);
|
||||
} else {
|
||||
setResolutions((prev) => {
|
||||
const newRes = { ...prev };
|
||||
if (newRes[record.model]) {
|
||||
delete newRes[record.model][record.ratioType];
|
||||
if (Object.keys(newRes[record.model]).length === 0) {
|
||||
delete newRes[record.model];
|
||||
}
|
||||
}
|
||||
return newRes;
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{String(upstreamVal)}
|
||||
</Checkbox>
|
||||
{!isConfident && (
|
||||
<Tooltip
|
||||
position='left'
|
||||
content={t('该数据可能不可信,请谨慎使用')}
|
||||
>
|
||||
<AlertTriangle size={16} className='text-yellow-500' />
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
render: (_, record) => renderUpstreamFields(record, upName),
|
||||
};
|
||||
}),
|
||||
];
|
||||
@@ -874,15 +1079,37 @@ export default function UpstreamRatioSync(props) {
|
||||
t={t}
|
||||
visible={confirmVisible}
|
||||
items={conflictItems}
|
||||
loading={confirmLoading}
|
||||
onOk={async () => {
|
||||
setConfirmVisible(false);
|
||||
setConfirmLoading(true);
|
||||
const curRatios = {
|
||||
ModelRatio: JSON.parse(props.options.ModelRatio || '{}'),
|
||||
CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'),
|
||||
CacheRatio: JSON.parse(props.options.CacheRatio || '{}'),
|
||||
CreateCacheRatio: JSON.parse(
|
||||
props.options.CreateCacheRatio || '{}',
|
||||
),
|
||||
ImageRatio: JSON.parse(props.options.ImageRatio || '{}'),
|
||||
AudioRatio: JSON.parse(props.options.AudioRatio || '{}'),
|
||||
AudioCompletionRatio: JSON.parse(
|
||||
props.options.AudioCompletionRatio || '{}',
|
||||
),
|
||||
ModelPrice: JSON.parse(props.options.ModelPrice || '{}'),
|
||||
'billing_setting.billing_mode': JSON.parse(
|
||||
props.options['billing_setting.billing_mode'] || '{}',
|
||||
),
|
||||
'billing_setting.billing_expr': JSON.parse(
|
||||
props.options['billing_setting.billing_expr'] || '{}',
|
||||
),
|
||||
};
|
||||
await performSync(curRatios);
|
||||
try {
|
||||
const success = await performSync(curRatios);
|
||||
if (success) {
|
||||
setConfirmVisible(false);
|
||||
}
|
||||
} finally {
|
||||
setConfirmLoading(false);
|
||||
}
|
||||
}}
|
||||
onCancel={() => setConfirmVisible(false)}
|
||||
/>
|
||||
|
||||
@@ -31,9 +31,10 @@ import {
|
||||
TextArea,
|
||||
Typography,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconDelete, IconPlus } from '@douyinfe/semi-icons';
|
||||
import { IconCopy, IconDelete, IconPlus } from '@douyinfe/semi-icons';
|
||||
import { renderQuota } from '../../../../helpers/render';
|
||||
import { BILLING_EXTRA_VARS, BILLING_CACHE_VAR_MAP } from '../../../../constants';
|
||||
import { copy, showSuccess } from '../../../../helpers';
|
||||
import { BILLING_EXTRA_VARS, BILLING_CACHE_VAR_MAP, BILLING_CONDITION_VARS } from '../../../../constants';
|
||||
import {
|
||||
createEmptyCondition,
|
||||
createEmptyTimeCondition,
|
||||
@@ -70,6 +71,7 @@ function priceToUnitCost(price) {
|
||||
|
||||
const OPS = ['<', '<=', '>', '>='];
|
||||
const VAR_OPTIONS = [
|
||||
{ value: 'len', label: 'len (长度)' },
|
||||
{ value: 'p', label: 'p (输入)' },
|
||||
{ value: 'c', label: 'c (输出)' },
|
||||
];
|
||||
@@ -224,7 +226,7 @@ function tryParseVisualConfig(exprStr) {
|
||||
}
|
||||
|
||||
// Multi-tier: cond1 ? tier(body) : cond2 ? tier(body) : tier(body)
|
||||
const condGroup = `((?:(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`;
|
||||
const condGroup = `((?:(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`;
|
||||
const tierRe = new RegExp(
|
||||
`(?:${condGroup}\\s*\\?\\s*)?tier\\("([^"]*)",\\s*${bodyPat}\\)`,
|
||||
'g',
|
||||
@@ -237,7 +239,7 @@ function tryParseVisualConfig(exprStr) {
|
||||
if (condStr) {
|
||||
const condParts = condStr.split(/\s*&&\s*/);
|
||||
for (const cp of condParts) {
|
||||
const cm = cp.trim().match(/^(p|c)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/);
|
||||
const cm = cp.trim().match(/^(p|c|len)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/);
|
||||
if (cm) {
|
||||
conditions.push({ var: cm[1], op: cm[2], value: Number(cm[3]) });
|
||||
}
|
||||
@@ -283,7 +285,7 @@ function ConditionRow({ cond, onChange, onRemove, t }) {
|
||||
}}>
|
||||
<Select
|
||||
size='small'
|
||||
value={cond.var || 'p'}
|
||||
value={cond.var || 'len'}
|
||||
onChange={(val) => onChange({ ...cond, var: val })}
|
||||
>
|
||||
{VAR_OPTIONS.map((v) => (
|
||||
@@ -500,7 +502,7 @@ function ExtendedPriceBlock({ tier, index, onUpdate, t }) {
|
||||
function VisualTierCard({ tier, index, isLast, isOnly, onUpdate, onRemove, t }) {
|
||||
const conditions = tier.conditions || [];
|
||||
|
||||
const varLabel = { p: t('输入'), c: t('输出') };
|
||||
const varLabel = { len: t('长度'), p: t('输入'), c: t('输出') };
|
||||
const condSummary = useMemo(() => {
|
||||
if (conditions.length === 0) return t('无条件(兜底档)');
|
||||
return conditions
|
||||
@@ -525,7 +527,7 @@ function VisualTierCard({ tier, index, isLast, isOnly, onUpdate, onRemove, t })
|
||||
const addCondition = () => {
|
||||
if (conditions.length >= 2) return;
|
||||
const usedVars = conditions.map((c) => c.var);
|
||||
const nextVar = usedVars.includes('p') ? 'c' : 'p';
|
||||
const nextVar = usedVars.includes('len') ? 'c' : 'len';
|
||||
onUpdate(index, 'conditions', [
|
||||
...conditions,
|
||||
{ var: nextVar, op: '<', value: 200000 },
|
||||
@@ -694,7 +696,7 @@ function VisualEditor({ visualConfig, onChange, t }) {
|
||||
) {
|
||||
newTiers[newTiers.length - 1] = {
|
||||
...newTiers[newTiers.length - 1],
|
||||
conditions: [{ var: 'p', op: '<', value: 200000 }],
|
||||
conditions: [{ var: 'len', op: '<', value: 200000 }],
|
||||
};
|
||||
}
|
||||
newTiers.push({
|
||||
@@ -723,7 +725,7 @@ function VisualEditor({ visualConfig, onChange, t }) {
|
||||
<div>
|
||||
<Banner
|
||||
type='info'
|
||||
description={t('每个档位可设置 0~2 个条件(对 p 和 c),最后一档为兜底档无需条件。')}
|
||||
description={t('每个档位可设置 0~2 个条件(对 len、p 和 c),最后一档为兜底档无需条件。len 为输入上下文总长度(含缓存),推荐用于阶梯条件。')}
|
||||
style={{ marginBottom: 12 }}
|
||||
/>
|
||||
|
||||
@@ -762,16 +764,16 @@ const PRESET_GROUPS = [
|
||||
presets: [
|
||||
{ key: 'flat', label: 'Flat', expr: 'tier("base", p * 2 + c * 4)' },
|
||||
{ key: 'claude-opus', label: 'Claude Opus 4.6', expr: 'tier("base", p * 5 + c * 25 + cr * 0.5 + cc * 6.25 + cc1h * 10)' },
|
||||
{ key: 'gpt-5.4', label: 'GPT-5.4', expr: 'p <= 272000 ? tier("standard", p * 2.5 + c * 15 + cr * 0.25) : tier("long_context", p * 5 + c * 22.5 + cr * 0.5)' },
|
||||
{ key: 'gpt-5.4', label: 'GPT-5.4', expr: 'len <= 272000 ? tier("standard", p * 2.5 + c * 15 + cr * 0.25) : tier("long_context", p * 5 + c * 22.5 + cr * 0.5)' },
|
||||
],
|
||||
},
|
||||
{
|
||||
group: '阶梯计费',
|
||||
presets: [
|
||||
{ key: 'claude-sonnet', label: 'Claude Sonnet 4.5', expr: 'p <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12)' },
|
||||
{ key: 'qwen3-max', label: 'Qwen3 Max', expr: 'p <= 32000 ? tier("short", p * 1.2 + c * 6 + cr * 0.24 + cc * 1.5) : p <= 128000 ? tier("mid", p * 2.4 + c * 12 + cr * 0.48 + cc * 3) : tier("long", p * 3 + c * 15 + cr * 0.6 + cc * 3.75)' },
|
||||
{ key: 'glm-4.5-air', label: 'GLM-4.5 Air', expr: 'p < 32000 && c < 200 ? tier("short_output", p * 0.8 + c * 2 + cr * 0.16) : p < 32000 && c >= 200 ? tier("long_output", p * 0.8 + c * 6 + cr * 0.16) : tier("mid_context", p * 1.2 + c * 8 + cr * 0.24)' },
|
||||
{ key: 'doubao-seed-1.8', label: 'Doubao Seed 1.8', expr: 'p <= 32000 && c <= 200 ? tier("discount", p * 0.8 + c * 2 + cr * 0.16 + cc * 0.17) : p <= 32000 ? tier("short", p * 0.8 + c * 8 + cr * 0.16 + cc * 0.17) : p <= 128000 ? tier("mid", p * 1.2 + c * 16 + cr * 0.16 + cc * 0.17) : tier("long", p * 2.4 + c * 24 + cr * 0.16 + cc * 0.17)' },
|
||||
{ key: 'claude-sonnet', label: 'Claude Sonnet 4.5', expr: 'len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12)' },
|
||||
{ key: 'qwen3-max', label: 'Qwen3 Max', expr: 'len <= 32000 ? tier("short", p * 1.2 + c * 6 + cr * 0.24 + cc * 1.5) : len <= 128000 ? tier("mid", p * 2.4 + c * 12 + cr * 0.48 + cc * 3) : tier("long", p * 3 + c * 15 + cr * 0.6 + cc * 3.75)' },
|
||||
{ key: 'glm-4.5-air', label: 'GLM-4.5 Air', expr: 'len < 32000 && c < 200 ? tier("short_output", p * 0.8 + c * 2 + cr * 0.16) : len < 32000 && c >= 200 ? tier("long_output", p * 0.8 + c * 6 + cr * 0.16) : tier("mid_context", p * 1.2 + c * 8 + cr * 0.24)' },
|
||||
{ key: 'doubao-seed-1.8', label: 'Doubao Seed 1.8', expr: 'len <= 32000 && c <= 200 ? tier("discount", p * 0.8 + c * 2 + cr * 0.16 + cc * 0.17) : len <= 32000 ? tier("short", p * 0.8 + c * 8 + cr * 0.16 + cc * 0.17) : len <= 128000 ? tier("mid", p * 1.2 + c * 16 + cr * 0.16 + cc * 0.17) : tier("long", p * 2.4 + c * 24 + cr * 0.16 + cc * 0.17)' },
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -793,7 +795,7 @@ const PRESET_GROUPS = [
|
||||
},
|
||||
{
|
||||
key: 'gpt-5.4-tiers', label: 'GPT-5.4 Priority/Flex',
|
||||
expr: 'p <= 272000 ? tier("standard", p * 2.5 + c * 15 + cr * 0.25) : tier("long_context", p * 5 + c * 22.5 + cr * 0.5)',
|
||||
expr: 'len <= 272000 ? tier("standard", p * 2.5 + c * 15 + cr * 0.25) : tier("long_context", p * 5 + c * 22.5 + cr * 0.5)',
|
||||
requestRules: [
|
||||
{ conditions: [{ source: SOURCE_PARAM, path: 'service_tier', mode: MATCH_EQ, value: 'priority' }], multiplier: '2' },
|
||||
{ conditions: [{ source: SOURCE_PARAM, path: 'service_tier', mode: MATCH_EQ, value: 'flex' }], multiplier: '0.5' },
|
||||
@@ -880,7 +882,8 @@ function RawExprEditor({ exprString, onChange, t }) {
|
||||
<div>
|
||||
<div>
|
||||
{t('变量')}: <code>p</code> ({t('输入 Token')}), <code>c</code> (
|
||||
{t('输出 Token')}), <code>cr</code> ({t('缓存读取')}),{' '}
|
||||
{t('输出 Token')}), <code>len</code> ({t('输入长度')}),{' '}
|
||||
<code>cr</code> ({t('缓存读取')}),{' '}
|
||||
<code>cc</code> ({t('缓存创建')}),{' '}
|
||||
<code>cc1h</code> ({t('缓存创建-1小时')})
|
||||
</div>
|
||||
@@ -968,7 +971,11 @@ function evalExprLocally(exprStr, p, c, extraTokenValues) {
|
||||
matchedTier = name;
|
||||
return value;
|
||||
};
|
||||
const env = { p, c, tier: tierFn, max: Math.max, min: Math.min, abs: Math.abs, ceil: Math.ceil, floor: Math.floor };
|
||||
const cacheReadTokens = extraTokenValues.cacheReadTokens || 0;
|
||||
const cacheCreateTokens = extraTokenValues.cacheCreateTokens || 0;
|
||||
const cacheCreate1hTokens = extraTokenValues.cacheCreate1hTokens || 0;
|
||||
const len = p + cacheReadTokens + cacheCreateTokens + cacheCreate1hTokens;
|
||||
const env = { p, c, len, tier: tierFn, max: Math.max, min: Math.min, abs: Math.abs, ceil: Math.ceil, floor: Math.floor };
|
||||
for (const field of EXTRA_ESTIMATOR_FIELDS) {
|
||||
env[field.var] = extraTokenValues[field.stateKey] || 0;
|
||||
}
|
||||
@@ -1220,6 +1227,146 @@ function RuleGroupCard({ group, index, onChange, onRemove, t }) {
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LLM prompt helper — copyable prompt for LLM-assisted expression design
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const LLM_PROMPT_TEMPLATE = `你是一个 AI API 计费表达式设计助手。用户需要你帮忙设计一个计费表达式(billing expression),用于 AI API 网关的模型计费。
|
||||
|
||||
## 表达式语言
|
||||
|
||||
表达式基于 expr-lang/expr,支持标准算术运算和三元运算符。
|
||||
|
||||
### Token 变量
|
||||
|
||||
输入侧:
|
||||
- p — 输入 token 数(计价用)。系统会自动排除表达式中单独计价的子类别(如用了 cr,缓存 token 就从 p 中扣除)
|
||||
- len — 输入上下文总长度(条件判断用)。不受自动排除影响,始终反映完整输入长度。用于阶梯条件判断
|
||||
- cr — 缓存命中(读取)token 数
|
||||
- cc — 缓存创建 token 数(5分钟 TTL)
|
||||
- cc1h — 缓存创建 token 数(1小时 TTL,Claude 专用)
|
||||
- img — 图片输入 token 数
|
||||
- ai — 音频输入 token 数
|
||||
|
||||
输出侧:
|
||||
- c — 输出 token 数。同样会自动排除单独计价的子类别
|
||||
- img_o — 图片输出 token 数
|
||||
- ao — 音频输出 token 数
|
||||
|
||||
### p/c 自动排除机制
|
||||
|
||||
p 和 c 是兜底变量,代表所有没有被表达式单独定价的 token。如果表达式使用了某个子类别变量(如 cr),对应 token 就从 p 中扣除,避免重复计费。没用到的子类别 token 则留在 p/c 中按基础价格计费。
|
||||
|
||||
重要:len 不受自动排除影响。阶梯条件应使用 len 而非 p,以避免缓存命中导致 p 降低而误判档位。
|
||||
|
||||
### 内置函数
|
||||
|
||||
- tier(name, value) — 标记计费档位名称,必须包裹费用表达式
|
||||
- max(a, b)、min(a, b) — 取大/小值
|
||||
- ceil(x)、floor(x)、abs(x) — 向上取整、向下取整、绝对值
|
||||
- header(name) — 读取请求头
|
||||
- param(path) — 读取请求体 JSON 路径(gjson 语法)
|
||||
- has(source, substr) — 子字符串检查
|
||||
- hour(tz)、minute(tz)、weekday(tz)、month(tz)、day(tz) — 时间函数,tz 为时区如 "Asia/Shanghai"
|
||||
|
||||
### 价格系数
|
||||
|
||||
表达式中的数字系数是 $/1M tokens 的价格。例如 p * 2.5 表示输入 $2.50/1M tokens。
|
||||
|
||||
## 表达式示例
|
||||
|
||||
简单定价:
|
||||
tier("base", p * 2.5 + c * 15)
|
||||
|
||||
带缓存的定价:
|
||||
tier("base", p * 2.5 + c * 15 + cr * 0.25)
|
||||
|
||||
多档阶梯(用 len 做条件):
|
||||
len <= 200000
|
||||
? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6)
|
||||
: tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12)
|
||||
|
||||
图片模型:
|
||||
tier("base", p * 2 + c * 8 + img * 2.5)
|
||||
|
||||
多模态含音频:
|
||||
tier("base", p * 0.43 + c * 3.06 + img * 0.78 + ai * 3.81 + ao * 15.11)
|
||||
|
||||
三档阶梯示例:
|
||||
len <= 128000
|
||||
? tier("standard", p * 1.1 + c * 4.4)
|
||||
: (len <= 1000000
|
||||
? tier("medium", p * 2.2 + c * 8.8)
|
||||
: tier("long", p * 4.4 + c * 17.6))
|
||||
|
||||
## 规则
|
||||
|
||||
1. 每个叶子分支必须用 tier("名称", 费用表达式) 包裹
|
||||
2. tier 名称用英文,如 "base"、"standard"、"long_context"
|
||||
3. 阶梯条件用 len(不要用 p),支持 <、<=、>、>=
|
||||
4. 多档用嵌套三元运算符:条件1 ? tier(...) : (条件2 ? tier(...) : tier(...))
|
||||
5. 价格系数直接写供应商官方 $/1M tokens 价格
|
||||
6. 不需要缓存/图片/音频单独定价时可以不写对应变量,它们的 token 会自动包含在 p/c 中
|
||||
|
||||
请根据用户提供的模型信息和定价需求,生成计费表达式。`;
|
||||
|
||||
function LlmPromptHelper({ t, model }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
const modelName = model?.name || '';
|
||||
const prompt = useMemo(() => {
|
||||
if (modelName) {
|
||||
return LLM_PROMPT_TEMPLATE + `\n\n当前模型:${modelName}`;
|
||||
}
|
||||
return LLM_PROMPT_TEMPLATE;
|
||||
}, [modelName]);
|
||||
|
||||
const handleCopy = useCallback(async () => {
|
||||
const ok = await copy(prompt);
|
||||
if (ok) showSuccess(t('已复制到剪贴板'));
|
||||
}, [prompt, t]);
|
||||
|
||||
return (
|
||||
<div style={{ marginBottom: 12 }}>
|
||||
<Button
|
||||
theme='borderless'
|
||||
size='small'
|
||||
icon={<IconCopy />}
|
||||
onClick={() => setOpen(!open)}
|
||||
style={{ color: 'var(--semi-color-tertiary)' }}
|
||||
>
|
||||
{t('LLM 辅助设计提示词')}
|
||||
</Button>
|
||||
<Collapsible isOpen={open}>
|
||||
<Card
|
||||
bodyStyle={{ padding: 12 }}
|
||||
style={{ marginTop: 8, background: 'var(--semi-color-fill-0)' }}
|
||||
>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 8 }}>
|
||||
<Text size='small' type='secondary'>
|
||||
{t('复制以下提示词发送给 LLM(如 ChatGPT / Claude),让它帮你设计计费表达式')}
|
||||
</Text>
|
||||
<Button
|
||||
icon={<IconCopy />}
|
||||
size='small'
|
||||
theme='light'
|
||||
onClick={handleCopy}
|
||||
>
|
||||
{t('复制提示词')}
|
||||
</Button>
|
||||
</div>
|
||||
<TextArea
|
||||
value={prompt}
|
||||
readonly
|
||||
autosize={{ minRows: 6, maxRows: 20 }}
|
||||
style={{ fontFamily: 'monospace', fontSize: 12 }}
|
||||
/>
|
||||
</Card>
|
||||
</Collapsible>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1543,6 +1690,8 @@ export default function TieredPricingEditor({ model, onExprChange, requestRuleEx
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<LlmPromptHelper t={t} model={model} />
|
||||
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1050,16 +1050,23 @@ export function useModelPricingEditorState({
|
||||
tieredOutput['billing_setting.billing_expr'][model.name] = finalBillingExpr;
|
||||
}
|
||||
}
|
||||
if (model.billingMode === 'tiered_expr') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const serialized = serializeModel(model, t);
|
||||
Object.entries(serialized).forEach(([key, value]) => {
|
||||
if (value !== null) {
|
||||
output[key][model.name] = value;
|
||||
// Always serialize ratio/price values for all models (including
|
||||
// tiered_expr) so they serve as fallback during multi-instance sync
|
||||
// delay. ModelPriceHelper checks billing_mode first, so these values
|
||||
// are only used when billing_setting hasn't propagated yet.
|
||||
try {
|
||||
const serialized = serializeModel(model, t);
|
||||
Object.entries(serialized).forEach(([key, value]) => {
|
||||
if (value !== null) {
|
||||
output[key][model.name] = value;
|
||||
}
|
||||
});
|
||||
} catch (e) {
|
||||
if (model.billingMode !== 'tiered_expr') {
|
||||
throw e;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const requestQueue = [
|
||||
|
||||
Reference in New Issue
Block a user