fix: add keepalive and timeout to OpenaiImageStreamHandler to prevent client_gone on long image generation
Docker Build / Build and Push Docker Image (push) Successful in 4m1s
Docker Build / Build and Push Docker Image (push) Successful in 4m1s
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -123,80 +124,128 @@ func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
|
||||
usage := &dto.Usage{}
|
||||
var lastStreamData []byte
|
||||
|
||||
// 使用 channel 异步读取上游 SSE 行,避免 bufio.Scanner 阻塞导致无法发送 keepalive
|
||||
type scanResult struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
lineCh := make(chan scanResult, 64)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, helper.InitialScannerBufferSize), helper.DefaultMaxScannerBufferSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
go func() {
|
||||
defer close(lineCh)
|
||||
for scanner.Scan() {
|
||||
lineCh <- scanResult{line: scanner.Text()}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
lineCh <- scanResult{err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
// keepalive 定时器:在上游长时间无数据时向客户端发送心跳,防止反向代理超时
|
||||
keepaliveInterval := 15 * time.Second
|
||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
|
||||
// 流式超时:如果超过 StreamingTimeout 秒未收到任何上游数据,终止请求
|
||||
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
||||
if streamingTimeout <= 0 {
|
||||
streamingTimeout = 300 * time.Second
|
||||
}
|
||||
timeoutTimer := time.NewTimer(streamingTimeout)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
|
||||
}
|
||||
return usage, nil
|
||||
default:
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
if len(line) < 5 || line[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
case <-keepaliveTicker.C:
|
||||
// 上游长时间无数据时发送 keepalive 心跳,防止 Nginx 等反向代理超时关闭连接
|
||||
fmt.Fprint(c.Writer, ": keepalive\n\n")
|
||||
_ = helper.FlushWriter(c)
|
||||
|
||||
data := strings.TrimSpace(line[5:])
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(data, "[DONE]") {
|
||||
case <-timeoutTimer.C:
|
||||
// 超过 StreamingTimeout 未收到上游数据,终止请求
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("streaming timeout after %d seconds", int64(streamingTimeout.Seconds())))
|
||||
}
|
||||
if err := writeOpenaiImageStreamDone(c); err != nil {
|
||||
return usage, types.NewOpenAIError(fmt.Errorf("streaming timeout after %d seconds", int64(streamingTimeout.Seconds())), types.ErrorCodeReadResponseBodyFailed, http.StatusGatewayTimeout)
|
||||
|
||||
case result, ok := <-lineCh:
|
||||
if !ok {
|
||||
// 上游关闭连接
|
||||
if info != nil {
|
||||
if info.StreamStatus == nil {
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
}
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
||||
}
|
||||
applyUsagePostProcessing(info, usage, lastStreamData)
|
||||
return usage, nil
|
||||
}
|
||||
if result.err != nil {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, result.err)
|
||||
}
|
||||
return usage, types.NewOpenAIError(result.err, types.ErrorCodeReadResponseBodyFailed, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
// 收到上游数据,重置 keepalive 和超时计时器
|
||||
keepaliveTicker.Reset(keepaliveInterval)
|
||||
timeoutTimer.Reset(streamingTimeout)
|
||||
|
||||
line := result.line
|
||||
if len(line) < 5 || line[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimSpace(line[5:])
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(data, "[DONE]") {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
}
|
||||
if err := writeOpenaiImageStreamDone(c); err != nil {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||
}
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
raw := common.StringToByteSlice(data)
|
||||
lastStreamData = raw
|
||||
if info != nil {
|
||||
info.SetFirstResponseTime()
|
||||
info.ReceivedResponseCount++
|
||||
}
|
||||
if isOpenAIImageStreamErrorEvent(raw) {
|
||||
message := extractOpenAIImageStreamErrorMessage(raw)
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.RecordError(message)
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("%s", message))
|
||||
}
|
||||
return usage, types.NewOpenAIError(fmt.Errorf("%s", message), types.ErrorCodeBadResponseBody, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
if err := common.Unmarshal(raw, &usageResp); err == nil {
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
if service.ValidUsage(&usageResp.Usage) {
|
||||
usage = &usageResp.Usage
|
||||
}
|
||||
}
|
||||
return usage, nil
|
||||
writeOpenaiImageStreamChunk(c, raw)
|
||||
}
|
||||
|
||||
raw := common.StringToByteSlice(data)
|
||||
lastStreamData = raw
|
||||
if info != nil {
|
||||
info.SetFirstResponseTime()
|
||||
info.ReceivedResponseCount++
|
||||
}
|
||||
if isOpenAIImageStreamErrorEvent(raw) {
|
||||
message := extractOpenAIImageStreamErrorMessage(raw)
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.RecordError(message)
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("%s", message))
|
||||
}
|
||||
return usage, types.NewOpenAIError(fmt.Errorf("%s", message), types.ErrorCodeBadResponseBody, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
if err := common.Unmarshal(raw, &usageResp); err == nil {
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
if service.ValidUsage(&usageResp.Usage) {
|
||||
usage = &usageResp.Usage
|
||||
}
|
||||
}
|
||||
writeOpenaiImageStreamChunk(c, raw)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
|
||||
}
|
||||
return usage, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
if info != nil {
|
||||
if info.StreamStatus == nil {
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
}
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
||||
}
|
||||
applyUsagePostProcessing(info, usage, lastStreamData)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk:
|
||||
|
||||
Reference in New Issue
Block a user