fix: use actual user id for channel tests (#5109)

This commit is contained in:
Seefs
2026-05-26 12:32:20 +08:00
committed by GitHub
parent c91ba0c4eb
commit 30025aeba3
2 changed files with 44 additions and 7 deletions
+33 -7
View File
@@ -57,7 +57,24 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
return normalized
}
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
func resolveChannelTestUserID(c *gin.Context) (int, error) {
if c != nil {
if userID := c.GetInt("id"); userID > 0 {
return userID, nil
}
}
var rootUser model.User
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
}
if rootUser.Id == 0 {
return 0, errors.New("failed to resolve channel test user")
}
return rootUser.Id, nil
}
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney,
@@ -143,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
Header: make(http.Header),
}
cache, err := model.GetUserCache(1)
cache, err := model.GetUserCache(testUserID)
if err != nil {
return testResult{
localErr: err,
@@ -151,13 +168,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
}
}
cache.WriteContext(c)
c.Set("id", 1)
c.Set("id", testUserID)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false)
group, _ := model.GetUserGroup(testUserID, false)
c.Set("group", group)
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
@@ -484,7 +501,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
@@ -834,8 +851,13 @@ func TestChannel(c *gin.Context) {
testModel := c.Query("model")
endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream"))
testUserID, err := resolveChannelTestUserID(c)
if err != nil {
common.ApiError(c, err)
return
}
tik := time.Now()
result := testChannel(channel, testModel, endpointType, isStream)
result := testChannel(channel, testUserID, testModel, endpointType, isStream)
if result.localErr != nil {
resp := gin.H{
"success": false,
@@ -872,6 +894,10 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
testUserID, err := resolveChannelTestUserID(nil)
if err != nil {
return err
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
@@ -902,7 +928,7 @@ func testAllChannels(notify bool) error {
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
+11
View File
@@ -69,3 +69,14 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
require.Equal(t, "base", other["matched_tier"])
require.NotEmpty(t, other["expr_b64"])
}
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Set("id", 2)
userID, err := resolveChannelTestUserID(ctx)
require.NoError(t, err)
require.Equal(t, 2, userID)
}