345 lines
8.5 KiB
Go
345 lines
8.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"proxy-platform/internal/config"
|
|
"proxy-platform/internal/handler"
|
|
"proxy-platform/internal/models"
|
|
"proxy-platform/internal/repository"
|
|
"proxy-platform/internal/socks5"
|
|
"proxy-platform/internal/scheduler"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func main() {
|
|
// 加载配置
|
|
cfg, err := config.Load("configs/scheduler.yaml")
|
|
if err != nil {
|
|
log.Fatalf("加载配置失败: %v", err)
|
|
}
|
|
|
|
// 初始化日志
|
|
logger, err := zap.NewProduction()
|
|
if err != nil {
|
|
log.Fatalf("初始化日志失败: %v", err)
|
|
}
|
|
defer logger.Sync()
|
|
|
|
// 连接数据库
|
|
db, err := gorm.Open(postgres.Open(cfg.Database.DSN()), &gorm.Config{})
|
|
if err != nil {
|
|
logger.Fatal("连接数据库失败", zap.Error(err))
|
|
}
|
|
|
|
// 自动迁移
|
|
if err := db.AutoMigrate(
|
|
&models.User{},
|
|
&models.Node{},
|
|
&models.NodeGroup{},
|
|
&models.UnlockStatus{},
|
|
&models.IPChangeLog{},
|
|
&models.ConnectionLog{},
|
|
&models.IPRefreshRule{},
|
|
); err != nil {
|
|
logger.Fatal("数据库迁移失败", zap.Error(err))
|
|
}
|
|
|
|
logger.Info("数据库迁移完成")
|
|
|
|
// 初始化仓库
|
|
repos := repository.NewRepositories(db)
|
|
|
|
// 初始化认证器
|
|
auth := NewSimpleAuthenticator(repos.User)
|
|
|
|
// 初始化节点缓存
|
|
cache := NewMemoryNodeCache()
|
|
|
|
// 初始化节点选择器
|
|
selector := scheduler.NewSelector(repos.Node, cache, scheduler.StrategyLeastLatency, logger)
|
|
|
|
// 初始化后端选择器
|
|
backendSelector := NewBackendSelector(repos.Node, selector, logger)
|
|
|
|
// 启动 SOCKS5 服务器
|
|
socks5Server := socks5.NewServer(
|
|
cfg.SOCKS5.Host,
|
|
cfg.SOCKS5.Port,
|
|
cfg.SOCKS5.MaxConnections,
|
|
cfg.SOCKS5.Timeout,
|
|
auth,
|
|
backendSelector,
|
|
logger,
|
|
)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// 启动 SOCKS5 服务器
|
|
go func() {
|
|
if err := socks5Server.Start(ctx); err != nil {
|
|
logger.Error("SOCKS5 服务器错误", zap.Error(err))
|
|
}
|
|
}()
|
|
|
|
// 启动健康检查
|
|
healthChecker := scheduler.NewHealthChecker(repos.Node, cache, logger)
|
|
go startHealthCheck(ctx, repos.Node, healthChecker, logger, time.Duration(cfg.Scheduler.HealthCheckInterval)*time.Second)
|
|
|
|
// 启动 API 服务器
|
|
apiServer := NewAPIServer(cfg, repos, logger)
|
|
go func() {
|
|
if err := apiServer.Start(); err != nil && err != http.ErrServerClosed {
|
|
logger.Error("API 服务器错误", zap.Error(err))
|
|
}
|
|
}()
|
|
|
|
logger.Info("调度中心启动完成",
|
|
zap.String("api", fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)),
|
|
zap.String("socks5", fmt.Sprintf("%s:%d", cfg.SOCKS5.Host, cfg.SOCKS5.Port)),
|
|
)
|
|
|
|
// 等待中断信号
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
<-quit
|
|
|
|
logger.Info("正在关闭服务器...")
|
|
cancel()
|
|
|
|
// 关闭 API 服务器
|
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel2()
|
|
apiServer.Shutdown(ctx2)
|
|
|
|
logger.Info("服务器已关闭")
|
|
}
|
|
|
|
// SimpleAuthenticator 简单认证器
|
|
type SimpleAuthenticator struct {
|
|
userRepo *repository.UserRepository
|
|
}
|
|
|
|
func NewSimpleAuthenticator(userRepo *repository.UserRepository) *SimpleAuthenticator {
|
|
return &SimpleAuthenticator{userRepo: userRepo}
|
|
}
|
|
|
|
func (a *SimpleAuthenticator) Authenticate(username, password string) (uint, bool) {
|
|
user, err := a.userRepo.FindByUsername(username)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
|
|
// TODO: 实现密码验证
|
|
// 这里简化处理,实际应该使用 bcrypt 验证
|
|
if user.PasswordHash == password && user.Status == "active" {
|
|
return user.ID, true
|
|
}
|
|
|
|
return 0, false
|
|
}
|
|
|
|
// BackendSelector 后端节点选择器
|
|
type BackendSelector struct {
|
|
nodeRepo *repository.NodeRepository
|
|
selector *scheduler.Selector
|
|
logger *zap.Logger
|
|
}
|
|
|
|
func NewBackendSelector(nodeRepo *repository.NodeRepository, selector *scheduler.Selector, logger *zap.Logger) *BackendSelector {
|
|
return &BackendSelector{
|
|
nodeRepo: nodeRepo,
|
|
selector: selector,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (s *BackendSelector) SelectBackend(ctx context.Context, targetHost string, targetPort int, services []string) (string, int, error) {
|
|
node, err := s.selector.Select(ctx, targetHost, targetPort, services)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
return node.Host, node.Port, nil
|
|
}
|
|
|
|
func (s *BackendSelector) ReleaseBackend(host string, port int, bytesIn, bytesOut int64) {
|
|
// TODO: 更新节点统计信息
|
|
s.logger.Info("释放后端节点",
|
|
zap.String("host", host),
|
|
zap.Int("port", port),
|
|
zap.Int64("bytes_in", bytesIn),
|
|
zap.Int64("bytes_out", bytesOut),
|
|
)
|
|
}
|
|
|
|
// MemoryNodeCache 内存节点缓存
|
|
type MemoryNodeCache struct {
|
|
stats map[string]*models.NodeStats
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewMemoryNodeCache() *MemoryNodeCache {
|
|
return &MemoryNodeCache{
|
|
stats: make(map[string]*models.NodeStats),
|
|
}
|
|
}
|
|
|
|
func (c *MemoryNodeCache) GetStats(nodeID string) (*models.NodeStats, bool) {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
stats, ok := c.stats[nodeID]
|
|
return stats, ok
|
|
}
|
|
|
|
func (c *MemoryNodeCache) SetStats(nodeID string, stats *models.NodeStats) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.stats[nodeID] = stats
|
|
}
|
|
|
|
func (c *MemoryNodeCache) GetAllStats() map[string]*models.NodeStats {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
result := make(map[string]*models.NodeStats)
|
|
for k, v := range c.stats {
|
|
result[k] = v
|
|
}
|
|
return result
|
|
}
|
|
|
|
// startHealthCheck 启动健康检查
|
|
func startHealthCheck(ctx context.Context, nodeRepo *repository.NodeRepository, checker *scheduler.HealthChecker, logger *zap.Logger, interval time.Duration) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
nodes, err := nodeRepo.ListOnline()
|
|
if err != nil {
|
|
logger.Error("获取节点列表失败", zap.Error(err))
|
|
continue
|
|
}
|
|
|
|
for _, node := range nodes {
|
|
if err := checker.Check(ctx, &node); err != nil {
|
|
logger.Warn("节点健康检查失败",
|
|
zap.String("node_id", node.NodeID),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// APIServer API 服务器
|
|
type APIServer struct {
|
|
cfg *config.Config
|
|
repos *repository.Repositories
|
|
logger *zap.Logger
|
|
server *http.Server
|
|
router *gin.Engine
|
|
}
|
|
|
|
func NewAPIServer(cfg *config.Config, repos *repository.Repositories, logger *zap.Logger) *APIServer {
|
|
gin.SetMode(cfg.Server.Mode)
|
|
router := gin.New()
|
|
|
|
server := &APIServer{
|
|
cfg: cfg,
|
|
repos: repos,
|
|
logger: logger,
|
|
router: router,
|
|
server: &http.Server{
|
|
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
|
Handler: router,
|
|
},
|
|
}
|
|
|
|
// 注册路由
|
|
server.setupRoutes()
|
|
|
|
return server
|
|
}
|
|
|
|
func (s *APIServer) setupRoutes() {
|
|
// 健康检查
|
|
s.router.GET("/health", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
|
|
// API 路由组
|
|
api := s.router.Group("/api/v1")
|
|
{
|
|
// 用户相关
|
|
users := api.Group("/users")
|
|
{
|
|
users.GET("", handler.ListUsers(s.repos.User))
|
|
users.POST("", handler.CreateUser(s.repos.User))
|
|
users.GET("/:id", handler.GetUser(s.repos.User))
|
|
users.PUT("/:id", handler.UpdateUser(s.repos.User))
|
|
users.DELETE("/:id", handler.DeleteUser(s.repos.User))
|
|
}
|
|
|
|
// 节点相关
|
|
nodes := api.Group("/nodes")
|
|
{
|
|
nodes.GET("", handler.ListNodes(s.repos.Node))
|
|
nodes.POST("", handler.CreateNode(s.repos.Node))
|
|
nodes.GET("/:id", handler.GetNode(s.repos.Node))
|
|
nodes.PUT("/:id", handler.UpdateNode(s.repos.Node))
|
|
nodes.DELETE("/:id", handler.DeleteNode(s.repos.Node))
|
|
nodes.POST("/:id/refresh-ip", handler.RefreshNodeIP(s.repos.Node, s.logger))
|
|
}
|
|
|
|
// Agent 相关
|
|
agent := api.Group("/agent")
|
|
{
|
|
agent.POST("/heartbeat", handler.AgentHeartbeat(s.repos.Node, s.logger))
|
|
agent.POST("/unlock/report", handler.ReportUnlockStatus(s.repos.UnlockStatus, s.logger))
|
|
agent.POST("/ip/change/result", handler.ReportIPChange(s.repos.Node, s.repos.IPChangeLog, s.logger))
|
|
}
|
|
|
|
// 规则相关
|
|
rules := api.Group("/rules")
|
|
{
|
|
rules.GET("", handler.ListRules(s.repos.IPRefreshRule))
|
|
rules.POST("", handler.CreateRule(s.repos.IPRefreshRule))
|
|
rules.GET("/:id", handler.GetRule(s.repos.IPRefreshRule))
|
|
rules.PUT("/:id", handler.UpdateRule(s.repos.IPRefreshRule))
|
|
rules.DELETE("/:id", handler.DeleteRule(s.repos.IPRefreshRule))
|
|
}
|
|
|
|
// 统计相关
|
|
stats := api.Group("/stats")
|
|
{
|
|
stats.GET("/overview", handler.GetOverview(s.repos))
|
|
stats.GET("/traffic", handler.GetTrafficStats(s.repos.ConnectionLog))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *APIServer) Start() error {
|
|
return s.server.ListenAndServe()
|
|
}
|
|
|
|
func (s *APIServer) Shutdown(ctx context.Context) {
|
|
s.server.Shutdown(ctx)
|
|
} |