Files
proxy-platform/internal/repository/repository.go
T

293 lines
7.5 KiB
Go

package repository
import (
"context"
"time"
"proxy-platform/internal/models"
"gorm.io/gorm"
)
// UserRepository 用户仓库
type UserRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
func (r *UserRepository) Create(user *models.User) error {
return r.db.Create(user).Error
}
func (r *UserRepository) FindByUsername(username string) (*models.User, error) {
var user models.User
err := r.db.Where("username = ?", username).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
var user models.User
err := r.db.First(&user, id).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *UserRepository) Update(user *models.User) error {
return r.db.Save(user).Error
}
func (r *UserRepository) UpdateTraffic(userID uint, bytesIn, bytesOut int64) error {
return r.db.Model(&models.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"traffic_used": gorm.Expr("traffic_used + ?", bytesIn+bytesOut),
}).Error
}
func (r *UserRepository) List(offset, limit int) ([]models.User, int64, error) {
var users []models.User
var total int64
r.db.Model(&models.User{}).Count(&total)
err := r.db.Offset(offset).Limit(limit).Find(&users).Error
return users, total, err
}
func (r *UserRepository) Delete(id uint) error {
return r.db.Delete(&models.User{}, id).Error
}
// NodeRepository 节点仓库
type NodeRepository struct {
db *gorm.DB
}
func NewNodeRepository(db *gorm.DB) *NodeRepository {
return &NodeRepository{db: db}
}
func (r *NodeRepository) Create(node *models.Node) error {
return r.db.Create(node).Error
}
func (r *NodeRepository) FindByNodeID(nodeID string) (*models.Node, error) {
var node models.Node
err := r.db.Where("node_id = ?", nodeID).First(&node).Error
if err != nil {
return nil, err
}
return &node, nil
}
func (r *NodeRepository) FindByID(id uint) (*models.Node, error) {
var node models.Node
err := r.db.Preload("UnlockStatuses").First(&node, id).Error
if err != nil {
return nil, err
}
return &node, nil
}
func (r *NodeRepository) Update(node *models.Node) error {
return r.db.Save(node).Error
}
func (r *NodeRepository) UpdateStatus(nodeID string, status string, warpStatus string) error {
return r.db.Model(&models.Node{}).
Where("node_id = ?", nodeID).
Updates(map[string]interface{}{
"status": status,
"warp_status": warpStatus,
"last_heartbeat": time.Now(),
}).Error
}
func (r *NodeRepository) UpdateIP(nodeID string, newIP string, ipRegion string) error {
return r.db.Model(&models.Node{}).
Where("node_id = ?", nodeID).
Updates(map[string]interface{}{
"current_ip": newIP,
"ip_region": ipRegion,
}).Error
}
func (r *NodeRepository) UpdateConnections(nodeID string, connections int) error {
return r.db.Model(&models.Node{}).
Where("node_id = ?", nodeID).
Update("current_connections", connections).Error
}
func (r *NodeRepository) List() ([]models.Node, error) {
var nodes []models.Node
err := r.db.Preload("UnlockStatuses").Find(&nodes).Error
return nodes, err
}
func (r *NodeRepository) ListOnline() ([]models.Node, error) {
var nodes []models.Node
err := r.db.Where("status = ?", "online").
Preload("UnlockStatuses").
Find(&nodes).Error
return nodes, err
}
func (r *NodeRepository) Delete(id uint) error {
return r.db.Delete(&models.Node{}, id).Error
}
// UnlockStatusRepository 解锁状态仓库
type UnlockStatusRepository struct {
db *gorm.DB
}
func NewUnlockStatusRepository(db *gorm.DB) *UnlockStatusRepository {
return &UnlockStatusRepository{db: db}
}
func (r *UnlockStatusRepository) Upsert(nodeID uint, service string, unlocked bool, region string) error {
return r.db.Exec(`
INSERT INTO unlock_statuses (node_id, service, unlocked, region, detected_at)
VALUES (?, ?, ?, ?, NOW())
ON CONFLICT (node_id, service)
DO UPDATE SET unlocked = EXCLUDED.unlocked, region = EXCLUDED.region, detected_at = NOW()
`, nodeID, service, unlocked, region).Error
}
func (r *UnlockStatusRepository) FindByNodeID(nodeID uint) ([]models.UnlockStatus, error) {
var statuses []models.UnlockStatus
err := r.db.Where("node_id = ?", nodeID).Find(&statuses).Error
return statuses, err
}
// IPChangeLogRepository IP 变更日志仓库
type IPChangeLogRepository struct {
db *gorm.DB
}
func NewIPChangeLogRepository(db *gorm.DB) *IPChangeLogRepository {
return &IPChangeLogRepository{db: db}
}
func (r *IPChangeLogRepository) Create(log *models.IPChangeLog) error {
return r.db.Create(log).Error
}
func (r *IPChangeLogRepository) FindByNodeID(nodeID uint, limit int) ([]models.IPChangeLog, error) {
var logs []models.IPChangeLog
err := r.db.Where("node_id = ?", nodeID).
Order("created_at DESC").
Limit(limit).
Find(&logs).Error
return logs, err
}
// ConnectionLogRepository 连接日志仓库
type ConnectionLogRepository struct {
db *gorm.DB
}
func NewConnectionLogRepository(db *gorm.DB) *ConnectionLogRepository {
return &ConnectionLogRepository{db: db}
}
func (r *ConnectionLogRepository) Create(log *models.ConnectionLog) error {
return r.db.Create(log).Error
}
func (r *ConnectionLogRepository) GetStatsByUser(userID uint, startTime, endTime time.Time) (map[string]interface{}, error) {
var result struct {
TotalBytes int64
TotalSeconds int
Connections int64
}
err := r.db.Model(&models.ConnectionLog{}).
Where("user_id = ? AND created_at BETWEEN ? AND ?", userID, startTime, endTime).
Select("SUM(bytes_in + bytes_out) as total_bytes, SUM(duration) as total_seconds, COUNT(*) as connections").
Scan(&result).Error
if err != nil {
return nil, err
}
return map[string]interface{}{
"total_bytes": result.TotalBytes,
"total_seconds": result.TotalSeconds,
"connections": result.Connections,
}, nil
}
// IPRefreshRuleRepository IP 刷新规则仓库
type IPRefreshRuleRepository struct {
db *gorm.DB
}
func NewIPRefreshRuleRepository(db *gorm.DB) *IPRefreshRuleRepository {
return &IPRefreshRuleRepository{db: db}
}
func (r *IPRefreshRuleRepository) Create(rule *models.IPRefreshRule) error {
return r.db.Create(rule).Error
}
func (r *IPRefreshRuleRepository) FindByID(id uint) (*models.IPRefreshRule, error) {
var rule models.IPRefreshRule
err := r.db.First(&rule, id).Error
if err != nil {
return nil, err
}
return &rule, nil
}
func (r *IPRefreshRuleRepository) List() ([]models.IPRefreshRule, error) {
var rules []models.IPRefreshRule
err := r.db.Find(&rules).Error
return rules, err
}
func (r *IPRefreshRuleRepository) Update(rule *models.IPRefreshRule) error {
return r.db.Save(rule).Error
}
func (r *IPRefreshRuleRepository) Delete(id uint) error {
return r.db.Delete(&models.IPRefreshRule{}, id).Error
}
// Repositories 仓库集合
type Repositories struct {
User *UserRepository
Node *NodeRepository
UnlockStatus *UnlockStatusRepository
IPChangeLog *IPChangeLogRepository
ConnectionLog *ConnectionLogRepository
IPRefreshRule *IPRefreshRuleRepository
}
func NewRepositories(db *gorm.DB) *Repositories {
return &Repositories{
User: NewUserRepository(db),
Node: NewNodeRepository(db),
UnlockStatus: NewUnlockStatusRepository(db),
IPChangeLog: NewIPChangeLogRepository(db),
ConnectionLog: NewConnectionLogRepository(db),
IPRefreshRule: NewIPRefreshRuleRepository(db),
}
}
// HealthChecker 健康检查接口
type HealthChecker interface {
Ping(ctx context.Context) error
}
func (r *UserRepository) Ping(ctx context.Context) error {
return r.db.WithContext(ctx).Raw("SELECT 1").Error
}