fix: apply group filter to channel list queries (#4885)
This commit is contained in:
+63
-54
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type OpenAIModel struct {
|
||||
@@ -68,12 +69,33 @@ func clearChannelInfo(channel *model.Channel) {
|
||||
}
|
||||
}
|
||||
|
||||
func applyChannelStatusFilter(query *gorm.DB, statusFilter int) *gorm.DB {
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
return query.Where("status = ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
if statusFilter == 0 {
|
||||
return query.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func buildChannelListQuery(group string, statusFilter int, typeFilter int) *gorm.DB {
|
||||
query := model.DB.Model(&model.Channel{})
|
||||
query = model.ApplyChannelGroupFilter(query, group)
|
||||
query = applyChannelStatusFilter(query, statusFilter)
|
||||
if typeFilter >= 0 {
|
||||
query = query.Where("type = ?", typeFilter)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
groupFilter := model.NormalizeChannelGroupFilter(c.Query("group"))
|
||||
statusParam := c.Query("status")
|
||||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
@@ -85,69 +107,49 @@ func GetAllChannels(c *gin.Context) {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
// group filter
|
||||
groupFilter := c.Query("group")
|
||||
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
tags, err := model.GetPaginatedChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter), pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.SysError("failed to get paginated tags: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
|
||||
return
|
||||
}
|
||||
total, err = model.CountChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter))
|
||||
if err != nil {
|
||||
common.SysError("failed to count tags: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签数量失败,请稍后重试"})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag == nil || *tag == "" {
|
||||
continue
|
||||
}
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
|
||||
var tagChannels []*model.Channel
|
||||
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter).Where("tag = ?", *tag)).
|
||||
Omit("key").
|
||||
Find(&tagChannels).Error
|
||||
if err != nil {
|
||||
continue
|
||||
common.SysError("failed to get channels by tag: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签渠道失败,请稍后重试"})
|
||||
return
|
||||
}
|
||||
filtered := make([]*model.Channel, 0)
|
||||
for _, ch := range tagChannels {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if typeFilter >= 0 && ch.Type != typeFilter {
|
||||
continue
|
||||
}
|
||||
if groupFilter != "" && groupFilter != "null" {
|
||||
if !strings.Contains(","+ch.Group+",", ","+groupFilter+",") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = append(channelData, filtered...)
|
||||
channelData = append(channelData, tagChannels...)
|
||||
}
|
||||
total, _ = model.CountAllTags()
|
||||
} else {
|
||||
baseQuery := model.DB.Model(&model.Channel{})
|
||||
if typeFilter >= 0 {
|
||||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||||
}
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
if groupFilter != "" && groupFilter != "null" {
|
||||
if common.UsingMySQL {
|
||||
baseQuery = baseQuery.Where("CONCAT(',', `group`, ',') LIKE ?", "%,"+groupFilter+",%")
|
||||
} else {
|
||||
// SQLite, PostgreSQL
|
||||
baseQuery = baseQuery.Where("(',' || \"group\" || ',') LIKE ?", "%,"+groupFilter+",%")
|
||||
}
|
||||
if err := buildChannelListQuery(groupFilter, statusFilter, typeFilter).Count(&total).Error; err != nil {
|
||||
common.SysError("failed to count channels: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道数量失败,请稍后重试"})
|
||||
return
|
||||
}
|
||||
|
||||
baseQuery.Count(&total)
|
||||
|
||||
err := sortOptions.Apply(baseQuery).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||||
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter)).
|
||||
Limit(pageInfo.GetPageSize()).
|
||||
Offset(pageInfo.GetStartIdx()).
|
||||
Omit("key").
|
||||
Find(&channelData).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to get channels: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
|
||||
@@ -159,17 +161,16 @@ func GetAllChannels(c *gin.Context) {
|
||||
clearChannelInfo(datum)
|
||||
}
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
countQuery := buildChannelListQuery(groupFilter, statusFilter, -1)
|
||||
var results []struct {
|
||||
Type int64
|
||||
Count int64
|
||||
}
|
||||
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
if err := countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error; err != nil {
|
||||
common.SysError("failed to count channel types: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道类型统计失败,请稍后重试"})
|
||||
return
|
||||
}
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
typeCounts[r.Type] = r.Count
|
||||
@@ -277,10 +278,18 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag != nil && *tag != "" {
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
|
||||
if err == nil {
|
||||
channelData = append(channelData, tagChannel...)
|
||||
var tagChannels []*model.Channel
|
||||
err := sortOptions.Apply(buildChannelListQuery(group, -1, -1).Where("tag = ?", *tag)).
|
||||
Omit("key").
|
||||
Find(&tagChannels).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channelData = append(channelData, tagChannels...)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user