fix: use actual user id for channel tests (#5109)
This commit is contained in:
@@ -57,7 +57,24 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
|
||||
return normalized
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
|
||||
func resolveChannelTestUserID(c *gin.Context) (int, error) {
|
||||
if c != nil {
|
||||
if userID := c.GetInt("id"); userID > 0 {
|
||||
return userID, nil
|
||||
}
|
||||
}
|
||||
|
||||
var rootUser model.User
|
||||
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
|
||||
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
|
||||
}
|
||||
if rootUser.Id == 0 {
|
||||
return 0, errors.New("failed to resolve channel test user")
|
||||
}
|
||||
return rootUser.Id, nil
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
|
||||
tik := time.Now()
|
||||
var unsupportedTestChannelTypes = []int{
|
||||
constant.ChannelTypeMidjourney,
|
||||
@@ -143,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
cache, err := model.GetUserCache(testUserID)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
localErr: err,
|
||||
@@ -151,13 +168,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
}
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
c.Set("id", 1)
|
||||
c.Set("id", testUserID)
|
||||
|
||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
group, _ := model.GetUserGroup(testUserID, false)
|
||||
c.Set("group", group)
|
||||
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
@@ -484,7 +501,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
@@ -834,8 +851,13 @@ func TestChannel(c *gin.Context) {
|
||||
testModel := c.Query("model")
|
||||
endpointType := c.Query("endpoint_type")
|
||||
isStream, _ := strconv.ParseBool(c.Query("stream"))
|
||||
testUserID, err := resolveChannelTestUserID(c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, testModel, endpointType, isStream)
|
||||
result := testChannel(channel, testUserID, testModel, endpointType, isStream)
|
||||
if result.localErr != nil {
|
||||
resp := gin.H{
|
||||
"success": false,
|
||||
@@ -872,6 +894,10 @@ var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
testUserID, err := resolveChannelTestUserID(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
testAllChannelsLock.Lock()
|
||||
if testAllChannelsRunning {
|
||||
@@ -902,7 +928,7 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||
result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
|
||||
@@ -69,3 +69,14 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
|
||||
require.Equal(t, "base", other["matched_tier"])
|
||||
require.NotEmpty(t, other["expr_b64"])
|
||||
}
|
||||
|
||||
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Set("id", 2)
|
||||
|
||||
userID, err := resolveChannelTestUserID(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, userID)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user