Compare commits

...

8 Commits

Author SHA1 Message Date
1808837298@qq.com 8101cd3ce3 feat: Add reasoning content support in OpenAI response handling 2025-02-21 18:52:51 +08:00
1808837298@qq.com 4a49d6c795 refactor: Improve message content parsing with robust type handling 2025-02-21 18:27:43 +08:00
1808837298@qq.com 4194f4bd21 refactor: Improve message content handling and quota error responses 2025-02-21 18:18:21 +08:00
1808837298@qq.com e1784f8981 refactor: Optimize sensitive word detection and text processing 2025-02-21 17:05:35 +08:00
1808837298@qq.com 78f9a30c39 feat: Enhance sensitive word detection with detailed logging 2025-02-21 16:57:30 +08:00
1808837298@qq.com 009333da8b refactor: Improve quota error messages with formatted quota display 2025-02-21 16:42:48 +08:00
1808837298@qq.com 23bfc06fd8 feat: Add base URL input with localized tooltip for channel configuration 2025-02-21 16:17:59 +08:00
1808837298@qq.com f64540cd1c feat: Add localization for notification and webhook settings 2025-02-21 15:36:24 +08:00
10 changed files with 185 additions and 123 deletions
+73 -48
View File
@@ -88,15 +88,15 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
// parsedContent not json field
parsedContent []MediaContent
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent
parsedStringContent *string
}
type MediaContent struct {
@@ -150,6 +150,9 @@ func (m *Message) SetToolCalls(toolCalls any) {
}
func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
@@ -160,16 +163,24 @@ func (m *Message) StringContent() string {
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
m.parsedStringContent = &content
m.parsedContent = nil
}
func (m *Message) SetMediaContent(content []MediaContent) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
m.parsedContent = nil
m.parsedStringContent = nil
}
func (m *Message) IsStringContent() bool {
if m.parsedStringContent != nil {
return true
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return true
}
return false
@@ -179,72 +190,86 @@ func (m *Message) ParseContent() []MediaContent {
if m.parsedContent != nil {
return m.parsedContent
}
var contentList []MediaContent
defer func() {
if len(contentList) > 0 {
m.parsedContent = contentList
}
}()
// 先尝试解析为字符串
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaContent{
contentList = []MediaContent{{
Type: ContentTypeText,
Text: stringContent,
})
}}
m.parsedContent = contentList
return contentList
}
var arrayContent []json.RawMessage
// 尝试解析为数组
var arrayContent []map[string]interface{}
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
contentType, ok := contentItem["type"].(string)
if !ok {
continue
}
switch contentMap["type"] {
switch contentType {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: subStr,
Text: text,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "high"
}
imageUrl := contentItem["image_url"]
switch v := imageUrl.(type) {
case string:
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Url: v,
Detail: "high",
},
})
case map[string]interface{}:
url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string)
if !ok2 {
detail = "high"
}
if ok1 {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: detail,
},
})
}
}
case ContentTypeInputAudio:
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: subObj["data"].(string),
Format: subObj["format"].(string),
},
})
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: data,
Format: format,
},
})
}
}
}
}
return contentList
}
return nil
if len(contentList) > 0 {
m.parsedContent = contentList
}
return contentList
}
+11 -3
View File
@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -78,6 +79,13 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
return *c.Content
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
if c.ReasoningContent == nil {
return ""
}
return *c.ReasoningContent
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
+3 -1
View File
@@ -162,6 +162,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -182,6 +183,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -273,7 +275,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
+3 -1
View File
@@ -13,6 +13,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return nil, errors.New("model is required")
}
if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(audioRequest.Input)
words, err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
+8 -17
View File
@@ -61,8 +61,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(imageRequest.Prompt)
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
@@ -85,15 +86,13 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
imageRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success {
modelRatio := common.GetModelRatio(imageRequest.Model)
priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if !priceData.UsePrice {
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
modelPrice = 0.0025 * modelRatio
priceData.ModelPrice = 0.0025 * priceData.ModelRatio
}
groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
sizeRatio := 1.0
@@ -116,11 +115,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
}
imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
imageRatio := priceData.ModelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * priceData.GroupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest)
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
}
adaptor := GetAdaptor(relayInfo.ApiType)
@@ -177,14 +176,6 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
quality = "hd"
}
priceData := helper.PriceData{
UsePrice: true,
GroupRatio: groupRatio,
ModelPrice: modelPrice,
ModelRatio: 0,
ShouldPreConsumedQuota: 0,
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
return nil
+10 -8
View File
@@ -78,8 +78,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo)
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
@@ -219,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
return promptTokens, err
}
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
var err error
var words []string
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
err = service.CheckSensitiveMessages(textRequest.Messages)
words, err = service.CheckSensitiveMessages(textRequest.Messages)
case relayconstant.RelayModeCompletions:
err = service.CheckSensitiveInput(textRequest.Prompt)
words, err = service.CheckSensitiveInput(textRequest.Prompt)
case relayconstant.RelayModeModerations:
err = service.CheckSensitiveInput(textRequest.Input)
words, err = service.CheckSensitiveInput(textRequest.Input)
case relayconstant.RelayModeEmbeddings:
err = service.CheckSensitiveInput(textRequest.Input)
words, err = service.CheckSensitiveInput(textRequest.Input)
}
return err
return words, err
}
// 预扣费并返回用户剩余配额
@@ -244,7 +246,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
+3 -3
View File
@@ -95,11 +95,11 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
}
err = PostConsumeQuota(relayInfo, quota, 0, false)
@@ -262,7 +262,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
+38 -29
View File
@@ -8,48 +8,47 @@ import (
"strings"
)
func CheckSensitiveMessages(messages []dto.Message) error {
func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
if len(messages) == 0 {
return nil, nil
}
for _, message := range messages {
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
if ok, words := SensitiveWordContains(stringContent); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
// TODO: check image url
continue
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
// TODO: check image url
} else {
if ok, words := SensitiveWordContains(m.Text); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
}
}
// 检查 text 是否为空
if m.Text == "" {
continue
}
if ok, words := SensitiveWordContains(m.Text); ok {
return words, errors.New("sensitive words detected")
}
}
}
return nil
return nil, nil
}
func CheckSensitiveText(text string) error {
func CheckSensitiveText(text string) ([]string, error) {
if ok, words := SensitiveWordContains(text); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
return words, errors.New("sensitive words detected")
}
return nil
return nil, nil
}
func CheckSensitiveInput(input any) error {
func CheckSensitiveInput(input any) ([]string, error) {
switch v := input.(type) {
case string:
return CheckSensitiveText(v)
case []string:
text := ""
var builder strings.Builder
for _, s := range v {
text += s
builder.WriteString(s)
}
return CheckSensitiveText(text)
return CheckSensitiveText(builder.String())
}
return CheckSensitiveText(fmt.Sprintf("%v", input))
}
@@ -59,8 +58,11 @@ func SensitiveWordContains(text string) (bool, []string) {
if len(setting.SensitiveWords) == 0 {
return false, nil
}
if len(text) == 0 {
return false, nil
}
checkText := strings.ToLower(text)
return AcSearch(checkText, setting.SensitiveWords, false)
return AcSearch(checkText, setting.SensitiveWords, true)
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -72,14 +74,21 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
m := InitAc(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)
words := make([]string, 0, len(hits))
var builder strings.Builder
builder.Grow(len(text))
lastPos := 0
for _, hit := range hits {
pos := hit.Pos
word := string(hit.Word)
text = text[:pos] + "**###**" + text[pos+len(word):]
builder.WriteString(text[lastPos:pos])
builder.WriteString("**###**")
lastPos = pos + len(word)
words = append(words, word)
}
return true, words, text
builder.WriteString(text[lastPos:])
return true, words, builder.String()
}
return false, nil, text
}
+22 -1
View File
@@ -1249,5 +1249,26 @@
"已注销": "Logged out",
"自动禁用关键词": "Automatic disable keywords",
"一行一个,不区分大小写": "One line per keyword, not case-sensitive",
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel",
"请求并计费模型": "Request and charge model",
"实际模型": "Actual model",
"渠道信息": "Channel information",
"通知设置": "Notification settings",
"Webhook地址": "Webhook URL",
"请输入Webhook地址,例如: https://example.com/webhook": "Please enter the Webhook URL, e.g.: https://example.com/webhook",
"邮件通知": "Email notification",
"Webhook通知": "Webhook notification",
"接口凭证(可选)": "Interface credentials (optional)",
"密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "The secret will be added to the request header as a Bearer token to verify the legitimacy of the webhook request",
"Authorization: Bearer your-secret-key": "Authorization: Bearer your-secret-key",
"额度预警阈值": "Quota warning threshold",
"当剩余额度低于此数值时,系统将通过选择的方式发送通知": "When the remaining quota is lower than this value, the system will send a notification through the selected method",
"Webhook请求结构": "Webhook request structure",
"只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "Only https is supported, the system will send a notification through POST, please ensure the address can receive POST requests",
"保存设置": "Save settings",
"通知邮箱": "Notification email",
"设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "Set the email address for receiving quota warning notifications, if not set, the email address bound to the account will be used",
"留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used",
"代理站地址": "Base URL",
"对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in"
}
+14 -12
View File
@@ -540,21 +540,23 @@ const EditChannel = (props) => {
value={inputs.name}
autoComplete="new-password"
/>
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>{t('BaseURL')}</Typography.Text>
<Typography.Text strong>{t('代理站地址')}</Typography.Text>
</div>
<Input
label={t('BaseURL')}
name="base_url"
placeholder={t('此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/')}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete="new-password"
/>
<Tooltip content={t('对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写')}>
<Input
label={t('代理站地址')}
name="base_url"
placeholder={t('此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/')}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete="new-password"
/>
</Tooltip>
</>
)}
<div style={{ marginTop: 10 }}>