Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b46efdc8c2 | |||
| 1e22f40518 | |||
| a78f363af9 | |||
| 7e8adb5b34 | |||
| 87692e606f | |||
| 6288b26f35 | |||
| 72be58523c | |||
| dbd412f852 |
@@ -8,3 +8,4 @@ build
|
||||
logs
|
||||
web/dist
|
||||
.env
|
||||
one-api
|
||||
+3
-1
@@ -35,7 +35,9 @@ func StrToMap(str string) map[string]interface{} {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
return map[string]interface{}{
|
||||
"result": str,
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
+2
-1
@@ -4,7 +4,6 @@ import (
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"html/template"
|
||||
"log"
|
||||
"math/big"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func OpenBrowser(url string) {
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
)
|
||||
|
||||
@@ -141,7 +141,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
+95
-17
@@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type != common.ChannelTypeOpenAI {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "仅支持 OpenAI 类型渠道",
|
||||
})
|
||||
return
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
result := OpenAIModelsResponse{}
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
|
||||
var result OpenAIModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
if !result.Success {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "上游返回错误",
|
||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var ids []string
|
||||
@@ -492,3 +494,79 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func FetchModels(c *gin.Context) {
|
||||
var req struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", "Bearer "+req.Key)
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
//check status code
|
||||
if response.StatusCode != http.StatusOK {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "Failed to fetch models",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range result.Data {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
+8
-4
@@ -25,7 +25,8 @@ func GetAllLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel)
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -63,7 +64,8 @@ func GetUserLogs(c *gin.Context) {
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize)
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -146,7 +148,8 @@ func GetLogsStat(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
group := c.Query("group")
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -168,7 +171,8 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
group := c.Query("group")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -112,6 +113,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
+28
-4
@@ -12,6 +12,16 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
|
||||
func init() {
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
} else {
|
||||
groupCol = "`group`"
|
||||
}
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
@@ -28,6 +38,7 @@ type Log struct {
|
||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
@@ -70,7 +81,9 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
@@ -92,6 +105,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
TokenId: tokenId,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
@@ -105,7 +119,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB
|
||||
@@ -130,6 +144,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -141,7 +158,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB.Where("user_id = ?", userId)
|
||||
@@ -160,6 +177,9 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -193,7 +213,7 @@ type Stat struct {
|
||||
Tpm int `json:"tpm"`
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
|
||||
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
|
||||
|
||||
// 为rpm和tpm创建单独的查询
|
||||
@@ -221,6 +241,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
|
||||
}
|
||||
|
||||
tx = tx.Where("type = ?", LogTypeConsume)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
|
||||
|
||||
@@ -95,7 +95,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
}
|
||||
}
|
||||
|
||||
tool_call_ids := make(map[string]string)
|
||||
//shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
|
||||
@@ -108,6 +108,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
}
|
||||
continue
|
||||
} else if message.Role == "tool" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
|
||||
name := ""
|
||||
if message.Name != nil {
|
||||
name = *message.Name
|
||||
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
|
||||
name = val
|
||||
}
|
||||
functionResp := &FunctionResponse{
|
||||
Name: name,
|
||||
Response: common.StrToMap(message.StringContent()),
|
||||
}
|
||||
*parts = append(*parts, GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
@@ -125,62 +146,49 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
}
|
||||
parts = append(parts, toolCall)
|
||||
tool_call_ids[call.ID] = call.Function.Name
|
||||
}
|
||||
}
|
||||
if !isToolCall {
|
||||
if message.Role == "tool" {
|
||||
content.Role = "user"
|
||||
name := ""
|
||||
if message.Name != nil {
|
||||
name = *message.Name
|
||||
}
|
||||
functionResp := &FunctionResponse{
|
||||
Name: name,
|
||||
Response: common.StrToMap(message.StringContent()),
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
} else {
|
||||
openaiContent := message.ParseContent()
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == dto.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
openaiContent := message.ParseContent()
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == dto.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
|
||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: "image/" + format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: "image/" + format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
@@ -242,19 +250,13 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
||||
var toolCalls []dto.ToolCall
|
||||
|
||||
item := candidate.Content.Parts[0]
|
||||
if item.FunctionCall == nil {
|
||||
return toolCalls
|
||||
}
|
||||
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
//common.SysError("getToolCalls failed: " + err.Error())
|
||||
return toolCalls
|
||||
//common.SysError("getToolCall failed: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
toolCall := dto.ToolCall{
|
||||
return &dto.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
Type: "function",
|
||||
Function: dto.FunctionCall{
|
||||
@@ -262,10 +264,32 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
|
||||
// var toolCalls []dto.ToolCall
|
||||
|
||||
// item := candidate.Content.Parts[index]
|
||||
// if item.FunctionCall == nil {
|
||||
// return toolCalls
|
||||
// }
|
||||
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
// if err != nil {
|
||||
// //common.SysError("getToolCalls failed: " + err.Error())
|
||||
// return toolCalls
|
||||
// }
|
||||
// toolCall := dto.ToolCall{
|
||||
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
// Type: "function",
|
||||
// Function: dto.FunctionCall{
|
||||
// Arguments: string(argsBytes),
|
||||
// Name: item.FunctionCall.FunctionName,
|
||||
// },
|
||||
// }
|
||||
// toolCalls = append(toolCalls, toolCall)
|
||||
// return toolCalls
|
||||
// }
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
@@ -275,6 +299,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
for i, candidate := range response.Candidates {
|
||||
// jsonData, _ := json.MarshalIndent(candidate, "", " ")
|
||||
// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
@@ -284,16 +310,20 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
FinishReason: constant.FinishReasonStop,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
choice.Message.SetToolCalls(getToolCalls(&candidate))
|
||||
} else {
|
||||
var texts []string
|
||||
for _, part := range candidate.Content.Parts {
|
||||
var texts []string
|
||||
var tool_calls []dto.ToolCall
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
if call := getToolCall(&part); call != nil {
|
||||
tool_calls = append(tool_calls, *call)
|
||||
}
|
||||
} else {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
}
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
choice.Message.SetToolCalls(tool_calls)
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
@@ -304,18 +334,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
||||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
||||
respFirstParts := geminiResponse.Candidates[0].Content.Parts
|
||||
if respFirstParts[0].FunctionCall != nil {
|
||||
// function response
|
||||
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
|
||||
} else {
|
||||
// text response
|
||||
var texts []string
|
||||
for _, part := range respFirstParts {
|
||||
var texts []string
|
||||
var tool_calls []dto.ToolCall
|
||||
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
if call := getToolCall(&part); call != nil {
|
||||
tool_calls = append(tool_calls, *call)
|
||||
}
|
||||
} else {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
}
|
||||
if len(tool_calls) > 0 {
|
||||
choice.Delta.ToolCalls = tool_calls
|
||||
}
|
||||
}
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
|
||||
@@ -2,8 +2,9 @@ package common
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -66,13 +67,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := time.Now()
|
||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
@@ -158,10 +159,10 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &TaskRelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
|
||||
+4
-2
@@ -208,7 +208,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
@@ -513,7 +514,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
+1
-1
@@ -385,7 +385,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
|
||||
+2
-1
@@ -126,7 +126,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/batch", controller.DeleteChannelBatch)
|
||||
channelRoute.POST("/fix", controller.FixChannelsAbilities)
|
||||
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
|
||||
channelRoute.POST("/fetch_models", controller.FetchModels)
|
||||
|
||||
}
|
||||
tokenRoute := apiRouter.Group("/token")
|
||||
|
||||
@@ -12,7 +12,6 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
other["group_ratio"] = groupRatio
|
||||
other["completion_ratio"] = completionRatio
|
||||
other["model_price"] = modelPrice
|
||||
other["group"] = relayInfo.Group
|
||||
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
|
||||
|
||||
+2
-2
@@ -139,7 +139,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -208,5 +208,5 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
API,
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
} from '../helpers/render';
|
||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||
import { getLogOther } from '../helpers/other.js';
|
||||
import { StyleContext } from '../context/Style/index.js';
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
@@ -222,19 +223,27 @@ const LogsTable = () => {
|
||||
dataIndex: 'group',
|
||||
render: (text, record, index) => {
|
||||
if (record.type === 0 || record.type === 2) {
|
||||
let other = JSON.parse(record.other);
|
||||
if (other === null) {
|
||||
return <></>;
|
||||
}
|
||||
if (other.group !== undefined) {
|
||||
if (record.group) {
|
||||
return (
|
||||
<>
|
||||
{renderGroup(other.group)}
|
||||
{renderGroup(record.group)}
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
return <></>;
|
||||
}
|
||||
} else {
|
||||
let other = JSON.parse(record.other);
|
||||
if (other === null) {
|
||||
return <></>;
|
||||
}
|
||||
if (other.group !== undefined) {
|
||||
return (
|
||||
<>
|
||||
{renderGroup(other.group)}
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
return <></>;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return <></>;
|
||||
}
|
||||
@@ -398,6 +407,7 @@ const LogsTable = () => {
|
||||
},
|
||||
];
|
||||
|
||||
const [styleState, styleDispatch] = useContext(StyleContext);
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [expandData, setExpandData] = useState({});
|
||||
const [showStat, setShowStat] = useState(false);
|
||||
@@ -417,6 +427,7 @@ const LogsTable = () => {
|
||||
start_timestamp: timestamp2string(getTodayStartTimestamp()),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: '',
|
||||
group: '',
|
||||
});
|
||||
const {
|
||||
username,
|
||||
@@ -425,6 +436,7 @@ const LogsTable = () => {
|
||||
start_timestamp,
|
||||
end_timestamp,
|
||||
channel,
|
||||
group,
|
||||
} = inputs;
|
||||
|
||||
const [stat, setStat] = useState({
|
||||
@@ -433,13 +445,19 @@ const LogsTable = () => {
|
||||
});
|
||||
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
if (value && (name === 'start_timestamp' || name === 'end_timestamp')) {
|
||||
// 确保日期值是有效的
|
||||
const dateValue = typeof value === 'string' ? value : timestamp2string(value);
|
||||
setInputs(inputs => ({ ...inputs, [name]: dateValue }));
|
||||
} else {
|
||||
setInputs(inputs => ({ ...inputs, [name]: value }));
|
||||
}
|
||||
};
|
||||
|
||||
const getLogSelfStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localStartTimestamp = Date.parse(3) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
|
||||
url = encodeURI(url);
|
||||
let res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
@@ -453,7 +471,7 @@ const LogsTable = () => {
|
||||
const getLogStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
|
||||
url = encodeURI(url);
|
||||
let res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
@@ -596,9 +614,9 @@ const LogsTable = () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
if (isAdminUser) {
|
||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
|
||||
} else {
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
|
||||
}
|
||||
url = encodeURI(url);
|
||||
const res = await API.get(url);
|
||||
@@ -682,10 +700,53 @@ const LogsTable = () => {
|
||||
</Header>
|
||||
<Form layout='horizontal' style={{ marginTop: 10 }}>
|
||||
<>
|
||||
<Form.Section>
|
||||
<div style={{ marginBottom: 10 }}>
|
||||
{
|
||||
styleState.isMobile ? (
|
||||
<div>
|
||||
<Form.DatePicker
|
||||
field='start_timestamp'
|
||||
label={t('起始时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
type='dateTime'
|
||||
onChange={(value) => {
|
||||
console.log(value);
|
||||
handleInputChange(value, 'start_timestamp')
|
||||
}}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='end_timestamp'
|
||||
fluid
|
||||
label={t('结束时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
type='dateTime'
|
||||
onChange={(value) => handleInputChange(value, 'end_timestamp')}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<Form.DatePicker
|
||||
field="range_timestamp"
|
||||
label={t('时间范围')}
|
||||
initValue={[start_timestamp, end_timestamp]}
|
||||
type="dateTimeRange"
|
||||
name="range_timestamp"
|
||||
onChange={(value) => {
|
||||
if (Array.isArray(value) && value.length === 2) {
|
||||
handleInputChange(value[0], 'start_timestamp');
|
||||
handleInputChange(value[1], 'end_timestamp');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</Form.Section>
|
||||
<Form.Input
|
||||
field='token_name'
|
||||
label={t('令牌名称')}
|
||||
style={{ width: 176 }}
|
||||
value={token_name}
|
||||
placeholder={t('可选值')}
|
||||
name='token_name'
|
||||
@@ -694,39 +755,24 @@ const LogsTable = () => {
|
||||
<Form.Input
|
||||
field='model_name'
|
||||
label={t('模型名称')}
|
||||
style={{ width: 176 }}
|
||||
value={model_name}
|
||||
placeholder={t('可选值')}
|
||||
name='model_name'
|
||||
onChange={(value) => handleInputChange(value, 'model_name')}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='start_timestamp'
|
||||
label={t('起始时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp}
|
||||
type='dateTime'
|
||||
name='start_timestamp'
|
||||
onChange={(value) => handleInputChange(value, 'start_timestamp')}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='end_timestamp'
|
||||
fluid
|
||||
label={t('结束时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp}
|
||||
type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={(value) => handleInputChange(value, 'end_timestamp')}
|
||||
<Form.Input
|
||||
field='group'
|
||||
label={t('分组')}
|
||||
value={group}
|
||||
placeholder={t('可选值')}
|
||||
name='group'
|
||||
onChange={(value) => handleInputChange(value, 'group')}
|
||||
/>
|
||||
{isAdminUser && (
|
||||
<>
|
||||
<Form.Input
|
||||
field='channel'
|
||||
label={t('渠道 ID')}
|
||||
style={{ width: 176 }}
|
||||
value={channel}
|
||||
placeholder={t('可选值')}
|
||||
name='channel'
|
||||
@@ -735,7 +781,6 @@ const LogsTable = () => {
|
||||
<Form.Input
|
||||
field='username'
|
||||
label={t('用户名称')}
|
||||
style={{ width: 176 }}
|
||||
value={username}
|
||||
placeholder={t('可选值')}
|
||||
name='username'
|
||||
|
||||
@@ -1234,5 +1234,6 @@
|
||||
"应用更改": "Apply changes",
|
||||
"更多": "Expand more",
|
||||
"个模型": "models",
|
||||
"可用模型": "Available models"
|
||||
"可用模型": "Available models",
|
||||
"时间范围": "Time range"
|
||||
}
|
||||
@@ -193,14 +193,16 @@ const EditChannel = (props) => {
|
||||
|
||||
|
||||
const fetchUpstreamModelList = async (name) => {
|
||||
if (inputs['type'] !== 1) {
|
||||
showError(t('仅支持 OpenAI 接口格式'));
|
||||
return;
|
||||
}
|
||||
// if (inputs['type'] !== 1) {
|
||||
// showError(t('仅支持 OpenAI 接口格式'));
|
||||
// return;
|
||||
// }
|
||||
setLoading(true);
|
||||
const models = inputs['models'] || [];
|
||||
let err = false;
|
||||
|
||||
if (isEdit) {
|
||||
// 如果是编辑模式,使用已有的channel id获取模型列表
|
||||
const res = await API.get('/api/channel/fetch_models/' + channelId);
|
||||
if (res.data && res.data?.success) {
|
||||
models.push(...res.data.data);
|
||||
@@ -208,30 +210,29 @@ const EditChannel = (props) => {
|
||||
err = true;
|
||||
}
|
||||
} else {
|
||||
// 如果是新建模式,通过后端代理获取模型列表
|
||||
if (!inputs?.['key']) {
|
||||
showError(t('请填写密钥'));
|
||||
err = true;
|
||||
} else {
|
||||
try {
|
||||
const host = new URL((inputs['base_url'] || 'https://api.openai.com'));
|
||||
|
||||
const url = `https://${host.hostname}/v1/models`;
|
||||
const key = inputs['key'];
|
||||
const res = await axios.get(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${key}`
|
||||
}
|
||||
const res = await API.post('/api/channel/fetch_models', {
|
||||
base_url: inputs['base_url'],
|
||||
key: inputs['key']
|
||||
});
|
||||
if (res.data) {
|
||||
models.push(...res.data.data.map((model) => model.id));
|
||||
|
||||
if (res.data && res.data.success) {
|
||||
models.push(...res.data.data);
|
||||
} else {
|
||||
err = true;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
err = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!err) {
|
||||
handleInputChange(name, Array.from(new Set(models)));
|
||||
showSuccess(t('获取模型列表成功'));
|
||||
@@ -638,7 +639,7 @@ const EditChannel = (props) => {
|
||||
{inputs.type === 21 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>知识库 ID:</Typography.Text>
|
||||
<Typography.Text strong>��识库 ID:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
label="知识库 ID"
|
||||
|
||||
Reference in New Issue
Block a user