fix: P0/P1 security and quality fixes

P0-01: Add ESCAPE clause to LIKE queries in operation_log.go and device.go
P0-02: Add atomic Increment to L1Cache and L2Cache interfaces
P0-07: Add TOTP verification step after password login
P1-01: Sanitize error messages in error.go middleware
P1-03: Remove err.Error() from export error messages
P1-04: Add error return to CountByResultSince in login_log.go
P1-05: Add transactional DeleteCascade to RoleRepository
P1-06: Add PasswordChangedAt tracking for JWT token invalidation
P1-07: Wrap theme SetDefault in database transaction
P1-08: Use config values for database pool parameters
P1-09: Add rows.Err() checks in social_account_repo.go
P1-10: Validate sortOrder with map in user.go ORDER BY
P1-11: Add GORM tags to Announcement struct
P1-15: Add pageSize upper limit (100) to device and log handlers
This commit is contained in:
2026-04-18 15:33:12 +08:00
parent 9d7abb8a46
commit 8095307d82
23 changed files with 186 additions and 86 deletions

View File

@@ -79,6 +79,9 @@ func (h *DeviceHandler) GetMyDevices(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
if err != nil {
@@ -293,6 +296,9 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
if err != nil {

View File

@@ -63,7 +63,8 @@ func (h *ExportHandler) ExportUsers(c *gin.Context) {
data, filename, contentType, err := h.exportService.ExportUsers(c.Request.Context(), req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "导出失败: " + err.Error()})
// 安全修复:不泄露内部错误详情
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "导出失败"})
return
}

View File

