package main import ( "context" "fmt" "log" "net/http" "os" "os/signal" "sync" "syscall" "time" "proxy-platform/internal/config" "proxy-platform/internal/handler" "proxy-platform/internal/models" "proxy-platform/internal/repository" "proxy-platform/internal/socks5" "proxy-platform/internal/scheduler" "github.com/gin-gonic/gin" "go.uber.org/zap" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" ) func main() { // 加载配置 cfg, err := config.Load("configs/scheduler.yaml") if err != nil { log.Fatalf("加载配置失败: %v", err) } // 初始化日志 logger, err := zap.NewProduction() if err != nil { log.Fatalf("初始化日志失败: %v", err) } defer logger.Sync() // 检查是否已安装 if !cfg.Install.Installed { logger.Info("系统未安装,启动安装模式") startInstallMode(cfg, logger) return } // 连接数据库 var db *gorm.DB if cfg.Database.Type == "sqlite" { // SQLite 连接 db, err = gorm.Open(sqlite.Open(cfg.Database.DSN()), &gorm.Config{}) } else { // PostgreSQL 连接 db, err = gorm.Open(postgres.Open(cfg.Database.DSN()), &gorm.Config{}) } if err != nil { logger.Fatal("连接数据库失败", zap.Error(err)) } // 自动迁移 if err := db.AutoMigrate( &models.User{}, &models.Node{}, &models.NodeGroup{}, &models.UnlockStatus{}, &models.IPChangeLog{}, &models.ConnectionLog{}, &models.IPRefreshRule{}, &models.InstallStatus{}, ); err != nil { logger.Fatal("数据库迁移失败", zap.Error(err)) } logger.Info("数据库迁移完成") // 初始化仓库 repos := repository.NewRepositories(db) // 初始化认证器 auth := NewSimpleAuthenticator(repos.User) // 初始化节点缓存 cache := NewMemoryNodeCache() // 初始化节点选择器 selector := scheduler.NewSelector(repos.Node, cache, scheduler.StrategyLeastLatency, logger) // 初始化后端选择器 backendSelector := NewBackendSelector(repos.Node, selector, logger) // 启动 SOCKS5 服务器 socks5Server := socks5.NewServer( cfg.SOCKS5.Host, cfg.SOCKS5.Port, cfg.SOCKS5.MaxConnections, cfg.SOCKS5.Timeout, auth, backendSelector, logger, ) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 启动 SOCKS5 服务器 go func() { if err := socks5Server.Start(ctx); err != nil { logger.Error("SOCKS5 服务器错误", zap.Error(err)) } }() // 启动健康检查 healthChecker := scheduler.NewHealthChecker(repos.Node, cache, logger) go startHealthCheck(ctx, repos.Node, healthChecker, logger, time.Duration(cfg.Scheduler.HealthCheckInterval)*time.Second) // 启动 API 服务器 apiServer := NewAPIServer(cfg, repos, db, logger) go func() { if err := apiServer.Start(); err != nil && err != http.ErrServerClosed { logger.Error("API 服务器错误", zap.Error(err)) } }() logger.Info("调度中心启动完成", zap.String("api", fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)), zap.String("socks5", fmt.Sprintf("%s:%d", cfg.SOCKS5.Host, cfg.SOCKS5.Port)), ) // 等待中断信号 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit logger.Info("正在关闭服务器...") cancel() // 关闭 API 服务器 ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) defer cancel2() apiServer.Shutdown(ctx2) logger.Info("服务器已关闭") } // startInstallMode 启动安装模式 func startInstallMode(cfg *config.Config, logger *zap.Logger) { gin.SetMode("release") router := gin.New() router.Use(gin.Recovery()) // 安装相关路由 api := router.Group("/api/v1") { api.GET("/install/check", handler.CheckInstall()) api.POST("/install", handler.DoInstall(logger)) api.GET("/install/status", handler.GetInstallStatus(nil)) } // 静态文件服务(前端) router.Static("/assets", "web/dist/assets") router.StaticFile("/", "web/dist/index.html") router.StaticFile("/favicon.ico", "web/dist/favicon.ico") // 前端路由回退 router.NoRoute(func(c *gin.Context) { c.File("web/dist/index.html") }) server := &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), Handler: router, } logger.Info("安装模式启动", zap.String("addr", server.Addr)) // 等待中断信号 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) go func() { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Error("服务器错误", zap.Error(err)) } }() <-quit logger.Info("安装模式关闭") } // SimpleAuthenticator 简单认证器 type SimpleAuthenticator struct { userRepo *repository.UserRepository } func NewSimpleAuthenticator(userRepo *repository.UserRepository) *SimpleAuthenticator { return &SimpleAuthenticator{userRepo: userRepo} } func (a *SimpleAuthenticator) Authenticate(username, password string) (uint, bool) { user, err := a.userRepo.FindByUsername(username) if err != nil { return 0, false } // TODO: 实现密码验证 // 这里简化处理,实际应该使用 bcrypt 验证 if user.PasswordHash == password && user.Status == "active" { return user.ID, true } return 0, false } // BackendSelector 后端节点选择器 type BackendSelector struct { nodeRepo *repository.NodeRepository selector *scheduler.Selector logger *zap.Logger } func NewBackendSelector(nodeRepo *repository.NodeRepository, selector *scheduler.Selector, logger *zap.Logger) *BackendSelector { return &BackendSelector{ nodeRepo: nodeRepo, selector: selector, logger: logger, } } func (s *BackendSelector) SelectBackend(ctx context.Context, targetHost string, targetPort int, services []string) (string, int, error) { node, err := s.selector.Select(ctx, targetHost, targetPort, services) if err != nil { return "", 0, err } return node.Host, node.Port, nil } func (s *BackendSelector) ReleaseBackend(host string, port int, bytesIn, bytesOut int64) { // TODO: 更新节点统计信息 s.logger.Info("释放后端节点", zap.String("host", host), zap.Int("port", port), zap.Int64("bytes_in", bytesIn), zap.Int64("bytes_out", bytesOut), ) } // MemoryNodeCache 内存节点缓存 type MemoryNodeCache struct { stats map[string]*models.NodeStats mu sync.RWMutex } func NewMemoryNodeCache() *MemoryNodeCache { return &MemoryNodeCache{ stats: make(map[string]*models.NodeStats), } } func (c *MemoryNodeCache) GetStats(nodeID string) (*models.NodeStats, bool) { c.mu.RLock() defer c.mu.RUnlock() stats, ok := c.stats[nodeID] return stats, ok } func (c *MemoryNodeCache) SetStats(nodeID string, stats *models.NodeStats) { c.mu.Lock() defer c.mu.Unlock() c.stats[nodeID] = stats } func (c *MemoryNodeCache) GetAllStats() map[string]*models.NodeStats { c.mu.RLock() defer c.mu.RUnlock() result := make(map[string]*models.NodeStats) for k, v := range c.stats { result[k] = v } return result } // startHealthCheck 启动健康检查 func startHealthCheck(ctx context.Context, nodeRepo *repository.NodeRepository, checker *scheduler.HealthChecker, logger *zap.Logger, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: nodes, err := nodeRepo.ListOnline() if err != nil { logger.Error("获取节点列表失败", zap.Error(err)) continue } for _, node := range nodes { if err := checker.Check(ctx, &node); err != nil { logger.Warn("节点健康检查失败", zap.String("node_id", node.NodeID), zap.Error(err), ) } } } } } // APIServer API 服务器 type APIServer struct { cfg *config.Config repos *repository.Repositories db *gorm.DB logger *zap.Logger server *http.Server router *gin.Engine } func NewAPIServer(cfg *config.Config, repos *repository.Repositories, db *gorm.DB, logger *zap.Logger) *APIServer { gin.SetMode(cfg.Server.Mode) router := gin.New() server := &APIServer{ cfg: cfg, repos: repos, db: db, logger: logger, router: router, server: &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), Handler: router, }, } // 注册路由 server.setupRoutes() return server } func (s *APIServer) setupRoutes() { // 健康检查 s.router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) // API 路由组 api := s.router.Group("/api/v1") { // 安装相关 install := api.Group("/install") { install.GET("/check", handler.CheckInstall()) install.GET("/status", handler.GetInstallStatus(s.db)) } // 用户相关 users := api.Group("/users") { users.GET("", handler.ListUsers(s.repos.User)) users.POST("", handler.CreateUser(s.repos.User)) users.GET("/:id", handler.GetUser(s.repos.User)) users.PUT("/:id", handler.UpdateUser(s.repos.User)) users.DELETE("/:id", handler.DeleteUser(s.repos.User)) } // 节点相关 nodes := api.Group("/nodes") { nodes.GET("", handler.ListNodes(s.repos.Node)) nodes.POST("", handler.CreateNode(s.repos.Node)) nodes.GET("/:id", handler.GetNode(s.repos.Node)) nodes.PUT("/:id", handler.UpdateNode(s.repos.Node)) nodes.DELETE("/:id", handler.DeleteNode(s.repos.Node)) nodes.POST("/:id/refresh-ip", handler.RefreshNodeIP(s.repos.Node, s.logger)) } // Agent 相关 agent := api.Group("/agent") { agent.POST("/heartbeat", handler.AgentHeartbeat(s.repos.Node, s.logger)) agent.POST("/unlock/report", handler.ReportUnlockStatus(s.repos.UnlockStatus, s.logger)) agent.POST("/ip/change/result", handler.ReportIPChange(s.repos.Node, s.repos.IPChangeLog, s.logger)) } // 规则相关 rules := api.Group("/rules") { rules.GET("", handler.ListRules(s.repos.IPRefreshRule)) rules.POST("", handler.CreateRule(s.repos.IPRefreshRule)) rules.GET("/:id", handler.GetRule(s.repos.IPRefreshRule)) rules.PUT("/:id", handler.UpdateRule(s.repos.IPRefreshRule)) rules.DELETE("/:id", handler.DeleteRule(s.repos.IPRefreshRule)) } // 统计相关 stats := api.Group("/stats") { stats.GET("/overview", handler.GetOverview(s.repos)) stats.GET("/traffic", handler.GetTrafficStats(s.repos.ConnectionLog)) } } } func (s *APIServer) Start() error { return s.server.ListenAndServe() } func (s *APIServer) Shutdown(ctx context.Context) { s.server.Shutdown(ctx) }