Files
AI-CS/backend/websocket/client.go
T
2026-04-02 14:55:06 +08:00

183 lines
4.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package websocket
import (
"encoding/json"
"log"
"strings"
"time"
"github.com/gorilla/websocket"
)
type inboundWSMessage struct {
Type string `json:"type"`
Data map[string]interface{} `json:"data"`
}
const (
// 客户端发送 ping 的最大等待时间
writeWait = 10 * time.Second
// 从客户端读取 pong 的最大等待时间
pongWait = 60 * time.Second
// 发送 ping 的频率(必须小于 pongWait
pingPeriod = (pongWait * 9) / 10
// 最大消息大小
maxMessageSize = 512 * 1024 // 512KB
)
// Client 是一个 WebSocket 客户端
type Client struct {
hub *Hub
// WebSocket 连接
conn *websocket.Conn
// 发送消息的通道
send chan *Message
// 对话ID(这个客户端属于哪个对话)
conversationID uint
// 是否是访客(true 表示访客,false 表示客服)
isVisitor bool
// 客服ID(如果是客服连接,存储客服的用户ID)
agentID uint
}
// NewClient 创建一个新的客户端
func NewClient(hub *Hub, conn *websocket.Conn, conversationID uint, isVisitor bool, agentID uint) *Client {
return &Client{
hub: hub,
conn: conn,
send: make(chan *Message, 256),
conversationID: conversationID,
isVisitor: isVisitor,
agentID: agentID,
}
}
// ReadPump 从 WebSocket 连接读取消息
func (c *Client) ReadPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
// 设置读取限制和超时
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
// 持续读取消息
for {
_, payload, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("⚠️ WebSocket 读取错误: 对话ID=%d, 错误=%v", c.conversationID, err)
}
break
}
c.handleIncoming(payload)
}
}
// WritePump 向 WebSocket 连接写入消息
func (c *Client) WritePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// Hub 关闭了通道
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
// 发送消息
if err := c.conn.WriteJSON(message); err != nil {
log.Printf("❌ WebSocket 写入错误: 对话ID=%d, 类型=%s, 错误=%v",
c.conversationID, message.Type, err)
return
}
case <-ticker.C:
// 定期发送 ping 保持连接
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.Printf("❌ 发送 ping 失败: 对话ID=%d, 错误=%v", c.conversationID, err)
return
}
}
}
}
// SendMessage 发送消息给客户端(用于测试)
func (c *Client) SendMessage(messageType string, data interface{}) error {
message := &Message{
ConversationID: c.conversationID,
Type: messageType,
Data: data,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return err
}
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
return c.conn.WriteMessage(websocket.TextMessage, messageJSON)
}
func (c *Client) handleIncoming(payload []byte) {
var in inboundWSMessage
if err := json.Unmarshal(payload, &in); err != nil {
return
}
switch in.Type {
case "typing_draft":
text, _ := in.Data["text"].(string)
text = strings.TrimSpace(text)
if text == "" {
c.hub.BroadcastMessage(c.conversationID, "typing_stop", map[string]interface{}{
"sender_id": c.agentID,
"sender_is_agent": !c.isVisitor,
})
return
}
// 控制草稿长度,避免超长输入导致 WS 事件过大。
if len(text) > 300 {
text = text[:300]
}
out := map[string]interface{}{
"sender_id": c.agentID,
"sender_is_agent": !c.isVisitor,
"text": text,
}
if seq, ok := in.Data["seq"]; ok {
out["seq"] = seq
}
c.hub.BroadcastMessage(c.conversationID, "typing_draft", out)
case "typing_stop":
c.hub.BroadcastMessage(c.conversationID, "typing_stop", map[string]interface{}{
"sender_id": c.agentID,
"sender_is_agent": !c.isVisitor,
})
default:
// 忽略未知客户端事件,避免污染服务端日志。
}
}