Compare commits

...

9 Commits

10 changed files with 120 additions and 29 deletions
+1
View File
@@ -14,6 +14,7 @@ type ImageRequest struct {
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
}
type ImageResponse struct {
+16 -3
View File
@@ -89,9 +89,22 @@ func main() {
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if common.MemoryCacheEnabled {
// Add panic recovery and retry for InitChannelCache
func() {
defer func() {
if r := recover(); r != nil {
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once
_, fixErr := model.FixAbility()
if fixErr != nil {
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
}
}
}()
model.InitChannelCache()
}()
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
+36 -11
View File
@@ -50,7 +50,7 @@ func getPriority(group string, model string, retry int) (int, error) {
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
@@ -261,12 +261,28 @@ func FixAbility() (int, error) {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
}
// Delete abilities of channels that are not in channel table
err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
return 0, err
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
if len(channelIds) > 0 {
// Process deletion in chunks to avoid "too many placeholders" error
for _, chunk := range lo.Chunk(channelIds, 100) {
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
return 0, err
}
}
} else {
// If no channels exist, delete all abilities
err = DB.Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
return 0, err
}
common.SysLog("Delete all abilities successfully")
return 0, nil
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
@@ -275,17 +291,26 @@ func FixAbility() (int, error) {
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err
return count, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
}
if err != nil {
return 0, err
// Process query in chunks to avoid "too many placeholders" error
err = nil
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
var channelsChunk []Channel
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
if err != nil {
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
return count, err
}
channels = append(channels, channelsChunk...)
}
}
for _, channel := range channels {
err := channel.UpdateAbilities(nil)
if err != nil {
+5 -2
View File
@@ -16,6 +16,9 @@ var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
if !common.MemoryCacheEnabled {
return
}
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
@@ -84,11 +87,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, retry)
}
channelSyncLock.RLock()
channels := group2model2channels[group][model]
channelSyncLock.RUnlock()
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
+11
View File
@@ -46,6 +46,17 @@ func (channel *Channel) GetModels() []string {
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
func (channel *Channel) GetGroups() []string {
if channel.Group == "" {
return []string{}
}
groups := strings.Split(strings.Trim(channel.Group, ","), ",")
for i, group := range groups {
groups[i] = strings.TrimSpace(group)
}
return groups
}
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
+3 -3
View File
@@ -38,10 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
} else {
a.RequestMode = RequestModeMessage
}
}
+4 -3
View File
@@ -2,10 +2,10 @@ package gemini
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
}
type GeminiThinkingConfig struct {
@@ -54,6 +54,7 @@ type GeminiFileData struct {
type GeminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
+15 -3
View File
@@ -539,6 +539,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else if part.Thought {
choice.Message.ReasoningContent = part.Text
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
@@ -556,7 +558,6 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
@@ -596,6 +597,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
var texts []string
isTools := false
isThought := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
switch *candidate.FinishReason {
@@ -620,6 +622,9 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
} else if part.Thought {
isThought = true
texts = append(texts, part.Text)
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
@@ -632,7 +637,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
}
}
choice.Delta.SetContentString(strings.Join(texts, "\n"))
if isThought {
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
} else {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
}
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
@@ -716,8 +725,11 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
err = common.DecodeJson(responseBody, &geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
+17 -1
View File
@@ -32,7 +32,23 @@ func RelayMidjourneyImage(c *gin.Context) {
})
return
}
resp, err := http.Get(midjourneyTask.ImageUrl)
var httpClient *http.Client
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
if proxy, ok := channel.GetSetting()["proxy"]; ok {
if proxyURL, ok := proxy.(string); ok && proxyURL != "" {
if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil {
c.JSON(400, gin.H{
"error": "proxy_url_invalid",
})
return
}
}
}
}
if httpClient == nil {
httpClient = service.GetHttpClient()
}
resp, err := httpClient.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",
+12 -3
View File
@@ -871,7 +871,16 @@ const ChannelsTable = () => {
};
const refresh = async () => {
await loadChannels(activePage - 1, pageSize, idSort, enableTagMode);
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
await loadChannels(activePage - 1, pageSize, idSort, enableTagMode);
} else {
await searchChannels(
searchKeyword,
searchGroup,
searchModel,
enableTagMode,
);
}
};
useEffect(() => {
@@ -979,8 +988,8 @@ const ChannelsTable = () => {
enableTagMode,
) => {
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
await loadChannels(0, pageSize, idSort, enableTagMode);
setActivePage(1);
await loadChannels(activePage - 1, pageSize, idSort, enableTagMode);
// setActivePage(1);
return;
}
setSearching(true);