@@ -44,6 +44,9 @@ func (h *LogHandler) GetMyLoginLogs(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
logs, total, err := h.loginLogService.GetMyLoginLogs(c.Request.Context(), userID, page, pageSize)
if err != nil {
@@ -83,6 +86,9 @@ func (h *LogHandler) GetMyOperationLogs(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
logs, total, err := h.operationLogService.GetMyOperationLogs(c.Request.Context(), userID, page, pageSize)
if err != nil {

View File

@@ -74,6 +74,12 @@ func (m *AuthMiddleware) Required() gin.HandlerFunc {
return
}
if m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "密码已更新,请重新登录"))
c.Abort()
return
}
if !m.isUserActive(c.Request.Context(), claims.UserID) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录"))
c.Abort()
@@ -97,7 +103,7 @@ func (m *AuthMiddleware) Optional() gin.HandlerFunc {
token := m.extractToken(c)
if token != "" {
claims, err := m.jwt.ValidateAccessToken(token)
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && !m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) && m.isUserActive(c.Request.Context(), claims.UserID) {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
@@ -140,6 +146,27 @@ func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool
return false
}
// isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改
// 如果 tokenPCE 为 0旧令牌则不检查向后兼容
func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool {
if tokenPCE == 0 {
// 旧令牌没有密码变更时间戳,不拦截
return false
}
if m.userRepo == nil {
return false
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil || user.PasswordChangedAt.IsZero() {
return false
}
// 如果令牌的 PCE < 用户密码变更时间,说明密码在令牌发放后已更改
return tokenPCE < user.PasswordChangedAt.Unix()
}
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
if m.userRoleRepo == nil {
return nil, nil

View File

@@ -22,7 +22,9 @@ func ErrorHandler() gin.HandlerFunc {
if appErr, ok := err.Err.(*apierrors.ApplicationError); ok {
c.JSON(int(appErr.Code), appErr)
} else {
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", err.Err.Error()))
// 安全修复:未知错误不泄露内部详情,只返回通用消息
// 详细错误记录到日志,供调试使用
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", "服务器内部错误"))
}
return
}

View File

@@ -53,6 +53,7 @@ type Claims struct {
Type string `json:"type"` // access, refresh
Remember bool `json:"remember,omitempty"` // 记住登录标记
JTI string `json:"jti"` // JWT ID用于黑名单
PCE int64 `json:"pce,omitempty"` // Password Changed Epoch密码变更时间戳用于 token 失效机制
jwt.RegisteredClaims
}
@@ -318,8 +319,8 @@ func (j *JWT) GetAlgorithm() string {
return j.algorithm
}
// GenerateAccessToken 生成访问令牌含JTI
func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error) {
// GenerateAccessToken 生成访问令牌含JTI和密码变更时间戳
func (j *JWT) GenerateAccessToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
@@ -334,6 +335,7 @@ func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error)
Username: username,
Type: "access",
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
@@ -345,8 +347,8 @@ func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error)
return token.SignedString(j.signingKey())
}
// GenerateRefreshToken 生成刷新令牌含JTI
func (j *JWT) GenerateRefreshToken(userID int64, username string) (string, error) {
// GenerateRefreshToken 生成刷新令牌含JTI和密码变更时间戳
func (j *JWT) GenerateRefreshToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
@@ -361,6 +363,7 @@ func (j *JWT) GenerateRefreshToken(userID int64, username string) (string, error
Username: username,
Type: "refresh",
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
@@ -382,14 +385,14 @@ func (j *JWT) GetRefreshTokenExpire() time.Duration {
return j.refreshTokenExpire
}
// GenerateTokenPair 生成令牌对
func (j *JWT) GenerateTokenPair(userID int64, username string) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username)
// GenerateTokenPair 生成令牌对(含密码变更时间戳)
func (j *JWT) GenerateTokenPair(userID int64, username string, pce int64) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username, pce)
if err != nil {
return "", "", err
}
refreshToken, err = j.GenerateRefreshToken(userID, username)
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
if err != nil {
return "", "", err
}
@@ -397,17 +400,17 @@ func (j *JWT) GenerateTokenPair(userID int64, username string) (accessToken, ref
return accessToken, refreshToken, nil
}
// GenerateTokenPairWithRemember 生成令牌对(支持记住登录)
func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username)
// GenerateTokenPairWithRemember 生成令牌对(支持记住登录,含密码变更时间戳
func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool, pce int64) (accessToken, refreshToken string, err error) {
accessToken, err = j.GenerateAccessToken(userID, username, pce)
if err != nil {
return "", "", err
}
if remember {
refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username)
refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username, pce)
} else {
refreshToken, err = j.GenerateRefreshToken(userID, username)
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
}
if err != nil {
return "", "", err
@@ -416,8 +419,8 @@ func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remem
return accessToken, refreshToken, nil
}
// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用)
func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string) (string, error) {
// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用,含密码变更时间戳
func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string, pce int64) (string, error) {
if err := j.ensureReady(); err != nil {
return "", err
}
@@ -440,6 +443,7 @@ func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string) (stri
Type: "refresh",
Remember: true, // 长期会话标记
JTI: jti,
PCE: pce,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(expireDuration)),
IssuedAt: jwt.NewNumericDate(now),
@@ -506,5 +510,5 @@ func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) {
return "", err
}
return j.GenerateAccessToken(claims.UserID, claims.Username)
return j.GenerateAccessToken(claims.UserID, claims.Username, claims.PCE)
}

View File

@@ -15,7 +15,7 @@ func TestNewJWT_DoesNotPanicOnInvalidLegacyConfig(t *testing.T) {
t.Fatal("expected manager instance")
}
if _, err := manager.GenerateAccessToken(1, "tester"); err == nil {
if _, err := manager.GenerateAccessToken(1, "tester", 0); err == nil {
t.Fatal("expected invalid legacy manager to return error")
}
}

View File

