diff --git a/internal/api/handler/device_handler.go b/internal/api/handler/device_handler.go index fc0cf1f..321dd9e 100644 --- a/internal/api/handler/device_handler.go +++ b/internal/api/handler/device_handler.go @@ -79,6 +79,9 @@ func (h *DeviceHandler) GetMyDevices(c *gin.Context) { page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + if pageSize < 1 || pageSize > 100 { + pageSize = 20 + } devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize) if err != nil { @@ -293,6 +296,9 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) { page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + if pageSize < 1 || pageSize > 100 { + pageSize = 20 + } devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize) if err != nil { diff --git a/internal/api/handler/export_handler.go b/internal/api/handler/export_handler.go index e6bbedd..ee76458 100644 --- a/internal/api/handler/export_handler.go +++ b/internal/api/handler/export_handler.go @@ -63,7 +63,8 @@ func (h *ExportHandler) ExportUsers(c *gin.Context) { data, filename, contentType, err := h.exportService.ExportUsers(c.Request.Context(), req) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "导出失败: " + err.Error()}) + // 安全修复:不泄露内部错误详情 + c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "导出失败"}) return } diff --git a/internal/api/handler/log_handler.go b/internal/api/handler/log_handler.go index 5867ab1..65c2092 100644 --- a/internal/api/handler/log_handler.go +++ b/internal/api/handler/log_handler.go @@ -44,6 +44,9 @@ func (h *LogHandler) GetMyLoginLogs(c *gin.Context) { page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + if pageSize < 1 || pageSize > 100 { + pageSize = 20 + } logs, total, err := h.loginLogService.GetMyLoginLogs(c.Request.Context(), userID, page, pageSize) if err != nil { @@ -83,6 +86,9 @@ func (h *LogHandler) GetMyOperationLogs(c *gin.Context) { page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + if pageSize < 1 || pageSize > 100 { + pageSize = 20 + } logs, total, err := h.operationLogService.GetMyOperationLogs(c.Request.Context(), userID, page, pageSize) if err != nil { diff --git a/internal/api/middleware/auth.go b/internal/api/middleware/auth.go index 7cd6f6e..e499648 100644 --- a/internal/api/middleware/auth.go +++ b/internal/api/middleware/auth.go @@ -74,6 +74,12 @@ func (m *AuthMiddleware) Required() gin.HandlerFunc { return } + if m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) { + c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "密码已更新,请重新登录")) + c.Abort() + return + } + if !m.isUserActive(c.Request.Context(), claims.UserID) { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录")) c.Abort() @@ -97,7 +103,7 @@ func (m *AuthMiddleware) Optional() gin.HandlerFunc { token := m.extractToken(c) if token != "" { claims, err := m.jwt.ValidateAccessToken(token) - if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) { + if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && !m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) && m.isUserActive(c.Request.Context(), claims.UserID) { c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("token_jti", claims.JTI) @@ -140,6 +146,27 @@ func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool return false } +// isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改 +// 如果 tokenPCE 为 0(旧令牌),则不检查(向后兼容) +func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool { + if tokenPCE == 0 { + // 旧令牌没有密码变更时间戳,不拦截 + return false + } + + if m.userRepo == nil { + return false + } + + user, err := m.userRepo.GetByID(ctx, userID) + if err != nil || user.PasswordChangedAt.IsZero() { + return false + } + + // 如果令牌的 PCE < 用户密码变更时间,说明密码在令牌发放后已更改 + return tokenPCE < user.PasswordChangedAt.Unix() +} + func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) { if m.userRoleRepo == nil { return nil, nil diff --git a/internal/api/middleware/error.go b/internal/api/middleware/error.go index d1f86b9..47ee0d5 100644 --- a/internal/api/middleware/error.go +++ b/internal/api/middleware/error.go @@ -22,7 +22,9 @@ func ErrorHandler() gin.HandlerFunc { if appErr, ok := err.Err.(*apierrors.ApplicationError); ok { c.JSON(int(appErr.Code), appErr) } else { - c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", err.Err.Error())) + // 安全修复:未知错误不泄露内部详情,只返回通用消息 + // 详细错误记录到日志,供调试使用 + c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", "服务器内部错误")) } return } diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index bc9bbd6..f9d661d 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -53,6 +53,7 @@ type Claims struct { Type string `json:"type"` // access, refresh Remember bool `json:"remember,omitempty"` // 记住登录标记 JTI string `json:"jti"` // JWT ID,用于黑名单 + PCE int64 `json:"pce,omitempty"` // Password Changed Epoch,密码变更时间戳,用于 token 失效机制 jwt.RegisteredClaims } @@ -318,8 +319,8 @@ func (j *JWT) GetAlgorithm() string { return j.algorithm } -// GenerateAccessToken 生成访问令牌(含JTI) -func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error) { +// GenerateAccessToken 生成访问令牌(含JTI和密码变更时间戳) +func (j *JWT) GenerateAccessToken(userID int64, username string, pce int64) (string, error) { if err := j.ensureReady(); err != nil { return "", err } @@ -334,6 +335,7 @@ func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error) Username: username, Type: "access", JTI: jti, + PCE: pce, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)), IssuedAt: jwt.NewNumericDate(now), @@ -345,8 +347,8 @@ func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error) return token.SignedString(j.signingKey()) } -// GenerateRefreshToken 生成刷新令牌(含JTI) -func (j *JWT) GenerateRefreshToken(userID int64, username string) (string, error) { +// GenerateRefreshToken 生成刷新令牌(含JTI和密码变更时间戳) +func (j *JWT) GenerateRefreshToken(userID int64, username string, pce int64) (string, error) { if err := j.ensureReady(); err != nil { return "", err } @@ -361,6 +363,7 @@ func (j *JWT) GenerateRefreshToken(userID int64, username string) (string, error Username: username, Type: "refresh", JTI: jti, + PCE: pce, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)), IssuedAt: jwt.NewNumericDate(now), @@ -382,14 +385,14 @@ func (j *JWT) GetRefreshTokenExpire() time.Duration { return j.refreshTokenExpire } -// GenerateTokenPair 生成令牌对 -func (j *JWT) GenerateTokenPair(userID int64, username string) (accessToken, refreshToken string, err error) { - accessToken, err = j.GenerateAccessToken(userID, username) +// GenerateTokenPair 生成令牌对(含密码变更时间戳) +func (j *JWT) GenerateTokenPair(userID int64, username string, pce int64) (accessToken, refreshToken string, err error) { + accessToken, err = j.GenerateAccessToken(userID, username, pce) if err != nil { return "", "", err } - refreshToken, err = j.GenerateRefreshToken(userID, username) + refreshToken, err = j.GenerateRefreshToken(userID, username, pce) if err != nil { return "", "", err } @@ -397,17 +400,17 @@ func (j *JWT) GenerateTokenPair(userID int64, username string) (accessToken, ref return accessToken, refreshToken, nil } -// GenerateTokenPairWithRemember 生成令牌对(支持记住登录) -func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool) (accessToken, refreshToken string, err error) { - accessToken, err = j.GenerateAccessToken(userID, username) +// GenerateTokenPairWithRemember 生成令牌对(支持记住登录,含密码变更时间戳) +func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool, pce int64) (accessToken, refreshToken string, err error) { + accessToken, err = j.GenerateAccessToken(userID, username, pce) if err != nil { return "", "", err } if remember { - refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username) + refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username, pce) } else { - refreshToken, err = j.GenerateRefreshToken(userID, username) + refreshToken, err = j.GenerateRefreshToken(userID, username, pce) } if err != nil { return "", "", err @@ -416,8 +419,8 @@ func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remem return accessToken, refreshToken, nil } -// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用) -func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string) (string, error) { +// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用,含密码变更时间戳) +func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string, pce int64) (string, error) { if err := j.ensureReady(); err != nil { return "", err } @@ -440,6 +443,7 @@ func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string) (stri Type: "refresh", Remember: true, // 长期会话标记 JTI: jti, + PCE: pce, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(expireDuration)), IssuedAt: jwt.NewNumericDate(now), @@ -506,5 +510,5 @@ func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) { return "", err } - return j.GenerateAccessToken(claims.UserID, claims.Username) + return j.GenerateAccessToken(claims.UserID, claims.Username, claims.PCE) } diff --git a/internal/auth/jwt_closure_test.go b/internal/auth/jwt_closure_test.go index be20fad..0d3558f 100644 --- a/internal/auth/jwt_closure_test.go +++ b/internal/auth/jwt_closure_test.go @@ -15,7 +15,7 @@ func TestNewJWT_DoesNotPanicOnInvalidLegacyConfig(t *testing.T) { t.Fatal("expected manager instance") } - if _, err := manager.GenerateAccessToken(1, "tester"); err == nil { + if _, err := manager.GenerateAccessToken(1, "tester", 0); err == nil { t.Fatal("expected invalid legacy manager to return error") } } diff --git a/internal/auth/jwt_password_test.go b/internal/auth/jwt_password_test.go index 306ed1c..87f2b8f 100644 --- a/internal/auth/jwt_password_test.go +++ b/internal/auth/jwt_password_test.go @@ -43,7 +43,7 @@ func TestNewJWTWithOptions_RS256(t *testing.T) { t.Fatalf("create rs256 jwt manager failed: %v", err) } - accessToken, refreshToken, err := jwtManager.GenerateTokenPair(42, "rs256-user") + accessToken, refreshToken, err := jwtManager.GenerateTokenPair(42, "rs256-user", 0) if err != nil { t.Fatalf("generate token pair failed: %v", err) } @@ -136,7 +136,7 @@ func TestGenerateAccessToken_Success(t *testing.T) { t.Fatalf("create jwt manager failed: %v", err) } - token, err := jwtManager.GenerateAccessToken(123, "testuser") + token, err := jwtManager.GenerateAccessToken(123, "testuser", 0) if err != nil { t.Fatalf("generate access token failed: %v", err) } @@ -170,7 +170,7 @@ func TestGenerateRefreshToken_Success(t *testing.T) { t.Fatalf("create jwt manager failed: %v", err) } - token, err := jwtManager.GenerateRefreshToken(456, "refreshuser") + token, err := jwtManager.GenerateRefreshToken(456, "refreshuser", 0) if err != nil { t.Fatalf("generate refresh token failed: %v", err) } @@ -201,7 +201,7 @@ func TestGenerateTokenPair_Success(t *testing.T) { t.Fatalf("create jwt manager failed: %v", err) } - accessToken, refreshToken, err := jwtManager.GenerateTokenPair(789, "pairuser") + accessToken, refreshToken, err := jwtManager.GenerateTokenPair(789, "pairuser", 0) if err != nil { t.Fatalf("generate token pair failed: %v", err) } @@ -238,7 +238,7 @@ func TestGenerateTokenPairWithRemember_Success(t *testing.T) { t.Fatalf("create jwt manager failed: %v", err) } - accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(999, "rememberuser", true) + accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(999, "rememberuser", true, 0) if err != nil { t.Fatalf("generate token pair with remember failed: %v", err) } @@ -275,7 +275,7 @@ func TestValidateAccessToken_WrongType(t *testing.T) { } // Use a refresh token as if it were an access token - refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser") + refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser", 0) if err != nil { t.Fatalf("generate refresh token failed: %v", err) } @@ -298,7 +298,7 @@ func TestValidateRefreshToken_WrongType(t *testing.T) { } // Use an access token as if it were a refresh token - accessToken, err := jwtManager.GenerateAccessToken(123, "testuser") + accessToken, err := jwtManager.GenerateAccessToken(123, "testuser", 0) if err != nil { t.Fatalf("generate access token failed: %v", err) } @@ -389,7 +389,7 @@ func TestGenerateLongLivedRefreshToken_Success(t *testing.T) { t.Fatalf("create jwt manager failed: %v", err) } - token, err := jwtManager.GenerateLongLivedRefreshToken(123, "longliveuser") + token, err := jwtManager.GenerateLongLivedRefreshToken(123, "longliveuser", 0) if err != nil { t.Fatalf("generate long lived refresh token failed: %v", err) } @@ -446,7 +446,7 @@ func TestRefreshAccessToken_Success(t *testing.T) { } // Generate a valid refresh token first - refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser") + refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser", 0) if err != nil { t.Fatalf("generate refresh token failed: %v", err) } @@ -498,7 +498,7 @@ func TestRefreshAccessToken_AccessTokenProvided(t *testing.T) { } // Generate an access token and try to use it as refresh - accessToken, err := jwtManager.GenerateAccessToken(123, "testuser") + accessToken, err := jwtManager.GenerateAccessToken(123, "testuser", 0) if err != nil { t.Fatalf("generate access token failed: %v", err) } @@ -521,7 +521,7 @@ func TestGenerateTokenPairWithRemember_RememberFalse(t *testing.T) { t.Fatalf("create jwt manager failed: %v", err) } - accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", false) + accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", false, 0) if err != nil { t.Fatalf("GenerateTokenPairWithRemember failed: %v", err) } @@ -553,7 +553,7 @@ func TestGenerateTokenPairWithRemember_NoRememberExpireConfig(t *testing.T) { } // Should use RefreshTokenExpire when RememberLoginExpire is not set - accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", true) + accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", true, 0) if err != nil { t.Fatalf("GenerateTokenPairWithRemember failed: %v", err) } @@ -583,7 +583,7 @@ func TestGenerateLongLivedRefreshToken_NoRememberExpire(t *testing.T) { t.Fatalf("create jwt manager failed: %v", err) } - token, err := jwtManager.GenerateLongLivedRefreshToken(123, "testuser") + token, err := jwtManager.GenerateLongLivedRefreshToken(123, "testuser", 0) if err != nil { t.Fatalf("GenerateLongLivedRefreshToken failed: %v", err) } diff --git a/internal/database/db.go b/internal/database/db.go index cbccf29..e7acb14 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -59,11 +59,29 @@ func NewDB(cfg *config.Config) (*DB, error) { log.Printf("warn: set busy_timeout failed: %v", err) } - // 连接池配置:SQLite 本身不支持真正的并发写,但需要控制连接数量 - sqlDB.SetMaxOpenConns(10) - sqlDB.SetMaxIdleConns(5) - sqlDB.SetConnMaxLifetime(30 * time.Minute) - sqlDB.SetConnMaxIdleTime(10 * time.Minute) + // 连接池配置:使用配置文件中的参数 + maxOpenConns := 10 + maxIdleConns := 5 + connMaxLifetime := 30 * time.Minute + connMaxIdleTime := 10 * time.Minute + if cfg != nil { + if cfg.Database.MaxOpenConns > 0 { + maxOpenConns = cfg.Database.MaxOpenConns + } + if cfg.Database.MaxIdleConns > 0 { + maxIdleConns = cfg.Database.MaxIdleConns + } + if cfg.Database.ConnMaxLifetimeMinutes > 0 { + connMaxLifetime = time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute + } + if cfg.Database.ConnMaxIdleTimeMinutes > 0 { + connMaxIdleTime = time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute + } + } + sqlDB.SetMaxOpenConns(maxOpenConns) + sqlDB.SetMaxIdleConns(maxIdleConns) + sqlDB.SetConnMaxLifetime(connMaxLifetime) + sqlDB.SetConnMaxIdleTime(connMaxIdleTime) log.Println("database: SQLite WAL mode enabled, connection pool configured") diff --git a/internal/domain/announcement.go b/internal/domain/announcement.go index cbf4e3b..6e50d9a 100644 --- a/internal/domain/announcement.go +++ b/internal/domain/announcement.go @@ -200,18 +200,18 @@ func (c AnnouncementCondition) validate() error { } type Announcement struct { - ID int64 - Title string - Content string - Status string - NotifyMode string - Targeting AnnouncementTargeting - StartsAt *time.Time - EndsAt *time.Time - CreatedBy *int64 - UpdatedBy *int64 - CreatedAt time.Time - UpdatedAt time.Time + ID int64 `gorm:"primaryKey;autoIncrement" json:"id"` + Title string `gorm:"type:varchar(255);not null" json:"title"` + Content string `gorm:"type:text;not null" json:"content"` + Status string `gorm:"type:varchar(20);default:draft;index" json:"status"` + NotifyMode string `gorm:"type:varchar(20);default:silent" json:"notify_mode"` + Targeting AnnouncementTargeting `gorm:"type:text" json:"targeting"` + StartsAt *time.Time `gorm:"type:datetime" json:"starts_at,omitempty"` + EndsAt *time.Time `gorm:"type:datetime" json:"ends_at,omitempty"` + CreatedBy *int64 `json:"created_by,omitempty"` + UpdatedBy *int64 `json:"updated_by,omitempty"` + CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` + UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` } func (a *Announcement) IsActiveAt(now time.Time) bool { diff --git a/internal/domain/user.go b/internal/domain/user.go index c2ce76a..073d807 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -62,6 +62,9 @@ type User struct { TOTPEnabled bool `gorm:"default:false" json:"totp_enabled"` TOTPSecret string `gorm:"type:varchar(64)" json:"-"` // Base32 密钥,不返回给前端 TOTPRecoveryCodes string `gorm:"type:text" json:"-"` // JSON 编码的恢复码列表 + + // PasswordChangedAt 密码更新时间,用于 token 失效机制 + PasswordChangedAt time.Time `gorm:"type:timestamp;index" json:"password_changed_at,omitempty"` } // TableName 指定表名 diff --git a/internal/repository/login_log.go b/internal/repository/login_log.go index d2a6bdb..8d9bff1 100644 --- a/internal/repository/login_log.go +++ b/internal/repository/login_log.go @@ -104,16 +104,19 @@ func (r *LoginLogRepository) DeleteOlderThan(ctx context.Context, days int) erro // CountByResultSince 统计指定时间之后特定结果的登录次数 // success=true 统计成功次数,false 统计失败次数 -func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 { +func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) { status := 0 // 失败 if success { status = 1 // 成功 } var count int64 - r.db.WithContext(ctx).Model(&domain.LoginLog{}). + err := r.db.WithContext(ctx).Model(&domain.LoginLog{}). Where("status = ? AND created_at >= ?", status, since). - Count(&count) - return count + Count(&count).Error + if err != nil { + return 0, err + } + return count, nil } // ListAllForExport 获取所有登录日志(用于导出,无分页) diff --git a/internal/repository/repository_additional_test.go b/internal/repository/repository_additional_test.go index 1369ed0..9ada003 100644 --- a/internal/repository/repository_additional_test.go +++ b/internal/repository/repository_additional_test.go @@ -263,10 +263,14 @@ func TestLoginLogRepositoryQueriesAndRetention(t *testing.T) { t.Fatalf("expected 2 recent logs, got total=%d len=%d", total, len(recentLogs)) } - if count := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); count != 1 { + if count, err := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); err != nil { + t.Fatalf("CountByResultSince failed: %v", err) + } else if count != 1 { t.Fatalf("expected 1 recent success login, got %d", count) } - if count := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); count != 1 { + if count, err := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); err != nil { + t.Fatalf("CountByResultSince failed: %v", err) + } else if count != 1 { t.Fatalf("expected 1 recent failed login, got %d", count) } diff --git a/internal/repository/role.go b/internal/repository/role.go index 7d3bced..cbb5d6a 100644 --- a/internal/repository/role.go +++ b/internal/repository/role.go @@ -48,6 +48,18 @@ func (r *RoleRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error } +// DeleteCascade 级联删除角色(同时删除角色权限关联) +func (r *RoleRepository) DeleteCascade(ctx context.Context, id int64) error { + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // 先删除角色权限关联 + if err := tx.Where("role_id = ?", id).Delete(&domain.RolePermission{}).Error; err != nil { + return err + } + // 再删除角色 + return tx.Delete(&domain.Role{}, id).Error + }) +} + // GetByID 根据ID获取角色 func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) { var role domain.Role diff --git a/internal/repository/social_account_repo.go b/internal/repository/social_account_repo.go index 88ec785..1cbc6c7 100644 --- a/internal/repository/social_account_repo.go +++ b/internal/repository/social_account_repo.go @@ -204,6 +204,9 @@ func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID in } accounts = append(accounts, &account) } + if err := rows.Err(); err != nil { + return nil, err + } return accounts, nil } @@ -290,6 +293,9 @@ func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit in } accounts = append(accounts, &account) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return accounts, total, nil } diff --git a/internal/repository/theme.go b/internal/repository/theme.go index e6492fd..6855a7f 100644 --- a/internal/repository/theme.go +++ b/internal/repository/theme.go @@ -89,11 +89,13 @@ func (r *ThemeConfigRepository) ListAll(ctx context.Context) ([]*domain.ThemeCon // SetDefault 设置默认主题 func (r *ThemeConfigRepository) SetDefault(ctx context.Context, id int64) error { - // 先清除所有默认标记 - if err := r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil { - return err - } - - // 设置新的默认主题 - return r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error + // 使用事务确保原子性:先清除所有默认标记,再设置新默认 + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // 先清除所有默认标记 + if err := tx.Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil { + return err + } + // 设置新的默认主题 + return tx.Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error + }) } diff --git a/internal/repository/user.go b/internal/repository/user.go index 386561e..2183c75 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -326,8 +326,9 @@ func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFil sortBy = filter.SortBy } } - if filter.SortOrder == "asc" { - sortOrder = "ASC" + allowedSortOrders := map[string]bool{"asc": true, "desc": true} + if allowedSortOrders[strings.ToLower(filter.SortOrder)] { + sortOrder = strings.ToUpper(filter.SortOrder) } query = query.Order(sortBy + " " + sortOrder) @@ -404,8 +405,9 @@ func (r *UserRepository) ListCursor(ctx context.Context, filter *AdvancedFilter, } sortOrder := "DESC" - if filter.SortOrder == "asc" { - sortOrder = "ASC" + allowedSortOrders := map[string]bool{"asc": true, "desc": true} + if allowedSortOrders[strings.ToLower(filter.SortOrder)] { + sortOrder = strings.ToUpper(filter.SortOrder) } orderClause := sortBy + " " + sortOrder + ", id " + sortOrder diff --git a/internal/service/auth.go b/internal/service/auth.go index 4f3c387..788e09a 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -1369,10 +1369,12 @@ func (s *AuthService) generateLoginResponse(ctx context.Context, user *domain.Us var accessToken, refreshToken string var err error + pce := user.PasswordChangedAt.Unix() + if remember { - accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember) + accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember, pce) } else { - accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username) + accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username, pce) } if err != nil { return nil, err diff --git a/internal/service/role.go b/internal/service/role.go index f19da7d..f0fd77f 100644 --- a/internal/service/role.go +++ b/internal/service/role.go @@ -181,13 +181,8 @@ func (s *RoleService) DeleteRole(ctx context.Context, roleID int64) error { return errors.New("存在子角色,无法删除") } - // 删除角色权限关联 - if err := s.rolePermissionRepo.DeleteByRoleID(ctx, roleID); err != nil { - return err - } - - // 删除角色 - return s.roleRepo.Delete(ctx, roleID) + // 级联删除角色及其权限关联(在事务中执行) + return s.roleRepo.DeleteCascade(ctx, roleID) } // GetRole 获取角色信息 diff --git a/internal/service/stats.go b/internal/service/stats.go index 4d73b0e..4b53507 100644 --- a/internal/service/stats.go +++ b/internal/service/stats.go @@ -15,7 +15,7 @@ type statsUserRepository interface { } type statsLoginLogRepository interface { - CountByResultSince(ctx context.Context, success bool, since time.Time) int64 + CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) } // StatsService 统计服务 @@ -115,9 +115,15 @@ func (s *StatsService) GetDashboardStats(ctx context.Context) (*DashboardStats, // 今日登录成功/失败 today := daysAgo(0) if s.loginLogRepo != nil { - loginStats.LoginsTodaySuccess = s.loginLogRepo.CountByResultSince(ctx, true, today) - loginStats.LoginsTodayFailed = s.loginLogRepo.CountByResultSince(ctx, false, today) - loginStats.LoginsWeek = s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7)) + if successCount, err := s.loginLogRepo.CountByResultSince(ctx, true, today); err == nil { + loginStats.LoginsTodaySuccess = successCount + } + if failedCount, err := s.loginLogRepo.CountByResultSince(ctx, false, today); err == nil { + loginStats.LoginsTodayFailed = failedCount + } + if weekCount, err := s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7)); err == nil { + loginStats.LoginsWeek = weekCount + } } return &DashboardStats{ diff --git a/internal/service/stats_internal_test.go b/internal/service/stats_internal_test.go index a7b629f..1254d95 100644 --- a/internal/service/stats_internal_test.go +++ b/internal/service/stats_internal_test.go @@ -51,11 +51,11 @@ type mockStatsLoginLogRepoInternal struct { weekCount int64 } -func (m *mockStatsLoginLogRepoInternal) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 { +func (m *mockStatsLoginLogRepoInternal) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) { if success { - return m.successCount + return m.successCount, nil } - return m.failedCount + return m.failedCount, nil } func TestStatsService_GetDashboardStats_Internal(t *testing.T) { diff --git a/internal/service/stats_test.go b/internal/service/stats_test.go index d1cadf9..a379500 100644 --- a/internal/service/stats_test.go +++ b/internal/service/stats_test.go @@ -52,11 +52,11 @@ type mockStatsLoginLogRepo struct { weekCount int64 } -func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 { +func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) { if success { - return m.successCount + return m.successCount, nil } - return m.failedCount + return m.failedCount, nil } func TestStatsService_GetUserStats(t *testing.T) { diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 5abed7f..dee3ab6 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -141,6 +141,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw // 更新密码(使用同一哈希值) user.Password = newHashedPassword + user.PasswordChangedAt = time.Now() return s.userRepo.Update(ctx, user) }