Files
AI-CS/backend/websocket/handler.go
T
2026-04-01 13:40:42 +08:00

97 lines
2.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package websocket
import (
"log"
"net/http"
"strconv"
"github.com/2930134478/AI-CS/backend/repository"
"github.com/2930134478/AI-CS/backend/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许跨域连接
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// HandleWebSocket 处理 WebSocket 连接
func HandleWebSocket(hub *Hub, userRepo *repository.UserRepository) gin.HandlerFunc {
return func(c *gin.Context) {
// 从查询参数获取对话ID
conversationIDStr := c.Query("conversation_id")
if conversationIDStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversation_id 不能为空"})
return
}
conversationID, err := strconv.ParseUint(conversationIDStr, 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 conversation_id"})
return
}
// 从查询参数获取是否是访客(默认为 true,因为默认是访客连接)
isVisitorStr := c.DefaultQuery("is_visitor", "true")
isVisitor := isVisitorStr == "true" || isVisitorStr == "1"
// 从查询参数获取客服ID(如果是客服连接,需要传递 agent_id)
var agentID uint
if !isVisitor {
agentIDStr := c.Query("agent_id")
if agentIDStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "agent_id 不能为空"})
return
}
parsed, parseErr := strconv.ParseUint(agentIDStr, 10, 32)
if parseErr != nil || parsed == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 agent_id"})
return
}
agentID = uint(parsed)
wsToken := c.Query("ws_token")
if !utils.ValidateWSToken(wsToken, agentID) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "ws_token 无效或已过期"})
return
}
if userRepo != nil {
user, userErr := userRepo.GetByID(agentID)
if userErr != nil || user == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "客服身份无效"})
return
}
if user.Role != "admin" && user.Role != "agent" {
c.JSON(http.StatusForbidden, gin.H{"error": "仅客服账号允许建立该连接"})
return
}
}
}
// 升级 HTTP 连接为 WebSocket 连接
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("WebSocket 升级失败: %v", err)
return
}
// 创建客户端
client := NewClient(hub, conn, uint(conversationID), isVisitor, agentID)
// 注册客户端到 Hub
client.hub.register <- client
// 启动两个 goroutine
// 1. ReadPump:从客户端读取消息(主要是心跳包)
// 2. WritePump:向客户端发送消息
go client.WritePump()
go client.ReadPump()
log.Printf("✅ WebSocket 连接已建立: 对话ID=%d, 是访客=%v", conversationID, isVisitor)
}
}