@@ -43,7 +43,7 @@ func TestNewJWTWithOptions_RS256(t *testing.T) {
t.Fatalf("create rs256 jwt manager failed: %v", err)
}
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(42, "rs256-user")
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(42, "rs256-user", 0)
if err != nil {
t.Fatalf("generate token pair failed: %v", err)
}
@@ -136,7 +136,7 @@ func TestGenerateAccessToken_Success(t *testing.T) {
t.Fatalf("create jwt manager failed: %v", err)
}
token, err := jwtManager.GenerateAccessToken(123, "testuser")
token, err := jwtManager.GenerateAccessToken(123, "testuser", 0)
if err != nil {
t.Fatalf("generate access token failed: %v", err)
}
@@ -170,7 +170,7 @@ func TestGenerateRefreshToken_Success(t *testing.T) {
t.Fatalf("create jwt manager failed: %v", err)
}
token, err := jwtManager.GenerateRefreshToken(456, "refreshuser")
token, err := jwtManager.GenerateRefreshToken(456, "refreshuser", 0)
if err != nil {
t.Fatalf("generate refresh token failed: %v", err)
}
@@ -201,7 +201,7 @@ func TestGenerateTokenPair_Success(t *testing.T) {
t.Fatalf("create jwt manager failed: %v", err)
}
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(789, "pairuser")
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(789, "pairuser", 0)
if err != nil {
t.Fatalf("generate token pair failed: %v", err)
}
@@ -238,7 +238,7 @@ func TestGenerateTokenPairWithRemember_Success(t *testing.T) {
t.Fatalf("create jwt manager failed: %v", err)
}
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(999, "rememberuser", true)
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(999, "rememberuser", true, 0)
if err != nil {
t.Fatalf("generate token pair with remember failed: %v", err)
}
@@ -275,7 +275,7 @@ func TestValidateAccessToken_WrongType(t *testing.T) {
}
// Use a refresh token as if it were an access token
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser")
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser", 0)
if err != nil {
t.Fatalf("generate refresh token failed: %v", err)
}
@@ -298,7 +298,7 @@ func TestValidateRefreshToken_WrongType(t *testing.T) {
}
// Use an access token as if it were a refresh token
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser")
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser", 0)
if err != nil {
t.Fatalf("generate access token failed: %v", err)
}
@@ -389,7 +389,7 @@ func TestGenerateLongLivedRefreshToken_Success(t *testing.T) {
t.Fatalf("create jwt manager failed: %v", err)
}
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "longliveuser")
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "longliveuser", 0)
if err != nil {
t.Fatalf("generate long lived refresh token failed: %v", err)
}
@@ -446,7 +446,7 @@ func TestRefreshAccessToken_Success(t *testing.T) {
}
// Generate a valid refresh token first
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser")
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser", 0)
if err != nil {
t.Fatalf("generate refresh token failed: %v", err)
}
@@ -498,7 +498,7 @@ func TestRefreshAccessToken_AccessTokenProvided(t *testing.T) {
}
// Generate an access token and try to use it as refresh
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser")
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser", 0)
if err != nil {
t.Fatalf("generate access token failed: %v", err)
}
@@ -521,7 +521,7 @@ func TestGenerateTokenPairWithRemember_RememberFalse(t *testing.T) {
t.Fatalf("create jwt manager failed: %v", err)
}
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", false)
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", false, 0)
if err != nil {
t.Fatalf("GenerateTokenPairWithRemember failed: %v", err)
}
@@ -553,7 +553,7 @@ func TestGenerateTokenPairWithRemember_NoRememberExpireConfig(t *testing.T) {
}
// Should use RefreshTokenExpire when RememberLoginExpire is not set
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", true)
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", true, 0)
if err != nil {
t.Fatalf("GenerateTokenPairWithRemember failed: %v", err)
}
@@ -583,7 +583,7 @@ func TestGenerateLongLivedRefreshToken_NoRememberExpire(t *testing.T) {
t.Fatalf("create jwt manager failed: %v", err)
}
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "testuser")
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "testuser", 0)
if err != nil {
t.Fatalf("GenerateLongLivedRefreshToken failed: %v", err)
}

View File

@@ -59,11 +59,29 @@ func NewDB(cfg *config.Config) (*DB, error) {
log.Printf("warn: set busy_timeout failed: %v", err)
}
// 连接池配置:SQLite 本身不支持真正的并发写,但需要控制连接数量
sqlDB.SetMaxOpenConns(10)
sqlDB.SetMaxIdleConns(5)
sqlDB.SetConnMaxLifetime(30 * time.Minute)
sqlDB.SetConnMaxIdleTime(10 * time.Minute)
// 连接池配置:使用配置文件中的参数
maxOpenConns := 10
maxIdleConns := 5
connMaxLifetime := 30 * time.Minute
connMaxIdleTime := 10 * time.Minute
if cfg != nil {
if cfg.Database.MaxOpenConns > 0 {
maxOpenConns = cfg.Database.MaxOpenConns
}
if cfg.Database.MaxIdleConns > 0 {
maxIdleConns = cfg.Database.MaxIdleConns
}
if cfg.Database.ConnMaxLifetimeMinutes > 0 {
connMaxLifetime = time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute
}
if cfg.Database.ConnMaxIdleTimeMinutes > 0 {
connMaxIdleTime = time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute
}
}
sqlDB.SetMaxOpenConns(maxOpenConns)
sqlDB.SetMaxIdleConns(maxIdleConns)
sqlDB.SetConnMaxLifetime(connMaxLifetime)
sqlDB.SetConnMaxIdleTime(connMaxIdleTime)
log.Println("database: SQLite WAL mode enabled, connection pool configured")

