Compare commits

..

20 Commits

Author SHA1 Message Date
Calcium-Ion e5e73a33f0 Merge pull request #781 from zeyugao/main
feat: Pass extra_body in OpenAI request to the backend
2025-02-24 16:29:48 +08:00
Calcium-Ion 2d15f63eaa Merge pull request #783 from Calcium-Ion/rate-limit
feat: Add model request rate limiting functionality
2025-02-24 16:29:23 +08:00
1808837298@qq.com 6f3072895a feat: Add model rate limit settings in system configuration 2025-02-24 16:27:20 +08:00
1808837298@qq.com 1763145fea feat: Add model request rate limiting functionality 2025-02-24 16:20:55 +08:00
1808837298@qq.com 66831a1bde feat: Add support for different Dify bot types and request URLs 2025-02-24 14:18:30 +08:00
1808837298@qq.com fd44ac7c0c feat: Enhance token counting and content parsing for messages 2025-02-24 14:18:15 +08:00
Elsa f5bf67c636 Pass extra_body to the backend 2025-02-24 10:52:55 +08:00
1808837298@qq.com 7becf62a7a fix: Improve 429 error logging with detailed message 2025-02-23 21:26:31 +08:00
1808837298@qq.com 40c0333eaa fix typo 2025-02-23 17:27:33 +08:00
1808837298@qq.com 65021e2e0e feat: Add thinking-to-content option in channel extra settings #780 2025-02-23 17:13:08 +08:00
1808837298@qq.com 4597816a14 feat: Add thinking-to-content conversion for stream responses 2025-02-23 17:05:57 +08:00
1808837298@qq.com 991b6f8bb0 fix: mistral 2025-02-22 16:29:48 +08:00
1808837298@qq.com 0333576bee fix: fix image ratio calculation 2025-02-22 15:50:18 +08:00
Calcium-Ion 1c35a2fd0b Merge pull request #778 from utopeadia/main
美化日志界面刷新图标
2025-02-22 15:21:28 +08:00
1808837298@qq.com 7a13f9c99a fix: Ensure correct quota warning threshold type conversion 2025-02-22 15:19:55 +08:00
1808837298@qq.com 1c20d16c49 chore: update rerank.md 2025-02-22 15:13:26 +08:00
HowieWood 4f2d44187d 进一步美化刷新图标 2025-02-22 14:18:25 +08:00
HowieWood d3b8647019 优化日志刷新图标显示 2025-02-22 14:12:49 +08:00
1808837298@qq.com 3eb186825d fix: ShouldDisableChannel 2025-02-22 02:02:03 +08:00
1808837298@qq.com d71315cfa5 fix: mistral adaptor (close #774) 2025-02-21 22:21:19 +08:00
28 changed files with 659 additions and 97 deletions
+2
View File
@@ -63,6 +63,8 @@
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
## Model Support
This version additionally supports:
+2
View File
@@ -69,6 +69,8 @@
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
## 模型支持
此版本额外支持以下模型:
+1 -1
View File
@@ -13,7 +13,7 @@ Request:
```json
{
"model": "rerank-multilingual-v3.0",
"model": "jina-reranker-v2-base-multilingual",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
+1 -1
View File
@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
+3 -2
View File
@@ -1,6 +1,7 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
)
+1
View File
@@ -85,6 +85,7 @@ func Relay(c *gin.Context) {
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
+5 -5
View File
@@ -913,11 +913,11 @@ func TopUp(c *gin.Context) {
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold int `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
}
func UpdateUserSetting(c *gin.Context) {
+5
View File
@@ -10,6 +10,10 @@
- 用于配置网络代理
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换
--------------------------------------------------------------
## JSON 格式示例
@@ -19,6 +23,7 @@
```json
{
"force_format": true,
"thinking_to_content": true,
"proxy": "socks5://xxxxxxx"
}
```
+20 -3
View File
@@ -1,6 +1,9 @@
package dto
import "encoding/json"
import (
"encoding/json"
"strings"
)
type ResponseFormat struct {
Type string `json:"type,omitempty"`
@@ -47,6 +50,7 @@ type GeneralOpenAIRequest struct {
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
}
type OpenAITools struct {
@@ -101,7 +105,7 @@ type Message struct {
type MediaContent struct {
Type string `json:"type"`
Text string `json:"text"`
Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
}
@@ -153,11 +157,24 @@ func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent
}
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
}
func (m *Message) SetStringContent(content string) {
+18
View File
@@ -86,6 +86,10 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
return *c.ReasoningContent
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
c.ReasoningContent = &s
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
@@ -116,6 +120,20 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices)
return &ChatCompletionsStreamResponse{
Id: c.Id,
Object: c.Object,
Created: c.Created,
Model: c.Model,
SystemFingerprint: c.SystemFingerprint,
Choices: choices,
Usage: c.Usage,
}
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
+172
View File
@@ -0,0 +1,172 @@
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0,表示不限制
if maxCount == 0 {
return true, nil
}
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
// 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
return false, err
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
return false, err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
return false, nil
}
return true, nil
}
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0,不记录请求
if maxCount == 0 {
return
}
now := time.Now().Format(timeFormat)
rdb.LPush(ctx, key, now)
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
// 4. 处理请求
c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制(当totalMaxCount为0时跳过)
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制,这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 3. 处理请求
c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
// 如果未启用限流,直接放行
if !setting.ModelRequestRateLimitEnabled {
return defNext
}
// 计算限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 根据存储类型选择限流处理器
if common.RedisEnabled {
return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
} else {
return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
}
}
+13
View File
@@ -85,6 +85,9 @@ func InitOptionMap() {
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -105,6 +108,7 @@ func InitOptionMap() {
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
@@ -226,6 +230,9 @@ func updateOptionMap(key string, value string) (err error) {
setting.DemoSiteEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
case "ModelRequestRateLimitEnabled":
setting.ModelRequestRateLimitEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
@@ -308,6 +315,12 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRequestRateLimitCount":
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitDurationMinutes":
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
+28 -2
View File
@@ -9,9 +9,18 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
const (
BotTypeChatFlow = 1 // chatflow default
BotTypeAgent = 2
BotTypeWorkFlow = 3
BotTypeCompletion = 4
)
type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -25,10 +34,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion
} else {
a.BotType = BotTypeChatFlow
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -53,7 +80,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+1 -4
View File
@@ -41,9 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
mistralReq := requestOpenAI2Mistral(*request)
//common.LogJson(c, "body", mistralReq)
return mistralReq, nil
return requestOpenAI2Mistral(request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -55,7 +53,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
+8 -12
View File
@@ -1,25 +1,21 @@
package mistral
import (
"encoding/json"
"one-api/dto"
)
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
}
message.SetMediaContent(mediaMessages)
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
+66 -18
View File
@@ -5,10 +5,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io"
"math"
"mime/multipart"
@@ -23,21 +19,66 @@ import (
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
if data == "" {
return nil
}
if forceFormat {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
if !forceFormat && !thinkToContent {
return service.StringData(c, data)
}
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
if !thinkToContent {
return service.ObjectData(c, lastStreamResponse)
}
return service.StringData(c, data)
// Handle think to content conversion
if info.IsFirstResponse {
response := lastStreamResponse.Copy()
for i := range response.Choices {
response.Choices[i].Delta.SetContentString("<think>\n")
response.Choices[i].Delta.SetReasoningContent("")
}
service.ObjectData(c, response)
}
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return service.ObjectData(c, lastStreamResponse)
}
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>")
response.Choices[j].Delta.SetReasoningContent("")
}
info.SendLastReasoningResponse = true
service.ObjectData(c, response)
}
// Convert reasoning content to regular content
if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.SetReasoningContent("")
}
}
return service.ObjectData(c, lastStreamResponse)
}
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -56,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var usage = &dto.Usage{}
var streamItems []string // store stream items
var forceFormat bool
var thinkToContent bool
if info.ChannelType == common.ChannelTypeCustom {
if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok {
forceFormat = forceFmt
}
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
}
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
thinkToContent = think2Content
}
toolCount := 0
@@ -84,9 +128,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
)
gopool.Go(func() {
for scanner.Scan() {
info.SetFirstResponseTime()
//info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 { // ignore blank line or wrong format
continue
}
@@ -98,10 +145,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat)
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
@@ -141,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
if shouldSendLastResp {
sendStreamData(c, lastStreamData, forceFormat)
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 计算token
+21 -19
View File
@@ -13,23 +13,24 @@ import (
)
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
TokenKey string
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
setFirstResponse bool
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
ChannelType int
ChannelId int
TokenId int
TokenKey string
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
IsFirstResponse bool
SendLastReasoningResponse bool
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
//RecodeModelName string
RequestURLPath string
ApiVersion string
@@ -88,6 +89,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{
IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
@@ -139,9 +141,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
}
func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse {
if info.IsFirstResponse {
info.FirstResponseTime = time.Now()
info.setFirstResponse = true
info.IsFirstResponse = false
}
}
+2 -2
View File
@@ -115,8 +115,8 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
}
imageRatio := priceData.ModelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * priceData.GroupRatio * common.QuotaPerUnit)
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
+1
View File
@@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) {
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
{
// WebSocket 路由
wsRouter := relayV1Router.Group("")
+21 -23
View File
@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if text == "" {
return 0
}
return len(tokenEncoder.Encode(text, nil, nil))
}
@@ -282,30 +285,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
}
}
+1
View File
@@ -23,6 +23,7 @@ func AutomaticDisableKeywordsFromString(s string) {
ak := strings.Split(s, "\n")
for _, k := range ak {
k = strings.TrimSpace(k)
k = strings.ToLower(k)
if k != "" {
AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
}
+6
View File
@@ -0,0 +1,6 @@
package setting
var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
+1 -1
View File
@@ -192,7 +192,7 @@ const LogsTable = () => {
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
suffixIcon={<IconRefresh />}
suffixIcon={<IconRefresh style={{width: '0.8em', height: '0.8em', opacity: 0.6}} />}
>
{' '}{record.model_name}{' '}
</Tag>
+1 -1
View File
@@ -330,7 +330,7 @@ const PersonalSetting = () => {
try {
const res = await API.put('/api/user/setting', {
notify_type: notificationSettings.warningType,
quota_warning_threshold: notificationSettings.warningThreshold,
quota_warning_threshold: parseFloat(notificationSettings.warningThreshold),
webhook_url: notificationSettings.webhookUrl,
webhook_secret: notificationSettings.webhookSecret,
notification_email: notificationSettings.notificationEmail
+80
View File
@@ -0,0 +1,80 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js';
const RateLimitSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key.endsWith('Enabled')
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* AI请求速率限制 */}
<Card style={{ marginTop: '10px' }}>
<RequestRateLimit options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default RateLimitSetting;
+14 -3
View File
@@ -856,7 +856,7 @@
"IP黑名单": "IP blacklist",
"不允许的IP,一行一个": "IPs not allowed, one per line",
"请选择该渠道所支持的模型": "Please select the model supported by this channel",
"次": "Second-rate",
"次": "times",
"达到限速报错内容": "Error content when the speed limit is reached",
"不填则使用默认报错": "If not filled in, the default error will be reported.",
"Midjouney 设置 (可选)": "Midjouney settings (optional)",
@@ -1270,5 +1270,16 @@
"设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "Set the email address for receiving quota warning notifications, if not set, the email address bound to the account will be used",
"留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used",
"代理站地址": "Base URL",
"对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in"
}
"对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in",
"渠道额外设置": "Channel extra settings",
"模型请求速率限制": "Model request rate limit",
"启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)",
"限制周期": "Limit period",
"用户每周期最多请求次数": "User max request times per period",
"用户每周期最多请求完成次数": "User max successful request times per period",
"包括失败请求的次数,0代表不限制": "Including failed request times, 0 means no limit",
"频率限制的周期(分钟)": "Rate limit period (minutes)",
"只包括请求成功的次数": "Only include successful request times",
"保存模型速率限制": "Save model rate limit settings",
"速率限制设置": "Rate limit settings"
}
@@ -0,0 +1,159 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning,
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
export default function RequestRateLimit(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('模型请求速率限制')}>
<Row gutter={16}>
<Col span={8}>
<Form.Switch
field={'ModelRequestRateLimitEnabled'}
label={t('启用用户模型请求速率限制(可能会影响高并发性能)')}
size='default'
checkedText=''
uncheckedText=''
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitEnabled: value,
});
}}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('限制周期')}
step={1}
min={0}
suffix={t('分钟')}
extraText={t('频率限制的周期(分钟)')}
field={'ModelRequestRateLimitDurationMinutes'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitDurationMinutes: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求次数')}
step={1}
min={0}
suffix={t('次')}
extraText={t('包括失败请求的次数,0代表不限制')}
field={'ModelRequestRateLimitCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitCount: String(value),
})
}
/>
</Col>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求完成次数')}
step={1}
min={1}
suffix={t('次')}
extraText={t('只包括请求成功的次数')}
field={'ModelRequestRateLimitSuccessCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitSuccessCount: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}
+6
View File
@@ -8,6 +8,7 @@ import { isRoot } from '../../helpers';
import OtherSetting from '../../components/OtherSetting';
import PersonalSetting from '../../components/PersonalSetting';
import OperationSetting from '../../components/OperationSetting';
import RateLimitSetting from '../../components/RateLimitSetting.js';
const Setting = () => {
const { t } = useTranslation();
@@ -28,6 +29,11 @@ const Setting = () => {
content: <OperationSetting />,
itemKey: 'operation',
});
panes.push({
tab: t('速率限制设置'),
content: <RateLimitSetting />,
itemKey: 'ratelimit',
});
panes.push({
tab: t('系统设置'),
content: <SystemSetting />,