351 lines
7.7 KiB
Go
351 lines
7.7 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"math/rand"
|
|
"sync"
|
|
"time"
|
|
|
|
"proxy-platform/internal/models"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
ErrNoAvailableNode = errors.New("没有可用的节点")
|
|
)
|
|
|
|
// Strategy 负载均衡策略
|
|
type Strategy string
|
|
|
|
const (
|
|
StrategyLeastLatency Strategy = "least_latency"
|
|
StrategyLeastConnections Strategy = "least_connections"
|
|
StrategyWeightedRoundRobin Strategy = "weighted_round_robin"
|
|
StrategyRandom Strategy = "random"
|
|
)
|
|
|
|
// Selector 节点选择器
|
|
type Selector struct {
|
|
repo NodeRepository
|
|
cache NodeCache
|
|
logger *zap.Logger
|
|
strategy Strategy
|
|
randPool *sync.Pool
|
|
mu sync.RWMutex
|
|
rrIndex int
|
|
}
|
|
|
|
// NodeRepository 节点数据访问接口
|
|
type NodeRepository interface {
|
|
ListOnline() ([]models.Node, error)
|
|
FindByNodeID(nodeID string) (*models.Node, error)
|
|
UpdateConnections(nodeID string, connections int) error
|
|
}
|
|
|
|
// NodeCache 节点缓存接口
|
|
type NodeCache interface {
|
|
GetStats(nodeID string) (*models.NodeStats, bool)
|
|
SetStats(nodeID string, stats *models.NodeStats)
|
|
GetAllStats() map[string]*models.NodeStats
|
|
}
|
|
|
|
// NewSelector 创建节点选择器
|
|
func NewSelector(repo NodeRepository, cache NodeCache, strategy Strategy, logger *zap.Logger) *Selector {
|
|
return &Selector{
|
|
repo: repo,
|
|
cache: cache,
|
|
logger: logger,
|
|
strategy: strategy,
|
|
randPool: &sync.Pool{
|
|
New: func() interface{} {
|
|
return rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// Select 选择最优节点
|
|
func (s *Selector) Select(ctx context.Context, targetHost string, targetPort int, requiredServices []string) (*models.Node, error) {
|
|
// 1. 获取在线节点列表
|
|
nodes, err := s.repo.ListOnline()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(nodes) == 0 {
|
|
return nil, ErrNoAvailableNode
|
|
}
|
|
|
|
// 2. 过滤满足解锁需求的节点
|
|
if len(requiredServices) > 0 {
|
|
nodes = s.filterByUnlockStatus(nodes, requiredServices)
|
|
if len(nodes) == 0 {
|
|
return nil, ErrNoAvailableNode
|
|
}
|
|
}
|
|
|
|
// 3. 过滤未达到连接上限的节点
|
|
nodes = s.filterByConnectionLimit(nodes)
|
|
if len(nodes) == 0 {
|
|
return nil, ErrNoAvailableNode
|
|
}
|
|
|
|
// 4. 根据策略选择节点
|
|
var selected *models.Node
|
|
switch s.strategy {
|
|
case StrategyLeastLatency:
|
|
selected = s.selectLeastLatency(nodes)
|
|
case StrategyLeastConnections:
|
|
selected = s.selectLeastConnections(nodes)
|
|
case StrategyWeightedRoundRobin:
|
|
selected = s.selectWeightedRoundRobin(nodes)
|
|
case StrategyRandom:
|
|
selected = s.selectRandom(nodes)
|
|
default:
|
|
selected = s.selectLeastLatency(nodes)
|
|
}
|
|
|
|
return selected, nil
|
|
}
|
|
|
|
// filterByUnlockStatus 过滤满足解锁需求的节点
|
|
func (s *Selector) filterByUnlockStatus(nodes []models.Node, services []string) []models.Node {
|
|
var result []models.Node
|
|
|
|
for _, node := range nodes {
|
|
// 构建节点的解锁服务映射
|
|
unlockMap := make(map[string]bool)
|
|
for _, status := range node.UnlockStatuses {
|
|
unlockMap[status.Service] = status.Unlocked
|
|
}
|
|
|
|
// 检查是否满足所有需求
|
|
allMatched := true
|
|
for _, service := range services {
|
|
if !unlockMap[service] {
|
|
allMatched = false
|
|
break
|
|
}
|
|
}
|
|
|
|
if allMatched {
|
|
result = append(result, node)
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// filterByConnectionLimit 过滤未达到连接上限的节点
|
|
func (s *Selector) filterByConnectionLimit(nodes []models.Node) []models.Node {
|
|
var result []models.Node
|
|
|
|
for _, node := range nodes {
|
|
stats, ok := s.cache.GetStats(node.NodeID)
|
|
if !ok {
|
|
// 没有缓存数据,使用数据库中的连接数
|
|
if node.CurrentConnections < node.MaxConnections {
|
|
result = append(result, node)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if stats.Connections < node.MaxConnections {
|
|
result = append(result, node)
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// selectLeastLatency 选择延迟最低的节点
|
|
func (s *Selector) selectLeastLatency(nodes []models.Node) *models.Node {
|
|
var selected *models.Node
|
|
minLatency := time.Duration(1<<63 - 1)
|
|
|
|
for i := range nodes {
|
|
stats, ok := s.cache.GetStats(nodes[i].NodeID)
|
|
if ok {
|
|
latency := time.Duration(stats.CPUUsage * 100) // 简化:用 CPU 使用率模拟延迟
|
|
if latency < minLatency {
|
|
minLatency = latency
|
|
selected = &nodes[i]
|
|
}
|
|
} else {
|
|
// 没有缓存数据,使用权重作为候选
|
|
if selected == nil || nodes[i].Weight > selected.Weight {
|
|
selected = &nodes[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
return selected
|
|
}
|
|
|
|
// selectLeastConnections 选择连接数最少的节点
|
|
func (s *Selector) selectLeastConnections(nodes []models.Node) *models.Node {
|
|
var selected *models.Node
|
|
minConnections := int(^uint(0) >> 1) // Max int
|
|
|
|
for i := range nodes {
|
|
stats, ok := s.cache.GetStats(nodes[i].NodeID)
|
|
connCount := nodes[i].CurrentConnections
|
|
if ok {
|
|
connCount = stats.Connections
|
|
}
|
|
|
|
if connCount < minConnections {
|
|
minConnections = connCount
|
|
selected = &nodes[i]
|
|
}
|
|
}
|
|
|
|
return selected
|
|
}
|
|
|
|
// selectWeightedRoundRobin 加权轮询选择
|
|
func (s *Selector) selectWeightedRoundRobin(nodes []models.Node) *models.Node {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// 计算总权重
|
|
totalWeight := 0
|
|
for _, node := range nodes {
|
|
totalWeight += node.Weight
|
|
}
|
|
|
|
if totalWeight == 0 {
|
|
return &nodes[0]
|
|
}
|
|
|
|
// 轮询选择
|
|
s.rrIndex = (s.rrIndex + 1) % totalWeight
|
|
|
|
currentWeight := 0
|
|
for i := range nodes {
|
|
currentWeight += nodes[i].Weight
|
|
if s.rrIndex < currentWeight {
|
|
return &nodes[i]
|
|
}
|
|
}
|
|
|
|
return &nodes[0]
|
|
}
|
|
|
|
// selectRandom 随机选择
|
|
func (s *Selector) selectRandom(nodes []models.Node) *models.Node {
|
|
r := s.randPool.Get().(*rand.Rand)
|
|
defer s.randPool.Put(r)
|
|
|
|
idx := r.Intn(len(nodes))
|
|
return &nodes[idx]
|
|
}
|
|
|
|
// HealthChecker 健康检查器
|
|
type HealthChecker struct {
|
|
repo NodeRepository
|
|
cache NodeCache
|
|
logger *zap.Logger
|
|
}
|
|
|
|
func NewHealthChecker(repo NodeRepository, cache NodeCache, logger *zap.Logger) *HealthChecker {
|
|
return &HealthChecker{
|
|
repo: repo,
|
|
cache: cache,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Check 检查节点健康状态
|
|
func (h *HealthChecker) Check(ctx context.Context, node *models.Node) error {
|
|
stats, ok := h.cache.GetStats(node.NodeID)
|
|
if !ok {
|
|
return errors.New("节点未上报状态")
|
|
}
|
|
|
|
// 检查是否超时
|
|
if time.Since(stats.LastUpdate) > 30*time.Second {
|
|
return errors.New("节点心跳超时")
|
|
}
|
|
|
|
// 检查 WARP 状态
|
|
if stats.Connections < 0 {
|
|
return errors.New("节点状态异常")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RuleEngine 规则引擎
|
|
type RuleEngine struct {
|
|
rules []IPRefreshRule
|
|
actions map[string]ActionFunc
|
|
mu sync.RWMutex
|
|
logger *zap.Logger
|
|
}
|
|
|
|
type IPRefreshRule struct {
|
|
ID uint
|
|
NodeGroupID *uint
|
|
TriggerType string // unlock_failure, usage_count, usage_traffic, scheduled, anomaly
|
|
TriggerValue map[string]interface{}
|
|
Cooldown time.Duration
|
|
Enabled bool
|
|
}
|
|
|
|
type ActionFunc func(ctx context.Context, node *models.Node, reason string) error
|
|
|
|
func NewRuleEngine(logger *zap.Logger) *RuleEngine {
|
|
return &RuleEngine{
|
|
rules: make([]IPRefreshRule, 0),
|
|
actions: make(map[string]ActionFunc),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// RegisterAction 注册动作
|
|
func (e *RuleEngine) RegisterAction(name string, action ActionFunc) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
e.actions[name] = action
|
|
}
|
|
|
|
// AddRule 添加规则
|
|
func (e *RuleEngine) AddRule(rule IPRefreshRule) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
e.rules = append(e.rules, rule)
|
|
}
|
|
|
|
// Evaluate 评估规则
|
|
func (e *RuleEngine) Evaluate(ctx context.Context, node *models.Node, event string, data map[string]interface{}) error {
|
|
e.mu.RLock()
|
|
defer e.mu.RUnlock()
|
|
|
|
for _, rule := range e.rules {
|
|
if !rule.Enabled {
|
|
continue
|
|
}
|
|
|
|
if rule.TriggerType != event {
|
|
continue
|
|
}
|
|
|
|
// 触发规则
|
|
e.logger.Info("规则触发",
|
|
zap.Uint("rule_id", rule.ID),
|
|
zap.String("trigger", event),
|
|
zap.String("node_id", node.NodeID),
|
|
)
|
|
|
|
// 执行动作
|
|
if action, ok := e.actions["refresh_ip"]; ok {
|
|
return action(ctx, node, event)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|