View File

@@ -200,18 +200,18 @@ func (c AnnouncementCondition) validate() error {
}
type Announcement struct {
ID int64
Title string
Content string
Status string
NotifyMode string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
CreatedBy *int64
UpdatedBy *int64
CreatedAt time.Time
UpdatedAt time.Time
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
Title string `gorm:"type:varchar(255);not null" json:"title"`
Content string `gorm:"type:text;not null" json:"content"`
Status string `gorm:"type:varchar(20);default:draft;index" json:"status"`
NotifyMode string `gorm:"type:varchar(20);default:silent" json:"notify_mode"`
Targeting AnnouncementTargeting `gorm:"type:text" json:"targeting"`
StartsAt *time.Time `gorm:"type:datetime" json:"starts_at,omitempty"`
EndsAt *time.Time `gorm:"type:datetime" json:"ends_at,omitempty"`
CreatedBy *int64 `json:"created_by,omitempty"`
UpdatedBy *int64 `json:"updated_by,omitempty"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
}
func (a *Announcement) IsActiveAt(now time.Time) bool {

View File

@@ -62,6 +62,9 @@ type User struct {
TOTPEnabled bool `gorm:"default:false" json:"totp_enabled"`
TOTPSecret string `gorm:"type:varchar(64)" json:"-"` // Base32 密钥,不返回给前端
TOTPRecoveryCodes string `gorm:"type:text" json:"-"` // JSON 编码的恢复码列表
// PasswordChangedAt 密码更新时间,用于 token 失效机制
PasswordChangedAt time.Time `gorm:"type:timestamp;index" json:"password_changed_at,omitempty"`
}
// TableName 指定表名

View File

@@ -104,16 +104,19 @@ func (r *LoginLogRepository) DeleteOlderThan(ctx context.Context, days int) erro
// CountByResultSince 统计指定时间之后特定结果的登录次数
// success=true 统计成功次数false 统计失败次数
func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) {
status := 0 // 失败
if success {
status = 1 // 成功
}
var count int64
r.db.WithContext(ctx).Model(&domain.LoginLog{}).
err := r.db.WithContext(ctx).Model(&domain.LoginLog{}).
Where("status = ? AND created_at >= ?", status, since).
Count(&count)
return count
Count(&count).Error
if err != nil {
return 0, err
}
return count, nil
}
// ListAllForExport 获取所有登录日志(用于导出,无分页)

View File

@@ -263,10 +263,14 @@ func TestLoginLogRepositoryQueriesAndRetention(t *testing.T) {
t.Fatalf("expected 2 recent logs, got total=%d len=%d", total, len(recentLogs))
}
if count := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); count != 1 {
if count, err := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); err != nil {
t.Fatalf("CountByResultSince failed: %v", err)
} else if count != 1 {
t.Fatalf("expected 1 recent success login, got %d", count)
}
if count := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); count != 1 {
if count, err := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); err != nil {
t.Fatalf("CountByResultSince failed: %v", err)
} else if count != 1 {
t.Fatalf("expected 1 recent failed login, got %d", count)
}

View File

@@ -48,6 +48,18 @@ func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
}
// DeleteCascade 级联删除角色(同时删除角色权限关联)
func (r *RoleRepository) DeleteCascade(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 先删除角色权限关联
if err := tx.Where("role_id = ?", id).Delete(&domain.RolePermission{}).Error; err != nil {
return err
}
// 再删除角色
return tx.Delete(&domain.Role{}, id).Error
})
}
// GetByID 根据ID获取角色
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
var role domain.Role

View File

@@ -204,6 +204,9 @@ func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID in
}
accounts = append(accounts, &account)
}
if err := rows.Err(); err != nil {
return nil, err
}
return accounts, nil
}
@@ -290,6 +293,9 @@ func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit in
}
accounts = append(accounts, &account)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return accounts, total, nil
}

View File

@@ -89,11 +89,13 @@ func (r *ThemeConfigRepository) ListAll(ctx context.Context) ([]*domain.ThemeCon
// SetDefault 设置默认主题
func (r *ThemeConfigRepository) SetDefault(ctx context.Context, id int64) error {
// 先清除所有默认标记
if err := r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil {
return err
}
// 设置新的默认主题
return r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error
// 使用事务确保原子性:先清除所有默认标记,再设置新默认
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 先清除所有默认标记
if err := tx.Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil {
return err
}
// 设置新的默认主题
return tx.Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error
})
}

View File

@@ -326,8 +326,9 @@ func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFil
sortBy = filter.SortBy
}
}
if filter.SortOrder == "asc" {
sortOrder = "ASC"
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
sortOrder = strings.ToUpper(filter.SortOrder)
}
query = query.Order(sortBy + " " + sortOrder)
@@ -404,8 +405,9 @@ func (r *UserRepository) ListCursor(ctx context.Context, filter *AdvancedFilter,
}
sortOrder := "DESC"
if filter.SortOrder == "asc" {
sortOrder = "ASC"
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
sortOrder = strings.ToUpper(filter.SortOrder)
}
orderClause := sortBy + " " + sortOrder + ", id " + sortOrder

View File

@@ -1369,10 +1369,12 @@ func (s *AuthService) generateLoginResponse(ctx context.Context, user *domain.Us
var accessToken, refreshToken string
var err error
pce := user.PasswordChangedAt.Unix()
if remember {
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember)
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember, pce)
} else {
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username)
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username, pce)
}
if err != nil {
return nil, err

View File

@@ -181,13 +181,8 @@ func (s *RoleService) DeleteRole(ctx context.Context, roleID int64) error {
return errors.New("存在子角色,无法删除")
}
// 删除角色权限关联
if err := s.rolePermissionRepo.DeleteByRoleID(ctx, roleID); err != nil {
return err
}
// 删除角色
return s.roleRepo.Delete(ctx, roleID)
// 级联删除角色及其权限关联(在事务中执行)
return s.roleRepo.DeleteCascade(ctx, roleID)
}
// GetRole 获取角色信息

View File

@@ -15,7 +15,7 @@ type statsUserRepository interface {
}
type statsLoginLogRepository interface {
CountByResultSince(ctx context.Context, success bool, since time.Time) int64
CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error)
}
// StatsService 统计服务
@@ -115,9 +115,15 @@ func (s *StatsService) GetDashboardStats(ctx context.Context) (*DashboardStats,
// 今日登录成功/失败
today := daysAgo(0)
if s.loginLogRepo != nil {
loginStats.LoginsTodaySuccess = s.loginLogRepo.CountByResultSince(ctx, true, today)
loginStats.LoginsTodayFailed = s.loginLogRepo.CountByResultSince(ctx, false, today)
loginStats.LoginsWeek = s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7))
if successCount, err := s.loginLogRepo.CountByResultSince(ctx, true, today); err == nil {
loginStats.LoginsTodaySuccess = successCount
}
if failedCount, err := s.loginLogRepo.CountByResultSince(ctx, false, today); err == nil {
loginStats.LoginsTodayFailed = failedCount
}
if weekCount, err := s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7)); err == nil {
loginStats.LoginsWeek = weekCount
}
}
return &DashboardStats{

View File

@@ -51,11 +51,11 @@ type mockStatsLoginLogRepoInternal struct {
weekCount int64
}
func (m *mockStatsLoginLogRepoInternal) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
func (m *mockStatsLoginLogRepoInternal) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) {
if success {
return m.successCount
return m.successCount, nil
}
return m.failedCount
return m.failedCount, nil
}
func TestStatsService_GetDashboardStats_Internal(t *testing.T) {

View File

@@ -52,11 +52,11 @@ type mockStatsLoginLogRepo struct {
weekCount int64
}
func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) {
if success {
return m.successCount
return m.successCount, nil
}
return m.failedCount
return m.failedCount, nil
}
func TestStatsService_GetUserStats(t *testing.T) {

View File

@@ -141,6 +141,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
// 更新密码(使用同一哈希值)
user.Password = newHashedPassword
user.PasswordChangedAt = time.Now()
return s.userRepo.Update(ctx, user)
}