230 lines
7.5 KiB
Go
230 lines
7.5 KiB
Go
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"fmt"
|
||
|
|
"log"
|
||
|
|
"net/http"
|
||
|
|
"os"
|
||
|
|
"os/signal"
|
||
|
|
"syscall"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
|
||
|
|
"github.com/user-management-system/internal/api/handler"
|
||
|
|
"github.com/user-management-system/internal/api/middleware"
|
||
|
|
"github.com/user-management-system/internal/api/router"
|
||
|
|
"github.com/user-management-system/internal/auth"
|
||
|
|
"github.com/user-management-system/internal/cache"
|
||
|
|
"github.com/user-management-system/internal/config"
|
||
|
|
"github.com/user-management-system/internal/database"
|
||
|
|
"github.com/user-management-system/internal/repository"
|
||
|
|
"github.com/user-management-system/internal/security"
|
||
|
|
"github.com/user-management-system/internal/service"
|
||
|
|
)
|
||
|
|
|
||
|
|
func main() {
|
||
|
|
// 加载配置
|
||
|
|
cfg, err := config.Load()
|
||
|
|
if err != nil {
|
||
|
|
log.Fatalf("load config failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 设置 Gin 模式
|
||
|
|
gin.SetMode(resolveGinMode(cfg.Server.Mode))
|
||
|
|
|
||
|
|
// 初始化数据库
|
||
|
|
db, err := database.NewDB(cfg)
|
||
|
|
if err != nil {
|
||
|
|
log.Fatalf("connect database failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 执行数据库迁移
|
||
|
|
if err := db.AutoMigrate(cfg); err != nil {
|
||
|
|
log.Fatalf("auto migrate failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 初始化 JWT 管理器
|
||
|
|
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||
|
|
HS256Secret: cfg.JWT.Secret,
|
||
|
|
AccessTokenExpire: time.Duration(cfg.JWT.AccessTokenExpireMinutes) * time.Minute,
|
||
|
|
RefreshTokenExpire: time.Duration(cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
log.Fatalf("create jwt manager failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 初始化缓存
|
||
|
|
l1Cache := cache.NewL1Cache()
|
||
|
|
l2Cache := cache.NewRedisCacheWithConfig(cache.RedisCacheConfig{
|
||
|
|
Addr: fmt.Sprintf("%s:%d", cfg.Redis.Host, cfg.Redis.Port),
|
||
|
|
Password: cfg.Redis.Password,
|
||
|
|
DB: cfg.Redis.DB,
|
||
|
|
})
|
||
|
|
defer l2Cache.Close()
|
||
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||
|
|
|
||
|
|
// 初始化 Repository
|
||
|
|
userRepo := repository.NewUserRepository(db.DB)
|
||
|
|
roleRepo := repository.NewRoleRepository(db.DB)
|
||
|
|
permissionRepo := repository.NewPermissionRepository(db.DB)
|
||
|
|
userRoleRepo := repository.NewUserRoleRepository(db.DB)
|
||
|
|
rolePermissionRepo := repository.NewRolePermissionRepository(db.DB)
|
||
|
|
deviceRepo := repository.NewDeviceRepository(db.DB)
|
||
|
|
loginLogRepo := repository.NewLoginLogRepository(db.DB)
|
||
|
|
operationLogRepo := repository.NewOperationLogRepository(db.DB)
|
||
|
|
customFieldRepo := repository.NewCustomFieldRepository(db.DB)
|
||
|
|
userCustomFieldValueRepo := repository.NewUserCustomFieldValueRepository(db.DB)
|
||
|
|
themeRepo := repository.NewThemeConfigRepository(db.DB)
|
||
|
|
socialRepo, err := repository.NewSocialAccountRepository(db.DB)
|
||
|
|
if err != nil {
|
||
|
|
log.Fatalf("initialize social account repository failed: %v", err)
|
||
|
|
}
|
||
|
|
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db.DB)
|
||
|
|
|
||
|
|
// 初始化 Service
|
||
|
|
deviceService := service.NewDeviceService(deviceRepo, userRepo)
|
||
|
|
authService := service.NewAuthService(
|
||
|
|
userRepo,
|
||
|
|
socialRepo,
|
||
|
|
jwtManager,
|
||
|
|
cacheManager,
|
||
|
|
8, // passwordMinLength
|
||
|
|
5, // maxLoginAttempts
|
||
|
|
15*time.Minute, // loginLockDuration
|
||
|
|
)
|
||
|
|
authService.SetRoleRepositories(userRoleRepo, roleRepo)
|
||
|
|
authService.SetLoginLogRepository(loginLogRepo)
|
||
|
|
authService.SetDeviceService(deviceService)
|
||
|
|
|
||
|
|
// IP 过滤中间件
|
||
|
|
var ipFilterMiddleware *middleware.IPFilterMiddleware
|
||
|
|
ipFilter := security.NewIPFilter()
|
||
|
|
if ipFilter != nil {
|
||
|
|
ipFilterMiddleware = middleware.NewIPFilterMiddleware(ipFilter, middleware.IPFilterConfig{
|
||
|
|
TrustProxy: cfg.CORS.AllowCredentials,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// 初始化异常检测器并注入
|
||
|
|
anomalyDetector := security.NewAnomalyDetector(security.DefaultAnomalyConfig, ipFilter)
|
||
|
|
authService.SetAnomalyDetector(anomalyDetector)
|
||
|
|
log.Println("anomaly detector initialized")
|
||
|
|
|
||
|
|
userService := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo)
|
||
|
|
roleService := service.NewRoleService(roleRepo, rolePermissionRepo)
|
||
|
|
permissionService := service.NewPermissionService(permissionRepo)
|
||
|
|
loginLogService := service.NewLoginLogService(loginLogRepo)
|
||
|
|
operationLogService := service.NewOperationLogService(operationLogRepo)
|
||
|
|
captchaService := service.NewCaptchaService(cacheManager)
|
||
|
|
totpService := service.NewTOTPService(userRepo)
|
||
|
|
|
||
|
|
passwordResetConfig := service.DefaultPasswordResetConfig()
|
||
|
|
passwordResetService := service.NewPasswordResetService(userRepo, cacheManager, passwordResetConfig)
|
||
|
|
|
||
|
|
webhookService := service.NewWebhookService(db.DB, service.WebhookServiceConfig{
|
||
|
|
Enabled: false,
|
||
|
|
})
|
||
|
|
exportService := service.NewExportService(userRepo, roleRepo)
|
||
|
|
statsService := service.NewStatsService(userRepo, loginLogRepo)
|
||
|
|
customFieldService := service.NewCustomFieldService(customFieldRepo, userCustomFieldValueRepo)
|
||
|
|
themeService := service.NewThemeService(themeRepo)
|
||
|
|
|
||
|
|
// 设置 CORS 配置
|
||
|
|
middleware.SetCORSConfig(cfg.CORS)
|
||
|
|
|
||
|
|
// 初始化中间件
|
||
|
|
rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)
|
||
|
|
authMiddleware := middleware.NewAuthMiddleware(
|
||
|
|
jwtManager,
|
||
|
|
userRepo,
|
||
|
|
userRoleRepo,
|
||
|
|
roleRepo,
|
||
|
|
rolePermissionRepo,
|
||
|
|
permissionRepo,
|
||
|
|
)
|
||
|
|
authMiddleware.SetCacheManager(cacheManager)
|
||
|
|
|
||
|
|
opLogMiddleware := middleware.NewOperationLogMiddleware(operationLogRepo)
|
||
|
|
|
||
|
|
// 初始化 Handler
|
||
|
|
authHandler := handler.NewAuthHandler(authService)
|
||
|
|
userHandler := handler.NewUserHandler(userService)
|
||
|
|
roleHandler := handler.NewRoleHandler(roleService)
|
||
|
|
permissionHandler := handler.NewPermissionHandler(permissionService)
|
||
|
|
deviceHandler := handler.NewDeviceHandler(deviceService)
|
||
|
|
logHandler := handler.NewLogHandler(loginLogService, operationLogService)
|
||
|
|
captchaHandler := handler.NewCaptchaHandler(captchaService)
|
||
|
|
totpHandler := handler.NewTOTPHandler(authService, totpService)
|
||
|
|
webhookHandler := handler.NewWebhookHandler(webhookService)
|
||
|
|
exportHandler := handler.NewExportHandler(exportService)
|
||
|
|
statsHandler := handler.NewStatsHandler(statsService)
|
||
|
|
passwordResetHandler := handler.NewPasswordResetHandler(passwordResetService)
|
||
|
|
smsHandler := handler.NewSMSHandler()
|
||
|
|
avatarHandler := handler.NewAvatarHandler()
|
||
|
|
customFieldHandler := handler.NewCustomFieldHandler(customFieldService)
|
||
|
|
themeHandler := handler.NewThemeHandler(themeService)
|
||
|
|
|
||
|
|
// 初始化 SSO 管理器
|
||
|
|
ssoManager := auth.NewSSOManager()
|
||
|
|
ssoHandler := handler.NewSSOHandler(ssoManager)
|
||
|
|
|
||
|
|
// 设置路由
|
||
|
|
r := router.NewRouter(
|
||
|
|
authHandler, userHandler, roleHandler, permissionHandler, deviceHandler,
|
||
|
|
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
|
||
|
|
passwordResetHandler, captchaHandler, totpHandler, webhookHandler,
|
||
|
|
ipFilterMiddleware, exportHandler, statsHandler, smsHandler, customFieldHandler, themeHandler, ssoHandler, avatarHandler,
|
||
|
|
)
|
||
|
|
engine := r.Setup()
|
||
|
|
|
||
|
|
// 健康检查
|
||
|
|
engine.GET("/health", func(c *gin.Context) {
|
||
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||
|
|
})
|
||
|
|
|
||
|
|
// 启动服务器
|
||
|
|
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
||
|
|
srv := &http.Server{
|
||
|
|
Addr: addr,
|
||
|
|
Handler: engine,
|
||
|
|
ReadTimeout: 30 * time.Second,
|
||
|
|
WriteTimeout: 30 * time.Second,
|
||
|
|
}
|
||
|
|
|
||
|
|
go func() {
|
||
|
|
log.Printf("server listening on %s", addr)
|
||
|
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||
|
|
log.Fatalf("listen failed: %v", err)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
// 等待中断信号
|
||
|
|
quit := make(chan os.Signal, 1)
|
||
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||
|
|
<-quit
|
||
|
|
|
||
|
|
log.Println("shutting down server...")
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
if err := srv.Shutdown(ctx); err != nil {
|
||
|
|
log.Fatalf("server forced to shutdown: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
log.Println("server exited")
|
||
|
|
}
|
||
|
|
|
||
|
|
func resolveGinMode(mode string) string {
|
||
|
|
switch mode {
|
||
|
|
case "debug":
|
||
|
|
return gin.DebugMode
|
||
|
|
case "test":
|
||
|
|
return gin.TestMode
|
||
|
|
default:
|
||
|
|
return gin.ReleaseMode
|
||
|
|
}
|
||
|
|
}
|