- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量 - AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量 - 在 userRepositoryInterface 补充 GetByIDs 方法签名
369 lines
12 KiB
Go
369 lines
12 KiB
Go
package server
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"os/signal"
|
||
"strconv"
|
||
"strings"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
|
||
"github.com/user-management-system/internal/api/handler"
|
||
"github.com/user-management-system/internal/api/middleware"
|
||
"github.com/user-management-system/internal/api/router"
|
||
"github.com/user-management-system/internal/auth"
|
||
"github.com/user-management-system/internal/cache"
|
||
"github.com/user-management-system/internal/config"
|
||
"github.com/user-management-system/internal/database"
|
||
"github.com/user-management-system/internal/monitoring"
|
||
"github.com/user-management-system/internal/repository"
|
||
"github.com/user-management-system/internal/security"
|
||
"github.com/user-management-system/internal/service"
|
||
)
|
||
|
||
func Serve(cfg *config.Config) error {
|
||
// 设置 Gin 模式
|
||
gin.SetMode(resolveGinMode(cfg.Server.Mode))
|
||
|
||
// 初始化数据库
|
||
db, err := database.NewDB(cfg)
|
||
if err != nil {
|
||
return fmt.Errorf("connect database failed: %w", err)
|
||
}
|
||
|
||
// 执行数据库迁移
|
||
if err := db.AutoMigrate(cfg); err != nil {
|
||
return fmt.Errorf("auto migrate failed: %w", err)
|
||
}
|
||
|
||
// P1-3:Argon2id 启动时自适应校准
|
||
auth.CalibrateArgon2id(500 * time.Millisecond)
|
||
|
||
accessTokenExpire := resolveJWTAccessTokenExpire(cfg)
|
||
|
||
// 初始化 JWT 管理器
|
||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||
HS256Secret: cfg.JWT.Secret,
|
||
AccessTokenExpire: accessTokenExpire,
|
||
RefreshTokenExpire: time.Duration(cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("create jwt manager failed: %w", err)
|
||
}
|
||
|
||
// 初始化缓存
|
||
l1Cache := cache.NewL1Cache()
|
||
redisAddr := fmt.Sprintf("%s:%d", cfg.Redis.Host, cfg.Redis.Port)
|
||
redisEnabled := cfg.Redis.Host != "" && cache.ProbeRedis(redisAddr, cfg.Redis.Password, cfg.Redis.DB)
|
||
if !redisEnabled {
|
||
log.Printf("cache: running in memory-only mode (Redis unreachable or not configured)")
|
||
}
|
||
l2Cache := cache.NewRedisCacheWithConfig(cache.RedisCacheConfig{
|
||
Enabled: redisEnabled,
|
||
Addr: redisAddr,
|
||
Password: cfg.Redis.Password,
|
||
DB: cfg.Redis.DB,
|
||
})
|
||
defer l2Cache.Close()
|
||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||
|
||
// 初始化 Repository
|
||
userRepo := repository.NewUserRepository(db.DB)
|
||
roleRepo := repository.NewRoleRepository(db.DB)
|
||
permissionRepo := repository.NewPermissionRepository(db.DB)
|
||
userRoleRepo := repository.NewUserRoleRepository(db.DB)
|
||
rolePermissionRepo := repository.NewRolePermissionRepository(db.DB)
|
||
deviceRepo := repository.NewDeviceRepository(db.DB)
|
||
loginLogRepo := repository.NewLoginLogRepository(db.DB)
|
||
operationLogRepo := repository.NewOperationLogRepository(db.DB)
|
||
customFieldRepo := repository.NewCustomFieldRepository(db.DB)
|
||
userCustomFieldValueRepo := repository.NewUserCustomFieldValueRepository(db.DB)
|
||
themeRepo := repository.NewThemeConfigRepository(db.DB)
|
||
socialRepo, err := repository.NewSocialAccountRepository(db.DB)
|
||
if err != nil {
|
||
return fmt.Errorf("initialize social account repository failed: %w", err)
|
||
}
|
||
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db.DB)
|
||
|
||
// 初始化 Service
|
||
deviceService := service.NewDeviceService(deviceRepo, userRepo)
|
||
authService := service.NewAuthService(
|
||
userRepo,
|
||
socialRepo,
|
||
jwtManager,
|
||
cacheManager,
|
||
8, // passwordMinLength
|
||
5, // maxLoginAttempts
|
||
15*time.Minute, // loginLockDuration
|
||
)
|
||
authService.SetRoleRepositories(userRoleRepo, roleRepo)
|
||
authService.SetLoginLogRepository(loginLogRepo)
|
||
authService.SetDeviceService(deviceService)
|
||
|
||
// IP 过滤中间件
|
||
var ipFilterMiddleware *middleware.IPFilterMiddleware
|
||
ipFilter := security.NewIPFilter()
|
||
if ipFilter != nil {
|
||
ipFilterMiddleware = middleware.NewIPFilterMiddleware(ipFilter, middleware.IPFilterConfig{
|
||
TrustProxy: cfg.CORS.AllowCredentials,
|
||
})
|
||
}
|
||
|
||
// 初始化异常检测器并注入
|
||
anomalyDetector := security.NewAnomalyDetector(security.DefaultAnomalyConfig, ipFilter)
|
||
authService.SetAnomalyDetector(anomalyDetector)
|
||
log.Println("anomaly detector initialized")
|
||
|
||
userService := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo)
|
||
roleService := service.NewRoleService(roleRepo, rolePermissionRepo)
|
||
permissionService := service.NewPermissionService(permissionRepo)
|
||
loginLogService := service.NewLoginLogService(loginLogRepo)
|
||
operationLogService := service.NewOperationLogService(operationLogRepo)
|
||
captchaService := service.NewCaptchaService(cacheManager)
|
||
totpService := service.NewTOTPService(userRepo)
|
||
|
||
passwordResetConfig := service.DefaultPasswordResetConfig()
|
||
if err := configureAuthEmailServices(cfg, cacheManager, authService, passwordResetConfig); err != nil {
|
||
return fmt.Errorf("configure auth email services failed: %w", err)
|
||
}
|
||
passwordResetService := service.NewPasswordResetService(userRepo, cacheManager, passwordResetConfig).
|
||
WithPasswordHistoryRepo(passwordHistoryRepo)
|
||
|
||
webhookService := service.NewWebhookService(db.DB, service.WebhookServiceConfig{
|
||
Enabled: false,
|
||
})
|
||
exportService := service.NewExportService(userRepo, roleRepo)
|
||
statsService := service.NewStatsService(userRepo, loginLogRepo)
|
||
customFieldService := service.NewCustomFieldService(customFieldRepo, userCustomFieldValueRepo)
|
||
themeService := service.NewThemeService(themeRepo)
|
||
|
||
// 设置 CORS 配置
|
||
middleware.SetCORSConfig(cfg.CORS)
|
||
|
||
// 初始化中间件
|
||
rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)
|
||
stopRateLimitCleanup := rateLimitMiddleware.StartCleanup()
|
||
defer stopRateLimitCleanup()
|
||
|
||
authMiddleware := middleware.NewAuthMiddleware(
|
||
jwtManager,
|
||
userRepo,
|
||
userRoleRepo,
|
||
l1Cache,
|
||
)
|
||
authMiddleware.SetCacheManager(cacheManager)
|
||
|
||
opLogMiddleware := middleware.NewOperationLogMiddleware(operationLogRepo)
|
||
|
||
// 初始化 Handler
|
||
authHandler := handler.NewAuthHandler(authService)
|
||
userHandler := handler.NewUserHandler(userService)
|
||
roleHandler := handler.NewRoleHandler(roleService)
|
||
permissionHandler := handler.NewPermissionHandler(permissionService)
|
||
deviceHandler := handler.NewDeviceHandler(deviceService)
|
||
logHandler := handler.NewLogHandler(loginLogService, operationLogService)
|
||
captchaHandler := handler.NewCaptchaHandler(captchaService)
|
||
totpHandler := handler.NewTOTPHandler(authService, totpService)
|
||
webhookHandler := handler.NewWebhookHandler(webhookService)
|
||
exportHandler := handler.NewExportHandler(exportService)
|
||
statsHandler := handler.NewStatsHandler(statsService)
|
||
passwordResetHandler := handler.NewPasswordResetHandler(passwordResetService)
|
||
smsHandler := handler.NewSMSHandler(authService, nil)
|
||
avatarHandler := handler.NewAvatarHandler(userRepo)
|
||
customFieldHandler := handler.NewCustomFieldHandler(customFieldService)
|
||
themeHandler := handler.NewThemeHandler(themeService)
|
||
|
||
// 初始化 SSO 管理器
|
||
ssoManager := auth.NewSSOManager()
|
||
ssoClientsStore := auth.NewDefaultSSOClientsStore()
|
||
ssoHandler := handler.NewSSOHandler(ssoManager, ssoClientsStore)
|
||
|
||
// 系统设置服务
|
||
settingsService := service.NewSettingsService()
|
||
settingsHandler := handler.NewSettingsHandler(settingsService)
|
||
|
||
// SSO 会话清理 context(随服务器关闭而取消)
|
||
ssoCtx, ssoCancel := context.WithCancel(context.Background())
|
||
defer ssoCancel()
|
||
ssoManager.StartCleanup(ssoCtx)
|
||
|
||
// 初始化监控指标
|
||
metrics := monitoring.GetGlobalMetrics()
|
||
sloMetrics := monitoring.GetGlobalSLOMetrics()
|
||
|
||
// 启动后台 goroutine 定期采集系统指标
|
||
metricsCtx, metricsCancel := context.WithCancel(context.Background())
|
||
defer metricsCancel()
|
||
go monitoring.StartSystemMetricsCollector(metricsCtx, metrics, sloMetrics, db.DB)
|
||
|
||
// 设置路由
|
||
r := router.NewRouter(
|
||
authHandler, userHandler, roleHandler, permissionHandler, deviceHandler,
|
||
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
|
||
passwordResetHandler, captchaHandler, totpHandler, webhookHandler,
|
||
ipFilterMiddleware, exportHandler, statsHandler, smsHandler, customFieldHandler, themeHandler, ssoHandler,
|
||
settingsHandler, metrics, avatarHandler,
|
||
)
|
||
engine := r.Setup()
|
||
|
||
// 健康检查
|
||
healthCheck := monitoring.NewHealthCheck(db.DB)
|
||
engine.GET("/health", healthCheck.Handler)
|
||
engine.GET("/health/live", healthCheck.LivenessHandler)
|
||
engine.GET("/health/ready", healthCheck.ReadinessHandler)
|
||
|
||
// 启动服务器
|
||
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
||
srv := &http.Server{
|
||
Addr: addr,
|
||
Handler: engine,
|
||
ReadTimeout: 30 * time.Second,
|
||
WriteTimeout: 30 * time.Second,
|
||
}
|
||
|
||
go func() {
|
||
log.Printf("server listening on %s", addr)
|
||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||
log.Fatalf("listen failed: %v", err)
|
||
}
|
||
}()
|
||
|
||
// 等待中断信号
|
||
quit := make(chan os.Signal, 1)
|
||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||
<-quit
|
||
|
||
log.Println("shutting down server...")
|
||
|
||
// 关闭 Webhook 服务
|
||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer shutdownCancel()
|
||
if err := webhookService.Shutdown(shutdownCtx); err != nil {
|
||
log.Printf("webhook service shutdown: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||
defer cancel()
|
||
|
||
if err := srv.Shutdown(ctx); err != nil {
|
||
return fmt.Errorf("server forced to shutdown: %w", err)
|
||
}
|
||
|
||
log.Println("server exited")
|
||
return nil
|
||
}
|
||
|
||
func resolveGinMode(mode string) string {
|
||
switch mode {
|
||
case "debug":
|
||
return gin.DebugMode
|
||
case "test":
|
||
return gin.TestMode
|
||
default:
|
||
return gin.ReleaseMode
|
||
}
|
||
}
|
||
|
||
func configureAuthEmailServices(
|
||
cfg *config.Config,
|
||
cacheManager *cache.CacheManager,
|
||
authService *service.AuthService,
|
||
passwordResetConfig *service.PasswordResetConfig,
|
||
) error {
|
||
smtpConfig, enabled, err := resolveSMTPEmailConfigFromEnv()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if !enabled || cacheManager == nil || authService == nil {
|
||
return nil
|
||
}
|
||
|
||
siteURL := resolveAuthEmailSiteURL(cfg)
|
||
siteName := resolveAuthEmailSiteName(cfg)
|
||
provider := service.NewSMTPEmailProvider(smtpConfig)
|
||
|
||
authService.SetEmailActivationService(
|
||
service.NewEmailActivationService(provider, cacheManager, siteURL, siteName),
|
||
)
|
||
|
||
emailCodeConfig := service.DefaultEmailCodeConfig()
|
||
emailCodeConfig.SiteURL = siteURL
|
||
emailCodeConfig.SiteName = siteName
|
||
authService.SetEmailCodeService(service.NewEmailCodeService(provider, cacheManager, emailCodeConfig))
|
||
|
||
if passwordResetConfig != nil {
|
||
passwordResetConfig.SMTPHost = smtpConfig.Host
|
||
passwordResetConfig.SMTPPort = smtpConfig.Port
|
||
passwordResetConfig.SMTPUser = smtpConfig.Username
|
||
passwordResetConfig.SMTPPass = smtpConfig.Password
|
||
passwordResetConfig.FromEmail = smtpConfig.FromEmail
|
||
passwordResetConfig.SiteURL = siteURL
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func resolveSMTPEmailConfigFromEnv() (service.SMTPEmailConfig, bool, error) {
|
||
host := strings.TrimSpace(os.Getenv("EMAIL_HOST"))
|
||
if host == "" {
|
||
return service.SMTPEmailConfig{}, false, nil
|
||
}
|
||
|
||
port := 587
|
||
if rawPort := strings.TrimSpace(os.Getenv("EMAIL_PORT")); rawPort != "" {
|
||
parsedPort, err := strconv.Atoi(rawPort)
|
||
if err != nil || parsedPort <= 0 {
|
||
return service.SMTPEmailConfig{}, false, fmt.Errorf("invalid EMAIL_PORT %q", rawPort)
|
||
}
|
||
port = parsedPort
|
||
}
|
||
|
||
fromEmail := strings.TrimSpace(os.Getenv("EMAIL_FROM_EMAIL"))
|
||
if fromEmail == "" {
|
||
fromEmail = service.DefaultPasswordResetConfig().FromEmail
|
||
}
|
||
|
||
return service.SMTPEmailConfig{
|
||
Host: host,
|
||
Port: port,
|
||
Username: strings.TrimSpace(os.Getenv("EMAIL_USER")),
|
||
Password: os.Getenv("EMAIL_PASS"),
|
||
FromEmail: fromEmail,
|
||
FromName: strings.TrimSpace(os.Getenv("EMAIL_FROM_NAME")),
|
||
}, true, nil
|
||
}
|
||
|
||
func resolveAuthEmailSiteURL(cfg *config.Config) string {
|
||
if cfg != nil {
|
||
if siteURL := strings.TrimSpace(cfg.Server.FrontendURL); siteURL != "" {
|
||
return siteURL
|
||
}
|
||
}
|
||
return service.DefaultEmailCodeConfig().SiteURL
|
||
}
|
||
|
||
func resolveAuthEmailSiteName(cfg *config.Config) string {
|
||
if cfg != nil {
|
||
if siteName := strings.TrimSpace(cfg.Log.ServiceName); siteName != "" {
|
||
return siteName
|
||
}
|
||
}
|
||
return service.DefaultEmailCodeConfig().SiteName
|
||
}
|
||
|
||
func resolveJWTAccessTokenExpire(cfg *config.Config) time.Duration {
|
||
if cfg == nil {
|
||
return 0
|
||
}
|
||
if cfg.JWT.AccessTokenExpireMinutes > 0 {
|
||
return time.Duration(cfg.JWT.AccessTokenExpireMinutes) * time.Minute
|
||
}
|
||
return time.Duration(cfg.JWT.ExpireHour) * time.Hour
|
||
}
|