fix: add keepalive and timeout to OpenaiImageStreamHandler to prevent client_gone on long image generation
Docker Build / Build and Push Docker Image (push) Successful in 4m1s

This commit is contained in:
2026-06-22 09:50:29 +08:00
parent 15a3699d21
commit 9df4b577df
+105 -56
View File
@@ -10,6 +10,7 @@ import (
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -123,80 +124,128 @@ func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
usage := &dto.Usage{}
var lastStreamData []byte
// 使用 channel 异步读取上游 SSE 行,避免 bufio.Scanner 阻塞导致无法发送 keepalive
type scanResult struct {
line string
err error
}
lineCh := make(chan scanResult, 64)
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, helper.InitialScannerBufferSize), helper.DefaultMaxScannerBufferSize)
for scanner.Scan() {
go func() {
defer close(lineCh)
for scanner.Scan() {
lineCh <- scanResult{line: scanner.Text()}
}
if err := scanner.Err(); err != nil {
lineCh <- scanResult{err: err}
}
}()
// keepalive 定时器:在上游长时间无数据时向客户端发送心跳,防止反向代理超时
keepaliveInterval := 15 * time.Second
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
// 流式超时:如果超过 StreamingTimeout 秒未收到任何上游数据,终止请求
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if streamingTimeout <= 0 {
streamingTimeout = 300 * time.Second
}
timeoutTimer := time.NewTimer(streamingTimeout)
defer timeoutTimer.Stop()
for {
select {
case <-c.Request.Context().Done():
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
}
return usage, nil
default:
}
line := scanner.Text()
if len(line) < 5 || line[:5] != "data:" {
continue
}
case <-keepaliveTicker.C:
// 上游长时间无数据时发送 keepalive 心跳,防止 Nginx 等反向代理超时关闭连接
fmt.Fprint(c.Writer, ": keepalive\n\n")
_ = helper.FlushWriter(c)
data := strings.TrimSpace(line[5:])
if data == "" {
continue
}
if strings.HasPrefix(data, "[DONE]") {
case <-timeoutTimer.C:
// 超过 StreamingTimeout 未收到上游数据,终止请求
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("streaming timeout after %d seconds", int64(streamingTimeout.Seconds())))
}
if err := writeOpenaiImageStreamDone(c); err != nil {
return usage, types.NewOpenAIError(fmt.Errorf("streaming timeout after %d seconds", int64(streamingTimeout.Seconds())), types.ErrorCodeReadResponseBodyFailed, http.StatusGatewayTimeout)
case result, ok := <-lineCh:
if !ok {
// 上游关闭连接
if info != nil {
if info.StreamStatus == nil {
info.StreamStatus = relaycommon.NewStreamStatus()
}
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
}
applyUsagePostProcessing(info, usage, lastStreamData)
return usage, nil
}
if result.err != nil {
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, result.err)
}
return usage, types.NewOpenAIError(result.err, types.ErrorCodeReadResponseBodyFailed, http.StatusBadGateway)
}
// 收到上游数据,重置 keepalive 和超时计时器
keepaliveTicker.Reset(keepaliveInterval)
timeoutTimer.Reset(streamingTimeout)
line := result.line
if len(line) < 5 || line[:5] != "data:" {
continue
}
data := strings.TrimSpace(line[5:])
if data == "" {
continue
}
if strings.HasPrefix(data, "[DONE]") {
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
}
if err := writeOpenaiImageStreamDone(c); err != nil {
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
}
}
return usage, nil
}
raw := common.StringToByteSlice(data)
lastStreamData = raw
if info != nil {
info.SetFirstResponseTime()
info.ReceivedResponseCount++
}
if isOpenAIImageStreamErrorEvent(raw) {
message := extractOpenAIImageStreamErrorMessage(raw)
if info != nil && info.StreamStatus != nil {
info.StreamStatus.RecordError(message)
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("%s", message))
}
return usage, types.NewOpenAIError(fmt.Errorf("%s", message), types.ErrorCodeBadResponseBody, http.StatusBadGateway)
}
var usageResp dto.SimpleResponse
if err := common.Unmarshal(raw, &usageResp); err == nil {
normalizeOpenAIUsage(&usageResp.Usage)
if service.ValidUsage(&usageResp.Usage) {
usage = &usageResp.Usage
}
}
return usage, nil
writeOpenaiImageStreamChunk(c, raw)
}
raw := common.StringToByteSlice(data)
lastStreamData = raw
if info != nil {
info.SetFirstResponseTime()
info.ReceivedResponseCount++
}
if isOpenAIImageStreamErrorEvent(raw) {
message := extractOpenAIImageStreamErrorMessage(raw)
if info != nil && info.StreamStatus != nil {
info.StreamStatus.RecordError(message)
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("%s", message))
}
return usage, types.NewOpenAIError(fmt.Errorf("%s", message), types.ErrorCodeBadResponseBody, http.StatusBadGateway)
}
var usageResp dto.SimpleResponse
if err := common.Unmarshal(raw, &usageResp); err == nil {
normalizeOpenAIUsage(&usageResp.Usage)
if service.ValidUsage(&usageResp.Usage) {
usage = &usageResp.Usage
}
}
writeOpenaiImageStreamChunk(c, raw)
}
if err := scanner.Err(); err != nil {
if info != nil && info.StreamStatus != nil {
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
}
return usage, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusBadGateway)
}
if info != nil {
if info.StreamStatus == nil {
info.StreamStatus = relaycommon.NewStreamStatus()
}
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
}
applyUsagePostProcessing(info, usage, lastStreamData)
return usage, nil
}
// writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk: