Files
proxy-platform/internal/socks5/socks5.go
T

445 lines
10 KiB
Go

package socks5
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"go.uber.org/zap"
)
var (
ErrUnsupportedVersion = errors.New("unsupported SOCKS version")
ErrUnsupportedMethod = errors.New("unsupported authentication method")
ErrAuthenticationFailed = errors.New("authentication failed")
ErrUnsupportedCommand = errors.New("unsupported command")
ErrUnsupportedAddrType = errors.New("unsupported address type")
)
const (
SOCKS5Version = 0x05
NoAuth = 0x00
UserPassAuth = 0x02
NoAcceptable = 0xFF
ConnectCommand = 0x01
BindCommand = 0x02
AssociateCommand = 0x03
IPv4Address = 0x01
FQDNAddress = 0x03
IPv6Address = 0x04
)
// Authenticator 认证接口
type Authenticator interface {
Authenticate(username, password string) (uint, bool)
}
// BackendSelector 后端节点选择器
type BackendSelector interface {
SelectBackend(ctx context.Context, targetHost string, targetPort int, services []string) (string, int, error)
ReleaseBackend(host string, port int, bytesIn, bytesOut int64)
}
// Server SOCKS5 服务器
type Server struct {
host string
port int
maxConnections int
timeout time.Duration
auth Authenticator
selector BackendSelector
logger *zap.Logger
connections int64
connMutex sync.Mutex
connSem chan struct{}
}
// NewServer 创建 SOCKS5 服务器
func NewServer(host string, port int, maxConnections int, timeout int, auth Authenticator, selector BackendSelector, logger *zap.Logger) *Server {
return &Server{
host: host,
port: port,
maxConnections: maxConnections,
timeout: time.Duration(timeout) * time.Second,
auth: auth,
selector: selector,
logger: logger,
connSem: make(chan struct{}, maxConnections),
}
}
// Start 启动服务器
func (s *Server) Start(ctx context.Context) error {
addr := fmt.Sprintf("%s:%d", s.host, s.port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("监听失败: %w", err)
}
defer listener.Close()
s.logger.Info("SOCKS5 服务器启动", zap.String("addr", addr), zap.Int("max_connections", s.maxConnections))
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
conn, err := listener.Accept()
if err != nil {
s.logger.Error("接受连接失败", zap.Error(err))
continue
}
// 连接数限制
select {
case s.connSem <- struct{}{}:
go s.handleConnection(ctx, conn)
default:
s.logger.Warn("达到最大连接数限制", zap.Int64("connections", s.connections))
conn.Close()
}
}
}
}
// handleConnection 处理单个连接
func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) {
defer func() {
clientConn.Close()
<-s.connSem
s.connMutex.Lock()
s.connections--
s.connMutex.Unlock()
}()
s.connMutex.Lock()
s.connections++
s.connMutex.Unlock()
// 设置超时
clientConn.SetDeadline(time.Now().Add(s.timeout))
// 1. 协议握手
username, userID, err := s.handshake(clientConn)
if err != nil {
s.logger.Error("握手失败", zap.Error(err), zap.String("client", clientConn.RemoteAddr().String()))
return
}
// 2. 读取请求
targetHost, targetPort, err := s.readRequest(clientConn)
if err != nil {
s.logger.Error("读取请求失败", zap.Error(err))
return
}
s.logger.Info("连接请求",
zap.String("username", username),
zap.Uint("user_id", userID),
zap.String("target", fmt.Sprintf("%s:%d", targetHost, targetPort)),
)
// 3. 选择后端节点
backendHost, backendPort, err := s.selector.SelectBackend(ctx, targetHost, targetPort, nil)
if err != nil {
s.logger.Error("选择节点失败", zap.Error(err))
s.sendReply(clientConn, 0x04, net.IPv4zero, 0) // Host unreachable
return
}
// 4. 连接后端
backendAddr := fmt.Sprintf("%s:%d", backendHost, backendPort)
backendConn, err := net.DialTimeout("tcp", backendAddr, s.timeout)
if err != nil {
s.logger.Error("连接后端失败", zap.Error(err), zap.String("backend", backendAddr))
s.sendReply(clientConn, 0x04, net.IPv4zero, 0)
return
}
defer backendConn.Close()
// 5. 发送成功响应
s.sendReply(clientConn, 0x00, net.IPv4zero, 0)
// 6. 数据转发
bytesIn, bytesOut := s.relay(clientConn, backendConn)
// 7. 释放后端节点
s.selector.ReleaseBackend(backendHost, backendPort, bytesIn, bytesOut)
s.logger.Info("连接结束",
zap.String("username", username),
zap.Int64("bytes_in", bytesIn),
zap.Int64("bytes_out", bytesOut),
)
}
// handshake 协议握手
func (s *Server) handshake(conn net.Conn) (string, uint, error) {
// 读取客户端 hello
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
}
if buf[0] != SOCKS5Version {
return "", 0, ErrUnsupportedVersion
}
nMethods := int(buf[1])
methods := make([]byte, nMethods)
if _, err := io.ReadFull(conn, methods); err != nil {
return "", 0, err
}
// 检查是否支持用户密码认证
supportUserPass := false
for _, m := range methods {
if m == UserPassAuth {
supportUserPass = true
break
}
}
// 发送选择的认证方法
if supportUserPass {
conn.Write([]byte{SOCKS5Version, UserPassAuth})
} else {
conn.Write([]byte{SOCKS5Version, NoAuth})
return "", 0, nil // 无认证
}
// 用户密码认证
authBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, authBuf); err != nil {
return "", 0, err
}
ulen := int(authBuf[1])
usernameBuf := make([]byte, ulen)
if _, err := io.ReadFull(conn, usernameBuf); err != nil {
return "", 0, err
}
plen := int(authBuf[2])
passwordBuf := make([]byte, plen)
if _, err := io.ReadFull(conn, passwordBuf); err != nil {
return "", 0, err
}
username := string(usernameBuf)
password := string(passwordBuf)
// 认证
userID, ok := s.auth.Authenticate(username, password)
if !ok {
conn.Write([]byte{0x01, 0x01}) // 认证失败
return "", 0, ErrAuthenticationFailed
}
conn.Write([]byte{0x01, 0x00}) // 认证成功
return username, userID, nil
}
// readRequest 读取请求
func (s *Server) readRequest(conn net.Conn) (string, int, error) {
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
}
if buf[0] != SOCKS5Version {
return "", 0, ErrUnsupportedVersion
}
if buf[1] != ConnectCommand {
return "", 0, ErrUnsupportedCommand
}
// 读取目标地址
var host string
var port int
switch buf[3] {
case IPv4Address:
addr := make([]byte, 4)
if _, err := io.ReadFull(conn, addr); err != nil {
return "", 0, err
}
host = net.IP(addr).String()
case FQDNAddress:
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", 0, err
}
fqdn := make([]byte, lenBuf[0])
if _, err := io.ReadFull(conn, fqdn); err != nil {
return "", 0, err
}
host = string(fqdn)
case IPv6Address:
addr := make([]byte, 16)
if _, err := io.ReadFull(conn, addr); err != nil {
return "", 0, err
}
host = net.IP(addr).String()
default:
return "", 0, ErrUnsupportedAddrType
}
// 读取端口
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, err
}
port = int(binary.BigEndian.Uint16(portBuf))
return host, port, nil
}
// sendReply 发送响应
func (s *Server) sendReply(conn net.Conn, status byte, ip net.IP, port int) {
reply := []byte{
SOCKS5Version,
status,
0x00, // RSV
IPv4Address,
}
reply = append(reply, ip.To4()...)
reply = append(reply, []byte{byte(port >> 8), byte(port)}...)
conn.Write(reply)
}
// relay 数据转发
func (s *Server) relay(client, backend net.Conn) (int64, int64) {
var bytesIn, bytesOut int64
var wg sync.WaitGroup
wg.Add(2)
// 客户端 -> 后端
go func() {
defer wg.Done()
n, _ := io.Copy(backend, client)
bytesOut = n
}()
// 后端 -> 客户端
go func() {
defer wg.Done()
n, _ := io.Copy(client, backend)
bytesIn = n
}()
wg.Wait()
return bytesIn, bytesOut
}
// GetConnections 获取当前连接数
func (s *Server) GetConnections() int64 {
s.connMutex.Lock()
defer s.connMutex.Unlock()
return s.connections
}
// Dialer SOCKS5 客户端连接器
type Dialer struct {
Host string
Port int
Username string
Password string
Timeout time.Duration
}
// NewDialer 创建连接器
func NewDialer(host string, port int, username, password string, timeout time.Duration) *Dialer {
return &Dialer{
Host: host,
Port: port,
Username: username,
Password: password,
Timeout: timeout,
}
}
// Dial 通过 SOCKS5 代理连接目标
func (d *Dialer) Dial(targetHost string, targetPort int) (net.Conn, error) {
proxyAddr := net.JoinHostPort(d.Host, strconv.Itoa(d.Port))
conn, err := net.DialTimeout("tcp", proxyAddr, d.Timeout)
if err != nil {
return nil, err
}
// 发送 hello
hello := []byte{SOCKS5Version, 2, NoAuth, UserPassAuth}
if _, err := conn.Write(hello); err != nil {
conn.Close()
return nil, err
}
// 读取响应
resp := make([]byte, 2)
if _, err := io.ReadFull(conn, resp); err != nil {
conn.Close()
return nil, err
}
if resp[0] != SOCKS5Version {
conn.Close()
return nil, ErrUnsupportedVersion
}
// 认证
if resp[1] == UserPassAuth {
auth := []byte{0x01, byte(len(d.Username))}
auth = append(auth, []byte(d.Username)...)
auth = append(auth, byte(len(d.Password)))
auth = append(auth, []byte(d.Password)...)
if _, err := conn.Write(auth); err != nil {
conn.Close()
return nil, err
}
authResp := make([]byte, 2)
if _, err := io.ReadFull(conn, authResp); err != nil {
conn.Close()
return nil, err
}
if authResp[1] != 0x00 {
conn.Close()
return nil, ErrAuthenticationFailed
}
}
// 发送连接请求
req := []byte{SOCKS5Version, ConnectCommand, 0x00, FQDNAddress, byte(len(targetHost))}
req = append(req, []byte(targetHost)...)
req = append(req, []byte{byte(targetPort >> 8), byte(targetPort)}...)
if _, err := conn.Write(req); err != nil {
conn.Close()
return nil, err
}
// 读取响应
reply := make([]byte, 10)
if _, err := io.ReadFull(conn, reply); err != nil {
conn.Close()
return nil, err
}
if reply[1] != 0x00 {
conn.Close()
return nil, fmt.Errorf("SOCKS5 连接失败: status %d", reply[1])
}
return conn, nil
}