From 9cc589256516098fefcbe40c5ebd1e4d85d03d58 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 28 May 2026 20:38:34 +0800 Subject: [PATCH] fix: tighten password and surface persistence errors --- internal/api/middleware/operation_log.go | 9 +- internal/api/middleware/operation_log_test.go | 59 +++++++++ internal/service/auth.go | 4 + .../service/auth_password_internal_test.go | 12 ++ internal/service/auth_service_test.go | 11 +- internal/service/totp.go | 32 +++-- internal/service/totp_internal_test.go | 112 ++++++++++++++++++ 7 files changed, 228 insertions(+), 11 deletions(-) create mode 100644 internal/api/middleware/operation_log_test.go create mode 100644 internal/service/auth_password_internal_test.go create mode 100644 internal/service/totp_internal_test.go diff --git a/internal/api/middleware/operation_log.go b/internal/api/middleware/operation_log.go index c01a344..fb2a59d 100644 --- a/internal/api/middleware/operation_log.go +++ b/internal/api/middleware/operation_log.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "io" + "log" "time" "github.com/gin-gonic/gin" @@ -87,10 +88,16 @@ func (m *OperationLogMiddleware) Record() gin.HandlerFunc { UserAgent: c.Request.UserAgent(), } + if m == nil || m.repo == nil { + return + } + go func(entry *domain.OperationLog) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - _ = m.repo.Create(ctx, entry) + if err := m.repo.Create(ctx, entry); err != nil { + log.Printf("[operation-log] create failed: %v", err) + } }(logEntry) } } diff --git a/internal/api/middleware/operation_log_test.go b/internal/api/middleware/operation_log_test.go new file mode 100644 index 0000000..1861887 --- /dev/null +++ b/internal/api/middleware/operation_log_test.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestOperationLogRecord_AllowsNilRepository(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use((&OperationLogMiddleware{}).Record()) + router.POST("/operation-log", func(c *gin.Context) { + c.JSON(http.StatusCreated, gin.H{"ok": true}) + }) + + body := bytes.NewBufferString(`{"password":"secret","token":"abc"}`) + req := httptest.NewRequest(http.MethodPost, "/operation-log", body) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusCreated { + t.Fatalf("unexpected status: got %d want %d", recorder.Code, http.StatusCreated) + } +} + +func TestSanitizeParams_MasksSensitiveFields(t *testing.T) { + sanitized := sanitizeParams([]byte(`{"password":"secret","nested":"ok","token":"abc"}`)) + + var payload map[string]any + if err := json.Unmarshal([]byte(sanitized), &payload); err != nil { + t.Fatalf("sanitized payload should remain valid json: %v", err) + } + if payload["password"] != "***" { + t.Fatalf("password should be masked, got: %#v", payload["password"]) + } + if payload["token"] != "***" { + t.Fatalf("token should be masked, got: %#v", payload["token"]) + } +} + +func TestSanitizeParams_FallbacksForNonJSONPayload(t *testing.T) { + longText := strings.Repeat("x", 600) + sanitized := sanitizeParams([]byte(longText)) + if len(sanitized) != 503 { + t.Fatalf("expected truncated fallback length 503, got %d", len(sanitized)) + } + if !strings.HasSuffix(sanitized, "...") { + t.Fatalf("expected truncated fallback to end with ellipsis: %q", sanitized[len(sanitized)-3:]) + } +} diff --git a/internal/service/auth.go b/internal/service/auth.go index d045f4f..c76521d 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -306,6 +306,10 @@ func validatePasswordStrength(password string, minLength int, strict bool) error return nil } + if info.Length <= minLength && info.Score < 3 { + return errors.New("密码强度不足,短密码需至少包含三种字符类型") + } + if info.Score < 2 { return errors.New("密码强度不足") } diff --git a/internal/service/auth_password_internal_test.go b/internal/service/auth_password_internal_test.go new file mode 100644 index 0000000..a70686a --- /dev/null +++ b/internal/service/auth_password_internal_test.go @@ -0,0 +1,12 @@ +package service + +import "testing" + +func TestValidatePasswordStrengthBoundaryRules(t *testing.T) { + t.Run("accepts boundary password with three character classes", func(t *testing.T) { + err := validatePasswordStrength("Abcd1234", defaultPasswordMinLen, false) + if err != nil { + t.Fatalf("expected 8-char password with three classes to pass: %v", err) + } + }) +} diff --git a/internal/service/auth_service_test.go b/internal/service/auth_service_test.go index 4cd93ac..cd68fa3 100644 --- a/internal/service/auth_service_test.go +++ b/internal/service/auth_service_test.go @@ -138,8 +138,15 @@ func TestValidatePasswordStrength(t *testing.T) { wantErr: true, }, { - name: "valid_weak_password_non_strict", - password: "Abcd1234", + name: "boundary_password_requires_three_character_classes", + password: "abcd1234", + minLength: 8, + strict: false, + wantErr: true, + }, + { + name: "longer_password_allows_two_character_classes", + password: "abcdefgh1234", minLength: 8, strict: false, wantErr: false, diff --git a/internal/service/totp.go b/internal/service/totp.go index 47e45f7..d25ddb8 100644 --- a/internal/service/totp.go +++ b/internal/service/totp.go @@ -47,9 +47,16 @@ func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPRe // Hash recovery codes before storing (SEC-03 fix) hashedCodes := make([]string, len(setup.RecoveryCodes)) for i, code := range setup.RecoveryCodes { - hashedCodes[i], _ = auth.HashRecoveryCode(code) + hashedCode, err := auth.HashRecoveryCode(code) + if err != nil { + return nil, fmt.Errorf("生成恢复码摘要失败: %w", err) + } + hashedCodes[i] = hashedCode + } + codesJSON, err := json.Marshal(hashedCodes) + if err != nil { + return nil, fmt.Errorf("序列化恢复码失败: %w", err) } - codesJSON, _ := json.Marshal(hashedCodes) user.TOTPRecoveryCodes = string(codesJSON) if err := s.userRepo.UpdateTOTP(ctx, user); err != nil { @@ -96,11 +103,13 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string if !valid { var hashedCodes []string if user.TOTPRecoveryCodes != "" { - _ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes) + if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil { + return fmt.Errorf("解析恢复码失败: %w", err) + } } _, matched := auth.VerifyRecoveryCode(code, hashedCodes) if !matched { - return errors.New("\u9a8c\u8bc1\u7801\u6216\u6062\u590d\u7801\u9519\u8bef") + return errors.New("验证码或恢复码错误") } } @@ -125,17 +134,24 @@ func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) var storedCodes []string if user.TOTPRecoveryCodes != "" { - _ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes) + if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil { + return fmt.Errorf("解析恢复码失败: %w", err) + } } idx, matched := auth.ValidateRecoveryCode(code, storedCodes) if !matched { - return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f") + return errors.New("验证码错误或已过期") } storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...) - codesJSON, _ := json.Marshal(storedCodes) + codesJSON, err := json.Marshal(storedCodes) + if err != nil { + return fmt.Errorf("序列化恢复码失败: %w", err) + } user.TOTPRecoveryCodes = string(codesJSON) - _ = s.userRepo.UpdateTOTP(ctx, user) + if err := s.userRepo.UpdateTOTP(ctx, user); err != nil { + return fmt.Errorf("更新恢复码失败: %w", err) + } return nil } diff --git a/internal/service/totp_internal_test.go b/internal/service/totp_internal_test.go new file mode 100644 index 0000000..fe6679f --- /dev/null +++ b/internal/service/totp_internal_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/user-management-system/internal/domain" +) + +type totpTestRepo struct { + user *domain.User + getErr error + updateTOTPErr error +} + +func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil } +func (r *totpTestRepo) Update(ctx context.Context, user *domain.User) error { return nil } +func (r *totpTestRepo) UpdateTOTP(ctx context.Context, user *domain.User) error { + if r.updateTOTPErr != nil { + return r.updateTOTPErr + } + copyUser := *user + r.user = ©User + return nil +} +func (r *totpTestRepo) Delete(ctx context.Context, id int64) error { return nil } +func (r *totpTestRepo) GetByID(ctx context.Context, id int64) (*domain.User, error) { + if r.getErr != nil { + return nil, r.getErr + } + if r.user == nil || r.user.ID != id { + return nil, errors.New("not found") + } + copyUser := *r.user + return ©User, nil +} +func (r *totpTestRepo) GetByUsername(ctx context.Context, username string) (*domain.User, error) { + return nil, errors.New("not implemented") +} +func (r *totpTestRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) { + return nil, errors.New("not implemented") +} +func (r *totpTestRepo) GetByPhone(ctx context.Context, phone string) (*domain.User, error) { + return nil, errors.New("not implemented") +} +func (r *totpTestRepo) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) { + return nil, 0, errors.New("not implemented") +} +func (r *totpTestRepo) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) { + return nil, 0, errors.New("not implemented") +} +func (r *totpTestRepo) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error { + return errors.New("not implemented") +} +func (r *totpTestRepo) UpdateLastLogin(ctx context.Context, id int64, ip string) error { + return errors.New("not implemented") +} +func (r *totpTestRepo) ExistsByUsername(ctx context.Context, username string) (bool, error) { + return false, errors.New("not implemented") +} +func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + return false, errors.New("not implemented") +} +func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) { + return false, errors.New("not implemented") +} +func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) { + return nil, 0, errors.New("not implemented") +} + +func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) { + repo := &totpTestRepo{user: &domain.User{ + ID: 42, + Username: "totp-user", + TOTPEnabled: true, + TOTPSecret: "invalid-secret", + TOTPRecoveryCodes: "not-json", + }} + svc := NewTOTPService(repo) + + err := svc.VerifyTOTP(context.Background(), 42, "recovery-code") + if err == nil { + t.Fatal("expected corrupted recovery-code payload to fail") + } + if !strings.Contains(err.Error(), "解析恢复码失败") { + t.Fatalf("expected decode error, got: %v", err) + } +} + +func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T) { + repo := &totpTestRepo{ + user: &domain.User{ + ID: 7, + Username: "totp-user", + TOTPEnabled: true, + TOTPSecret: "invalid-secret", + TOTPRecoveryCodes: `["RECOVERY-1"]`, + }, + updateTOTPErr: errors.New("write failed"), + } + svc := NewTOTPService(repo) + + err := svc.VerifyTOTP(context.Background(), 7, "RECOVERY-1") + if err == nil { + t.Fatal("expected update failure to be returned") + } + if !strings.Contains(err.Error(), "更新恢复码失败") { + t.Fatalf("expected update error, got: %v", err) + } +}