Merge origin/main into nightly

Resolve conflicts:
- .gitignore: keep nightly additions (.test, skills-lock.json)
- relay/helper/price.go: keep both billingexpr and model imports
- en.json / zh-CN.json: keep nightly's superset of i18n entries
- service/billing_session.go: add missing 3rd arg to DecreaseUserQuota
- en.json / zh-CN.json: deduplicate 129+320 duplicate i18n keys
This commit is contained in:
CaIon
2026-04-23 21:23:40 +08:00
117 changed files with 9031 additions and 2910 deletions
+51 -9
View File
@@ -151,6 +151,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
}
}
cache.WriteContext(c)
c.Set("id", 1)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
@@ -284,7 +285,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithStatusCode(http.StatusBadRequest)),
}
}
@@ -469,7 +470,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
if bodyErr := validateTestResponseBody(respBody, isStream); bodyErr != nil {
return testResult{
context: c,
localErr: bodyErr,
@@ -613,6 +614,42 @@ func detectErrorFromTestResponseBody(respBody []byte) error {
return nil
}
func validateStreamTestResponseBody(respBody []byte) error {
b := bytes.TrimSpace(respBody)
if len(b) == 0 {
return errors.New("stream response body is empty")
}
for _, line := range bytes.Split(b, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
return nil
}
return errors.New("stream response body does not contain a valid stream event")
}
func validateTestResponseBody(respBody []byte, isStream bool) error {
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
return bodyErr
}
if isStream {
return validateStreamTestResponseBody(respBody)
}
return nil
}
func shouldUseStreamForAutomaticChannelTest(channel *model.Channel) bool {
return channel != nil && channel.Type == constant.ChannelTypeCodex
}
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
if len(jsonBytes) == 0 {
return ""
@@ -800,11 +837,15 @@ func TestChannel(c *gin.Context) {
tik := time.Now()
result := testChannel(channel, testModel, endpointType, isStream)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"success": false,
"message": result.localErr.Error(),
"time": 0.0,
})
}
if result.newAPIError != nil {
resp["error_code"] = result.newAPIError.GetErrorCode()
}
c.JSON(http.StatusOK, resp)
return
}
tok := time.Now()
@@ -813,9 +854,10 @@ func TestChannel(c *gin.Context) {
consumedTime := float64(milliseconds) / 1000.0
if result.newAPIError != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.newAPIError.Error(),
"time": consumedTime,
"success": false,
"message": result.newAPIError.Error(),
"time": consumedTime,
"error_code": result.newAPIError.GetErrorCode(),
})
return
}
@@ -860,7 +902,7 @@ func testAllChannels(notify bool) error {
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, "", "", false)
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
@@ -868,7 +910,7 @@ func testAllChannels(notify bool) error {
newAPIError := result.newAPIError
// request error disables the channel
if newAPIError != nil {
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
shouldBanChannel = service.ShouldDisableChannel(result.newAPIError)
}
// 当错误检查通过,才检查响应时间
+12 -2
View File
@@ -27,6 +27,15 @@ var completionRatioMetaOptionKeys = []string{
"AudioCompletionRatio",
}
func isVisiblePublicKeyOption(key string) bool {
switch key {
case "WaffoPancakeWebhookPublicKey", "WaffoPancakeWebhookTestKey":
return true
default:
return false
}
}
func collectModelNamesFromOptionValue(raw string, modelNames map[string]struct{}) {
if strings.TrimSpace(raw) == "" {
return
@@ -66,11 +75,12 @@ func GetOptions(c *gin.Context) {
common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap {
value := common.Interface2String(v)
if strings.HasSuffix(k, "Token") ||
isSensitiveKey := strings.HasSuffix(k, "Token") ||
strings.HasSuffix(k, "Secret") ||
strings.HasSuffix(k, "Key") ||
strings.HasSuffix(k, "secret") ||
strings.HasSuffix(k, "api_key") {
strings.HasSuffix(k, "api_key")
if isSensitiveKey && !isVisiblePublicKeyOption(k) {
continue
}
options = append(options, &model.Option{
+70
View File
@@ -36,6 +36,10 @@ func PasskeyRegisterBegin(c *gin.Context) {
return
}
if !requirePasskeyRegistrationVerification(c, user.Id) {
return
}
credential, err := model.GetPasskeyByUserID(user.Id)
if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
common.ApiError(c, err)
@@ -96,6 +100,10 @@ func PasskeyRegisterFinish(c *gin.Context) {
return
}
if !requirePasskeyRegistrationVerification(c, user.Id) {
return
}
wa, err := passkeysvc.BuildWebAuthn(c.Request)
if err != nil {
common.ApiError(c, err)
@@ -151,6 +159,10 @@ func PasskeyDelete(c *gin.Context) {
return
}
if !requirePasskeyDeleteVerification(c, user.Id) {
return
}
if err := model.DeletePasskeyByUserID(user.Id); err != nil {
common.ApiError(c, err)
return
@@ -474,6 +486,7 @@ func PasskeyVerifyFinish(c *gin.Context) {
// Mark passkey as ready; /api/verify will convert this into the final secure verification session.
session.Set(PasskeyReadySessionKey, time.Now().Unix())
session.Delete(SecureVerificationSessionKey)
session.Delete(secureVerificationMethodSessionKey)
if err := session.Save(); err != nil {
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
return
@@ -504,3 +517,60 @@ func getSessionUser(c *gin.Context) (*model.User, error) {
}
return user, nil
}
func requirePasskeyRegistrationVerification(c *gin.Context, userID int) bool {
twoFA, err := model.GetTwoFAByUserId(userID)
if err != nil {
common.ApiError(c, err)
return false
}
if twoFA == nil || !twoFA.IsEnabled {
return true
}
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
}
func requirePasskeyDeleteVerification(c *gin.Context, userID int) bool {
twoFA, err := model.GetTwoFAByUserId(userID)
if err != nil {
common.ApiError(c, err)
return false
}
if twoFA != nil && twoFA.IsEnabled {
return requireSecureVerificationMethod(c, secureVerificationMethod2FA)
}
_, err = model.GetPasskeyByUserID(userID)
if err != nil {
if errors.Is(err, model.ErrPasskeyNotFound) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户尚未绑定 Passkey",
})
return false
}
common.ApiError(c, err)
return false
}
return requireSecureVerificationMethod(c, secureVerificationMethodPasskey)
}
func requireSecureVerificationMethod(c *gin.Context, method string) bool {
session := sessions.Default(c)
verifiedAt, ok := session.Get(SecureVerificationSessionKey).(int64)
if !ok || time.Now().Unix()-verifiedAt >= SecureVerificationTimeout {
session.Delete(SecureVerificationSessionKey)
session.Delete(secureVerificationMethodSessionKey)
_ = session.Save()
common.ApiErrorMsg(c, "请先完成安全验证")
return false
}
if verifiedMethod, ok := session.Get(secureVerificationMethodSessionKey).(string); !ok || verifiedMethod != method {
common.ApiErrorMsg(c, "请先完成对应的安全验证")
return false
}
return true
}
+100
View File
@@ -0,0 +1,100 @@
package controller
import (
"strings"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
)
func isStripeTopUpEnabled() bool {
return strings.TrimSpace(setting.StripeApiSecret) != "" &&
strings.TrimSpace(setting.StripeWebhookSecret) != "" &&
strings.TrimSpace(setting.StripePriceId) != ""
}
func isStripeWebhookConfigured() bool {
return strings.TrimSpace(setting.StripeWebhookSecret) != ""
}
func isStripeWebhookEnabled() bool {
return isStripeTopUpEnabled()
}
func isCreemTopUpEnabled() bool {
products := strings.TrimSpace(setting.CreemProducts)
return strings.TrimSpace(setting.CreemApiKey) != "" &&
products != "" &&
products != "[]"
}
func isCreemWebhookConfigured() bool {
return strings.TrimSpace(setting.CreemWebhookSecret) != ""
}
func isCreemWebhookEnabled() bool {
return isCreemTopUpEnabled() && isCreemWebhookConfigured()
}
func isWaffoTopUpEnabled() bool {
if !setting.WaffoEnabled {
return false
}
return isWaffoWebhookConfigured()
}
func isWaffoWebhookConfigured() bool {
if setting.WaffoSandbox {
return strings.TrimSpace(setting.WaffoSandboxApiKey) != "" &&
strings.TrimSpace(setting.WaffoSandboxPrivateKey) != "" &&
strings.TrimSpace(setting.WaffoSandboxPublicCert) != ""
}
return strings.TrimSpace(setting.WaffoApiKey) != "" &&
strings.TrimSpace(setting.WaffoPrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPublicCert) != ""
}
func isWaffoWebhookEnabled() bool {
return isWaffoTopUpEnabled()
}
func isWaffoPancakeTopUpEnabled() bool {
if !setting.WaffoPancakeEnabled {
return false
}
return isWaffoPancakeWebhookConfigured() &&
strings.TrimSpace(setting.WaffoPancakeMerchantID) != "" &&
strings.TrimSpace(setting.WaffoPancakePrivateKey) != "" &&
strings.TrimSpace(setting.WaffoPancakeStoreID) != "" &&
strings.TrimSpace(setting.WaffoPancakeProductID) != ""
}
func isWaffoPancakeWebhookConfigured() bool {
currentWebhookKey := strings.TrimSpace(setting.WaffoPancakeWebhookPublicKey)
if setting.WaffoPancakeSandbox {
currentWebhookKey = strings.TrimSpace(setting.WaffoPancakeWebhookTestKey)
}
return currentWebhookKey != ""
}
func isWaffoPancakeWebhookEnabled() bool {
return isWaffoPancakeTopUpEnabled()
}
func isEpayTopUpEnabled() bool {
return isEpayWebhookConfigured() && len(operation_setting.PayMethods) > 0
}
func isEpayWebhookConfigured() bool {
return strings.TrimSpace(operation_setting.PayAddress) != "" &&
strings.TrimSpace(operation_setting.EpayId) != "" &&
strings.TrimSpace(operation_setting.EpayKey) != ""
}
func isEpayWebhookEnabled() bool {
return isEpayTopUpEnabled()
}
@@ -0,0 +1,166 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/stretchr/testify/require"
)
func TestStripeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalAPISecret := setting.StripeApiSecret
originalWebhookSecret := setting.StripeWebhookSecret
originalPriceID := setting.StripePriceId
t.Cleanup(func() {
setting.StripeApiSecret = originalAPISecret
setting.StripeWebhookSecret = originalWebhookSecret
setting.StripePriceId = originalPriceID
})
setting.StripeWebhookSecret = ""
setting.StripeApiSecret = "sk_test_123"
setting.StripePriceId = "price_123"
require.False(t, isStripeWebhookEnabled())
setting.StripeWebhookSecret = "whsec_test"
require.True(t, isStripeWebhookEnabled())
setting.StripePriceId = ""
require.False(t, isStripeWebhookEnabled())
}
func TestCreemWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalAPIKey := setting.CreemApiKey
originalProducts := setting.CreemProducts
originalWebhookSecret := setting.CreemWebhookSecret
t.Cleanup(func() {
setting.CreemApiKey = originalAPIKey
setting.CreemProducts = originalProducts
setting.CreemWebhookSecret = originalWebhookSecret
})
setting.CreemWebhookSecret = ""
setting.CreemApiKey = "creem_api_key"
setting.CreemProducts = `[{"productId":"prod_123"}]`
require.False(t, isCreemWebhookEnabled())
setting.CreemWebhookSecret = "creem_secret"
require.True(t, isCreemWebhookEnabled())
setting.CreemProducts = "[]"
require.False(t, isCreemWebhookEnabled())
}
func TestWaffoWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalEnabled := setting.WaffoEnabled
originalSandbox := setting.WaffoSandbox
originalAPIKey := setting.WaffoApiKey
originalPrivateKey := setting.WaffoPrivateKey
originalPublicCert := setting.WaffoPublicCert
originalSandboxAPIKey := setting.WaffoSandboxApiKey
originalSandboxPrivateKey := setting.WaffoSandboxPrivateKey
originalSandboxPublicCert := setting.WaffoSandboxPublicCert
t.Cleanup(func() {
setting.WaffoEnabled = originalEnabled
setting.WaffoSandbox = originalSandbox
setting.WaffoApiKey = originalAPIKey
setting.WaffoPrivateKey = originalPrivateKey
setting.WaffoPublicCert = originalPublicCert
setting.WaffoSandboxApiKey = originalSandboxAPIKey
setting.WaffoSandboxPrivateKey = originalSandboxPrivateKey
setting.WaffoSandboxPublicCert = originalSandboxPublicCert
})
setting.WaffoEnabled = true
setting.WaffoSandbox = false
setting.WaffoApiKey = ""
setting.WaffoPrivateKey = "private"
setting.WaffoPublicCert = "public"
require.False(t, isWaffoWebhookEnabled())
setting.WaffoApiKey = "api"
require.True(t, isWaffoWebhookEnabled())
setting.WaffoEnabled = false
require.False(t, isWaffoWebhookEnabled())
setting.WaffoEnabled = true
setting.WaffoSandbox = true
setting.WaffoSandboxApiKey = ""
setting.WaffoSandboxPrivateKey = "sandbox_private"
setting.WaffoSandboxPublicCert = "sandbox_public"
require.False(t, isWaffoWebhookEnabled())
setting.WaffoSandboxApiKey = "sandbox_api"
require.True(t, isWaffoWebhookEnabled())
}
func TestWaffoPancakeWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalEnabled := setting.WaffoPancakeEnabled
originalSandbox := setting.WaffoPancakeSandbox
originalMerchantID := setting.WaffoPancakeMerchantID
originalPrivateKey := setting.WaffoPancakePrivateKey
originalWebhookPublicKey := setting.WaffoPancakeWebhookPublicKey
originalWebhookTestKey := setting.WaffoPancakeWebhookTestKey
originalStoreID := setting.WaffoPancakeStoreID
originalProductID := setting.WaffoPancakeProductID
t.Cleanup(func() {
setting.WaffoPancakeEnabled = originalEnabled
setting.WaffoPancakeSandbox = originalSandbox
setting.WaffoPancakeMerchantID = originalMerchantID
setting.WaffoPancakePrivateKey = originalPrivateKey
setting.WaffoPancakeWebhookPublicKey = originalWebhookPublicKey
setting.WaffoPancakeWebhookTestKey = originalWebhookTestKey
setting.WaffoPancakeStoreID = originalStoreID
setting.WaffoPancakeProductID = originalProductID
})
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = false
setting.WaffoPancakeMerchantID = "merchant"
setting.WaffoPancakePrivateKey = "private"
setting.WaffoPancakeStoreID = "store"
setting.WaffoPancakeProductID = "product"
setting.WaffoPancakeWebhookPublicKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookPublicKey = "public"
require.True(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeEnabled = false
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeEnabled = true
setting.WaffoPancakeSandbox = true
setting.WaffoPancakeWebhookTestKey = ""
require.False(t, isWaffoPancakeWebhookEnabled())
setting.WaffoPancakeWebhookTestKey = "test_public"
require.True(t, isWaffoPancakeWebhookEnabled())
}
func TestEpayWebhookEnabledRequiresTopUpAndWebhookConfig(t *testing.T) {
originalPayAddress := operation_setting.PayAddress
originalEpayID := operation_setting.EpayId
originalEpayKey := operation_setting.EpayKey
originalPayMethods := operation_setting.PayMethods
t.Cleanup(func() {
operation_setting.PayAddress = originalPayAddress
operation_setting.EpayId = originalEpayID
operation_setting.EpayKey = originalEpayKey
operation_setting.PayMethods = originalPayMethods
})
operation_setting.PayAddress = "https://pay.example.com"
operation_setting.EpayId = "epay_id"
operation_setting.EpayKey = ""
operation_setting.PayMethods = []map[string]string{{"type": "alipay"}}
require.False(t, isEpayWebhookEnabled())
operation_setting.EpayKey = "epay_key"
require.True(t, isEpayWebhookEnabled())
operation_setting.PayMethods = nil
require.False(t, isEpayWebhookEnabled())
}
+3 -3
View File
@@ -151,7 +151,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithStatusCode(http.StatusBadRequest))
return
}
@@ -351,7 +351,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
if service.ShouldDisableChannel(err) && channelError.AutoBan {
gopool.Go(func() {
service.DisableChannel(channelError, err.ErrorWithStatusCode())
})
@@ -389,7 +389,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
startTime = time.Now()
}
useTimeSeconds := int(time.Since(startTime).Seconds())
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, useTimeSeconds, false, userGroup, other)
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, useTimeSeconds, common.GetContextKeyBool(c, constant.ContextKeyIsStream), userGroup, other)
}
}
+7 -3
View File
@@ -13,7 +13,10 @@ import (
const (
// SecureVerificationSessionKey means the user has fully passed secure verification.
SecureVerificationSessionKey = "secure_verified_at"
SecureVerificationSessionKey = "secure_verified_at"
secureVerificationMethodSessionKey = "secure_verified_method"
secureVerificationMethod2FA = "2fa"
secureVerificationMethodPasskey = "passkey"
// PasskeyReadySessionKey means WebAuthn finished and /api/verify can finalize step-up verification.
PasskeyReadySessionKey = "secure_passkey_ready_at"
// SecureVerificationTimeout 验证有效期(秒)
@@ -120,7 +123,7 @@ func UniversalVerify(c *gin.Context) {
}
// 验证成功,在 session 中记录时间戳
now, err := setSecureVerificationSession(c)
now, err := setSecureVerificationSession(c, req.Method)
if err != nil {
common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
return
@@ -139,11 +142,12 @@ func UniversalVerify(c *gin.Context) {
})
}
func setSecureVerificationSession(c *gin.Context) (int64, error) {
func setSecureVerificationSession(c *gin.Context, method string) (int64, error) {
session := sessions.Default(c)
session.Delete(PasskeyReadySessionKey)
now := time.Now().Unix()
session.Set(SecureVerificationSessionKey, now)
session.Set(secureVerificationMethodSessionKey, method)
if err := session.Save(); err != nil {
return 0, err
}
+12 -10
View File
@@ -2,11 +2,13 @@ package controller
import (
"bytes"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
@@ -24,14 +26,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
// Keep body for debugging consistency (like RequestCreemPay)
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("read subscription creem pay req body err: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "read query error"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付请求读取失败 error=%q", err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
return
}
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if err := c.ShouldBindJSON(&req); err != nil || req.PlanId <= 0 {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
@@ -85,12 +87,12 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: referenceId,
PaymentMethod: PaymentMethodCreem,
PaymentMethod: model.PaymentMethodCreem,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := order.Insert(); err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
@@ -112,14 +114,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
Quota: 0,
}
checkoutUrl, err := genCreemLink(referenceId, product, user.Email, user.Username)
checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, product, user.Email, user.Username)
if err != nil {
log.Printf("获取Creem支付链接失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅支付链接创建失败 trade_no=%s product_id=%s error=%q", referenceId, product.ProductId, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": checkoutUrl,
+3 -3
View File
@@ -104,7 +104,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
ReturnUrl: returnUrl,
})
if err != nil {
_ = model.ExpireSubscriptionOrder(tradeNo)
_ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod)
common.ApiErrorMsg(c, "拉起支付失败")
return
}
@@ -156,7 +156,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil {
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
_, _ = c.Writer.Write([]byte("fail"))
return
}
@@ -205,7 +205,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo)); err != nil {
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
return
}
+3 -3
View File
@@ -2,12 +2,12 @@ package controller
import (
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/system_setting"
@@ -78,7 +78,7 @@ func SubscriptionRequestStripePay(c *gin.Context) {
payLink, err := genStripeSubscriptionLink(referenceId, user.StripeCustomer, user.Email, plan.StripePriceId)
if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 订阅支付链接创建失败 trade_no=%s plan_id=%d error=%q", referenceId, plan.Id, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
@@ -88,7 +88,7 @@ func SubscriptionRequestStripePay(c *gin.Context) {
PlanId: plan.Id,
Money: plan.PriceAmount,
TradeNo: referenceId,
PaymentMethod: PaymentMethodStripe,
PaymentMethod: model.PaymentMethodStripe,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
+271 -5
View File
@@ -2,10 +2,12 @@ package controller
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"testing"
@@ -14,6 +16,8 @@ import (
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
@@ -38,7 +42,36 @@ type tokenKeyResponse struct {
Key string `json:"key"`
}
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
type sqliteColumnInfo struct {
Name string `gorm:"column:name"`
Type string `gorm:"column:type"`
}
type legacyToken struct {
Id int `gorm:"primaryKey"`
UserId int `gorm:"index"`
Key string `gorm:"column:key;type:char(48);uniqueIndex"`
Status int `gorm:"default:1"`
Name string `gorm:"index"`
CreatedTime int64 `gorm:"bigint"`
AccessedTime int64 `gorm:"bigint"`
ExpiredTime int64 `gorm:"bigint;default:-1"`
RemainQuota int `gorm:"default:0"`
UnlimitedQuota bool
ModelLimitsEnabled bool
ModelLimits string `gorm:"type:text"`
AllowIps *string `gorm:"default:''"`
UsedQuota int `gorm:"default:0"`
Group string `gorm:"column:group;default:''"`
CrossGroupRetry bool
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (legacyToken) TableName() string {
return "tokens"
}
func openTokenControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()
gin.SetMode(gin.TestMode)
@@ -55,10 +88,6 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
model.DB = db
model.LOG_DB = db
if err := db.AutoMigrate(&model.Token{}); err != nil {
t.Fatalf("failed to migrate token table: %v", err)
}
t.Cleanup(func() {
sqlDB, err := db.DB()
if err == nil {
@@ -69,6 +98,69 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
return db
}
func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
t.Helper()
if err := db.AutoMigrate(&model.Token{}); err != nil {
t.Fatalf("failed to migrate token table: %v", err)
}
}
func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()
db := openTokenControllerTestDB(t)
migrateTokenControllerTestDB(t, db)
return db
}
func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
t.Helper()
gin.SetMode(gin.TestMode)
common.RedisEnabled = false
common.UsingSQLite = false
common.UsingMySQL = dialect == "mysql"
common.UsingPostgreSQL = dialect == "postgres"
var (
db *gorm.DB
err error
)
switch dialect {
case "mysql":
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
case "postgres":
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
default:
t.Fatalf("unsupported dialect %q", dialect)
}
if err != nil {
t.Fatalf("failed to open %s db: %v", dialect, err)
}
model.DB = db
model.LOG_DB = db
if db.Migrator().HasTable("tokens") {
t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
}
managedTokensTable := new(bool)
t.Cleanup(func() {
if *managedTokensTable && db.Migrator().HasTable("tokens") {
_ = db.Migrator().DropTable("tokens")
}
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})
return db, managedTokensTable
}
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
t.Helper()
@@ -124,6 +216,180 @@ func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenA
return response
}
func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
t.Helper()
var columns []sqliteColumnInfo
if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
t.Fatalf("failed to inspect %s schema: %v", tableName, err)
}
for _, column := range columns {
if column.Name == columnName {
return strings.ToLower(column.Type)
}
}
t.Fatalf("column %s not found in %s schema", columnName, tableName)
return ""
}
func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
t.Helper()
switch dialect {
case "sqlite":
return getSQLiteColumnType(t, db, "tokens", "key")
case "mysql":
var columnType string
if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
"tokens", "key").Scan(&columnType).Error; err != nil {
t.Fatalf("failed to inspect mysql token key column: %v", err)
}
return strings.ToLower(columnType)
case "postgres":
var dataType string
var maxLength sql.NullInt64
if err := db.Raw(`SELECT data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
"tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
t.Fatalf("failed to inspect postgres token key column: %v", err)
}
switch strings.ToLower(dataType) {
case "character varying":
return fmt.Sprintf("varchar(%d)", maxLength.Int64)
case "character":
return fmt.Sprintf("char(%d)", maxLength.Int64)
default:
if maxLength.Valid {
return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
}
return strings.ToLower(dataType)
}
default:
t.Fatalf("unsupported dialect %q", dialect)
return ""
}
}
func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
t.Helper()
legacyKey := strings.Repeat("a", 48)
longKey := strings.Repeat("b", 64)
if err := db.AutoMigrate(&legacyToken{}); err != nil {
t.Fatalf("failed to create legacy token schema: %v", err)
}
if managedTokensTable != nil {
*managedTokensTable = true
}
if err := db.Create(&legacyToken{
UserId: 7,
Key: legacyKey,
Status: common.TokenStatusEnabled,
Name: "legacy-token",
CreatedTime: 1,
AccessedTime: 1,
ExpiredTime: -1,
RemainQuota: 100,
UnlimitedQuota: true,
ModelLimitsEnabled: false,
ModelLimits: "",
AllowIps: common.GetPointer(""),
UsedQuota: 0,
Group: "default",
CrossGroupRetry: false,
}).Error; err != nil {
t.Fatalf("failed to seed legacy token row: %v", err)
}
if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
t.Fatalf("expected legacy key column type char(48), got %q", got)
}
migrateTokenControllerTestDB(t, db)
if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
t.Fatalf("expected migrated key column type varchar(128), got %q", got)
}
var migratedToken model.Token
if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
t.Fatalf("failed to load migrated token row: %v", err)
}
if migratedToken.Key != legacyKey {
t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
}
if migratedToken.Name != "legacy-token" {
t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
}
inserted := model.Token{
UserId: 8,
Name: "long-token",
Key: longKey,
Status: common.TokenStatusEnabled,
CreatedTime: 1,
AccessedTime: 1,
ExpiredTime: -1,
RemainQuota: 200,
UnlimitedQuota: true,
ModelLimitsEnabled: false,
ModelLimits: "",
AllowIps: common.GetPointer(""),
UsedQuota: 0,
Group: "default",
CrossGroupRetry: false,
}
if err := db.Create(&inserted).Error; err != nil {
t.Fatalf("failed to insert long token after migration: %v", err)
}
var fetched model.Token
if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
t.Fatalf("failed to fetch long token after migration: %v", err)
}
if fetched.Key != longKey {
t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
}
}
func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
db := setupTokenControllerTestDB(t)
if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
t.Fatalf("expected key column type varchar(128), got %q", got)
}
}
func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
db := openTokenControllerTestDB(t)
runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
}
func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
dsn := os.Getenv("TEST_MYSQL_DSN")
if dsn == "" {
t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
}
db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
}
func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
dsn := os.Getenv("TEST_POSTGRES_DSN")
if dsn == "" {
t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
}
db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
}
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
db := setupTokenControllerTestDB(t)
token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
+103 -57
View File
@@ -2,7 +2,7 @@ package controller
import (
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"sync"
@@ -27,7 +27,7 @@ func GetTopUpInfo(c *gin.Context) {
payMethods := operation_setting.PayMethods
// 如果启用了 Stripe 支付,添加到支付方法列表
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
if isStripeTopUpEnabled() {
// 检查是否已经包含 Stripe
hasStripe := false
for _, method := range payMethods {
@@ -49,19 +49,11 @@ func GetTopUpInfo(c *gin.Context) {
}
// 如果启用了 Waffo 支付,添加到支付方法列表
enableWaffo := setting.WaffoEnabled &&
((!setting.WaffoSandbox &&
setting.WaffoApiKey != "" &&
setting.WaffoPrivateKey != "" &&
setting.WaffoPublicCert != "") ||
(setting.WaffoSandbox &&
setting.WaffoSandboxApiKey != "" &&
setting.WaffoSandboxPrivateKey != "" &&
setting.WaffoSandboxPublicCert != ""))
enableWaffo := isWaffoTopUpEnabled()
if enableWaffo {
hasWaffo := false
for _, method := range payMethods {
if method["type"] == "waffo" {
if method["type"] == model.PaymentMethodWaffo {
hasWaffo = true
break
}
@@ -70,7 +62,7 @@ func GetTopUpInfo(c *gin.Context) {
if !hasWaffo {
waffoMethod := map[string]string{
"name": "Waffo (Global Payment)",
"type": "waffo",
"type": model.PaymentMethodWaffo,
"color": "rgba(var(--semi-blue-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoMinTopUp),
}
@@ -78,24 +70,46 @@ func GetTopUpInfo(c *gin.Context) {
}
}
enableWaffoPancake := isWaffoPancakeTopUpEnabled()
if enableWaffoPancake {
hasWaffoPancake := false
for _, method := range payMethods {
if method["type"] == model.PaymentMethodWaffoPancake {
hasWaffoPancake = true
break
}
}
if !hasWaffoPancake {
payMethods = append(payMethods, map[string]string{
"name": "Waffo Pancake",
"type": model.PaymentMethodWaffoPancake,
"color": "rgba(var(--semi-orange-5), 1)",
"min_topup": strconv.Itoa(setting.WaffoPancakeMinTopUp),
})
}
}
data := gin.H{
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"enable_creem_topup": setting.CreemApiKey != "" && setting.CreemProducts != "[]",
"enable_waffo_topup": enableWaffo,
"enable_online_topup": isEpayTopUpEnabled(),
"enable_stripe_topup": isStripeTopUpEnabled(),
"enable_creem_topup": isCreemTopUpEnabled(),
"enable_waffo_topup": enableWaffo,
"enable_waffo_pancake_topup": enableWaffoPancake,
"waffo_pay_methods": func() interface{} {
if enableWaffo {
return setting.GetWaffoPayMethods()
}
return nil
}(),
"creem_products": setting.CreemProducts,
"pay_methods": payMethods,
"min_topup": operation_setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,
"waffo_min_topup": setting.WaffoMinTopUp,
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
"creem_products": setting.CreemProducts,
"pay_methods": payMethods,
"min_topup": operation_setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,
"waffo_min_topup": setting.WaffoMinTopUp,
"waffo_pancake_min_topup": setting.WaffoPancakeMinTopUp,
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
}
common.ApiSuccess(c, data)
}
@@ -109,6 +123,17 @@ type AmountRequest struct {
Amount int64 `json:"amount"`
}
var nonEpayPaymentMethodsForCallback = []string{
model.PaymentMethodStripe,
model.PaymentMethodCreem,
model.PaymentMethodWaffo,
model.PaymentMethodWaffoPancake,
}
func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool {
return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod)
}
func GetEpayClient() *epay.Client {
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
return nil
@@ -167,28 +192,28 @@ func RequestEpay(c *gin.Context) {
var req EpayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getPayMoney(req.Amount, group)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付方式不存在"})
return
}
@@ -199,7 +224,7 @@ func RequestEpay(c *gin.Context) {
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
@@ -212,7 +237,8 @@ func RequestEpay(c *gin.Context) {
ReturnUrl: returnUrl,
})
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 拉起支付失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
amount := req.Amount
@@ -228,14 +254,16 @@ func RequestEpay(c *gin.Context) {
TradeNo: tradeNo,
PaymentMethod: req.PaymentMethod,
CreateTime: time.Now().Unix(),
Status: "pending",
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 创建充值订单失败 user_id=%d trade_no=%s payment_method=%s amount=%d error=%q", id, tradeNo, req.PaymentMethod, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值订单创建成功 user_id=%d trade_no=%s payment_method=%s amount=%d money=%.2f uri=%q params=%q", id, tradeNo, req.PaymentMethod, req.Amount, payMoney, uri, common.GetJsonString(params)))
c.JSON(http.StatusOK, gin.H{"message": "success", "data": params, "url": uri})
}
// tradeNo lock
@@ -281,12 +309,18 @@ func UnlockOrder(tradeNo string) {
}
func EpayNotify(c *gin.Context) {
if !isEpayWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, _ = c.Writer.Write([]byte("fail"))
return
}
var params map[string]string
if c.Request.Method == "POST" {
// POST 请求:从 POST body 解析参数
if err := c.Request.ParseForm(); err != nil {
log.Println("易支付回调POST解析失败:", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook POST 表单解析失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
_, _ = c.Writer.Write([]byte("fail"))
return
}
@@ -301,50 +335,63 @@ func EpayNotify(c *gin.Context) {
return r
}, map[string]string{})
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 收到请求 path=%q client_ip=%s method=%s params=%q", c.Request.RequestURI, c.ClientIP(), c.Request.Method, common.GetJsonString(params)))
if len(params) == 0 {
log.Println("易支付回调参数为空")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 参数为空 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, _ = c.Writer.Write([]byte("fail"))
return
}
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 client 未初始化 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
}
return
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签成功 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 trade_no=%s client_ip=%s error=%q", verifyInfo.ServiceTradeNo, c.ClientIP(), err.Error()))
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 webhook 响应写入失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
}
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
} else {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 webhook 验签失败 path=%q client_ip=%s verify_status=false", c.Request.RequestURI, c.ClientIP()))
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
return
}
if topUp.PaymentMethod != verifyInfo.Type {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
return
}
if topUp.Status == common.TopUpStatusPending {
topUp.Status = common.TopUpStatusSuccess
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新充值订单失败 trade_no=%s user_id=%d client_ip=%s error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), err.Error(), common.GetJsonString(topUp)))
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
@@ -354,14 +401,14 @@ func EpayNotify(c *gin.Context) {
quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
logger.LogError(c.Request.Context(), fmt.Sprintf("易支付 更新用户额度失败 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d error=%q topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, err.Error(), common.GetJsonString(topUp)))
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 充值成功 trade_no=%s user_id=%d client_ip=%s quota_to_add=%d money=%.2f topup=%q", topUp.TradeNo, topUp.UserId, c.ClientIP(), quotaToAdd, topUp.Money, common.GetJsonString(topUp)))
model.RecordTopupLog(topUp.UserId, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money), c.ClientIP(), topUp.PaymentMethod, "epay")
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 webhook 忽略事件 trade_no=%s callback_type=%s trade_status=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, verifyInfo.TradeStatus, c.ClientIP(), common.GetJsonString(verifyInfo)))
}
}
@@ -369,26 +416,26 @@ func RequestAmount(c *gin.Context) {
var req AmountRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getPayMoney(req.Amount, group)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
func GetUserTopUps(c *gin.Context) {
@@ -457,10 +504,9 @@ func AdminCompleteTopUp(c *gin.Context) {
LockOrder(req.TradeNo)
defer UnlockOrder(req.TradeNo)
if err := model.ManualCompleteTopUp(req.TradeNo); err != nil {
if err := model.ManualCompleteTopUp(req.TradeNo, c.ClientIP()); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, nil)
}
+63 -71
View File
@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
@@ -9,10 +10,10 @@ import (
"errors"
"fmt"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"io"
"log"
"net/http"
"time"
@@ -20,10 +21,7 @@ import (
"github.com/thanhpk/randstr"
)
const (
PaymentMethodCreem = "creem"
CreemSignatureHeader = "creem-signature"
)
const CreemSignatureHeader = "creem-signature"
var creemAdaptor = &CreemAdaptor{}
@@ -37,9 +35,9 @@ func generateCreemSignature(payload string, secret string) string {
// 验证Creem webhook签名
func verifyCreemSignature(payload string, signature string, secret string) bool {
if secret == "" {
log.Printf("Creem webhook secret not set")
logger.LogWarn(context.Background(), fmt.Sprintf("Creem webhook secret 未配置 test_mode=%t signature=%q body=%q", setting.CreemTestMode, signature, payload))
if setting.CreemTestMode {
log.Printf("Skip Creem webhook sign verify in test mode")
logger.LogInfo(context.Background(), fmt.Sprintf("Creem webhook 验签已跳过 reason=test_mode signature=%q body=%q", signature, payload))
return true
}
return false
@@ -66,13 +64,13 @@ type CreemAdaptor struct {
}
func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
if req.PaymentMethod != PaymentMethodCreem {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
if req.PaymentMethod != model.PaymentMethodCreem {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.ProductId == "" {
c.JSON(200, gin.H{"message": "error", "data": "请选择产品"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "请选择产品"})
return
}
@@ -80,8 +78,8 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
var products []CreemProduct
err := json.Unmarshal([]byte(setting.CreemProducts), &products)
if err != nil {
log.Println("解析Creem产品列表失败", err)
c.JSON(200, gin.H{"message": "error", "data": "产品配置错误"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 产品配置解析失败 user_id=%d error=%q", c.GetInt("id"), err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品配置错误"})
return
}
@@ -95,7 +93,7 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
}
if selectedProduct == nil {
c.JSON(200, gin.H{"message": "error", "data": "产品不存在"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "产品不存在"})
return
}
@@ -108,32 +106,32 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
// 先创建订单记录,使用产品配置的金额和充值额度
topUp := &model.TopUp{
UserId: id,
Amount: selectedProduct.Quota, // 充值额度
Money: selectedProduct.Price, // 支付金额
TradeNo: referenceId,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
UserId: id,
Amount: selectedProduct.Quota, // 充值额度
Money: selectedProduct.Price, // 支付金额
TradeNo: referenceId,
PaymentMethod: model.PaymentMethodCreem,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
log.Printf("创建Creem订单失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建充值订单失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
// 创建支付链接,传入用户邮箱
checkoutUrl, err := genCreemLink(referenceId, selectedProduct, user.Email, user.Username)
checkoutUrl, err := genCreemLink(c.Request.Context(), referenceId, selectedProduct, user.Email, user.Username)
if err != nil {
log.Printf("获取Creem支付链接失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 创建支付链接失败 user_id=%d trade_no=%s product_id=%s error=%q", id, referenceId, selectedProduct.ProductId, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
log.Printf("Creem订单创建成功 - 用户ID: %d, 订单号: %s, 产品: %s, 充值额度: %d, 支付金额: %.2f",
id, referenceId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单创建成功 user_id=%d trade_no=%s product_id=%s product_name=%q quota=%d money=%.2f", id, referenceId, selectedProduct.ProductId, selectedProduct.Name, selectedProduct.Quota, selectedProduct.Price))
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": checkoutUrl,
@@ -148,20 +146,19 @@ func RequestCreemPay(c *gin.Context) {
// 读取body内容用于打印,同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("read creem pay req body err: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "read query error"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 支付请求读取失败 error=%q", err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "read query error"})
return
}
// 打印body内容
log.Printf("creem pay request body: %s", string(bodyBytes))
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付请求已收到 user_id=%d body=%q", c.GetInt("id"), string(bodyBytes)))
// 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err = c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
creemAdaptor.RequestPay(c, &req)
@@ -229,35 +226,37 @@ type CreemWebhookEvent struct {
}
func CreemWebhook(c *gin.Context) {
if !isCreemWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
// 读取body内容用于打印,同时保留原始数据供后续使用
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("读取Creem Webhook请求body失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest)
return
}
// 获取签名头
signature := c.GetHeader(CreemSignatureHeader)
// 打印关键信息(避免输出完整敏感payload)
log.Printf("Creem Webhook - URI: %s", c.Request.RequestURI)
if setting.CreemTestMode {
log.Printf("Creem Webhook - Signature: %s , Body: %s", signature, bodyBytes)
} else if signature == "" {
log.Printf("Creem Webhook缺少签名头")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
if signature == "" {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少签名 path=%q client_ip=%s body=%q", c.Request.RequestURI, c.ClientIP(), string(bodyBytes)))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
// 验证签名
if !verifyCreemSignature(string(bodyBytes), signature, setting.CreemWebhookSecret) {
log.Printf("Creem Webhook签名验证失败")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
log.Printf("Creem Webhook签名验证成功")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 验签成功 path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
// 重新设置body供后续的ShouldBindJSON使用
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -265,19 +264,19 @@ func CreemWebhook(c *gin.Context) {
// 解析新格式的webhook数据
var webhookEvent CreemWebhookEvent
if err := c.ShouldBindJSON(&webhookEvent); err != nil {
log.Printf("解析Creem Webhook参数失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), string(bodyBytes)))
c.AbortWithStatus(http.StatusBadRequest)
return
}
log.Printf("Creem Webhook解析成功 - EventType: %s, EventId: %s", webhookEvent.EventType, webhookEvent.Id)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 解析成功 event_type=%s event_id=%s request_id=%s order_id=%s order_status=%s", webhookEvent.EventType, webhookEvent.Id, webhookEvent.Object.RequestId, webhookEvent.Object.Order.Id, webhookEvent.Object.Order.Status))
// 根据事件类型处理不同的webhook
switch webhookEvent.EventType {
case "checkout.completed":
handleCheckoutCompleted(c, &webhookEvent)
default:
log.Printf("忽略Creem Webhook事件类型: %s", webhookEvent.EventType)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem webhook 忽略事件 event_type=%s event_id=%s", webhookEvent.EventType, webhookEvent.Id))
c.Status(http.StatusOK)
}
}
@@ -286,7 +285,7 @@ func CreemWebhook(c *gin.Context) {
func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 验证订单状态
if event.Object.Order.Status != "paid" {
log.Printf("订单状态不是已支付: %s, 跳过处理", event.Object.Order.Status)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订单状态未支付,忽略处理 request_id=%s order_id=%s order_status=%s", event.Object.RequestId, event.Object.Order.Id, event.Object.Order.Status))
c.Status(http.StatusOK)
return
}
@@ -294,7 +293,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 获取引用ID(这是我们创建订单时传递的request_id)
referenceId := event.Object.RequestId
if referenceId == "" {
log.Println("Creem Webhook缺少request_id字段")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem webhook 缺少 request_id event_id=%s order_id=%s", event.Id, event.Object.Order.Id))
c.AbortWithStatus(http.StatusBadRequest)
return
}
@@ -302,40 +301,35 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// Try complete subscription order first
LockOrder(referenceId)
defer UnlockOrder(referenceId)
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event)); err == nil {
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil {
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
c.Status(http.StatusOK)
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理失败 trade_no=%s creem_order_id=%s error=%q", referenceId, event.Object.Order.Id, err.Error()))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
// 验证订单类型,目前只处理一次性付款(充值)
if event.Object.Order.Type != "onetime" {
log.Printf("暂不支持订单类型: %s, 跳过处理", event.Object.Order.Type)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 暂不支持订单类型,忽略处理 request_id=%s creem_order_id=%s order_type=%s", referenceId, event.Object.Order.Id, event.Object.Order.Type))
c.Status(http.StatusOK)
return
}
// 记录详细的支付信息
log.Printf("处理Creem支付完成 - 订单号: %s, Creem订单ID: %s, 支付金额: %d %s, 客户邮箱: <redacted>, 产品: %s",
referenceId,
event.Object.Order.Id,
event.Object.Order.AmountPaid,
event.Object.Order.Currency,
event.Object.Product.Name)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 支付完成回调 trade_no=%s creem_order_id=%s amount_paid=%d currency=%s product_name=%q customer_email=%q customer_name=%q", referenceId, event.Object.Order.Id, event.Object.Order.AmountPaid, event.Object.Order.Currency, event.Object.Product.Name, event.Object.Customer.Email, event.Object.Customer.Name))
// 查询本地订单确认存在
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Printf("Creem充值订单不存在: %s", referenceId)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 充值订单不存在 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
c.AbortWithStatus(http.StatusBadRequest)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Printf("Creem充值订单状态错误: %s, 当前状态: %s", referenceId, topUp.Status)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值订单状态非 pending,忽略处理 trade_no=%s status=%s creem_order_id=%s", referenceId, topUp.Status, event.Object.Order.Id))
c.Status(http.StatusOK) // 已处理过的订单,返回成功避免重复处理
return
}
@@ -346,21 +340,20 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
// 防护性检查,确保邮箱和姓名不为空字符串
if customerEmail == "" {
log.Printf("警告:Creem回调客户邮箱为空 - 订单号: %s", referenceId)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户邮箱为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
}
if customerName == "" {
log.Printf("警告:Creem回调客户姓名为空 - 订单号: %s", referenceId)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Creem 回调客户姓名为空 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
}
err := model.RechargeCreem(referenceId, customerEmail, customerName)
err := model.RechargeCreem(referenceId, customerEmail, customerName, c.ClientIP())
if err != nil {
log.Printf("Creem充值处理失败: %s, 订单号: %s", err.Error(), referenceId)
logger.LogError(c.Request.Context(), fmt.Sprintf("Creem 充值处理失败 trade_no=%s creem_order_id=%s client_ip=%s error=%q", referenceId, event.Object.Order.Id, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
log.Printf("Creem充值成功 - 订单号: %s, 充值额度: %d, 支付金额: %.2f",
referenceId, topUp.Amount, topUp.Money)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 充值成功 trade_no=%s creem_order_id=%s quota=%d money=%.2f client_ip=%s", referenceId, event.Object.Order.Id, topUp.Amount, topUp.Money, c.ClientIP()))
c.Status(http.StatusOK)
}
@@ -378,7 +371,7 @@ type CreemCheckoutResponse struct {
Id string `json:"id"`
}
func genCreemLink(referenceId string, product *CreemProduct, email string, username string) (string, error) {
func genCreemLink(ctx context.Context, referenceId string, product *CreemProduct, email string, username string) (string, error) {
if setting.CreemApiKey == "" {
return "", fmt.Errorf("未配置Creem API密钥")
}
@@ -387,7 +380,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
apiUrl := "https://api.creem.io/v1/checkouts"
if setting.CreemTestMode {
apiUrl = "https://test-api.creem.io/v1/checkouts"
log.Printf("使用Creem测试环境: %s", apiUrl)
logger.LogInfo(ctx, fmt.Sprintf("Creem 使用测试环境 api_url=%s", apiUrl))
}
// 构建请求数据,确保包含用户邮箱
@@ -423,8 +416,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", setting.CreemApiKey)
log.Printf("发送Creem支付请求 - URL: %s, 产品ID: %s, 用户邮箱: %s, 订单号: %s",
apiUrl, product.ProductId, email, referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Creem 支付请求已发送 api_url=%s product_id=%s email=%q trade_no=%s", apiUrl, product.ProductId, email, referenceId))
// 发送请求
client := &http.Client{
@@ -442,7 +434,7 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
return "", fmt.Errorf("读取响应失败: %v", err)
}
log.Printf("Creem API resp - status code: %d, resp: %s", resp.StatusCode, string(body))
logger.LogInfo(ctx, fmt.Sprintf("Creem API 响应已收到 trade_no=%s status_code=%d body=%q", referenceId, resp.StatusCode, string(body)))
// 检查响应状态
if resp.StatusCode/100 != 2 {
@@ -459,6 +451,6 @@ func genCreemLink(referenceId string, product *CreemProduct, email string, usern
return "", fmt.Errorf("Creem API resp no checkout url ")
}
log.Printf("Creem 支付链接创建成功 - 订单号: %s, 支付链接: %s", referenceId, checkoutResp.CheckoutUrl)
logger.LogInfo(ctx, fmt.Sprintf("Creem 支付链接创建成功 trade_no=%s response_id=%s checkout_url=%q", referenceId, checkoutResp.Id, checkoutResp.CheckoutUrl))
return checkoutResp.CheckoutUrl, nil
}
+31
View File
@@ -0,0 +1,31 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/model"
)
func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) {
testCases := []struct {
name string
paymentMethod string
expectedBlocked bool
}{
{name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true},
{name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true},
{name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true},
{name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true},
{name: "alipay", paymentMethod: "alipay", expectedBlocked: false},
{name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false},
{name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked {
t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod)
}
})
}
}
+122 -52
View File
@@ -1,16 +1,17 @@
package controller
import (
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
@@ -23,10 +24,6 @@ import (
"github.com/thanhpk/randstr"
)
const (
PaymentMethodStripe = "stripe"
)
var stripeAdaptor = &StripeAdaptor{}
// StripePayRequest represents a payment request for Stripe checkout.
@@ -48,34 +45,34 @@ type StripeAdaptor struct {
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getStripePayMoney(float64(req.Amount), group)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
if req.PaymentMethod != PaymentMethodStripe {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
if req.PaymentMethod != model.PaymentMethodStripe {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
c.JSON(http.StatusOK, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
c.JSON(http.StatusOK, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return
}
@@ -98,8 +95,8 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL)
if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建 Checkout Session 失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
@@ -108,16 +105,18 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
Amount: req.Amount,
Money: chargedMoney,
TradeNo: referenceId,
PaymentMethod: PaymentMethodStripe,
PaymentMethod: model.PaymentMethodStripe,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Stripe 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, referenceId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Stripe 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f", id, referenceId, req.Amount, chargedMoney))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"pay_link": payLink,
@@ -129,7 +128,7 @@ func RequestStripeAmount(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestAmount(c, &req)
@@ -139,54 +138,130 @@ func RequestStripePay(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestPay(c, &req)
}
func StripeWebhook(c *gin.Context) {
ctx := c.Request.Context()
if !isStripeWebhookEnabled() {
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
logger.LogError(ctx, fmt.Sprintf("Stripe webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := setting.StripeWebhookSecret
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(payload)))
event, err := webhook.ConstructEventWithOptions(payload, signature, setting.StripeWebhookSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
logger.LogWarn(ctx, fmt.Sprintf("Stripe webhook 验签失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest)
return
}
callerIp := c.ClientIP()
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 验签成功 event_type=%s client_ip=%s path=%q", string(event.Type), callerIp, c.Request.RequestURI))
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
sessionCompleted(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
sessionExpired(ctx, event)
case stripe.EventTypeCheckoutSessionAsyncPaymentSucceeded:
sessionAsyncPaymentSucceeded(ctx, event, callerIp)
case stripe.EventTypeCheckoutSessionAsyncPaymentFailed:
sessionAsyncPaymentFailed(ctx, event, callerIp)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
logger.LogInfo(ctx, fmt.Sprintf("Stripe webhook 忽略事件 event_type=%s client_ip=%s", string(event.Type), callerIp))
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
func sessionCompleted(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.completed 状态异常,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, status, callerIp))
return
}
paymentStatus := event.GetObjectValue("payment_status")
if paymentStatus != "paid" {
logger.LogInfo(ctx, fmt.Sprintf("Stripe Checkout 支付未完成,等待异步结果 trade_no=%s payment_status=%s client_ip=%s", referenceId, paymentStatus, callerIp))
return
}
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentSucceeded handles delayed payment methods (bank transfer, SEPA, etc.)
// that confirm payment after the checkout session completes.
func sessionAsyncPaymentSucceeded(ctx context.Context, event stripe.Event, callerIp string) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付成功 trade_no=%s client_ip=%s", referenceId, callerIp))
fulfillOrder(ctx, event, referenceId, customerId, callerIp)
}
// sessionAsyncPaymentFailed marks orders as failed when delayed payment methods
// ultimately fail (e.g. bank transfer not received, SEPA rejected).
func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp string) {
referenceId := event.GetObjectValue("client_reference_id")
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败 trade_no=%s client_ip=%s", referenceId, callerIp))
if len(referenceId) == 0 {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败事件缺少订单号 client_ip=%s", callerIp))
return
}
LockOrder(referenceId)
defer UnlockOrder(referenceId)
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但本地订单不存在 trade_no=%s client_ip=%s", referenceId, callerIp))
return
}
if topUp.PaymentMethod != model.PaymentMethodStripe {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp))
return
}
if topUp.Status != common.TopUpStatusPending {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 异步支付失败但订单状态非 pending,忽略处理 trade_no=%s status=%s client_ip=%s", referenceId, topUp.Status, callerIp))
return
}
topUp.Status = common.TopUpStatusFailed
if err := topUp.Update(); err != nil {
logger.LogError(ctx, fmt.Sprintf("Stripe 标记充值订单失败状态失败 trade_no=%s client_ip=%s error=%q", referenceId, callerIp, err.Error()))
return
}
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已标记为失败 trade_no=%s client_ip=%s", referenceId, callerIp))
}
// fulfillOrder is the shared logic for crediting quota after payment is confirmed.
func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, customerId string, callerIp string) {
if len(referenceId) == 0 {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 完成订单时缺少订单号 client_ip=%s", callerIp))
return
}
// Try complete subscription order first
LockOrder(referenceId)
defer UnlockOrder(referenceId)
payload := map[string]any{
@@ -195,65 +270,60 @@ func sessionCompleted(event stripe.Event) {
"currency": strings.ToUpper(event.GetObjectValue("currency")),
"event_type": string(event.Type),
}
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload)); err == nil {
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Println("complete subscription order failed:", err.Error(), referenceId)
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return
}
err := model.Recharge(referenceId, customerId)
err := model.Recharge(referenceId, customerId, callerIp)
if err != nil {
log.Println(err.Error(), referenceId)
logger.LogError(ctx, fmt.Sprintf("Stripe 充值处理失败 trade_no=%s event_type=%s client_ip=%s error=%q", referenceId, string(event.Type), callerIp, err.Error()))
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值成功 trade_no=%s amount_total=%.2f currency=%s event_type=%s client_ip=%s", referenceId, total/100, currency, string(event.Type), callerIp))
}
func sessionExpired(event stripe.Event) {
func sessionExpired(ctx context.Context, event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
logger.LogWarn(ctx, fmt.Sprintf("Stripe checkout.expired 状态异常,忽略处理 trade_no=%s status=%s", referenceId, status))
return
}
if len(referenceId) == 0 {
log.Println("未提供支付单号")
logger.LogWarn(ctx, "Stripe checkout.expired 缺少订单号")
return
}
// Subscription order expiration
LockOrder(referenceId)
defer UnlockOrder(referenceId)
if err := model.ExpireSubscriptionOrder(referenceId); err == nil {
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil {
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
return
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
log.Println("过期订阅订单失败", referenceId, ", err:", err.Error())
logger.LogError(ctx, fmt.Sprintf("Stripe 订阅订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired)
if errors.Is(err, model.ErrTopUpNotFound) {
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
logger.LogError(ctx, fmt.Sprintf("Stripe 充值订单过期处理失败 trade_no=%s error=%q", referenceId, err.Error()))
return
}
log.Println("充值订单已过期", referenceId)
logger.LogInfo(ctx, fmt.Sprintf("Stripe 充值订单已过期 trade_no=%s", referenceId))
}
// genStripeLink generates a Stripe Checkout session URL for payment.
+73 -36
View File
@@ -1,14 +1,15 @@
package controller
import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
@@ -99,28 +100,57 @@ type WaffoPayRequest struct {
PayMethodName string `json:"pay_method_name"` // Deprecated: 兼容旧前端,优先使用 pay_method_index
}
func RequestWaffoAmount(c *gin.Context) {
var req WaffoPayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
waffoMinTopup := int64(setting.WaffoMinTopUp)
if req.Amount < waffoMinTopup {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPayMoney(float64(req.Amount), group)
if payMoney <= 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
// RequestWaffoPay 创建 Waffo 支付订单
func RequestWaffoPay(c *gin.Context) {
if !setting.WaffoEnabled {
c.JSON(200, gin.H{"message": "error", "data": "Waffo 支付未启用"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo 支付未启用"})
return
}
var req WaffoPayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
waffoMinTopup := int64(setting.WaffoMinTopUp)
if req.Amount < waffoMinTopup {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", waffoMinTopup)})
return
}
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil || user == nil {
c.JSON(200, gin.H{"message": "error", "data": "用户不存在"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
return
}
@@ -131,8 +161,8 @@ func RequestWaffoPay(c *gin.Context) {
// 新协议:按索引查找
idx := *req.PayMethodIndex
if idx < 0 || idx >= len(methods) {
log.Printf("Waffo 无效的支付方式索引: %d, UserId=%d, 可用范围: [0, %d)", idx, id, len(methods))
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式索引无效 user_id=%d pay_method_index=%d method_count=%d", id, idx, len(methods)))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
return
}
resolvedPayMethodType = methods[idx].PayMethodType
@@ -149,8 +179,8 @@ func RequestWaffoPay(c *gin.Context) {
}
}
if !valid {
log.Printf("Waffo 无效的支付方式: PayMethodType=%s, PayMethodName=%s, UserId=%d", req.PayMethodType, req.PayMethodName, id)
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付方式"})
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 支付方式无效 user_id=%d pay_method_type=%s pay_method_name=%q", id, req.PayMethodType, req.PayMethodName))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "不支持的支付方式"})
return
}
}
@@ -159,7 +189,7 @@ func RequestWaffoPay(c *gin.Context) {
group, _ := model.GetUserGroup(id, true)
payMoney := getWaffoPayMoney(float64(req.Amount), group)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
@@ -182,22 +212,22 @@ func RequestWaffoPay(c *gin.Context) {
Amount: amount,
Money: payMoney,
TradeNo: merchantOrderId,
PaymentMethod: "waffo",
PaymentMethod: model.PaymentMethodWaffo,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := topUp.Insert(); err != nil {
log.Printf("Waffo 创建本地订单失败: %v", err)
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
sdk, err := getWaffoSDK()
if err != nil {
log.Printf("Waffo SDK 初始化失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo SDK 初始化失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "支付配置错误"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "支付配置错误"})
return
}
@@ -238,29 +268,29 @@ func RequestWaffoPay(c *gin.Context) {
}
resp, err := sdk.Order().Create(c.Request.Context(), createParams, nil)
if err != nil {
log.Printf("Waffo 创建订单失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建订单失败 user_id=%d trade_no=%s error=%q", id, merchantOrderId, err.Error()))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
if !resp.IsSuccess() {
log.Printf("Waffo 创建订单业务失败: [%s] %s, 完整响应: %+v", resp.Code, resp.Message, resp)
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo 创建订单业务失败 user_id=%d trade_no=%s code=%s message=%q response=%q", id, merchantOrderId, resp.Code, resp.Message, common.GetJsonString(resp)))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
orderData := resp.GetData()
log.Printf("Waffo 订单创建成功 - 用户: %d, 订单: %s, 金额: %.2f", id, merchantOrderId, payMoney)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值订单创建成功 user_id=%d trade_no=%s amount=%d money=%.2f pay_method_type=%s pay_method_name=%q", id, merchantOrderId, req.Amount, payMoney, resolvedPayMethodType, resolvedPayMethodName))
paymentUrl := orderData.FetchRedirectURL()
if paymentUrl == "" {
paymentUrl = orderData.OrderAction
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"payment_url": paymentUrl,
@@ -287,16 +317,22 @@ type webhookSubscriptionInfo struct {
// WaffoWebhook 处理 Waffo 回调通知(支付/退款/订阅)
func WaffoWebhook(c *gin.Context) {
if !isWaffoWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("Waffo Webhook 读取 body 失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusBadRequest)
return
}
sdk, err := getWaffoSDK()
if err != nil {
log.Printf("Waffo Webhook SDK 初始化失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook SDK 初始化失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
@@ -304,17 +340,18 @@ func WaffoWebhook(c *gin.Context) {
wh := sdk.Webhook()
bodyStr := string(bodyBytes)
signature := c.GetHeader("X-SIGNATURE")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
// 验证请求签名
if !wh.VerifySignature(bodyStr, signature) {
log.Printf("Waffo webhook 签名验证失败")
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签失败 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, bodyStr))
c.AbortWithStatus(http.StatusBadRequest)
return
}
var event core.WebhookEvent
if err := common.Unmarshal(bodyBytes, &event); err != nil {
log.Printf("Waffo Webhook 解析失败: %v", err)
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo webhook 解析失败 path=%q client_ip=%s error=%q body=%q", c.Request.RequestURI, c.ClientIP(), err.Error(), bodyStr))
sendWaffoWebhookResponse(c, wh, false, "invalid payload")
return
}
@@ -324,14 +361,14 @@ func WaffoWebhook(c *gin.Context) {
// 解析为扩展类型,区分普通支付和订阅支付
var payload webhookPayloadWithSubInfo
if err := common.Unmarshal(bodyBytes, &payload); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 支付回调载荷解析失败 event_type=%s client_ip=%s error=%q body=%q", event.EventType, c.ClientIP(), err.Error(), bodyStr))
sendWaffoWebhookResponse(c, wh, false, "invalid payment payload")
return
}
log.Printf("Waffo Webhook - EventType: %s, MerchantOrderId: %s, OrderStatus: %s",
event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 验签并解析成功 event_type=%s merchant_order_id=%s order_status=%s client_ip=%s", event.EventType, payload.Result.MerchantOrderID, payload.Result.OrderStatus, c.ClientIP()))
handleWaffoPayment(c, wh, &payload.Result.PaymentNotificationResult)
default:
log.Printf("Waffo Webhook 未知事件: %s", event.EventType)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo webhook 忽略事件 event_type=%s client_ip=%s", event.EventType, c.ClientIP()))
sendWaffoWebhookResponse(c, wh, true, "")
}
}
@@ -339,13 +376,13 @@ func WaffoWebhook(c *gin.Context) {
// handleWaffoPayment 处理支付完成通知
func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.PaymentNotificationResult) {
if result.OrderStatus != "PAY_SUCCESS" {
log.Printf("Waffo 订单状态非成功: %s, 订单: %s", result.OrderStatus, result.MerchantOrderID)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
// 终态失败订单标记为 failed,避免永远停在 pending
if result.MerchantOrderID != "" {
if topUp := model.GetTopUpByTradeNo(result.MerchantOrderID); topUp != nil &&
topUp.Status == common.TopUpStatusPending {
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil &&
!errors.Is(err, model.ErrTopUpNotFound) &&
!errors.Is(err, model.ErrTopUpStatusInvalid) {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
}
}
sendWaffoWebhookResponse(c, wh, true, "")
@@ -357,13 +394,13 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
LockOrder(merchantOrderId)
defer UnlockOrder(merchantOrderId)
if err := model.RechargeWaffo(merchantOrderId); err != nil {
log.Printf("Waffo 充值处理失败: %v, 订单: %s", err, merchantOrderId)
if err := model.RechargeWaffo(merchantOrderId, c.ClientIP()); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 充值处理失败 trade_no=%s client_ip=%s error=%q", merchantOrderId, c.ClientIP(), err.Error()))
sendWaffoWebhookResponse(c, wh, false, err.Error())
return
}
log.Printf("Waffo 充值成功 - 订单: %s", merchantOrderId)
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 充值成功 trade_no=%s client_ip=%s", merchantOrderId, c.ClientIP()))
sendWaffoWebhookResponse(c, wh, true, "")
}
+259
View File
@@ -0,0 +1,259 @@
package controller
import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"github.com/thanhpk/randstr"
)
type WaffoPancakePayRequest struct {
Amount int64 `json:"amount"`
}
func RequestWaffoPancakeAmount(c *gin.Context) {
var req WaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPancakePayMoney(req.Amount, group)
if payMoney <= 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": fmt.Sprintf("%.2f", payMoney)})
}
func getWaffoPancakePayMoney(amount int64, group string) float64 {
dAmount := decimal.NewFromInt(amount)
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
dAmount = dAmount.Div(decimal.NewFromFloat(common.QuotaPerUnit))
}
topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
discount := 1.0
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok && ds > 0 {
discount = ds
}
payMoney := dAmount.
Mul(decimal.NewFromFloat(setting.WaffoPancakeUnitPrice)).
Mul(decimal.NewFromFloat(topupGroupRatio)).
Mul(decimal.NewFromFloat(discount))
return payMoney.InexactFloat64()
}
func normalizeWaffoPancakeTopUpAmount(amount int64) int64 {
if operation_setting.GetQuotaDisplayType() != operation_setting.QuotaDisplayTypeTokens {
return amount
}
normalized := decimal.NewFromInt(amount).
Div(decimal.NewFromFloat(common.QuotaPerUnit)).
IntPart()
if normalized < 1 {
return 1
}
return normalized
}
func formatWaffoPancakeAmount(payMoney float64) string {
return decimal.NewFromFloat(payMoney).StringFixed(2)
}
func getWaffoPancakeBuyerEmail(user *model.User) string {
if user != nil && strings.TrimSpace(user.Email) != "" {
return user.Email
}
if user != nil {
return fmt.Sprintf("%d@new-api.local", user.Id)
}
return ""
}
func getWaffoPancakeReturnURL() string {
if strings.TrimSpace(setting.WaffoPancakeReturnURL) != "" {
return setting.WaffoPancakeReturnURL
}
return strings.TrimRight(system_setting.ServerAddress, "/") + "/console/topup?show_history=true"
}
func RequestWaffoPancakePay(c *gin.Context) {
if !setting.WaffoPancakeEnabled {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 支付未启用"})
return
}
currentWebhookKey := setting.WaffoPancakeWebhookPublicKey
if setting.WaffoPancakeSandbox {
currentWebhookKey = setting.WaffoPancakeWebhookTestKey
}
if strings.TrimSpace(setting.WaffoPancakeMerchantID) == "" ||
strings.TrimSpace(setting.WaffoPancakePrivateKey) == "" ||
strings.TrimSpace(currentWebhookKey) == "" ||
strings.TrimSpace(setting.WaffoPancakeStoreID) == "" ||
strings.TrimSpace(setting.WaffoPancakeProductID) == "" {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "Waffo Pancake 配置不完整"})
return
}
var req WaffoPancakePayRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < int64(setting.WaffoPancakeMinTopUp) {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", setting.WaffoPancakeMinTopUp)})
return
}
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil || user == nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "用户不存在"})
return
}
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getWaffoPancakePayMoney(req.Amount, group)
if payMoney < 0.01 {
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "充值金额过低"})
return
}
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
topUp := &model.TopUp{
UserId: id,
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
Money: payMoney,
TradeNo: tradeNo,
PaymentMethod: model.PaymentMethodWaffoPancake,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
if err := topUp.Insert(); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
return
}
expiresInSeconds := 45 * 60
session, err := service.CreateWaffoPancakeCheckoutSession(c.Request.Context(), &service.WaffoPancakeCreateSessionParams{
StoreID: setting.WaffoPancakeStoreID,
ProductID: setting.WaffoPancakeProductID,
ProductType: "onetime",
Currency: strings.ToUpper(strings.TrimSpace(setting.WaffoPancakeCurrency)),
PriceSnapshot: &service.WaffoPancakePriceSnapshot{
Amount: formatWaffoPancakeAmount(payMoney),
TaxIncluded: false,
TaxCategory: "saas",
},
BuyerEmail: getWaffoPancakeBuyerEmail(user),
SuccessURL: getWaffoPancakeReturnURL(),
ExpiresInSeconds: &expiresInSeconds,
})
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建结账会话失败 user_id=%d trade_no=%s error=%q", id, tradeNo, err.Error()))
topUp.Status = common.TopUpStatusFailed
_ = topUp.Update()
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值订单创建成功 user_id=%d trade_no=%s session_id=%s amount=%d money=%.2f", id, tradeNo, session.SessionID, req.Amount, payMoney))
c.JSON(http.StatusOK, gin.H{
"message": "success",
"data": gin.H{
"checkout_url": session.CheckoutURL,
"session_id": session.SessionID,
"expires_at": session.ExpiresAt,
"order_id": tradeNo,
},
})
}
func WaffoPancakeWebhook(c *gin.Context) {
if !isWaffoPancakeWebhookEnabled() {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 被拒绝 reason=webhook_disabled path=%q client_ip=%s", c.Request.RequestURI, c.ClientIP()))
c.String(http.StatusForbidden, "webhook disabled")
return
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 读取请求体失败 path=%q client_ip=%s error=%q", c.Request.RequestURI, c.ClientIP(), err.Error()))
c.String(http.StatusBadRequest, "bad request")
return
}
signature := c.GetHeader("X-Waffo-Signature")
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 收到请求 path=%q client_ip=%s signature=%q body=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes)))
event, err := service.VerifyConfiguredWaffoPancakeWebhook(string(bodyBytes), signature)
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签失败 path=%q client_ip=%s signature=%q body=%q error=%q", c.Request.RequestURI, c.ClientIP(), signature, string(bodyBytes), err.Error()))
c.String(http.StatusUnauthorized, "invalid signature")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 验签成功 event_type=%s event_id=%s order_id=%s client_ip=%s", event.NormalizedEventType(), event.ID, event.Data.OrderID, c.ClientIP()))
if event.NormalizedEventType() != "order.completed" {
c.String(http.StatusOK, "OK")
return
}
tradeNo, err := service.ResolveWaffoPancakeTradeNo(event)
if err != nil {
logger.LogWarn(c.Request.Context(), fmt.Sprintf("Waffo Pancake webhook 订单号映射失败 event_id=%s order_id=%s error=%q", event.ID, event.Data.OrderID, err.Error()))
c.String(http.StatusOK, "OK")
return
}
LockOrder(tradeNo)
defer UnlockOrder(tradeNo)
if err := model.RechargeWaffoPancake(tradeNo); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值处理失败 trade_no=%s event_id=%s order_id=%s client_ip=%s error=%q", tradeNo, event.ID, event.Data.OrderID, c.ClientIP(), err.Error()))
c.String(http.StatusInternalServerError, "retry")
return
}
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo Pancake 充值成功 trade_no=%s event_id=%s order_id=%s client_ip=%s", tradeNo, event.ID, event.Data.OrderID, c.ClientIP()))
c.String(http.StatusOK, "OK")
}
+91
View File
@@ -0,0 +1,91 @@
package controller
import (
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/stretchr/testify/require"
)
func TestFormatWaffoPancakeAmount_UsesDisplayPriceString(t *testing.T) {
testCases := []struct {
name string
amount float64
expected string
}{
{name: "whole amount", amount: 29, expected: "29.00"},
{name: "decimal amount", amount: 29.9, expected: "29.90"},
{name: "round half up to cents", amount: 29.999, expected: "30.00"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expected, formatWaffoPancakeAmount(tc.amount))
})
}
}
func TestGetWaffoPancakePayMoney(t *testing.T) {
originalUnitPrice := setting.WaffoPancakeUnitPrice
originalQuotaDisplayType := operation_setting.GetGeneralSetting().QuotaDisplayType
originalDiscounts := make(map[int]float64, len(operation_setting.GetPaymentSetting().AmountDiscount))
for k, v := range operation_setting.GetPaymentSetting().AmountDiscount {
originalDiscounts[k] = v
}
originalTopupGroupRatio := common.TopupGroupRatio2JSONString()
t.Cleanup(func() {
setting.WaffoPancakeUnitPrice = originalUnitPrice
operation_setting.GetGeneralSetting().QuotaDisplayType = originalQuotaDisplayType
operation_setting.GetPaymentSetting().AmountDiscount = originalDiscounts
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(originalTopupGroupRatio))
})
setting.WaffoPancakeUnitPrice = 2.5
operation_setting.GetPaymentSetting().AmountDiscount = map[int]float64{
10: 0.8,
int(common.QuotaPerUnit * 3): 0.5,
20: 0,
}
require.NoError(t, common.UpdateTopupGroupRatioByJSONString(`{"default":1,"vip":1.2}`))
testCases := []struct {
name string
amount int64
group string
quotaDisplayType string
expected float64
}{
{
name: "currency display applies unit price group ratio and discount",
amount: 10,
group: "vip",
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
expected: 24,
},
{
name: "tokens display converts quota to display units before pricing",
amount: int64(common.QuotaPerUnit * 3),
group: "vip",
quotaDisplayType: operation_setting.QuotaDisplayTypeTokens,
expected: 4.5,
},
{
name: "non-positive discount falls back to no discount",
amount: 20,
group: "default",
quotaDisplayType: operation_setting.QuotaDisplayTypeUSD,
expected: 50,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
operation_setting.GetGeneralSetting().QuotaDisplayType = tc.quotaDisplayType
actual := getWaffoPancakePayMoney(tc.amount, tc.group)
require.InDelta(t, tc.expected, actual, 0.000001)
})
}
}
+8 -4
View File
@@ -2,7 +2,6 @@ package controller
import (
"errors"
"fmt"
"net/http"
"strconv"
@@ -542,10 +541,15 @@ func AdminDisable2FA(c *gin.Context) {
return
}
// 记录操作日志
// 记录操作日志:管理员身份通过 admin_info 传递,避免在非管理员可见的日志内容中泄露。
adminId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeManage,
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
adminName := c.GetString("username")
adminInfo := map[string]interface{}{
"admin_id": adminId,
"admin_username": adminName,
}
model.RecordLogWithAdminInfo(userId, model.LogTypeManage,
"管理员强制禁用了用户的两步验证", adminInfo)
c.JSON(http.StatusOK, gin.H{
"success": true,
+75 -7
View File
@@ -52,10 +52,15 @@ func Login(c *gin.Context) {
}
err = user.ValidateAndFill()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"message": err.Error(),
"success": false,
})
switch {
case errors.Is(err, model.ErrDatabase):
common.SysLog(fmt.Sprintf("Login database error for user %s: %v", username, err))
common.ApiErrorI18n(c, i18n.MsgDatabaseError)
case errors.Is(err, model.ErrUserEmptyCredentials):
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
default:
common.ApiErrorI18n(c, i18n.MsgUserUsernameOrPasswordError)
}
return
}
@@ -572,9 +577,6 @@ func UpdateUser(c *gin.Context) {
common.ApiError(c, err)
return
}
if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -841,6 +843,8 @@ func CreateUser(c *gin.Context) {
type ManageRequest struct {
Id int `json:"id"`
Action string `json:"action"`
Value int `json:"value"`
Mode string `json:"mode"`
}
// ManageUser Only admin user can do this
@@ -887,6 +891,11 @@ func ManageUser(c *gin.Context) {
})
return
}
// 删除用户后,强制清理 Redis 中所有该用户令牌的缓存,
// 避免已缓存的令牌在 TTL 过期前仍能通过 TokenAuth 校验。
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
}
case "promote":
if myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserAdminCannotPromote)
@@ -907,12 +916,71 @@ func ManageUser(c *gin.Context) {
return
}
user.Role = common.RoleCommonUser
case "add_quota":
adminName := c.GetString("username")
adminId := c.GetInt("id")
adminInfo := map[string]interface{}{
"admin_id": adminId,
"admin_username": adminName,
}
switch req.Mode {
case "add":
if req.Value <= 0 {
common.ApiErrorI18n(c, i18n.MsgUserQuotaChangeZero)
return
}
if err := model.IncreaseUserQuota(user.Id, req.Value, true); err != nil {
common.ApiError(c, err)
return
}
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员增加用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
case "subtract":
if req.Value <= 0 {
common.ApiErrorI18n(c, i18n.MsgUserQuotaChangeZero)
return
}
if err := model.DecreaseUserQuota(user.Id, req.Value, true); err != nil {
common.ApiError(c, err)
return
}
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员减少用户额度 %s", logger.LogQuota(req.Value)), adminInfo)
case "override":
oldQuota := user.Quota
if err := model.DB.Model(&model.User{}).Where("id = ?", user.Id).Update("quota", req.Value).Error; err != nil {
common.ApiError(c, err)
return
}
model.RecordLogWithAdminInfo(user.Id, model.LogTypeManage,
fmt.Sprintf("管理员覆盖用户额度从 %s 为 %s", logger.LogQuota(oldQuota), logger.LogQuota(req.Value)), adminInfo)
default:
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
if err := user.Update(false); err != nil {
common.ApiError(c, err)
return
}
// 禁用 / 角色调整后,强制失效用户缓存与其全部令牌缓存,
// 避免在 Redis TTL 过期前仍使用旧状态(尤其是禁用后仍可发起请求的问题)。
// InvalidateUserCache 会让下一次 GetUserCache 从数据库重新加载,
// InvalidateUserTokensCache 则确保令牌侧的缓存也同步刷新。
if req.Action == "disable" || req.Action == "promote" || req.Action == "demote" {
if err := model.InvalidateUserCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate user cache for user %d: %s", user.Id, err.Error()))
}
if err := model.InvalidateUserTokensCache(user.Id); err != nil {
common.SysLog(fmt.Sprintf("failed to invalidate tokens cache for user %d: %s", user.Id, err.Error()))
}
}
clearUser := model.User{
Role: user.Role,
Status: user.Status,