Files
proxy-platform/internal/scheduler/scheduler.go
T

351 lines
7.7 KiB
Go

package scheduler
import (
"context"
"errors"
"math/rand"
"sync"
"time"
"proxy-platform/internal/models"
"go.uber.org/zap"
)
var (
ErrNoAvailableNode = errors.New("没有可用的节点")
)
// Strategy 负载均衡策略
type Strategy string
const (
StrategyLeastLatency Strategy = "least_latency"
StrategyLeastConnections Strategy = "least_connections"
StrategyWeightedRoundRobin Strategy = "weighted_round_robin"
StrategyRandom Strategy = "random"
)
// Selector 节点选择器
type Selector struct {
repo NodeRepository
cache NodeCache
logger *zap.Logger
strategy Strategy
randPool *sync.Pool
mu sync.RWMutex
rrIndex int
}
// NodeRepository 节点数据访问接口
type NodeRepository interface {
ListOnline() ([]models.Node, error)
FindByNodeID(nodeID string) (*models.Node, error)
UpdateConnections(nodeID string, connections int) error
}
// NodeCache 节点缓存接口
type NodeCache interface {
GetStats(nodeID string) (*models.NodeStats, bool)
SetStats(nodeID string, stats *models.NodeStats)
GetAllStats() map[string]*models.NodeStats
}
// NewSelector 创建节点选择器
func NewSelector(repo NodeRepository, cache NodeCache, strategy Strategy, logger *zap.Logger) *Selector {
return &Selector{
repo: repo,
cache: cache,
logger: logger,
strategy: strategy,
randPool: &sync.Pool{
New: func() interface{} {
return rand.New(rand.NewSource(time.Now().UnixNano()))
},
},
}
}
// Select 选择最优节点
func (s *Selector) Select(ctx context.Context, targetHost string, targetPort int, requiredServices []string) (*models.Node, error) {
// 1. 获取在线节点列表
nodes, err := s.repo.ListOnline()
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, ErrNoAvailableNode
}
// 2. 过滤满足解锁需求的节点
if len(requiredServices) > 0 {
nodes = s.filterByUnlockStatus(nodes, requiredServices)
if len(nodes) == 0 {
return nil, ErrNoAvailableNode
}
}
// 3. 过滤未达到连接上限的节点
nodes = s.filterByConnectionLimit(nodes)
if len(nodes) == 0 {
return nil, ErrNoAvailableNode
}
// 4. 根据策略选择节点
var selected *models.Node
switch s.strategy {
case StrategyLeastLatency:
selected = s.selectLeastLatency(nodes)
case StrategyLeastConnections:
selected = s.selectLeastConnections(nodes)
case StrategyWeightedRoundRobin:
selected = s.selectWeightedRoundRobin(nodes)
case StrategyRandom:
selected = s.selectRandom(nodes)
default:
selected = s.selectLeastLatency(nodes)
}
return selected, nil
}
// filterByUnlockStatus 过滤满足解锁需求的节点
func (s *Selector) filterByUnlockStatus(nodes []models.Node, services []string) []models.Node {
var result []models.Node
for _, node := range nodes {
// 构建节点的解锁服务映射
unlockMap := make(map[string]bool)
for _, status := range node.UnlockStatuses {
unlockMap[status.Service] = status.Unlocked
}
// 检查是否满足所有需求
allMatched := true
for _, service := range services {
if !unlockMap[service] {
allMatched = false
break
}
}
if allMatched {
result = append(result, node)
}
}
return result
}
// filterByConnectionLimit 过滤未达到连接上限的节点
func (s *Selector) filterByConnectionLimit(nodes []models.Node) []models.Node {
var result []models.Node
for _, node := range nodes {
stats, ok := s.cache.GetStats(node.NodeID)
if !ok {
// 没有缓存数据,使用数据库中的连接数
if node.CurrentConnections < node.MaxConnections {
result = append(result, node)
}
continue
}
if stats.Connections < node.MaxConnections {
result = append(result, node)
}
}
return result
}
// selectLeastLatency 选择延迟最低的节点
func (s *Selector) selectLeastLatency(nodes []models.Node) *models.Node {
var selected *models.Node
minLatency := time.Duration(1<<63 - 1)
for i := range nodes {
stats, ok := s.cache.GetStats(nodes[i].NodeID)
if ok {
latency := time.Duration(stats.CPUUsage * 100) // 简化:用 CPU 使用率模拟延迟
if latency < minLatency {
minLatency = latency
selected = &nodes[i]
}
} else {
// 没有缓存数据,使用权重作为候选
if selected == nil || nodes[i].Weight > selected.Weight {
selected = &nodes[i]
}
}
}
return selected
}
// selectLeastConnections 选择连接数最少的节点
func (s *Selector) selectLeastConnections(nodes []models.Node) *models.Node {
var selected *models.Node
minConnections := int(^uint(0) >> 1) // Max int
for i := range nodes {
stats, ok := s.cache.GetStats(nodes[i].NodeID)
connCount := nodes[i].CurrentConnections
if ok {
connCount = stats.Connections
}
if connCount < minConnections {
minConnections = connCount
selected = &nodes[i]
}
}
return selected
}
// selectWeightedRoundRobin 加权轮询选择
func (s *Selector) selectWeightedRoundRobin(nodes []models.Node) *models.Node {
s.mu.Lock()
defer s.mu.Unlock()
// 计算总权重
totalWeight := 0
for _, node := range nodes {
totalWeight += node.Weight
}
if totalWeight == 0 {
return &nodes[0]
}
// 轮询选择
s.rrIndex = (s.rrIndex + 1) % totalWeight
currentWeight := 0
for i := range nodes {
currentWeight += nodes[i].Weight
if s.rrIndex < currentWeight {
return &nodes[i]
}
}
return &nodes[0]
}
// selectRandom 随机选择
func (s *Selector) selectRandom(nodes []models.Node) *models.Node {
r := s.randPool.Get().(*rand.Rand)
defer s.randPool.Put(r)
idx := r.Intn(len(nodes))
return &nodes[idx]
}
// HealthChecker 健康检查器
type HealthChecker struct {
repo NodeRepository
cache NodeCache
logger *zap.Logger
}
func NewHealthChecker(repo NodeRepository, cache NodeCache, logger *zap.Logger) *HealthChecker {
return &HealthChecker{
repo: repo,
cache: cache,
logger: logger,
}
}
// Check 检查节点健康状态
func (h *HealthChecker) Check(ctx context.Context, node *models.Node) error {
stats, ok := h.cache.GetStats(node.NodeID)
if !ok {
return errors.New("节点未上报状态")
}
// 检查是否超时
if time.Since(stats.LastUpdate) > 30*time.Second {
return errors.New("节点心跳超时")
}
// 检查 WARP 状态
if stats.Connections < 0 {
return errors.New("节点状态异常")
}
return nil
}
// RuleEngine 规则引擎
type RuleEngine struct {
rules []IPRefreshRule
actions map[string]ActionFunc
mu sync.RWMutex
logger *zap.Logger
}
type IPRefreshRule struct {
ID uint
NodeGroupID *uint
TriggerType string // unlock_failure, usage_count, usage_traffic, scheduled, anomaly
TriggerValue map[string]interface{}
Cooldown time.Duration
Enabled bool
}
type ActionFunc func(ctx context.Context, node *models.Node, reason string) error
func NewRuleEngine(logger *zap.Logger) *RuleEngine {
return &RuleEngine{
rules: make([]IPRefreshRule, 0),
actions: make(map[string]ActionFunc),
logger: logger,
}
}
// RegisterAction 注册动作
func (e *RuleEngine) RegisterAction(name string, action ActionFunc) {
e.mu.Lock()
defer e.mu.Unlock()
e.actions[name] = action
}
// AddRule 添加规则
func (e *RuleEngine) AddRule(rule IPRefreshRule) {
e.mu.Lock()
defer e.mu.Unlock()
e.rules = append(e.rules, rule)
}
// Evaluate 评估规则
func (e *RuleEngine) Evaluate(ctx context.Context, node *models.Node, event string, data map[string]interface{}) error {
e.mu.RLock()
defer e.mu.RUnlock()
for _, rule := range e.rules {
if !rule.Enabled {
continue
}
if rule.TriggerType != event {
continue
}
// 触发规则
e.logger.Info("规则触发",
zap.Uint("rule_id", rule.ID),
zap.String("trigger", event),
zap.String("node_id", node.NodeID),
)
// 执行动作
if action, ok := e.actions["refresh_ip"]; ok {
return action(ctx, node, event)
}
}
return nil
}