fix: keep usage log filters exact unless wildcard is explicit (#5097)
This commit is contained in:
+40
-27
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -16,6 +17,20 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func applyExplicitLogTextFilter(tx *gorm.DB, column string, value string) (*gorm.DB, error) {
|
||||
if value == "" {
|
||||
return tx, nil
|
||||
}
|
||||
if strings.Contains(value, "%") {
|
||||
pattern, err := sanitizeLikePattern(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx.Where(column+" LIKE ? ESCAPE '!'", pattern), nil
|
||||
}
|
||||
return tx.Where(column+" = ?", value), nil
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
|
||||
@@ -308,14 +323,14 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
tx = LOG_DB.Where("logs.type = ?", logType)
|
||||
}
|
||||
|
||||
if modelName != "" {
|
||||
tx = tx.Where("logs.model_name like ?", modelName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if username != "" {
|
||||
tx = tx.Where("logs.username = ?", username)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.username", username); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("logs.token_name = ?", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.token_name", tokenName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
@@ -397,15 +412,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
|
||||
}
|
||||
|
||||
if modelName != "" {
|
||||
modelNamePattern, err := sanitizeLikePattern(modelName)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.model_name", modelName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("logs.token_name = ?", tokenName)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "logs.token_name", tokenName); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
@@ -449,13 +460,17 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
// 为rpm和tpm创建单独的查询
|
||||
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
|
||||
|
||||
if username != "" {
|
||||
tx = tx.Where("username = ?", username)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "username", username); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("token_name = ?", tokenName)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
|
||||
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "username", username); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "token_name", tokenName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "token_name", tokenName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||
@@ -463,13 +478,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if modelName != "" {
|
||||
modelNamePattern, err := sanitizeLikePattern(modelName)
|
||||
if err != nil {
|
||||
return stat, err
|
||||
}
|
||||
tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
||||
if tx, err = applyExplicitLogTextFilter(tx, "model_name", modelName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if rpmTpmQuery, err = applyExplicitLogTextFilter(rpmTpmQuery, "model_name", modelName); err != nil {
|
||||
return stat, err
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
|
||||
@@ -165,6 +165,7 @@ export function UsageLogsTable({ logCategory }: UsageLogsTableProps) {
|
||||
getFacetedRowModel: getFacetedRowModel(),
|
||||
getFacetedUniqueValues: getFacetedUniqueValues(),
|
||||
manualPagination: true,
|
||||
manualFiltering: true,
|
||||
pageCount: Math.ceil((data?.total || 0) / pagination.pageSize),
|
||||
})
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ const logTypeSearchSchema = z
|
||||
const usageLogsSearchSchema = z.object({
|
||||
page: z.number().optional().catch(1),
|
||||
pageSize: z.number().optional().catch(undefined),
|
||||
type: logTypeSearchSchema,
|
||||
type: logTypeSearchSchema.optional(),
|
||||
filter: z.string().optional().catch(''),
|
||||
model: z.string().optional().catch(''),
|
||||
token: z.string().optional().catch(''),
|
||||
|
||||
Reference in New Issue
Block a user