fix: truncate oversized upstream error logs (#5083)
This commit is contained in:
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -20,6 +21,16 @@ var (
|
||||
maskApiKeyPattern = regexp.MustCompile(`(['"]?)api_key:([^\s'"]+)(['"]?)`)
|
||||
)
|
||||
|
||||
const LocalLogContentLimit = 2048
|
||||
|
||||
// LocalLogPreview limits log-only content unless debug logging is enabled.
|
||||
func LocalLogPreview(content string) string {
|
||||
if DebugEnabled || len(content) <= LocalLogContentLimit {
|
||||
return content
|
||||
}
|
||||
return fmt.Sprintf("%s... [truncated, original_length=%d, limit=%d]", content[:LocalLogContentLimit], len(content), LocalLogContentLimit)
|
||||
}
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
if str == "" {
|
||||
return defaultValue
|
||||
|
||||
+2
-2
@@ -88,7 +88,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("relay error: %s", common.LocalLogPreview(newAPIError.Error())))
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
@@ -354,7 +354,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, common.LocalLogPreview(err.Error())))
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(err) && channelError.AutoBan {
|
||||
|
||||
+19
-19
@@ -17,24 +17,24 @@ import (
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
|
||||
UpstreamRequestId string `json:"upstream_request_id,omitempty" gorm:"type:varchar(128);index:idx_logs_upstream_request_id;default:''"`
|
||||
Other string `json:"other"`
|
||||
@@ -145,7 +145,7 @@ func RecordTopupLog(userId int, content string, callerIp string, paymentMethod s
|
||||
|
||||
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, common.LocalLogPreview(content)))
|
||||
username := c.GetString("username")
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@ func formatNotifyType(channelId int, status int) string {
|
||||
|
||||
// disable & notify
|
||||
func DisableChannel(channelError types.ChannelError, reason string) {
|
||||
common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason))
|
||||
common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, common.LocalLogPreview(reason)))
|
||||
|
||||
// 检查是否启用自动禁用功能
|
||||
if !channelError.AutoBan {
|
||||
|
||||
+5
-3
@@ -92,11 +92,13 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
|
||||
}
|
||||
CloseResponseBodyGracefully(resp)
|
||||
var errResponse dto.GeneralErrorResponse
|
||||
responseBodyText := string(responseBody)
|
||||
responseBodyPreview := common.LocalLogPreview(responseBodyText)
|
||||
buildErrWithBody := func(message string) error {
|
||||
if message == "" {
|
||||
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, responseBodyText)
|
||||
}
|
||||
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, string(responseBody))
|
||||
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, responseBodyText)
|
||||
}
|
||||
|
||||
err = common.Unmarshal(responseBody, &errResponse)
|
||||
@@ -104,7 +106,7 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
|
||||
if showBodyWhenFail {
|
||||
newApiErr.Err = buildErrWithBody("")
|
||||
} else {
|
||||
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, responseBodyPreview))
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -55,3 +63,99 @@ func TestResetStatusCode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerTruncatesInvalidJSONBodyInLog(t *testing.T) {
|
||||
withDebugEnabled(t, false)
|
||||
|
||||
body := strings.Repeat("b", common.LocalLogContentLimit+256)
|
||||
var logBuffer bytes.Buffer
|
||||
|
||||
common.LogWriterMu.Lock()
|
||||
oldWriter := gin.DefaultErrorWriter
|
||||
gin.DefaultErrorWriter = &logBuffer
|
||||
common.LogWriterMu.Unlock()
|
||||
t.Cleanup(func() {
|
||||
common.LogWriterMu.Lock()
|
||||
gin.DefaultErrorWriter = oldWriter
|
||||
common.LogWriterMu.Unlock()
|
||||
})
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.Equal(t, "bad response status code 500", newAPIError.Error())
|
||||
require.Contains(t, logBuffer.String(), "[truncated")
|
||||
require.Contains(t, logBuffer.String(), fmt.Sprintf("original_length=%d", len(body)))
|
||||
require.NotContains(t, logBuffer.String(), strings.Repeat("b", common.LocalLogContentLimit+1))
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerKeepsStructuredErrorMessage(t *testing.T) {
|
||||
message := strings.Repeat("c", common.LocalLogContentLimit+256)
|
||||
body := `{"message":"` + message + `"}`
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.Equal(t, message, newAPIError.Error())
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerKeepsOpenAIErrorMessage(t *testing.T) {
|
||||
message := strings.Repeat("d", common.LocalLogContentLimit+256)
|
||||
body := `{"error":{"message":"` + message + `","type":"server_error","code":"server_error"}}`
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.Equal(t, message, newAPIError.Error())
|
||||
}
|
||||
|
||||
func TestRelayErrorHandlerKeepsInvalidJSONBodyInDebugLog(t *testing.T) {
|
||||
withDebugEnabled(t, true)
|
||||
|
||||
body := strings.Repeat("e", common.LocalLogContentLimit+256)
|
||||
var logBuffer bytes.Buffer
|
||||
|
||||
common.LogWriterMu.Lock()
|
||||
oldWriter := gin.DefaultErrorWriter
|
||||
gin.DefaultErrorWriter = &logBuffer
|
||||
common.LogWriterMu.Unlock()
|
||||
t.Cleanup(func() {
|
||||
common.LogWriterMu.Lock()
|
||||
gin.DefaultErrorWriter = oldWriter
|
||||
common.LogWriterMu.Unlock()
|
||||
})
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
|
||||
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
||||
|
||||
require.NotNil(t, newAPIError)
|
||||
require.NotContains(t, logBuffer.String(), "[truncated")
|
||||
require.Contains(t, logBuffer.String(), body)
|
||||
}
|
||||
|
||||
func withDebugEnabled(t *testing.T, enabled bool) {
|
||||
t.Helper()
|
||||
|
||||
oldDebug := common.DebugEnabled
|
||||
common.DebugEnabled = enabled
|
||||
t.Cleanup(func() {
|
||||
common.DebugEnabled = oldDebug
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user