293 lines
7.5 KiB
Go
293 lines
7.5 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"proxy-platform/internal/models"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// UserRepository 用户仓库
|
|
type UserRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUserRepository(db *gorm.DB) *UserRepository {
|
|
return &UserRepository{db: db}
|
|
}
|
|
|
|
func (r *UserRepository) Create(user *models.User) error {
|
|
return r.db.Create(user).Error
|
|
}
|
|
|
|
func (r *UserRepository) FindByUsername(username string) (*models.User, error) {
|
|
var user models.User
|
|
err := r.db.Where("username = ?", username).First(&user).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
|
|
var user models.User
|
|
err := r.db.First(&user, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *UserRepository) Update(user *models.User) error {
|
|
return r.db.Save(user).Error
|
|
}
|
|
|
|
func (r *UserRepository) UpdateTraffic(userID uint, bytesIn, bytesOut int64) error {
|
|
return r.db.Model(&models.User{}).
|
|
Where("id = ?", userID).
|
|
Updates(map[string]interface{}{
|
|
"traffic_used": gorm.Expr("traffic_used + ?", bytesIn+bytesOut),
|
|
}).Error
|
|
}
|
|
|
|
func (r *UserRepository) List(offset, limit int) ([]models.User, int64, error) {
|
|
var users []models.User
|
|
var total int64
|
|
|
|
r.db.Model(&models.User{}).Count(&total)
|
|
err := r.db.Offset(offset).Limit(limit).Find(&users).Error
|
|
return users, total, err
|
|
}
|
|
|
|
func (r *UserRepository) Delete(id uint) error {
|
|
return r.db.Delete(&models.User{}, id).Error
|
|
}
|
|
|
|
// NodeRepository 节点仓库
|
|
type NodeRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewNodeRepository(db *gorm.DB) *NodeRepository {
|
|
return &NodeRepository{db: db}
|
|
}
|
|
|
|
func (r *NodeRepository) Create(node *models.Node) error {
|
|
return r.db.Create(node).Error
|
|
}
|
|
|
|
func (r *NodeRepository) FindByNodeID(nodeID string) (*models.Node, error) {
|
|
var node models.Node
|
|
err := r.db.Where("node_id = ?", nodeID).First(&node).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
func (r *NodeRepository) FindByID(id uint) (*models.Node, error) {
|
|
var node models.Node
|
|
err := r.db.Preload("UnlockStatuses").First(&node, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
func (r *NodeRepository) Update(node *models.Node) error {
|
|
return r.db.Save(node).Error
|
|
}
|
|
|
|
func (r *NodeRepository) UpdateStatus(nodeID string, status string, warpStatus string) error {
|
|
return r.db.Model(&models.Node{}).
|
|
Where("node_id = ?", nodeID).
|
|
Updates(map[string]interface{}{
|
|
"status": status,
|
|
"warp_status": warpStatus,
|
|
"last_heartbeat": time.Now(),
|
|
}).Error
|
|
}
|
|
|
|
func (r *NodeRepository) UpdateIP(nodeID string, newIP string, ipRegion string) error {
|
|
return r.db.Model(&models.Node{}).
|
|
Where("node_id = ?", nodeID).
|
|
Updates(map[string]interface{}{
|
|
"current_ip": newIP,
|
|
"ip_region": ipRegion,
|
|
}).Error
|
|
}
|
|
|
|
func (r *NodeRepository) UpdateConnections(nodeID string, connections int) error {
|
|
return r.db.Model(&models.Node{}).
|
|
Where("node_id = ?", nodeID).
|
|
Update("current_connections", connections).Error
|
|
}
|
|
|
|
func (r *NodeRepository) List() ([]models.Node, error) {
|
|
var nodes []models.Node
|
|
err := r.db.Preload("UnlockStatuses").Find(&nodes).Error
|
|
return nodes, err
|
|
}
|
|
|
|
func (r *NodeRepository) ListOnline() ([]models.Node, error) {
|
|
var nodes []models.Node
|
|
err := r.db.Where("status = ?", "online").
|
|
Preload("UnlockStatuses").
|
|
Find(&nodes).Error
|
|
return nodes, err
|
|
}
|
|
|
|
func (r *NodeRepository) Delete(id uint) error {
|
|
return r.db.Delete(&models.Node{}, id).Error
|
|
}
|
|
|
|
// UnlockStatusRepository 解锁状态仓库
|
|
type UnlockStatusRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUnlockStatusRepository(db *gorm.DB) *UnlockStatusRepository {
|
|
return &UnlockStatusRepository{db: db}
|
|
}
|
|
|
|
func (r *UnlockStatusRepository) Upsert(nodeID uint, service string, unlocked bool, region string) error {
|
|
return r.db.Exec(`
|
|
INSERT INTO unlock_statuses (node_id, service, unlocked, region, detected_at)
|
|
VALUES (?, ?, ?, ?, NOW())
|
|
ON CONFLICT (node_id, service)
|
|
DO UPDATE SET unlocked = EXCLUDED.unlocked, region = EXCLUDED.region, detected_at = NOW()
|
|
`, nodeID, service, unlocked, region).Error
|
|
}
|
|
|
|
func (r *UnlockStatusRepository) FindByNodeID(nodeID uint) ([]models.UnlockStatus, error) {
|
|
var statuses []models.UnlockStatus
|
|
err := r.db.Where("node_id = ?", nodeID).Find(&statuses).Error
|
|
return statuses, err
|
|
}
|
|
|
|
// IPChangeLogRepository IP 变更日志仓库
|
|
type IPChangeLogRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewIPChangeLogRepository(db *gorm.DB) *IPChangeLogRepository {
|
|
return &IPChangeLogRepository{db: db}
|
|
}
|
|
|
|
func (r *IPChangeLogRepository) Create(log *models.IPChangeLog) error {
|
|
return r.db.Create(log).Error
|
|
}
|
|
|
|
func (r *IPChangeLogRepository) FindByNodeID(nodeID uint, limit int) ([]models.IPChangeLog, error) {
|
|
var logs []models.IPChangeLog
|
|
err := r.db.Where("node_id = ?", nodeID).
|
|
Order("created_at DESC").
|
|
Limit(limit).
|
|
Find(&logs).Error
|
|
return logs, err
|
|
}
|
|
|
|
// ConnectionLogRepository 连接日志仓库
|
|
type ConnectionLogRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewConnectionLogRepository(db *gorm.DB) *ConnectionLogRepository {
|
|
return &ConnectionLogRepository{db: db}
|
|
}
|
|
|
|
func (r *ConnectionLogRepository) Create(log *models.ConnectionLog) error {
|
|
return r.db.Create(log).Error
|
|
}
|
|
|
|
func (r *ConnectionLogRepository) GetStatsByUser(userID uint, startTime, endTime time.Time) (map[string]interface{}, error) {
|
|
var result struct {
|
|
TotalBytes int64
|
|
TotalSeconds int
|
|
Connections int64
|
|
}
|
|
|
|
err := r.db.Model(&models.ConnectionLog{}).
|
|
Where("user_id = ? AND created_at BETWEEN ? AND ?", userID, startTime, endTime).
|
|
Select("SUM(bytes_in + bytes_out) as total_bytes, SUM(duration) as total_seconds, COUNT(*) as connections").
|
|
Scan(&result).Error
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"total_bytes": result.TotalBytes,
|
|
"total_seconds": result.TotalSeconds,
|
|
"connections": result.Connections,
|
|
}, nil
|
|
}
|
|
|
|
// IPRefreshRuleRepository IP 刷新规则仓库
|
|
type IPRefreshRuleRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewIPRefreshRuleRepository(db *gorm.DB) *IPRefreshRuleRepository {
|
|
return &IPRefreshRuleRepository{db: db}
|
|
}
|
|
|
|
func (r *IPRefreshRuleRepository) Create(rule *models.IPRefreshRule) error {
|
|
return r.db.Create(rule).Error
|
|
}
|
|
|
|
func (r *IPRefreshRuleRepository) FindByID(id uint) (*models.IPRefreshRule, error) {
|
|
var rule models.IPRefreshRule
|
|
err := r.db.First(&rule, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &rule, nil
|
|
}
|
|
|
|
func (r *IPRefreshRuleRepository) List() ([]models.IPRefreshRule, error) {
|
|
var rules []models.IPRefreshRule
|
|
err := r.db.Find(&rules).Error
|
|
return rules, err
|
|
}
|
|
|
|
func (r *IPRefreshRuleRepository) Update(rule *models.IPRefreshRule) error {
|
|
return r.db.Save(rule).Error
|
|
}
|
|
|
|
func (r *IPRefreshRuleRepository) Delete(id uint) error {
|
|
return r.db.Delete(&models.IPRefreshRule{}, id).Error
|
|
}
|
|
|
|
// Repositories 仓库集合
|
|
type Repositories struct {
|
|
User *UserRepository
|
|
Node *NodeRepository
|
|
UnlockStatus *UnlockStatusRepository
|
|
IPChangeLog *IPChangeLogRepository
|
|
ConnectionLog *ConnectionLogRepository
|
|
IPRefreshRule *IPRefreshRuleRepository
|
|
}
|
|
|
|
func NewRepositories(db *gorm.DB) *Repositories {
|
|
return &Repositories{
|
|
User: NewUserRepository(db),
|
|
Node: NewNodeRepository(db),
|
|
UnlockStatus: NewUnlockStatusRepository(db),
|
|
IPChangeLog: NewIPChangeLogRepository(db),
|
|
ConnectionLog: NewConnectionLogRepository(db),
|
|
IPRefreshRule: NewIPRefreshRuleRepository(db),
|
|
}
|
|
}
|
|
|
|
// HealthChecker 健康检查接口
|
|
type HealthChecker interface {
|
|
Ping(ctx context.Context) error
|
|
}
|
|
|
|
func (r *UserRepository) Ping(ctx context.Context) error {
|
|
return r.db.WithContext(ctx).Raw("SELECT 1").Error
|
|
}
|