fix(openai): align image streaming relay governance
Route OpenAI image streaming through shared stream handling, split image/realtime/usage helpers for maintainability, and include the related image request and rate limit updates.
This commit is contained in:
+2
-2
@@ -112,11 +112,11 @@ func InitEnv() {
|
|||||||
|
|
||||||
// Initialize rate limit variables
|
// Initialize rate limit variables
|
||||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 360)
|
||||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||||
|
|
||||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 120)
|
||||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||||
|
|
||||||
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
|
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
|
||||||
|
|||||||
+2
-2
@@ -26,7 +26,7 @@ type ImageRequest struct {
|
|||||||
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
||||||
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
Images json.RawMessage `json:"images,omitempty"`
|
Images json.RawMessage `json:"images,omitempty"`
|
||||||
Mask json.RawMessage `json:"mask,omitempty"`
|
Mask json.RawMessage `json:"mask,omitempty"`
|
||||||
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
||||||
@@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||||
return i.Stream
|
return i.Stream != nil && *i.Stream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *ImageRequest) SetModelName(modelName string) {
|
func (i *ImageRequest) SetModelName(modelName string) {
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestImageRequestStreamJSON verifies that image requests preserve stream=true.
|
|
||||||
func TestImageRequestStreamJSON(t *testing.T) {
|
|
||||||
var req ImageRequest
|
|
||||||
require.NoError(t, req.UnmarshalJSON([]byte(`{"model":"gpt-image-1","prompt":"draw a cat","stream":true}`)))
|
|
||||||
|
|
||||||
require.True(t, req.Stream)
|
|
||||||
require.True(t, req.IsStream(nil))
|
|
||||||
}
|
|
||||||
@@ -632,7 +632,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
usage, err = OpenaiImageStreamHandler(c, info, resp)
|
usage, err = OpenaiImageStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
usage, err = OpenaiHandlerWithUsage(c, info, resp)
|
usage, err = OpenaiImageHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
usage, err = common_handler.RerankHandler(c, info, resp)
|
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||||
|
|||||||
@@ -16,14 +16,18 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestConvertImageEditRequestKeepsValidMultipartStreamFields verifies multipart replay.
|
// TestConvertImageEditRequestMultipart verifies that ConvertImageRequest
|
||||||
func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) {
|
// re-serializes multipart image edit requests with all fields (including
|
||||||
|
// stream) and the file intact, both when the form was already parsed and when
|
||||||
|
// it must be re-parsed from the reusable body.
|
||||||
|
func TestConvertImageEditRequestMultipart(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
newMultipartContext := func(t *testing.T, prompt string) *gin.Context {
|
||||||
var body bytes.Buffer
|
var body bytes.Buffer
|
||||||
writer := multipart.NewWriter(&body)
|
writer := multipart.NewWriter(&body)
|
||||||
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||||
require.NoError(t, writer.WriteField("prompt", "edit this image"))
|
require.NoError(t, writer.WriteField("prompt", prompt))
|
||||||
require.NoError(t, writer.WriteField("stream", "true"))
|
require.NoError(t, writer.WriteField("stream", "true"))
|
||||||
require.NoError(t, writer.WriteField("partial_images", "3"))
|
require.NoError(t, writer.WriteField("partial_images", "3"))
|
||||||
part, err := writer.CreateFormFile("image", "input.png")
|
part, err := writer.CreateFormFile("image", "input.png")
|
||||||
@@ -32,34 +36,33 @@ func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, writer.Close())
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
convertAndReplay := func(t *testing.T, c *gin.Context, prompt string) {
|
||||||
info := &relaycommon.RelayInfo{
|
info := &relaycommon.RelayInfo{
|
||||||
RelayMode: relayconstant.RelayModeImagesEdits,
|
RelayMode: relayconstant.RelayModeImagesEdits,
|
||||||
}
|
}
|
||||||
request := dto.ImageRequest{
|
request := dto.ImageRequest{
|
||||||
Model: "gpt-image-1",
|
Model: "gpt-image-1",
|
||||||
Prompt: "edit this image",
|
Prompt: prompt,
|
||||||
Stream: true,
|
Stream: common.GetPointer(true),
|
||||||
}
|
}
|
||||||
|
|
||||||
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
|
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
convertedBody, ok := converted.(*bytes.Buffer)
|
convertedBody, ok := converted.(*bytes.Buffer)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
|
||||||
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
|
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
|
||||||
replayedRequest.Header.Set("Content-Type", contentType)
|
replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
|
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
|
||||||
|
|
||||||
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
|
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
|
||||||
require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt"))
|
require.Equal(t, prompt, replayedRequest.PostForm.Get("prompt"))
|
||||||
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
|
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
|
||||||
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
|
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
|
||||||
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
|
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
|
||||||
@@ -70,27 +73,19 @@ func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) {
|
|||||||
fileBytes, err := io.ReadAll(file)
|
fileBytes, err := io.ReadAll(file)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []byte("fake image"), fileBytes)
|
require.Equal(t, []byte("fake image"), fileBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing verifies fallback parsing.
|
t.Run("with pre-parsed form", func(t *testing.T) {
|
||||||
func TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing(t *testing.T) {
|
prompt := "edit this image"
|
||||||
gin.SetMode(gin.TestMode)
|
c := newMultipartContext(t, prompt)
|
||||||
|
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
|
||||||
|
|
||||||
var body bytes.Buffer
|
convertAndReplay(t, c, prompt)
|
||||||
writer := multipart.NewWriter(&body)
|
})
|
||||||
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
|
||||||
require.NoError(t, writer.WriteField("prompt", "edit without pre-parsed form"))
|
|
||||||
require.NoError(t, writer.WriteField("stream", "true"))
|
|
||||||
part, err := writer.CreateFormFile("image", "input.png")
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = part.Write([]byte("fake image"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, writer.Close())
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
t.Run("re-parses reusable body when form is missing", func(t *testing.T) {
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
prompt := "edit without pre-parsed form"
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
c := newMultipartContext(t, prompt)
|
||||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
|
||||||
|
|
||||||
storage, err := common.GetBodyStorage(c)
|
storage, err := common.GetBodyStorage(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -98,24 +93,6 @@ func TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing(t *test
|
|||||||
c.Request.MultipartForm = nil
|
c.Request.MultipartForm = nil
|
||||||
c.Request.PostForm = nil
|
c.Request.PostForm = nil
|
||||||
|
|
||||||
info := &relaycommon.RelayInfo{
|
convertAndReplay(t, c, prompt)
|
||||||
RelayMode: relayconstant.RelayModeImagesEdits,
|
})
|
||||||
}
|
|
||||||
request := dto.ImageRequest{
|
|
||||||
Model: "gpt-image-1",
|
|
||||||
Prompt: "edit without pre-parsed form",
|
|
||||||
Stream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
convertedBody, ok := converted.(*bytes.Buffer)
|
|
||||||
require.True(t, ok)
|
|
||||||
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
|
|
||||||
replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
|
|
||||||
require.Equal(t, "edit without pre-parsed form", replayedRequest.PostForm.Get("prompt"))
|
|
||||||
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
|
|
||||||
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,13 +8,34 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/constant"
|
"github.com/QuantumNous/new-api/constant"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/relay/helper"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func newImageTestContext(t *testing.T, body, contentType string, isStream bool) (*gin.Context, *httptest.ResponseRecorder, *http.Response, *relaycommon.RelayInfo) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{contentType}},
|
||||||
|
}
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
IsStream: isStream,
|
||||||
|
}
|
||||||
|
return c, recorder, resp, info
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpenaiImageStreamHandlerForwardsSSEAndUsage covers the core SSE path:
|
||||||
|
// chunks are forwarded with rebuilt event lines, usage is extracted and
|
||||||
|
// normalized (input_tokens -> prompt_tokens with details), and [DONE] is
|
||||||
|
// re-emitted to the client.
|
||||||
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
||||||
oldMode := gin.Mode()
|
oldMode := gin.Mode()
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
@@ -34,19 +55,7 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
|||||||
``,
|
``,
|
||||||
}, "\n")
|
}, "\n")
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
||||||
|
|
||||||
resp := &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
||||||
}
|
|
||||||
info := &relaycommon.RelayInfo{
|
|
||||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
||||||
IsStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -62,36 +71,8 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
|||||||
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenaiImageStreamHandlerForwardsLargeSSELine(t *testing.T) {
|
// TestOpenaiImageStreamHandlerWrapsJSONResponse covers the non-SSE fallback:
|
||||||
oldMode := gin.Mode()
|
// a JSON upstream response is wrapped into pseudo-SSE completed events.
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
||||||
|
|
||||||
payload := strings.Repeat("x", helper.DefaultMaxScannerBufferSize+1)
|
|
||||||
body := "data: " + payload + "\n\n"
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
||||||
|
|
||||||
resp := &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
||||||
}
|
|
||||||
info := &relaycommon.RelayInfo{
|
|
||||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
||||||
IsStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.NotNil(t, usage)
|
|
||||||
require.Contains(t, recorder.Body.String(), payload)
|
|
||||||
require.NotNil(t, info.StreamStatus)
|
|
||||||
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
||||||
oldMode := gin.Mode()
|
oldMode := gin.Mode()
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
@@ -99,19 +80,7 @@ func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
|||||||
|
|
||||||
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
|
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
||||||
|
|
||||||
resp := &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
||||||
}
|
|
||||||
info := &relaycommon.RelayInfo{
|
|
||||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
||||||
IsStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -129,28 +98,20 @@ func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
|||||||
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenaiHandlerWithUsageReturnsImageJSONError(t *testing.T) {
|
// TestOpenaiImageHandlersReturnJSONError covers JSON error responses for both
|
||||||
|
// entry points: the non-streaming handler and the stream handler's non-SSE
|
||||||
|
// fallback. Neither must leak the error body to the client.
|
||||||
|
func TestOpenaiImageHandlersReturnJSONError(t *testing.T) {
|
||||||
oldMode := gin.Mode()
|
oldMode := gin.Mode()
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
t.Run("non-streaming handler", func(t *testing.T) {
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, recorder, resp, info := newImageTestContext(t, body, "application/json", false)
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
||||||
|
|
||||||
resp := &http.Response{
|
usage, err := OpenaiImageHandler(c, info, resp)
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
||||||
}
|
|
||||||
info := &relaycommon.RelayInfo{
|
|
||||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
||||||
IsStream: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, err := OpenaiHandlerWithUsage(c, info, resp)
|
|
||||||
require.Nil(t, usage)
|
require.Nil(t, usage)
|
||||||
require.NotNil(t, err)
|
require.NotNil(t, err)
|
||||||
require.Equal(t, http.StatusOK, err.StatusCode)
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||||
@@ -159,43 +120,32 @@ func TestOpenaiHandlerWithUsageReturnsImageJSONError(t *testing.T) {
|
|||||||
require.Equal(t, "upstream_error", oaiError.Type)
|
require.Equal(t, "upstream_error", oaiError.Type)
|
||||||
require.Equal(t, "content_moderation_failed", oaiError.Code)
|
require.Equal(t, "content_moderation_failed", oaiError.Code)
|
||||||
require.Empty(t, recorder.Body.String())
|
require.Empty(t, recorder.Body.String())
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestOpenaiImageStreamHandlerReturnsJSONErrorFallback(t *testing.T) {
|
t.Run("stream handler JSON fallback", func(t *testing.T) {
|
||||||
oldMode := gin.Mode()
|
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
||||||
|
|
||||||
body := `{"error":{"message":"image edit failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
||||||
|
|
||||||
resp := &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
||||||
}
|
|
||||||
info := &relaycommon.RelayInfo{
|
|
||||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
||||||
IsStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
require.Nil(t, usage)
|
require.Nil(t, usage)
|
||||||
require.NotNil(t, err)
|
require.NotNil(t, err)
|
||||||
require.Equal(t, http.StatusOK, err.StatusCode)
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||||
oaiError := err.ToOpenAIError()
|
require.Equal(t, "content moderation failed", err.ToOpenAIError().Message)
|
||||||
require.Equal(t, "image edit failed", oaiError.Message)
|
|
||||||
require.Empty(t, recorder.Body.String())
|
require.Empty(t, recorder.Body.String())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent verifies that an error
|
||||||
|
// event inside the SSE stream is recorded as a soft error while the payload is
|
||||||
|
// still forwarded to the client.
|
||||||
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
||||||
oldMode := gin.Mode()
|
oldMode := gin.Mode()
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
|
oldTimeout := constant.StreamingTimeout
|
||||||
|
constant.StreamingTimeout = 30
|
||||||
|
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||||
|
|
||||||
body := strings.Join([]string{
|
body := strings.Join([]string{
|
||||||
`event: image_generation.partial_image`,
|
`event: image_generation.partial_image`,
|
||||||
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
||||||
@@ -205,49 +155,19 @@ func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
|||||||
``,
|
``,
|
||||||
}, "\n")
|
}, "\n")
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
||||||
|
|
||||||
resp := &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
||||||
}
|
|
||||||
info := &relaycommon.RelayInfo{
|
|
||||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
||||||
IsStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.NotNil(t, usage)
|
require.NotNil(t, usage)
|
||||||
require.NotNil(t, info.StreamStatus)
|
require.NotNil(t, info.StreamStatus)
|
||||||
require.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
||||||
require.True(t, info.StreamStatus.HasErrors())
|
require.True(t, info.StreamStatus.HasErrors())
|
||||||
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||||
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
|
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
|
||||||
require.Contains(t, recorder.Body.String(), `event: error`)
|
// The scanner strips the upstream "event: error" line; the event name is
|
||||||
|
// rebuilt from the JSON "type" field (upstream_error). The error message
|
||||||
|
// is still forwarded in the data: payload (stream ID 77).
|
||||||
|
require.Contains(t, recorder.Body.String(), `event: upstream_error`)
|
||||||
require.Contains(t, recorder.Body.String(), `stream ID 77`)
|
require.Contains(t, recorder.Body.String(), `stream ID 77`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) {
|
|
||||||
usage := &dto.Usage{
|
|
||||||
InputTokens: 5000,
|
|
||||||
OutputTokens: 4000,
|
|
||||||
InputTokensDetails: &dto.InputTokenDetails{
|
|
||||||
CachedCreationTokens: 200,
|
|
||||||
ImageTokens: 1000,
|
|
||||||
TextTokens: 4000,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizeOpenAIUsage(usage)
|
|
||||||
|
|
||||||
require.Equal(t, 5000, usage.PromptTokens)
|
|
||||||
require.Equal(t, 4000, usage.CompletionTokens)
|
|
||||||
require.Equal(t, 9000, usage.TotalTokens)
|
|
||||||
require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens)
|
|
||||||
require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens)
|
|
||||||
require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/QuantumNous/new-api/constant"
|
"github.com/QuantumNous/new-api/constant"
|
||||||
@@ -17,12 +14,9 @@ import (
|
|||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/relay/helper"
|
"github.com/QuantumNous/new-api/relay/helper"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||||
@@ -296,672 +290,3 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
|
|
||||||
return &simpleResponse.Usage, nil
|
return &simpleResponse.Usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamTTSResponse(c *gin.Context, resp *http.Response) {
|
|
||||||
c.Writer.WriteHeaderNow()
|
|
||||||
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
logger.LogWarn(c, "streaming not supported")
|
|
||||||
_, err := io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
logger.LogWarn(c, err.Error())
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer := make([]byte, 4096)
|
|
||||||
for {
|
|
||||||
n, err := resp.Body.Read(buffer)
|
|
||||||
//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
|
|
||||||
if n > 0 {
|
|
||||||
if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
|
|
||||||
logger.LogError(c, writeErr.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if err != io.EOF {
|
|
||||||
logger.LogError(c, err.Error())
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
|
||||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
|
||||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
info.IsStream = true
|
|
||||||
clientConn := info.ClientWs
|
|
||||||
targetConn := info.TargetWs
|
|
||||||
|
|
||||||
clientClosed := make(chan struct{})
|
|
||||||
targetClosed := make(chan struct{})
|
|
||||||
sendChan := make(chan []byte, 100)
|
|
||||||
receiveChan := make(chan []byte, 100)
|
|
||||||
errChan := make(chan error, 2)
|
|
||||||
|
|
||||||
usage := &dto.RealtimeUsage{}
|
|
||||||
localUsage := &dto.RealtimeUsage{}
|
|
||||||
sumUsage := &dto.RealtimeUsage{}
|
|
||||||
|
|
||||||
gopool.Go(func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, message, err := clientConn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
||||||
errChan <- fmt.Errorf("error reading from client: %v", err)
|
|
||||||
}
|
|
||||||
close(clientClosed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
realtimeEvent := &dto.RealtimeEvent{}
|
|
||||||
err = common.Unmarshal(message, realtimeEvent)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
|
||||||
if realtimeEvent.Session != nil {
|
|
||||||
if realtimeEvent.Session.Tools != nil {
|
|
||||||
info.RealtimeTools = realtimeEvent.Session.Tools
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
||||||
localUsage.TotalTokens += textToken + audioToken
|
|
||||||
localUsage.InputTokens += textToken + audioToken
|
|
||||||
localUsage.InputTokenDetails.TextTokens += textToken
|
|
||||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
||||||
|
|
||||||
err = helper.WssString(c, targetConn, string(message))
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case sendChan <- message:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
gopool.Go(func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, message, err := targetConn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
||||||
errChan <- fmt.Errorf("error reading from target: %v", err)
|
|
||||||
}
|
|
||||||
close(targetClosed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
info.SetFirstResponseTime()
|
|
||||||
realtimeEvent := &dto.RealtimeEvent{}
|
|
||||||
err = common.Unmarshal(message, realtimeEvent)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
|
||||||
realtimeUsage := realtimeEvent.Response.Usage
|
|
||||||
if realtimeUsage != nil {
|
|
||||||
usage.TotalTokens += realtimeUsage.TotalTokens
|
|
||||||
usage.InputTokens += realtimeUsage.InputTokens
|
|
||||||
usage.OutputTokens += realtimeUsage.OutputTokens
|
|
||||||
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
|
||||||
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
|
||||||
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
|
||||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
|
||||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
|
||||||
err := preConsumeUsage(c, info, usage, sumUsage)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 本次计费完成,清除
|
|
||||||
usage = &dto.RealtimeUsage{}
|
|
||||||
|
|
||||||
localUsage = &dto.RealtimeUsage{}
|
|
||||||
} else {
|
|
||||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
||||||
localUsage.TotalTokens += textToken + audioToken
|
|
||||||
info.IsFirstRequest = false
|
|
||||||
localUsage.InputTokens += textToken + audioToken
|
|
||||||
localUsage.InputTokenDetails.TextTokens += textToken
|
|
||||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
||||||
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 本次计费完成,清除
|
|
||||||
localUsage = &dto.RealtimeUsage{}
|
|
||||||
// print now usage
|
|
||||||
}
|
|
||||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
|
||||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
||||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
||||||
|
|
||||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
|
||||||
realtimeSession := realtimeEvent.Session
|
|
||||||
if realtimeSession != nil {
|
|
||||||
// update audio format
|
|
||||||
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
|
||||||
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
||||||
localUsage.TotalTokens += textToken + audioToken
|
|
||||||
localUsage.OutputTokens += textToken + audioToken
|
|
||||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
|
||||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
|
||||||
}
|
|
||||||
|
|
||||||
err = helper.WssString(c, clientConn, string(message))
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case receiveChan <- message:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-clientClosed:
|
|
||||||
case <-targetClosed:
|
|
||||||
case err := <-errChan:
|
|
||||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
|
||||||
logger.LogError(c, "realtime error: "+err.Error())
|
|
||||||
case <-c.Done():
|
|
||||||
}
|
|
||||||
|
|
||||||
if usage.TotalTokens != 0 {
|
|
||||||
_ = preConsumeUsage(c, info, usage, sumUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
if localUsage.TotalTokens != 0 {
|
|
||||||
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check usage total tokens, if 0, use local usage
|
|
||||||
|
|
||||||
return nil, sumUsage
|
|
||||||
}
|
|
||||||
|
|
||||||
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
|
||||||
if usage == nil || totalUsage == nil {
|
|
||||||
return fmt.Errorf("invalid usage pointer")
|
|
||||||
}
|
|
||||||
|
|
||||||
totalUsage.TotalTokens += usage.TotalTokens
|
|
||||||
totalUsage.InputTokens += usage.InputTokens
|
|
||||||
totalUsage.OutputTokens += usage.OutputTokens
|
|
||||||
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
|
||||||
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
|
||||||
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
|
||||||
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
|
||||||
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
|
||||||
// clear usage
|
|
||||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
||||||
defer service.CloseResponseBodyGracefully(resp)
|
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
var usageResp dto.SimpleResponse
|
|
||||||
err = common.Unmarshal(responseBody, &usageResp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
|
||||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 写入新的 response body
|
|
||||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
|
||||||
|
|
||||||
normalizeOpenAIUsage(&usageResp.Usage)
|
|
||||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
|
||||||
return &usageResp.Usage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeOpenAIUsage(usage *dto.Usage) {
|
|
||||||
if usage == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if usage.InputTokens != 0 {
|
|
||||||
usage.PromptTokens = usage.InputTokens
|
|
||||||
}
|
|
||||||
if usage.OutputTokens != 0 {
|
|
||||||
usage.CompletionTokens = usage.OutputTokens
|
|
||||||
}
|
|
||||||
if usage.InputTokensDetails != nil {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
|
||||||
usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens
|
|
||||||
usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens
|
|
||||||
usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens
|
|
||||||
usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens
|
|
||||||
}
|
|
||||||
if usage.TotalTokens == 0 {
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
||||||
if resp == nil || resp.Body == nil {
|
|
||||||
logger.LogError(c, "invalid image stream response")
|
|
||||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
|
||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
return OpenaiHandlerWithUsage(c, info, resp)
|
|
||||||
}
|
|
||||||
if !strings.Contains(contentType, "text/event-stream") {
|
|
||||||
return OpenaiImageJSONAsStreamHandler(c, info, resp)
|
|
||||||
}
|
|
||||||
defer service.CloseResponseBodyGracefully(resp)
|
|
||||||
|
|
||||||
usage := &dto.Usage{}
|
|
||||||
var lastStreamData []byte
|
|
||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
|
||||||
if info != nil && info.StreamStatus == nil {
|
|
||||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
|
||||||
}
|
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
|
||||||
currentEvent := ""
|
|
||||||
var readErr error
|
|
||||||
for {
|
|
||||||
line, err := reader.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
readErr = err
|
|
||||||
if len(line) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
line = strings.TrimSuffix(line, "\n")
|
|
||||||
line = strings.TrimSuffix(line, "\r")
|
|
||||||
if strings.HasPrefix(line, "event:") {
|
|
||||||
currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
|
||||||
} else if strings.HasPrefix(line, "data:") {
|
|
||||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
||||||
if data == "[DONE]" {
|
|
||||||
if info != nil && info.StreamStatus != nil {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
|
||||||
}
|
|
||||||
} else if data != "" {
|
|
||||||
if info != nil {
|
|
||||||
info.SetFirstResponseTime()
|
|
||||||
info.ReceivedResponseCount++
|
|
||||||
}
|
|
||||||
lastStreamData = common.StringToByteSlice(data)
|
|
||||||
if info != nil && info.StreamStatus != nil && isOpenAIImageStreamErrorEvent(currentEvent, lastStreamData) {
|
|
||||||
info.StreamStatus.RecordError(extractOpenAIImageStreamErrorMessage(lastStreamData))
|
|
||||||
}
|
|
||||||
var usageResp dto.SimpleResponse
|
|
||||||
if err := common.Unmarshal(lastStreamData, &usageResp); err == nil {
|
|
||||||
normalizeOpenAIUsage(&usageResp.Usage)
|
|
||||||
if service.ValidUsage(&usageResp.Usage) {
|
|
||||||
usage = &usageResp.Usage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, err := c.Writer.Write(append([]byte(line), '\n')); err != nil {
|
|
||||||
if info != nil && info.StreamStatus != nil {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
|
||||||
}
|
|
||||||
return usage, nil
|
|
||||||
}
|
|
||||||
if line == "" {
|
|
||||||
if err := helper.FlushWriter(c); err != nil {
|
|
||||||
if info != nil && info.StreamStatus != nil {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
|
||||||
}
|
|
||||||
return usage, nil
|
|
||||||
}
|
|
||||||
currentEvent = ""
|
|
||||||
}
|
|
||||||
if readErr != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if info != nil && info.StreamStatus != nil {
|
|
||||||
if readErr != nil && readErr != io.EOF {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, readErr)
|
|
||||||
} else if info.StreamStatus.HasErrors() {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("upstream image stream returned error event"))
|
|
||||||
} else if info.StreamStatus.EndReason == relaycommon.StreamEndReasonNone {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = helper.FlushWriter(c)
|
|
||||||
|
|
||||||
applyUsagePostProcessing(info, usage, lastStreamData)
|
|
||||||
return usage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isOpenAIImageStreamErrorEvent(eventName string, data []byte) bool {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(eventName), "error") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if !json.Valid(data) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var payload struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Error json.RawMessage `json:"error"`
|
|
||||||
}
|
|
||||||
if err := common.Unmarshal(data, &payload); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
payloadType := strings.ToLower(strings.TrimSpace(payload.Type))
|
|
||||||
return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractOpenAIImageStreamErrorMessage(data []byte) string {
|
|
||||||
if len(data) == 0 || !json.Valid(data) {
|
|
||||||
return "upstream image stream returned error event"
|
|
||||||
}
|
|
||||||
var payload struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Error json.RawMessage `json:"error"`
|
|
||||||
}
|
|
||||||
if err := common.Unmarshal(data, &payload); err != nil {
|
|
||||||
return "upstream image stream returned error event"
|
|
||||||
}
|
|
||||||
if msg := strings.TrimSpace(payload.Message); msg != "" {
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
if len(payload.Error) > 0 {
|
|
||||||
var nested struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
if err := common.Unmarshal(payload.Error, &nested); err == nil {
|
|
||||||
if msg := strings.TrimSpace(nested.Message); msg != "" {
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" {
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "upstream image stream returned error event"
|
|
||||||
}
|
|
||||||
|
|
||||||
func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
||||||
defer service.CloseResponseBodyGracefully(resp)
|
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
var imageResp dto.ImageResponse
|
|
||||||
if err := common.Unmarshal(responseBody, &imageResp); err != nil {
|
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
var usageResp dto.SimpleResponse
|
|
||||||
_ = common.Unmarshal(responseBody, &usageResp)
|
|
||||||
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
|
||||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
|
||||||
}
|
|
||||||
normalizeOpenAIUsage(&usageResp.Usage)
|
|
||||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
|
||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
|
|
||||||
created := imageResp.Created
|
|
||||||
if created == 0 {
|
|
||||||
created = time.Now().Unix()
|
|
||||||
}
|
|
||||||
if info != nil {
|
|
||||||
info.SetFirstResponseTime()
|
|
||||||
}
|
|
||||||
for _, image := range imageResp.Data {
|
|
||||||
payload := map[string]any{
|
|
||||||
"type": "image_generation.completed",
|
|
||||||
"created_at": created,
|
|
||||||
}
|
|
||||||
if image.Url != "" {
|
|
||||||
payload["url"] = image.Url
|
|
||||||
}
|
|
||||||
if image.B64Json != "" {
|
|
||||||
payload["b64_json"] = image.B64Json
|
|
||||||
}
|
|
||||||
if image.RevisedPrompt != "" {
|
|
||||||
payload["revised_prompt"] = image.RevisedPrompt
|
|
||||||
}
|
|
||||||
if service.ValidUsage(&usageResp.Usage) {
|
|
||||||
payload["usage"] = usageResp.Usage
|
|
||||||
}
|
|
||||||
if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil {
|
|
||||||
if info != nil && info.StreamStatus != nil {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
|
||||||
}
|
|
||||||
return &usageResp.Usage, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := writeOpenaiImageStreamDone(c); err != nil {
|
|
||||||
if info != nil && info.StreamStatus != nil {
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
|
||||||
}
|
|
||||||
return &usageResp.Usage, nil
|
|
||||||
}
|
|
||||||
if info != nil {
|
|
||||||
info.ReceivedResponseCount += len(imageResp.Data)
|
|
||||||
if info.StreamStatus == nil {
|
|
||||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
|
||||||
}
|
|
||||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
|
||||||
}
|
|
||||||
return &usageResp.Usage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error {
|
|
||||||
data, err := common.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if eventName != "" {
|
|
||||||
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return helper.FlushWriter(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeOpenaiImageStreamDone(c *gin.Context) error {
|
|
||||||
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return helper.FlushWriter(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
|
||||||
if info == nil || usage == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch info.ChannelType {
|
|
||||||
case constant.ChannelTypeDeepSeek:
|
|
||||||
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
||||||
}
|
|
||||||
case constant.ChannelTypeZhipu_v4:
|
|
||||||
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
|
||||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
|
||||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
|
||||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
|
||||||
} else if usage.PromptCacheHitTokens > 0 {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case constant.ChannelTypeMoonshot:
|
|
||||||
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
|
||||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
|
||||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
|
||||||
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
|
||||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
|
||||||
} else if usage.PromptCacheHitTokens > 0 {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case constant.ChannelTypeOpenAI:
|
|
||||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
|
||||||
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var payload struct {
|
|
||||||
Usage struct {
|
|
||||||
PromptTokensDetails struct {
|
|
||||||
CachedTokens *int `json:"cached_tokens"`
|
|
||||||
} `json:"prompt_tokens_details"`
|
|
||||||
CachedTokens *int `json:"cached_tokens"`
|
|
||||||
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
|
||||||
} `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.Unmarshal(body, &payload); err != nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
|
||||||
return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
|
||||||
}
|
|
||||||
if payload.Usage.CachedTokens != nil {
|
|
||||||
return *payload.Usage.CachedTokens, true
|
|
||||||
}
|
|
||||||
if payload.Usage.PromptCacheHitTokens != nil {
|
|
||||||
return *payload.Usage.PromptCacheHitTokens, true
|
|
||||||
}
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
|
||||||
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
|
||||||
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var payload struct {
|
|
||||||
Choices []struct {
|
|
||||||
Usage struct {
|
|
||||||
CachedTokens *int `json:"cached_tokens"`
|
|
||||||
} `json:"usage"`
|
|
||||||
} `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.Unmarshal(body, &payload); err != nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 遍历choices查找cached_tokens
|
|
||||||
for _, choice := range payload.Choices {
|
|
||||||
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
|
||||||
return *choice.Usage.CachedTokens, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
|
||||||
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var payload struct {
|
|
||||||
Timings struct {
|
|
||||||
CachedTokens *int `json:"cache_n"`
|
|
||||||
} `json:"timings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.Unmarshal(body, &payload); err != nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if payload.Timings.CachedTokens == nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return *payload.Timings.CachedTokens, true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,287 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/logger"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/relay/helper"
|
||||||
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenaiImageHandler handles non-streaming OpenAI image responses
|
||||||
|
// (generations/edits), returning the parsed usage for billing.
|
||||||
|
func OpenaiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var usageResp dto.SimpleResponse
|
||||||
|
err = common.Unmarshal(responseBody, &usageResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||||
|
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入新的 response body
|
||||||
|
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
|
normalizeOpenAIUsage(&usageResp.Usage)
|
||||||
|
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||||
|
return &usageResp.Usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeOpenAIUsage maps the OpenAI Images usage shape (input_tokens /
|
||||||
|
// output_tokens / input_tokens_details) onto the canonical prompt/completion
|
||||||
|
// fields. It is used only on the OpenAI image relay paths (generations/edits,
|
||||||
|
// streaming and non-streaming): the image API never returns prompt_tokens /
|
||||||
|
// completion_tokens, so the overwrite (=) semantics here are equivalent to the
|
||||||
|
// previous additive (+=) behavior while avoiding any future double-counting if
|
||||||
|
// both field sets are ever populated. Do not reuse this on chat/embedding paths
|
||||||
|
// without revisiting the overwrite semantics.
|
||||||
|
func normalizeOpenAIUsage(usage *dto.Usage) {
|
||||||
|
if usage == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if usage.InputTokens != 0 {
|
||||||
|
usage.PromptTokens = usage.InputTokens
|
||||||
|
}
|
||||||
|
if usage.OutputTokens != 0 {
|
||||||
|
usage.CompletionTokens = usage.OutputTokens
|
||||||
|
}
|
||||||
|
if usage.InputTokensDetails != nil {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||||
|
usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens
|
||||||
|
usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens
|
||||||
|
usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens
|
||||||
|
usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens
|
||||||
|
}
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
if resp == nil || resp.Body == nil {
|
||||||
|
logger.LogError(c, "invalid image stream response")
|
||||||
|
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return OpenaiImageHandler(c, info, resp)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentType, "text/event-stream") {
|
||||||
|
return OpenaiImageJSONAsStreamHandler(c, info, resp)
|
||||||
|
}
|
||||||
|
// Reuse the shared streaming engine (helper.StreamScannerHandler) so the
|
||||||
|
// image streaming path gets the same ping keepalive, streaming-timeout
|
||||||
|
// watchdog, client-disconnect detection, panic recovery and goroutine
|
||||||
|
// cleanup as every other relay stream. The scanner delivers only the
|
||||||
|
// "data:" payload, so the SSE "event:" line is rebuilt from the JSON "type"
|
||||||
|
// field (real OpenAI image events keep event == type).
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
var lastStreamData []byte
|
||||||
|
|
||||||
|
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||||
|
raw := common.StringToByteSlice(data)
|
||||||
|
lastStreamData = raw
|
||||||
|
if isOpenAIImageStreamErrorEvent(raw) {
|
||||||
|
// Record the error as a soft error; the scanner drives the final
|
||||||
|
// EndReason. HasErrors() flags the failure for logging/handling.
|
||||||
|
sr.Error(fmt.Errorf("%s", extractOpenAIImageStreamErrorMessage(raw)))
|
||||||
|
}
|
||||||
|
var usageResp dto.SimpleResponse
|
||||||
|
if err := common.Unmarshal(raw, &usageResp); err == nil {
|
||||||
|
normalizeOpenAIUsage(&usageResp.Usage)
|
||||||
|
if service.ValidUsage(&usageResp.Usage) {
|
||||||
|
usage = &usageResp.Usage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeOpenaiImageStreamChunk(c, raw)
|
||||||
|
})
|
||||||
|
|
||||||
|
// StreamScannerHandler consumes the upstream [DONE]; re-emit it so the
|
||||||
|
// client still receives a terminal data: [DONE].
|
||||||
|
if info != nil && info.StreamStatus != nil && info.StreamStatus.EndReason == relaycommon.StreamEndReasonDone {
|
||||||
|
helper.Done(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
applyUsagePostProcessing(info, usage, lastStreamData)
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk:
|
||||||
|
// it emits an "event:" line derived from the JSON "type" field (when present)
|
||||||
|
// followed by the verbatim "data:" payload, mirroring helper.ResponseChunkData.
|
||||||
|
func writeOpenaiImageStreamChunk(c *gin.Context, data []byte) {
|
||||||
|
var payload struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
_ = common.Unmarshal(data, &payload)
|
||||||
|
if eventName := strings.TrimSpace(payload.Type); eventName != "" {
|
||||||
|
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", eventName)})
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(data)})
|
||||||
|
_ = helper.FlushWriter(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isOpenAIImageStreamErrorEvent detects upstream error chunks by JSON content
|
||||||
|
// only ("type" of error/upstream_error, or a non-empty "error" field). The SSE
|
||||||
|
// "event:" line is not available here: StreamScannerHandler delivers only the
|
||||||
|
// "data:" payload. A payload carrying just a "message" key is deliberately NOT
|
||||||
|
// treated as an error to avoid false positives.
|
||||||
|
func isOpenAIImageStreamErrorEvent(data []byte) bool {
|
||||||
|
if !json.Valid(data) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Error json.RawMessage `json:"error"`
|
||||||
|
}
|
||||||
|
if err := common.Unmarshal(data, &payload); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
payloadType := strings.ToLower(strings.TrimSpace(payload.Type))
|
||||||
|
return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIImageStreamErrorMessage(data []byte) string {
|
||||||
|
if len(data) == 0 || !json.Valid(data) {
|
||||||
|
return "upstream image stream returned error event"
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error json.RawMessage `json:"error"`
|
||||||
|
}
|
||||||
|
if err := common.Unmarshal(data, &payload); err != nil {
|
||||||
|
return "upstream image stream returned error event"
|
||||||
|
}
|
||||||
|
if msg := strings.TrimSpace(payload.Message); msg != "" {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
if len(payload.Error) > 0 {
|
||||||
|
var nested struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
if err := common.Unmarshal(payload.Error, &nested); err == nil {
|
||||||
|
if msg := strings.TrimSpace(nested.Message); msg != "" {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "upstream image stream returned error event"
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var imageResp dto.ImageResponse
|
||||||
|
if err := common.Unmarshal(responseBody, &imageResp); err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var usageResp dto.SimpleResponse
|
||||||
|
_ = common.Unmarshal(responseBody, &usageResp)
|
||||||
|
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||||
|
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||||
|
}
|
||||||
|
normalizeOpenAIUsage(&usageResp.Usage)
|
||||||
|
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||||
|
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
|
||||||
|
created := imageResp.Created
|
||||||
|
if created == 0 {
|
||||||
|
created = time.Now().Unix()
|
||||||
|
}
|
||||||
|
if info != nil {
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
}
|
||||||
|
for _, image := range imageResp.Data {
|
||||||
|
payload := map[string]any{
|
||||||
|
"type": "image_generation.completed",
|
||||||
|
"created_at": created,
|
||||||
|
}
|
||||||
|
if image.Url != "" {
|
||||||
|
payload["url"] = image.Url
|
||||||
|
}
|
||||||
|
if image.B64Json != "" {
|
||||||
|
payload["b64_json"] = image.B64Json
|
||||||
|
}
|
||||||
|
if image.RevisedPrompt != "" {
|
||||||
|
payload["revised_prompt"] = image.RevisedPrompt
|
||||||
|
}
|
||||||
|
if service.ValidUsage(&usageResp.Usage) {
|
||||||
|
payload["usage"] = usageResp.Usage
|
||||||
|
}
|
||||||
|
if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil {
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||||
|
}
|
||||||
|
return &usageResp.Usage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writeOpenaiImageStreamDone(c); err != nil {
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||||
|
}
|
||||||
|
return &usageResp.Usage, nil
|
||||||
|
}
|
||||||
|
if info != nil {
|
||||||
|
info.ReceivedResponseCount += len(imageResp.Data)
|
||||||
|
if info.StreamStatus == nil {
|
||||||
|
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||||
|
}
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||||
|
}
|
||||||
|
return &usageResp.Usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error {
|
||||||
|
data, err := common.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if eventName != "" {
|
||||||
|
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return helper.FlushWriter(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenaiImageStreamDone(c *gin.Context) error {
|
||||||
|
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return helper.FlushWriter(c)
|
||||||
|
}
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/logger"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/relay/helper"
|
||||||
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||||
|
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||||
|
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info.IsStream = true
|
||||||
|
clientConn := info.ClientWs
|
||||||
|
targetConn := info.TargetWs
|
||||||
|
|
||||||
|
clientClosed := make(chan struct{})
|
||||||
|
targetClosed := make(chan struct{})
|
||||||
|
sendChan := make(chan []byte, 100)
|
||||||
|
receiveChan := make(chan []byte, 100)
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
usage := &dto.RealtimeUsage{}
|
||||||
|
localUsage := &dto.RealtimeUsage{}
|
||||||
|
sumUsage := &dto.RealtimeUsage{}
|
||||||
|
|
||||||
|
gopool.Go(func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_, message, err := clientConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
|
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||||
|
}
|
||||||
|
close(clientClosed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
|
err = common.Unmarshal(message, realtimeEvent)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||||
|
if realtimeEvent.Session != nil {
|
||||||
|
if realtimeEvent.Session.Tools != nil {
|
||||||
|
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||||
|
localUsage.TotalTokens += textToken + audioToken
|
||||||
|
localUsage.InputTokens += textToken + audioToken
|
||||||
|
localUsage.InputTokenDetails.TextTokens += textToken
|
||||||
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||||
|
|
||||||
|
err = helper.WssString(c, targetConn, string(message))
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case sendChan <- message:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
gopool.Go(func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_, message, err := targetConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
|
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||||
|
}
|
||||||
|
close(targetClosed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
|
err = common.Unmarshal(message, realtimeEvent)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||||
|
realtimeUsage := realtimeEvent.Response.Usage
|
||||||
|
if realtimeUsage != nil {
|
||||||
|
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||||
|
usage.InputTokens += realtimeUsage.InputTokens
|
||||||
|
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||||
|
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||||
|
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||||
|
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||||
|
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||||
|
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||||
|
err := preConsumeUsage(c, info, usage, sumUsage)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 本次计费完成,清除
|
||||||
|
usage = &dto.RealtimeUsage{}
|
||||||
|
|
||||||
|
localUsage = &dto.RealtimeUsage{}
|
||||||
|
} else {
|
||||||
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||||
|
localUsage.TotalTokens += textToken + audioToken
|
||||||
|
info.IsFirstRequest = false
|
||||||
|
localUsage.InputTokens += textToken + audioToken
|
||||||
|
localUsage.InputTokenDetails.TextTokens += textToken
|
||||||
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||||
|
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 本次计费完成,清除
|
||||||
|
localUsage = &dto.RealtimeUsage{}
|
||||||
|
// print now usage
|
||||||
|
}
|
||||||
|
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||||
|
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||||
|
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||||
|
|
||||||
|
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||||
|
realtimeSession := realtimeEvent.Session
|
||||||
|
if realtimeSession != nil {
|
||||||
|
// update audio format
|
||||||
|
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||||
|
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||||
|
localUsage.TotalTokens += textToken + audioToken
|
||||||
|
localUsage.OutputTokens += textToken + audioToken
|
||||||
|
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||||
|
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||||
|
}
|
||||||
|
|
||||||
|
err = helper.WssString(c, clientConn, string(message))
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case receiveChan <- message:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-clientClosed:
|
||||||
|
case <-targetClosed:
|
||||||
|
case err := <-errChan:
|
||||||
|
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||||
|
logger.LogError(c, "realtime error: "+err.Error())
|
||||||
|
case <-c.Done():
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.TotalTokens != 0 {
|
||||||
|
_ = preConsumeUsage(c, info, usage, sumUsage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if localUsage.TotalTokens != 0 {
|
||||||
|
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check usage total tokens, if 0, use local usage
|
||||||
|
|
||||||
|
return nil, sumUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
||||||
|
if usage == nil || totalUsage == nil {
|
||||||
|
return fmt.Errorf("invalid usage pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
totalUsage.TotalTokens += usage.TotalTokens
|
||||||
|
totalUsage.InputTokens += usage.InputTokens
|
||||||
|
totalUsage.OutputTokens += usage.OutputTokens
|
||||||
|
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
||||||
|
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
||||||
|
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
||||||
|
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
||||||
|
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
||||||
|
// clear usage
|
||||||
|
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
||||||
|
if info == nil || usage == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch info.ChannelType {
|
||||||
|
case constant.ChannelTypeDeepSeek:
|
||||||
|
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||||
|
}
|
||||||
|
case constant.ChannelTypeZhipu_v4:
|
||||||
|
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
||||||
|
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||||
|
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||||
|
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||||
|
} else if usage.PromptCacheHitTokens > 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case constant.ChannelTypeMoonshot:
|
||||||
|
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
||||||
|
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||||
|
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||||
|
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||||
|
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||||
|
} else if usage.PromptCacheHitTokens > 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case constant.ChannelTypeOpenAI:
|
||||||
|
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||||
|
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Usage struct {
|
||||||
|
PromptTokensDetails struct {
|
||||||
|
CachedTokens *int `json:"cached_tokens"`
|
||||||
|
} `json:"prompt_tokens_details"`
|
||||||
|
CachedTokens *int `json:"cached_tokens"`
|
||||||
|
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := common.Unmarshal(body, &payload); err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
||||||
|
return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
||||||
|
}
|
||||||
|
if payload.Usage.CachedTokens != nil {
|
||||||
|
return *payload.Usage.CachedTokens, true
|
||||||
|
}
|
||||||
|
if payload.Usage.PromptCacheHitTokens != nil {
|
||||||
|
return *payload.Usage.PromptCacheHitTokens, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
||||||
|
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
||||||
|
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Choices []struct {
|
||||||
|
Usage struct {
|
||||||
|
CachedTokens *int `json:"cached_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := common.Unmarshal(body, &payload); err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 遍历choices查找cached_tokens
|
||||||
|
for _, choice := range payload.Choices {
|
||||||
|
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
||||||
|
return *choice.Usage.CachedTokens, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
||||||
|
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Timings struct {
|
||||||
|
CachedTokens *int `json:"cache_n"`
|
||||||
|
} `json:"timings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := common.Unmarshal(body, &payload); err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.Timings.CachedTokens == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return *payload.Timings.CachedTokens, true
|
||||||
|
}
|
||||||
@@ -114,7 +114,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||||
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
usage, err = openai.OpenaiImageHandler(c, info, resp)
|
||||||
case constant.RelayModeResponses:
|
case constant.RelayModeResponses:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
|
usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
|
||||||
|
|||||||
@@ -15,31 +15,40 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestGetAndValidOpenAIImageRequestMultipartStream verifies reusable image edit parsing.
|
// TestGetAndValidOpenAIImageRequestMultipartStream verifies multipart image
|
||||||
|
// edit parsing: the stream field is parsed and validated, and the request body
|
||||||
|
// stays replayable for the upstream request.
|
||||||
func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
|
func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
newContext := func(t *testing.T, streamValue string, withImage bool) (*gin.Context, string) {
|
||||||
var body bytes.Buffer
|
var body bytes.Buffer
|
||||||
writer := multipart.NewWriter(&body)
|
writer := multipart.NewWriter(&body)
|
||||||
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||||
require.NoError(t, writer.WriteField("prompt", "edit this image"))
|
require.NoError(t, writer.WriteField("prompt", "edit this image"))
|
||||||
require.NoError(t, writer.WriteField("stream", "true"))
|
require.NoError(t, writer.WriteField("stream", streamValue))
|
||||||
require.NoError(t, writer.WriteField("n", "1"))
|
if withImage {
|
||||||
part, err := writer.CreateFormFile("image", "input.png")
|
part, err := writer.CreateFormFile("image", "input.png")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = part.Write([]byte("fake image"))
|
_, err = part.Write([]byte("fake image"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
require.NoError(t, writer.Close())
|
require.NoError(t, writer.Close())
|
||||||
originalBody := body.String()
|
originalBody := body.String()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
return c, originalBody
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("valid stream value keeps body replayable", func(t *testing.T) {
|
||||||
|
c, originalBody := newContext(t, "true", true)
|
||||||
|
|
||||||
req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, req.Stream)
|
require.NotNil(t, req.Stream)
|
||||||
|
require.True(t, *req.Stream)
|
||||||
require.True(t, req.IsStream(c))
|
require.True(t, req.IsStream(c))
|
||||||
|
|
||||||
bodyAfterValidation, err := io.ReadAll(c.Request.Body)
|
bodyAfterValidation, err := io.ReadAll(c.Request.Body)
|
||||||
@@ -50,24 +59,13 @@ func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "true", url.Values(form.Value).Get("stream"))
|
require.Equal(t, "true", url.Values(form.Value).Get("stream"))
|
||||||
require.Len(t, form.File["image"], 1)
|
require.Len(t, form.File["image"], 1)
|
||||||
}
|
})
|
||||||
|
|
||||||
// TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue verifies stream validation.
|
t.Run("invalid stream value is rejected", func(t *testing.T) {
|
||||||
func TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue(t *testing.T) {
|
c, _ := newContext(t, "notabool", false)
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
var body bytes.Buffer
|
|
||||||
writer := multipart.NewWriter(&body)
|
|
||||||
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
|
||||||
require.NoError(t, writer.WriteField("stream", "notabool"))
|
|
||||||
require.NoError(t, writer.Close())
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
|
||||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
|
||||||
|
|
||||||
_, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
_, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "invalid stream value")
|
require.Contains(t, err.Error(), "invalid stream value")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
||||||
DefaultMaxScannerBufferSize = 64 << 20 // 64MB (64*1024*1024) default SSE buffer size
|
DefaultMaxScannerBufferSize = 128 << 20 // 64MB (64*1024*1024) default SSE buffer size
|
||||||
DefaultPingInterval = 10 * time.Second
|
DefaultPingInterval = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid stream value: %w", err)
|
return nil, fmt.Errorf("invalid stream value: %w", err)
|
||||||
}
|
}
|
||||||
imageRequest.Stream = stream
|
imageRequest.Stream = common.GetPointer(stream)
|
||||||
}
|
}
|
||||||
if imageValue := formData.Get("image"); imageValue != "" {
|
if imageValue := formData.Get("image"); imageValue != "" {
|
||||||
imageRequest.Image, _ = common.Marshal(imageValue)
|
imageRequest.Image, _ = common.Marshal(imageValue)
|
||||||
|
|||||||
Reference in New Issue
Block a user