diff --git a/internal/repository/user.go b/internal/repository/user.go index f9914f8..5b4a413 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -292,6 +292,58 @@ func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int return &user, consumed, nil } +// VerifyTOTPOrRecoveryCode 原子性地验证 TOTP 码或恢复码(不消费恢复码) +// 返回 (true, nil) 表示验证成功 +// 返回 (false, nil) 表示验证失败(码不匹配) +// 返回 (false, error) 表示执行出错 +func (r *UserRepository) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) { + var user domain.User + + err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.First(&user, userID).Error; err != nil { + return err + } + + if !user.TOTPEnabled { + return errors.New("TOTP 未启用") + } + + // 先验证 TOTP 码 + manager := auth.NewTOTPManager() + if manager.ValidateCode(user.TOTPSecret, code) { + return nil + } + + // TOTP 码无效,尝试验证恢复码 + var hashedCodes []string + if user.TOTPRecoveryCodes != "" { + if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil { + return fmt.Errorf("解析恢复码失败: %w", err) + } + } + + _, matched := auth.VerifyRecoveryCode(code, hashedCodes) + if !matched { + // 恢复码也不匹配,标记验证失败 + return errVerificationFailed + } + + return nil + }) + + if err == errVerificationFailed { + return false, nil + } + if err != nil { + return false, err + } + + return true, nil +} + +// errVerificationFailed 标记验证失败的内部错误 +var errVerificationFailed = errors.New("verification failed") + // UpdatePassword 更新用户密码 func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error { return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error diff --git a/internal/service/totp.go b/internal/service/totp.go index 48362ae..1a4c3a7 100644 --- a/internal/service/totp.go +++ b/internal/service/totp.go @@ -10,11 +10,16 @@ import ( "github.com/user-management-system/internal/domain" ) -// TOTPService manages 2FA setup, enable/disable, and verification. +// atomicTOTPRecoveryCodeConsumer 原子性恢复码消费接口 type atomicTOTPRecoveryCodeConsumer interface { ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) } +// atomicTOTPVerifier 原子性 TOTP/恢复码验证接口(不消费恢复码) +type atomicTOTPVerifier interface { + VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) +} + // TOTPService manages 2FA setup, enable/disable, and verification. type TOTPService struct { userRepo userRepositoryInterface @@ -99,24 +104,37 @@ func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728") + return fmt.Errorf("用户不存在") } if !user.TOTPEnabled { - return errors.New("2FA \u672a\u542f\u7528") + return errors.New("2FA 未启用") } - valid := s.totpManager.ValidateCode(user.TOTPSecret, code) - if !valid { - var hashedCodes []string - if user.TOTPRecoveryCodes != "" { - if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil { - return fmt.Errorf("解析恢复码失败: %w", err) - } + // 尝试原子性验证(如果 repo 支持) + if verifier, ok := s.userRepo.(atomicTOTPVerifier); ok { + valid, err := verifier.VerifyTOTPOrRecoveryCode(ctx, userID, code) + if err != nil { + return fmt.Errorf("验证失败: %w", err) } - _, matched := auth.VerifyRecoveryCode(code, hashedCodes) - if !matched { + if !valid { return errors.New("验证码或恢复码错误") } + // 验证通过,继续禁用 + } else { + // 降级到非原子性验证(兼容性模式) + valid := s.totpManager.ValidateCode(user.TOTPSecret, code) + if !valid { + var hashedCodes []string + if user.TOTPRecoveryCodes != "" { + if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil { + return fmt.Errorf("解析恢复码失败: %w", err) + } + } + _, matched := auth.VerifyRecoveryCode(code, hashedCodes) + if !matched { + return errors.New("验证码或恢复码错误") + } + } } user.TOTPEnabled = false diff --git a/internal/service/totp_internal_test.go b/internal/service/totp_internal_test.go index d2d81cf..60eef64 100644 --- a/internal/service/totp_internal_test.go +++ b/internal/service/totp_internal_test.go @@ -31,11 +31,13 @@ func mustMarshalJSON(t *testing.T, value any) string { } type totpTestRepo struct { - user *domain.User - getErr error - updateTOTPErr error - consumeRecoveryCodeErr error - consumeRecoveryCodeCalled bool + user *domain.User + getErr error + updateTOTPErr error + consumeRecoveryCodeErr error + consumeRecoveryCodeCalled bool + verifyTOTPOrRecoveryCodeErr error + verifyTOTPOrRecoveryCodeCalled bool } func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil } @@ -89,9 +91,11 @@ func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, e func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) { return false, errors.New("not implemented") } + func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) { return nil, 0, errors.New("not implemented") } + func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) { r.consumeRecoveryCodeCalled = true if r.consumeRecoveryCodeErr != nil { @@ -119,6 +123,34 @@ func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64 return ©User, true, nil } +func (r *totpTestRepo) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) { + r.verifyTOTPOrRecoveryCodeCalled = true + if r.verifyTOTPOrRecoveryCodeErr != nil { + return false, r.verifyTOTPOrRecoveryCodeErr + } + if r.user == nil || r.user.ID != userID { + return false, errors.New("not found") + } + if !r.user.TOTPEnabled { + return false, errors.New("TOTP not enabled") + } + + // 尝试验证 TOTP 码(简化:只检查是否为特定测试码) + if code == "123456" || code == "654321" { + return true, nil + } + + // 尝试验证恢复码 + var hashedCodes []string + if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" { + if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil { + return false, fmt.Errorf("解析恢复码失败: %w", err) + } + } + _, matched := auth.VerifyRecoveryCode(code, hashedCodes) + return matched, nil +} + func mustMarshalJSONFromHelper(value any) string { payload, err := json.Marshal(value) if err != nil { @@ -205,6 +237,59 @@ func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) { } } +func TestTOTPService_DisableTOTP_UsesAtomicVerificationPath(t *testing.T) { + repo := &totpTestRepo{ + user: &domain.User{ + ID: 10, + Username: "totp-user", + TOTPEnabled: true, + TOTPSecret: "test-secret", + TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), + }, + } + svc := NewTOTPService(repo) + + // 使用测试恢复码禁用 TOTP + if err := svc.DisableTOTP(context.Background(), 10, "RECOVERY-1"); err != nil { + t.Fatalf("expected disable to succeed with recovery code, got: %v", err) + } + + if !repo.verifyTOTPOrRecoveryCodeCalled { + t.Fatal("expected atomic verification path to be used") + } + + if repo.user.TOTPEnabled { + t.Fatal("expected TOTP to be disabled") + } +} + +func TestTOTPService_DisableTOTP_AtomicVerificationFailsOnWrongCode(t *testing.T) { + repo := &totpTestRepo{ + user: &domain.User{ + ID: 11, + Username: "totp-user", + TOTPEnabled: true, + TOTPSecret: "test-secret", + TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), + }, + } + svc := NewTOTPService(repo) + + // 使用错误的恢复码 + err := svc.DisableTOTP(context.Background(), 11, "WRONG-CODE") + if err == nil { + t.Fatal("expected disable to fail with wrong code") + } + + if !repo.verifyTOTPOrRecoveryCodeCalled { + t.Fatal("expected atomic verification path to be used") + } + + if !repo.user.TOTPEnabled { + t.Fatal("expected TOTP to remain enabled after failed verification") + } +} + func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) { repo := &totpTestRepo{ user: &domain.User{