Compare commits

...

13 Commits

Author SHA1 Message Date
Calcium-Ion 4a02c5e219 Merge pull request #1346 from QuantumNous/fix-ability
 feat(ability): enhance FixAbility function
2025-07-08 18:38:35 +08:00
CaIon beb6ea96d2 feat(ability): enhance FixAbility function 2025-07-08 18:33:32 +08:00
Calcium-Ion 4e974daa8d Merge pull request #1334 from duyazhe/fix-baidu-bug
修复了百度请求时候需要传appid的bug
2025-07-07 14:51:23 +08:00
Xyfacai ffc823b11c Merge pull request #1341 from QuantumNous/refactor/log-params
refactor: log params and channel params
2025-07-07 14:29:16 +08:00
Xiangyuan-liu db98c0f4b1 refactor: log params and channel params
refactor: log params and channel params
2025-07-07 14:26:37 +08:00
CaIon 46ebea917e 🔧 refactor(adaptor): update HTTP referer to new API domain 2025-07-07 12:36:04 +08:00
duyazhe 78bb751e91 Update adaptor.go 2025-07-07 09:57:20 +08:00
duyazhe 3303cd1e6c 修复了百度请求时候需要传appid的bug 2025-07-06 23:09:49 +08:00
CaIon d7e15a9677 feat(tokens): add cherryConfig support for URL generation and base64 encoding 2025-07-06 20:56:09 +08:00
CaIon b67a4fda9b 🔧 refactor(model): change user group retrieval to non-strict mode 2025-07-06 10:23:38 +08:00
CaIon 6c9369a2c9 🔧 refactor(model): update context key retrieval to use token group instead of user group 2025-07-05 16:40:49 +08:00
Calcium-Ion aa0edd8dce Merge pull request #1321 from iszcz/main
支持Midjourney视频任务和图片编辑
2025-07-05 15:28:33 +08:00
iszcz d4f2f4dbbe 支持Midjourney视频任务和图片编辑 2025-06-30 22:31:12 +08:00
44 changed files with 1506 additions and 250 deletions
+10
View File
@@ -76,3 +76,13 @@ func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
return c.GetTime(string(key))
}
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
if value, ok := c.Get(string(key)); ok {
if v, ok := value.(T); ok {
return v, true
}
}
var t T
return t, false
}
+13
View File
@@ -1,6 +1,7 @@
package common
import (
"encoding/base64"
"encoding/json"
"math/rand"
"strconv"
@@ -68,3 +69,15 @@ func StringToByteSlice(s string) []byte {
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
func EncodeBase64(str string) string {
return base64.StdEncoding.EncodeToString([]byte(str))
}
func GetJsonString(data any) string {
if data == nil {
return ""
}
b, _ := json.Marshal(data)
return string(b)
}
-7
View File
@@ -1,7 +0,0 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
)
+4
View File
@@ -22,6 +22,8 @@ const (
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
MjActionVideo = "VIDEO"
MjActionEdits = "EDITS"
)
var MidjourneyModel2Action = map[string]string{
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
"mj_video": MjActionVideo,
"mj_edits": MjActionEdits,
}
-16
View File
@@ -1,16 +0,0 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)
+13 -2
View File
@@ -173,8 +173,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: info.OriginModelName,
TokenName: "模型测试",
Quota: quota,
Content: "模型测试",
UseTimeSeconds: int(consumedTime),
IsStream: false,
Group: info.UsingGroup,
Other: other,
})
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
+21 -2
View File
@@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) {
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
success, fails, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
"data": gin.H{
"success": success,
"fails": fails,
},
})
}
@@ -387,6 +390,14 @@ func AddChannel(c *gin.Context) {
})
return
}
err = channel.ValidateSettings()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel setting 格式错误:" + err.Error(),
})
return
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == constant.ChannelTypeVertexAi {
@@ -614,6 +625,14 @@ func UpdateChannel(c *gin.Context) {
})
return
}
err = channel.ValidateSettings()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel setting 格式错误:" + err.Error(),
})
return
}
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
+2 -2
View File
@@ -130,7 +130,7 @@ func ListModels(c *gin.Context) {
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId, true)
userGroup, err := model.GetUserGroup(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -139,7 +139,7 @@ func ListModels(c *gin.Context) {
return
}
group := userGroup
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
group = tokenGroup
}
+14 -13
View File
@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting"
"strconv"
@@ -961,7 +962,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
@@ -979,7 +980,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.QuotaWarningType == dto.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -998,7 +999,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
@@ -1020,24 +1021,24 @@ func UpdateUserSetting(c *gin.Context) {
}
// 构建设置
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
constant.UserSettingRecordIpLog: req.RecordIpLog,
settings := dto.UserSetting{
NotifyType: req.QuotaWarningType,
QuotaWarningThreshold: req.QuotaWarningThreshold,
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
RecordIpLog: req.RecordIpLog,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.QuotaWarningType == dto.NotifyTypeWebhook {
settings.WebhookUrl = req.WebhookUrl
if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
settings.WebhookSecret = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
settings.NotificationEmail = req.NotificationEmail
}
// 更新用户设置
+7
View File
@@ -0,0 +1,7 @@
package dto
type ChannelSettings struct {
ForceFormat bool `json:"force_format,omitempty"`
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
Proxy string `json:"proxy"`
}
+6
View File
@@ -57,6 +57,8 @@ type MidjourneyDto struct {
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
VideoUrl string `json:"videoUrl"`
VideoUrls []ImgUrls `json:"videoUrls"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
Properties *Properties `json:"properties"`
}
type ImgUrls struct {
Url string `json:"url"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
+16
View File
@@ -0,0 +1,16 @@
package dto
type UserSetting struct {
NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
}
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)
+1041
View File
File diff suppressed because it is too large Load Diff
+4 -3
View File
@@ -39,7 +39,6 @@ func main() {
return
}
common.SetupLogger()
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
@@ -69,9 +68,9 @@ func main() {
if r := recover(); r != nil {
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once
_, fixErr := model.FixAbility()
_, _, fixErr := model.FixAbility()
if fixErr != nil {
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
}
}
}()
@@ -172,6 +171,8 @@ func InitResources() error {
// 加载环境变量
common.InitEnv()
common.SetupLogger()
// Initialize model settings
ratio_setting.InitRatioSettings()
+3 -3
View File
@@ -247,9 +247,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
}
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
c.Set("channel_create_time", channel.CreatedTime)
c.Set("channel_setting", channel.GetSetting())
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
@@ -258,7 +258,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case constant.ChannelTypeAzure:
+33 -61
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"one-api/common"
"strings"
"sync"
"github.com/samber/lo"
"gorm.io/gorm"
@@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
}
func FixAbility() (int, error) {
var channelIds []int
count := 0
// Find all channel ids from channel table
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
var fixLock = sync.Mutex{}
func FixAbility() (int, int, error) {
lock := fixLock.TryLock()
if !lock {
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
}
defer fixLock.Unlock()
var channels []*Channel
// Find all channels
err := DB.Model(&Channel{}).Find(&channels).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
return 0, 0, err
}
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
if len(channelIds) > 0 {
// Process deletion in chunks to avoid "too many placeholders" error
for _, chunk := range lo.Chunk(channelIds, 100) {
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
return 0, err
}
}
} else {
// If no channels exist, delete all abilities
err = DB.Delete(&Ability{}).Error
if len(channels) == 0 {
return 0, 0, nil
}
successCount := 0
failCount := 0
for _, chunk := range lo.Chunk(channels, 50) {
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
// Delete all abilities of this channel
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
return 0, err
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
failCount += len(chunk)
continue
}
common.SysLog("Delete all abilities successfully")
return 0, nil
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return count, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
// Process query in chunks to avoid "too many placeholders" error
err = nil
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
var channelsChunk []Channel
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
// Then add new abilities
for _, channel := range chunk {
err = channel.AddAbilities()
if err != nil {
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
return count, err
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++
} else {
successCount++
}
channels = append(channels, channelsChunk...)
}
}
for _, channel := range channels {
err := channel.UpdateAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
count++
}
}
InitChannelCache()
return count, nil
return successCount, failCount, nil
}
+15 -3
View File
@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"one-api/common"
"one-api/dto"
"strings"
"sync"
@@ -514,8 +515,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
return tags, nil
}
func (channel *Channel) GetSetting() map[string]interface{} {
setting := make(map[string]interface{})
func (channel *Channel) ValidateSettings() error {
channelParams := &dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
if err != nil {
return err
}
}
return nil
}
func (channel *Channel) GetSetting() dto.ChannelSettings {
setting := dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
@@ -525,7 +537,7 @@ func (channel *Channel) GetSetting() map[string]interface{} {
return setting
}
func (channel *Channel) SetSetting(setting map[string]interface{}) {
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
+35 -26
View File
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"one-api/common"
"one-api/constant"
"os"
"strings"
"time"
@@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
if vb, ok := v.(bool); ok && vb {
needRecordIp = true
}
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
@@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
}
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
type RecordConsumeLogParams struct {
ChannelId int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
TokenName string `json:"token_name"`
Quota int `json:"quota"`
Content string `json:"content"`
TokenId int `json:"token_id"`
UserQuota int `json:"user_quota"`
UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"`
Group string `json:"group"`
Other map[string]interface{} `json:"other"`
}
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
if !common.LogConsumeEnabled {
return
}
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
otherStr := common.MapToJsonStr(params.Other)
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
if vb, ok := v.(bool); ok && vb {
needRecordIp = true
}
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
@@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
ChannelId: channelId,
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Content: params.Content,
PromptTokens: params.PromptTokens,
CompletionTokens: params.CompletionTokens,
TokenName: params.TokenName,
ModelName: params.ModelName,
Quota: params.Quota,
ChannelId: params.ChannelId,
TokenId: params.TokenId,
UseTime: params.UseTimeSeconds,
IsStream: params.IsStream,
Group: params.Group,
Ip: func() string {
if needRecordIp {
return c.ClientIP()
@@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
}
if common.DataExportEnabled {
gopool.Go(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
})
}
}
+2
View File
@@ -14,6 +14,8 @@ type Midjourney struct {
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
VideoUrls string `json:"video_urls"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`
+16 -9
View File
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/dto"
"strconv"
"strings"
@@ -68,14 +69,18 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
func (user *User) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
}
return common.StrToMap(user.Setting)
return setting
}
func (user *User) SetSetting(setting map[string]interface{}) {
func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
@@ -626,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
@@ -648,10 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
return settingMap, err
}
return common.StrToMap(setting), nil
userBase := &UserBase{
Setting: setting,
}
return userBase.GetSetting(), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
+11 -16
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"time"
"github.com/gin-gonic/gin"
@@ -32,20 +33,15 @@ func (user *UserBase) WriteContext(c *gin.Context) {
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
func (user *UserBase) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
}
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
return setting
}
// getUserCacheKey returns the key for user cache
@@ -174,11 +170,10 @@ func getUserNameCache(userId int) (string, error) {
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
func getUserSettingCache(userId int) (dto.UserSetting, error) {
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
return dto.UserSetting{}, err
}
return cache.GetSetting(), nil
}
+2 -2
View File
@@ -206,8 +206,8 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
+10 -1
View File
@@ -42,7 +42,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 0 || keyParts[0] == "" {
return errors.New("invalid API key: authorization token is required")
}
if len(keyParts) > 1 {
if keyParts[1] != "" {
req.Set("appid", keyParts[1])
}
}
req.Set("Authorization", "Bearer "+keyParts[0])
return nil
}
+2 -2
View File
@@ -278,8 +278,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error // 声明 err 变量
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
+2 -2
View File
@@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
if info.ChannelSetting.ThinkingToContent {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
@@ -145,7 +145,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
header.Set("Authorization", "Bearer "+info.ApiKey)
}
if info.ChannelType == constant.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
header.Set("HTTP-Referer", "https://www.newapi.ai")
header.Set("X-Title", "New API")
}
return nil
+6 -6
View File
@@ -124,12 +124,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var forceFormat bool
var thinkToContent bool
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
if info.ChannelSetting.ForceFormat {
forceFormat = true
}
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
thinkToContent = think2Content
if info.ChannelSetting.ThinkingToContent {
thinkToContent = true
}
var (
@@ -200,8 +200,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
}
forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
if info.ChannelSetting.ForceFormat {
forceFormat = true
}
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
+2 -2
View File
@@ -106,8 +106,8 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
var client *http.Client
var err error
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return "", fmt.Errorf("new proxy http client failed: %w", err)
}
+18 -10
View File
@@ -97,9 +97,9 @@ type RelayInfo struct {
IsFirstRequest bool
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
ChannelSetting dto.ChannelSettings
ParamOverride map[string]interface{}
UserSetting map[string]interface{}
UserSetting dto.UserSetting
UserEmail string
UserQuota int
RelayFormat string
@@ -213,7 +213,6 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
@@ -227,7 +226,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
info := &RelayInfo{
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
@@ -246,12 +244,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
//RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelCreateTime: c.GetInt64("channel_create_time"),
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
@@ -277,6 +275,16 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
if ok {
info.ChannelSetting = channelSetting
}
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
if ok {
info.UserSetting = userSetting
}
return info
}
+6
View File
@@ -29,6 +29,8 @@ const (
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeMidjourneyUpload
RelayModeMidjourneyVideo
RelayModeMidjourneyEdits
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
@@ -108,6 +110,10 @@ func Path2RelayModeMidjourney(path string) int {
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/video") {
relayMode = RelayModeMidjourneyVideo
} else if strings.HasSuffix(path, "/mj/submit/edits") {
relayMode = RelayModeMidjourneyEdits
} else if strings.HasSuffix(path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend
} else if strings.HasSuffix(path, "/mj/submit/describe") {
+2 -6
View File
@@ -3,7 +3,6 @@ package helper
import (
"fmt"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting/ratio_setting"
@@ -83,11 +82,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
if !success {
acceptUnsetRatio := false
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
b, ok := accept.(bool)
if ok {
acceptUnsetRatio = b
}
if info.UserSetting.AcceptUnsetRatioModel {
acceptUnsetRatio = true
}
if !acceptUnsetRatio {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
+57 -14
View File
@@ -34,14 +34,13 @@ func RelayMidjourneyImage(c *gin.Context) {
}
var httpClient *http.Client
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
if proxy, ok := channel.GetSetting()["proxy"]; ok {
if proxyURL, ok := proxy.(string); ok && proxyURL != "" {
if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil {
c.JSON(400, gin.H{
"error": "proxy_url_invalid",
})
return
}
proxy := channel.GetSetting().Proxy
if proxy != "" {
if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
c.JSON(400, gin.H{
"error": "proxy_url_invalid",
})
return
}
}
}
@@ -106,6 +105,9 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.VideoUrl = midjRequest.VideoUrl
videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls)
midjourneyTask.VideoUrls = string(videoUrlsStr)
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
@@ -136,6 +138,9 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
} else {
midjourneyTask.ImageUrl = originTask.ImageUrl
}
if originTask.VideoUrl != "" {
midjourneyTask.VideoUrl = originTask.VideoUrl
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
@@ -148,6 +153,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Buttons = buttons
}
}
if originTask.VideoUrls != "" {
var videoUrls []dto.ImgUrls
err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls)
if err == nil {
midjourneyTask.VideoUrls = videoUrls
}
}
if originTask.Properties != "" {
var properties dto.Properties
err := json.Unmarshal([]byte(originTask.Properties), &properties)
@@ -162,7 +174,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
startTime := time.Now().UnixNano() / int64(time.Millisecond)
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
//group := c.GetString("group")
channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
var swapFaceRequest dto.SwapFaceRequest
@@ -208,8 +220,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: channelId,
ModelName: modelName,
TokenName: tokenName,
Quota: priceData.Quota,
Content: logContent,
TokenId: tokenId,
UserQuota: userQuota,
Group: relayInfo.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
}
@@ -350,7 +371,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
tokenId := c.GetInt("token_id")
//tokenId := c.GetInt("token_id")
//channelType := c.GetInt("channel")
userId := c.GetInt("id")
group := c.GetString("group")
@@ -370,6 +391,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
relayMode = relayconstant.RelayModeMidjourneyChange
}
if relayMode == relayconstant.RelayModeMidjourneyVideo {
midjRequest.Action = constant.MjActionVideo
}
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
@@ -378,6 +402,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
midjRequest.Action = constant.MjActionImagine
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = constant.MjActionDescribe
} else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
midjRequest.Action = constant.MjActionEdits
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
@@ -412,6 +438,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//}
mjId = midjRequest.TaskId
midjRequest.Action = constant.MjActionModal
} else if relayMode == relayconstant.RelayModeMidjourneyVideo {
midjRequest.Action = constant.MjActionVideo
if midjRequest.TaskId == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
}
mjId = midjRequest.TaskId
}
originTask := model.GetByMJId(userId, mjId)
@@ -492,8 +526,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: channelId,
ModelName: modelName,
TokenName: tokenName,
Quota: priceData.Quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
Group: group,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
}
+15 -2
View File
@@ -540,6 +540,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other["audio_input_token_count"] = audioTokens
other["audio_input_price"] = audioInputPrice
}
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
+11 -2
View File
@@ -139,8 +139,17 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
if hasUserGroupRatio {
other["user_group_ratio"] = userGroupRatio
}
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
Group: relayInfo.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
+2
View File
@@ -103,6 +103,8 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
relayMjRouter.POST("/submit/edits", controller.RelayMidjourney)
relayMjRouter.POST("/submit/video", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
+6 -1
View File
@@ -3,7 +3,6 @@ package service
import (
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
@@ -15,6 +14,8 @@ import (
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func CoverActionToModelName(mjAction string) string {
@@ -38,6 +39,10 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
switch relayMode {
case relayconstant.RelayModeMidjourneyImagine:
action = constant.MjActionImagine
case relayconstant.RelayModeMidjourneyVideo:
action = constant.MjActionVideo
case relayconstant.RelayModeMidjourneyEdits:
action = constant.MjActionEdits
case relayconstant.RelayModeMidjourneyDescribe:
action = constant.MjActionDescribe
case relayconstant.RelayModeMidjourneyBlend:
+48 -8
View File
@@ -209,8 +209,21 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.InputTokens,
CompletionTokens: usage.OutputTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
@@ -286,8 +299,22 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
@@ -384,8 +411,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
@@ -447,8 +487,8 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
gopool.Go(func() {
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
if userSetting.QuotaWarningThreshold != 0 {
threshold = int(userSetting.QuotaWarningThreshold)
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
+10 -22
View File
@@ -3,7 +3,6 @@ package service
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"strings"
@@ -17,10 +16,10 @@ func NotifyRootUser(t string, subject string, content string) {
}
}
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error {
notifyType := userSetting.NotifyType
if notifyType == "" {
notifyType = dto.NotifyTypeEmail
}
// Check notification limit
@@ -34,34 +33,23 @@ func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}
}
switch notifyType {
case constant.NotifyTypeEmail:
case dto.NotifyTypeEmail:
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
userEmail = userSetting.NotificationEmail
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
case dto.NotifyTypeWebhook:
webhookURLStr := userSetting.WebhookUrl
if webhookURLStr == "" {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil
}
// 获取 webhook secret
var webhookSecret string
if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
webhookSecret, _ = secret.(string)
}
webhookSecret := userSetting.WebhookSecret
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
}
return nil
+4 -1
View File
@@ -6,8 +6,11 @@ import (
)
var Chats = []map[string]string{
//{
// "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
//},
{
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
"Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
+2
View File
@@ -231,7 +231,9 @@ var defaultModelPrice = map[string]float64{
"dall-e-3": 0.04,
"imagen-3.0-generate-002": 0.03,
"gpt-4-gizmo-*": 0.1,
"mj_video": 0.8,
"mj_imagine": 0.1,
"mj_edits": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
+2 -2
View File
@@ -1461,9 +1461,9 @@ const ChannelsTable = () => {
const fixChannelsAbilities = async () => {
const res = await API.post(`/api/channel/fix`);
const { success, message, data } = res.data;
const { success, message, data } = res.data;
if (success) {
showSuccess(t('已修复 ${data} 个通道!').replace('${data}', data));
showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
await refresh();
} else {
showError(message);
+12
View File
@@ -195,6 +195,18 @@ const LogsTable = () => {
{t('放大')}
</Tag>
);
case 'VIDEO':
return (
<Tag color='orange' size='large' shape='circle' prefixIcon={<Video size={14} />}>
{t('视频')}
</Tag>
);
case 'EDITS':
return (
<Tag color='orange' size='large' shape='circle' prefixIcon={<Video size={14} />}>
{t('编辑')}
</Tag>
);
case 'VARIATION':
return (
<Tag color='purple' size='large' shape='circle' prefixIcon={<Shuffle size={14} />}>
+16 -3
View File
@@ -432,9 +432,22 @@ const TokensTable = () => {
if (serverAddress === '') {
serverAddress = window.location.origin;
}
let encodedServerAddress = encodeURIComponent(serverAddress);
url = url.replaceAll('{address}', encodedServerAddress);
url = url.replaceAll('{key}', 'sk-' + record.key);
if (url.includes('{cherryConfig}') === true) {
let cherryConfig = {
id: 'new-api',
baseUrl: serverAddress,
apiKey: 'sk-' + record.key,
}
// 替换 {cherryConfig} 为base64编码的JSON字符串
let encodedConfig = encodeURIComponent(
btoa(JSON.stringify(cherryConfig))
);
url = url.replaceAll('{cherryConfig}', encodedConfig);
} else {
let encodedServerAddress = encodeURIComponent(serverAddress);
url = url.replaceAll('{address}', encodedServerAddress);
url = url.replaceAll('{key}', 'sk-' + record.key);
}
window.open(url, '_blank');
};
+3 -1
View File
@@ -153,6 +153,8 @@ const EditChannel = (props) => {
localModels = [
'swap_face',
'mj_imagine',
'mj_video',
'mj_edits',
'mj_variation',
'mj_reroll',
'mj_blend',
@@ -238,7 +240,7 @@ const EditChannel = (props) => {
if (isEdit) {
// 如果是编辑模式,使用已有的channel id获取模型列表
const res = await API.get('/api/channel/fetch_models/' + channelId);
if (res.data && res.data?.success) {
if (res.data && res.data.success) {
models.push(...res.data.data);
} else {
err = true;
+2
View File
@@ -77,6 +77,8 @@ const EditTagModal = (props) => {
localModels = [
'swap_face',
'mj_imagine',
'mj_video',
'mj_edits',
'mj_variation',
'mj_reroll',
'mj_blend',