Files
proxy-platform/cmd/scheduler/main.go
T

345 lines
8.5 KiB
Go

package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"proxy-platform/internal/config"
"proxy-platform/internal/handler"
"proxy-platform/internal/models"
"proxy-platform/internal/repository"
"proxy-platform/internal/socks5"
"proxy-platform/internal/scheduler"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
func main() {
// 加载配置
cfg, err := config.Load("configs/scheduler.yaml")
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
// 初始化日志
logger, err := zap.NewProduction()
if err != nil {
log.Fatalf("初始化日志失败: %v", err)
}
defer logger.Sync()
// 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSN()), &gorm.Config{})
if err != nil {
logger.Fatal("连接数据库失败", zap.Error(err))
}
// 自动迁移
if err := db.AutoMigrate(
&models.User{},
&models.Node{},
&models.NodeGroup{},
&models.UnlockStatus{},
&models.IPChangeLog{},
&models.ConnectionLog{},
&models.IPRefreshRule{},
); err != nil {
logger.Fatal("数据库迁移失败", zap.Error(err))
}
logger.Info("数据库迁移完成")
// 初始化仓库
repos := repository.NewRepositories(db)
// 初始化认证器
auth := NewSimpleAuthenticator(repos.User)
// 初始化节点缓存
cache := NewMemoryNodeCache()
// 初始化节点选择器
selector := scheduler.NewSelector(repos.Node, cache, scheduler.StrategyLeastLatency, logger)
// 初始化后端选择器
backendSelector := NewBackendSelector(repos.Node, selector, logger)
// 启动 SOCKS5 服务器
socks5Server := socks5.NewServer(
cfg.SOCKS5.Host,
cfg.SOCKS5.Port,
cfg.SOCKS5.MaxConnections,
cfg.SOCKS5.Timeout,
auth,
backendSelector,
logger,
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动 SOCKS5 服务器
go func() {
if err := socks5Server.Start(ctx); err != nil {
logger.Error("SOCKS5 服务器错误", zap.Error(err))
}
}()
// 启动健康检查
healthChecker := scheduler.NewHealthChecker(repos.Node, cache, logger)
go startHealthCheck(ctx, repos.Node, healthChecker, logger, time.Duration(cfg.Scheduler.HealthCheckInterval)*time.Second)
// 启动 API 服务器
apiServer := NewAPIServer(cfg, repos, logger)
go func() {
if err := apiServer.Start(); err != nil && err != http.ErrServerClosed {
logger.Error("API 服务器错误", zap.Error(err))
}
}()
logger.Info("调度中心启动完成",
zap.String("api", fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)),
zap.String("socks5", fmt.Sprintf("%s:%d", cfg.SOCKS5.Host, cfg.SOCKS5.Port)),
)
// 等待中断信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info("正在关闭服务器...")
cancel()
// 关闭 API 服务器
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
apiServer.Shutdown(ctx2)
logger.Info("服务器已关闭")
}
// SimpleAuthenticator 简单认证器
type SimpleAuthenticator struct {
userRepo *repository.UserRepository
}
func NewSimpleAuthenticator(userRepo *repository.UserRepository) *SimpleAuthenticator {
return &SimpleAuthenticator{userRepo: userRepo}
}
func (a *SimpleAuthenticator) Authenticate(username, password string) (uint, bool) {
user, err := a.userRepo.FindByUsername(username)
if err != nil {
return 0, false
}
// TODO: 实现密码验证
// 这里简化处理,实际应该使用 bcrypt 验证
if user.PasswordHash == password && user.Status == "active" {
return user.ID, true
}
return 0, false
}
// BackendSelector 后端节点选择器
type BackendSelector struct {
nodeRepo *repository.NodeRepository
selector *scheduler.Selector
logger *zap.Logger
}
func NewBackendSelector(nodeRepo *repository.NodeRepository, selector *scheduler.Selector, logger *zap.Logger) *BackendSelector {
return &BackendSelector{
nodeRepo: nodeRepo,
selector: selector,
logger: logger,
}
}
func (s *BackendSelector) SelectBackend(ctx context.Context, targetHost string, targetPort int, services []string) (string, int, error) {
node, err := s.selector.Select(ctx, targetHost, targetPort, services)
if err != nil {
return "", 0, err
}
return node.Host, node.Port, nil
}
func (s *BackendSelector) ReleaseBackend(host string, port int, bytesIn, bytesOut int64) {
// TODO: 更新节点统计信息
s.logger.Info("释放后端节点",
zap.String("host", host),
zap.Int("port", port),
zap.Int64("bytes_in", bytesIn),
zap.Int64("bytes_out", bytesOut),
)
}
// MemoryNodeCache 内存节点缓存
type MemoryNodeCache struct {
stats map[string]*models.NodeStats
mu sync.RWMutex
}
func NewMemoryNodeCache() *MemoryNodeCache {
return &MemoryNodeCache{
stats: make(map[string]*models.NodeStats),
}
}
func (c *MemoryNodeCache) GetStats(nodeID string) (*models.NodeStats, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
stats, ok := c.stats[nodeID]
return stats, ok
}
func (c *MemoryNodeCache) SetStats(nodeID string, stats *models.NodeStats) {
c.mu.Lock()
defer c.mu.Unlock()
c.stats[nodeID] = stats
}
func (c *MemoryNodeCache) GetAllStats() map[string]*models.NodeStats {
c.mu.RLock()
defer c.mu.RUnlock()
result := make(map[string]*models.NodeStats)
for k, v := range c.stats {
result[k] = v
}
return result
}
// startHealthCheck 启动健康检查
func startHealthCheck(ctx context.Context, nodeRepo *repository.NodeRepository, checker *scheduler.HealthChecker, logger *zap.Logger, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
nodes, err := nodeRepo.ListOnline()
if err != nil {
logger.Error("获取节点列表失败", zap.Error(err))
continue
}
for _, node := range nodes {
if err := checker.Check(ctx, &node); err != nil {
logger.Warn("节点健康检查失败",
zap.String("node_id", node.NodeID),
zap.Error(err),
)
}
}
}
}
}
// APIServer API 服务器
type APIServer struct {
cfg *config.Config
repos *repository.Repositories
logger *zap.Logger
server *http.Server
router *gin.Engine
}
func NewAPIServer(cfg *config.Config, repos *repository.Repositories, logger *zap.Logger) *APIServer {
gin.SetMode(cfg.Server.Mode)
router := gin.New()
server := &APIServer{
cfg: cfg,
repos: repos,
logger: logger,
router: router,
server: &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: router,
},
}
// 注册路由
server.setupRoutes()
return server
}
func (s *APIServer) setupRoutes() {
// 健康检查
s.router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// API 路由组
api := s.router.Group("/api/v1")
{
// 用户相关
users := api.Group("/users")
{
users.GET("", handler.ListUsers(s.repos.User))
users.POST("", handler.CreateUser(s.repos.User))
users.GET("/:id", handler.GetUser(s.repos.User))
users.PUT("/:id", handler.UpdateUser(s.repos.User))
users.DELETE("/:id", handler.DeleteUser(s.repos.User))
}
// 节点相关
nodes := api.Group("/nodes")
{
nodes.GET("", handler.ListNodes(s.repos.Node))
nodes.POST("", handler.CreateNode(s.repos.Node))
nodes.GET("/:id", handler.GetNode(s.repos.Node))
nodes.PUT("/:id", handler.UpdateNode(s.repos.Node))
nodes.DELETE("/:id", handler.DeleteNode(s.repos.Node))
nodes.POST("/:id/refresh-ip", handler.RefreshNodeIP(s.repos.Node, s.logger))
}
// Agent 相关
agent := api.Group("/agent")
{
agent.POST("/heartbeat", handler.AgentHeartbeat(s.repos.Node, s.logger))
agent.POST("/unlock/report", handler.ReportUnlockStatus(s.repos.UnlockStatus, s.logger))
agent.POST("/ip/change/result", handler.ReportIPChange(s.repos.Node, s.repos.IPChangeLog, s.logger))
}
// 规则相关
rules := api.Group("/rules")
{
rules.GET("", handler.ListRules(s.repos.IPRefreshRule))
rules.POST("", handler.CreateRule(s.repos.IPRefreshRule))
rules.GET("/:id", handler.GetRule(s.repos.IPRefreshRule))
rules.PUT("/:id", handler.UpdateRule(s.repos.IPRefreshRule))
rules.DELETE("/:id", handler.DeleteRule(s.repos.IPRefreshRule))
}
// 统计相关
stats := api.Group("/stats")
{
stats.GET("/overview", handler.GetOverview(s.repos))
stats.GET("/traffic", handler.GetTrafficStats(s.repos.ConnectionLog))
}
}
}
func (s *APIServer) Start() error {
return s.server.ListenAndServe()
}
func (s *APIServer) Shutdown(ctx context.Context) {
s.server.Shutdown(ctx)
}