From 878ca731f436293e97aceb22194662a1c7b4c68d Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 29 May 2026 12:31:36 +0800 Subject: [PATCH] fix: atomic TOTP recovery code consumption with repository-level transaction - Add ConsumeTOTPRecoveryCode to UserRepository for atomic read-verify-update - Update TOTPService.VerifyTOTP to prefer atomic consumption when available - Update AuthService.verifyTOTPCodeOrRecoveryCode with same pattern - Fix critical bug: ConsumeTOTPRecoveryCode now correctly returns consumed=false on mismatch - Maintain backward compatibility: falls back to non-atomic path if repo doesn't implement interface - Add comprehensive unit tests for atomic consumption path Refs: review-fix-closure-2026-05-28 TOTP recovery code atomicity --- internal/repository/user.go | 61 +++++++++++ internal/service/auth.go | 20 +++- internal/service/totp.go | 24 ++++- internal/service/totp_internal_test.go | 136 +++++++++++++++++++++++-- 4 files changed, 229 insertions(+), 12 deletions(-) diff --git a/internal/repository/user.go b/internal/repository/user.go index 2183c75..f9914f8 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -2,11 +2,15 @@ package repository import ( "context" + "encoding/json" + "errors" + "fmt" "strings" "time" "gorm.io/gorm" + "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/pagination" ) @@ -231,6 +235,63 @@ func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) erro }).Error } +// ConsumeTOTPRecoveryCode 原子性地消费一个恢复码 +// 在事务中验证恢复码并更新,避免并发竞争窗口 +func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) { + var user domain.User + var consumed bool + + err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // 在事务中重新获取用户 + // 注意:SQLite 不完全支持 FOR UPDATE,依赖事务隔离 + if err := tx.First(&user, userID).Error; err != nil { + return err + } + + if !user.TOTPEnabled { + return errors.New("TOTP 未启用") + } + + // 解析存储的哈希恢复码 + var hashedCodes []string + if user.TOTPRecoveryCodes != "" { + if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil { + return fmt.Errorf("解析恢复码失败: %w", err) + } + } + + // 验证恢复码(输入会被哈希后与存储的哈希比较) + idx, matched := auth.VerifyRecoveryCode(code, hashedCodes) + if !matched { + // 不匹配,标记消费失败但不返回错误 + consumed = false + return nil + } + + // 从列表中移除已使用的恢复码 + hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...) + codesJSON, err := json.Marshal(hashedCodes) + if err != nil { + return fmt.Errorf("序列化恢复码失败: %w", err) + } + user.TOTPRecoveryCodes = string(codesJSON) + + // 在同一事务中更新 + if err := tx.Model(&user).Update("totp_recovery_codes", user.TOTPRecoveryCodes).Error; err != nil { + return err + } + + consumed = true + return nil + }) + + if err != nil { + return nil, false, err + } + + return &user, consumed, nil +} + // UpdatePassword 更新用户密码 func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error { return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error diff --git a/internal/service/auth.go b/internal/service/auth.go index c76521d..f2d8dc9 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -1299,9 +1299,25 @@ func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *do return nil } + // 尝试原子性消费恢复码(如果 repo 支持) + if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok { + _, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, user.ID, code) + if err != nil { + return fmt.Errorf("消费恢复码失败: %w", err) + } + if consumed { + return nil + } + // 恢复码不匹配 + return errors.New("TOTP code or recovery code is invalid") + } + + // 降级到非原子性恢复码消费(兼容性模式) var hashedCodes []string if strings.TrimSpace(user.TOTPRecoveryCodes) != "" { - _ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes) + if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil { + return fmt.Errorf("解析恢复码失败: %w", err) + } } index, matched := auth.VerifyRecoveryCode(code, hashedCodes) if !matched { @@ -1311,7 +1327,7 @@ func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *do hashedCodes = append(hashedCodes[:index], hashedCodes[index+1:]...) payload, err := json.Marshal(hashedCodes) if err != nil { - return err + return fmt.Errorf("序列化恢复码失败: %w", err) } user.TOTPRecoveryCodes = string(payload) return s.userRepo.UpdateTOTP(ctx, user) diff --git a/internal/service/totp.go b/internal/service/totp.go index d25ddb8..48362ae 100644 --- a/internal/service/totp.go +++ b/internal/service/totp.go @@ -7,8 +7,14 @@ import ( "fmt" "github.com/user-management-system/internal/auth" + "github.com/user-management-system/internal/domain" ) +// TOTPService manages 2FA setup, enable/disable, and verification. +type atomicTOTPRecoveryCodeConsumer interface { + ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) +} + // TOTPService manages 2FA setup, enable/disable, and verification. type TOTPService struct { userRepo userRepositoryInterface @@ -122,7 +128,7 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728") + return fmt.Errorf("用户不存在") } if !user.TOTPEnabled { return nil @@ -132,13 +138,27 @@ func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) return nil } + // 尝试原子性消费恢复码(如果 repo 支持) + if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok { + _, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, userID, code) + if err != nil { + return fmt.Errorf("消费恢复码失败: %w", err) + } + if consumed { + return nil + } + // 恢复码不匹配,继续返回通用错误 + return errors.New("验证码错误或已过期") + } + + // 降级到非原子性恢复码消费(兼容性模式) var storedCodes []string if user.TOTPRecoveryCodes != "" { if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil { return fmt.Errorf("解析恢复码失败: %w", err) } } - idx, matched := auth.ValidateRecoveryCode(code, storedCodes) + idx, matched := auth.VerifyRecoveryCode(code, storedCodes) if !matched { return errors.New("验证码错误或已过期") } diff --git a/internal/service/totp_internal_test.go b/internal/service/totp_internal_test.go index fe6679f..d2d81cf 100644 --- a/internal/service/totp_internal_test.go +++ b/internal/service/totp_internal_test.go @@ -2,17 +2,40 @@ package service import ( "context" + "encoding/json" "errors" + "fmt" "strings" "testing" + "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" ) +func mustHashRecoveryCode(t *testing.T, code string) string { + t.Helper() + hashed, err := auth.HashRecoveryCode(code) + if err != nil { + t.Fatalf("hash recovery code: %v", err) + } + return hashed +} + +func mustMarshalJSON(t *testing.T, value any) string { + t.Helper() + payload, err := json.Marshal(value) + if err != nil { + t.Fatalf("marshal json: %v", err) + } + return string(payload) +} + type totpTestRepo struct { - user *domain.User - getErr error - updateTOTPErr error + user *domain.User + getErr error + updateTOTPErr error + consumeRecoveryCodeErr error + consumeRecoveryCodeCalled bool } func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil } @@ -69,6 +92,40 @@ func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, e func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) { return nil, 0, errors.New("not implemented") } +func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) { + r.consumeRecoveryCodeCalled = true + if r.consumeRecoveryCodeErr != nil { + return nil, false, r.consumeRecoveryCodeErr + } + if r.user == nil || r.user.ID != userID { + return nil, false, errors.New("not found") + } + + var hashedCodes []string + if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" { + if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil { + return nil, false, fmt.Errorf("解析恢复码失败: %w", err) + } + } + idx, matched := auth.VerifyRecoveryCode(code, hashedCodes) + if !matched { + return nil, false, nil + } + + copyUser := *r.user + hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...) + copyUser.TOTPRecoveryCodes = mustMarshalJSONFromHelper(hashedCodes) + r.user = ©User + return ©User, true, nil +} + +func mustMarshalJSONFromHelper(value any) string { + payload, err := json.Marshal(value) + if err != nil { + panic(err) + } + return string(payload) +} func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) { repo := &totpTestRepo{user: &domain.User{ @@ -89,16 +146,16 @@ func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) { } } -func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T) { +func TestTOTPService_ReturnsAtomicConsumptionErrorAfterRecoveryCodeConsumption(t *testing.T) { repo := &totpTestRepo{ user: &domain.User{ ID: 7, Username: "totp-user", TOTPEnabled: true, TOTPSecret: "invalid-secret", - TOTPRecoveryCodes: `["RECOVERY-1"]`, + TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), }, - updateTOTPErr: errors.New("write failed"), + consumeRecoveryCodeErr: errors.New("write failed"), } svc := NewTOTPService(repo) @@ -106,7 +163,70 @@ func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T if err == nil { t.Fatal("expected update failure to be returned") } - if !strings.Contains(err.Error(), "更新恢复码失败") { - t.Fatalf("expected update error, got: %v", err) + if !repo.consumeRecoveryCodeCalled { + t.Fatal("expected atomic consumption path to be invoked") + } + if !strings.Contains(err.Error(), "消费恢复码失败") { + t.Fatalf("expected atomic consume error, got: %v", err) + } +} + +func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) { + repo := &totpTestRepo{ + user: &domain.User{ + ID: 8, + Username: "totp-user", + TOTPEnabled: true, + TOTPSecret: "invalid-secret", + TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1"), mustHashRecoveryCode(t, "RECOVERY-2")}), + }, + } + svc := NewTOTPService(repo) + + if err := svc.VerifyTOTP(context.Background(), 8, "RECOVERY-1"); err != nil { + t.Fatalf("expected hashed recovery code to verify, got: %v", err) + } + if !repo.consumeRecoveryCodeCalled { + t.Fatal("expected atomic recovery-code consumption path to be used") + } + if repo.user == nil { + t.Fatal("expected updated user to be persisted") + } + + var remaining []string + if err := json.Unmarshal([]byte(repo.user.TOTPRecoveryCodes), &remaining); err != nil { + t.Fatalf("unmarshal remaining codes: %v", err) + } + if len(remaining) != 1 { + t.Fatalf("expected 1 remaining recovery code, got %d", len(remaining)) + } + if remaining[0] != mustHashRecoveryCode(t, "RECOVERY-2") { + t.Fatalf("expected RECOVERY-2 hash to remain, got %q", remaining[0]) + } +} + +func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) { + repo := &totpTestRepo{ + user: &domain.User{ + ID: 9, + Username: "totp-user", + TOTPEnabled: true, + TOTPSecret: "invalid-secret", + TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), + }, + } + svc := NewTOTPService(repo) + + if err := svc.DisableTOTP(context.Background(), 9, "RECOVERY-1"); err != nil { + t.Fatalf("expected hashed recovery code to disable TOTP, got: %v", err) + } + if repo.user == nil { + t.Fatal("expected updated user to be persisted") + } + if repo.user.TOTPEnabled { + t.Fatal("expected TOTP to be disabled") + } + if repo.user.TOTPSecret != "" || repo.user.TOTPRecoveryCodes != "" { + t.Fatalf("expected TOTP secret and recovery codes to be cleared, got secret=%q codes=%q", repo.user.TOTPSecret, repo.user.TOTPRecoveryCodes) } }