From 289aab2930d18f3c413807df3cb3a1b08eb0eef9 Mon Sep 17 00:00:00 2001 From: long-agent Date: Sat, 11 Apr 2026 21:58:28 +0800 Subject: [PATCH] test: add repository tests to improve coverage from 46.6% to 74% New test files: - custom_field_repository_test.go: 10 tests for CustomFieldRepository & UserCustomFieldValueRepository - login_log_repository_test.go: 3 tests for ListCursor, ListByUserIDCursor, ListAllForExport - operation_log_repository_test.go: 1 test for ListCursor - role_repository_test.go: 2 tests for GetAncestorIDs, GetAncestors - social_account_repository_test.go: 8 CRUD tests - theme_repository_test.go: 10 tests for ThemeConfigRepository - user_role_repository_test.go: 1 test for DeleteByUserAndRole Modified test files: - device_repository_test.go: Added ListAllCursor tests - user_repository_test.go: Added AdvancedSearch tests - webhook_repository_test.go: Added ListByCreatorPaginated test Updated documentation with new coverage status. --- ...OJECT_REAL_COMPLETION_REVIEW_2026-04-10.md | 1 + .../custom_field_repository_test.go | 332 ++++++++++++++++++ internal/repository/device_repository_test.go | 89 +++++ .../repository/login_log_repository_test.go | 156 ++++++++ .../operation_log_repository_test.go | 94 +++++ internal/repository/role_repository_test.go | 90 +++++ .../social_account_repository_test.go | 263 ++++++++++++++ internal/repository/theme_repository_test.go | 275 +++++++++++++++ internal/repository/user_repository_test.go | 257 ++++++++++++++ .../repository/user_role_repository_test.go | 36 ++ .../repository/webhook_repository_test.go | 37 ++ 11 files changed, 1630 insertions(+) create mode 100644 internal/repository/custom_field_repository_test.go create mode 100644 internal/repository/login_log_repository_test.go create mode 100644 internal/repository/operation_log_repository_test.go create mode 100644 internal/repository/role_repository_test.go create mode 100644 internal/repository/social_account_repository_test.go create mode 100644 internal/repository/theme_repository_test.go create mode 100644 internal/repository/user_role_repository_test.go diff --git a/docs/code-review/PROJECT_REAL_COMPLETION_REVIEW_2026-04-10.md b/docs/code-review/PROJECT_REAL_COMPLETION_REVIEW_2026-04-10.md index cc677c7..76cecc6 100644 --- a/docs/code-review/PROJECT_REAL_COMPLETION_REVIEW_2026-04-10.md +++ b/docs/code-review/PROJECT_REAL_COMPLETION_REVIEW_2026-04-10.md @@ -56,6 +56,7 @@ RBAC/admin 改动必须验证: | `.gitattributes` | ✅ 已添加 | 统一行尾符为 LF(消除 LF/CRLF 污染) | | Swagger 注解 | ✅ 已添加 | 13 个 handler 共 86 处 `@Summary/@Description/@Tags/@Param/@Router` 注解 | | Device Repository 测试 | ✅ 已添加 | 15 个测试用例覆盖 DeviceRepository CRUD | +| Repository 测试覆盖率 | ✅ 已提升 | 从 46.6% 提升至 74%(目标 80%)| ## 最新验证结果 diff --git a/internal/repository/custom_field_repository_test.go b/internal/repository/custom_field_repository_test.go new file mode 100644 index 0000000..596ec7f --- /dev/null +++ b/internal/repository/custom_field_repository_test.go @@ -0,0 +1,332 @@ +package repository + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + _ "modernc.org/sqlite" + gormsqlite "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/user-management-system/internal/domain" +) + +var customFieldTestCounter int64 + +// openCustomFieldTestDB 为每个测试打开独立的内存数据库 +func openCustomFieldTestDB(t *testing.T) *gorm.DB { + t.Helper() + + id := atomic.AddInt64(&customFieldTestCounter, 1) + dsn := fmt.Sprintf("file:customfieldtestdb%d?mode=memory&cache=private", id) + + db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ + DriverName: "sqlite", + DSN: dsn, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("打开测试数据库失败: %v", err) + } + + if err := db.AutoMigrate(&domain.CustomField{}, &domain.UserCustomFieldValue{}); err != nil { + t.Fatalf("数据库迁移失败: %v", err) + } + return db +} + +// setupCustomFieldTestDB 兼容性别名 +func setupCustomFieldTestDB(t *testing.T) *gorm.DB { + return openCustomFieldTestDB(t) +} + +// TestCustomFieldRepository_Create 测试创建自定义字段 +func TestCustomFieldRepository_Create(t *testing.T) { + db := setupCustomFieldTestDB(t) + repo := NewCustomFieldRepository(db) + ctx := context.Background() + + field := &domain.CustomField{ + Name: "测试字段", + FieldKey: "test_field", + Type: domain.CustomFieldTypeString, + Required: false, + Sort: 1, + } + + if err := repo.Create(ctx, field); err != nil { + t.Fatalf("Create() error = %v", err) + } + if field.ID == 0 { + t.Error("创建后字段ID不应为0") + } +} + +// TestCustomFieldRepository_GetByID 测试根据ID获取字段 +func TestCustomFieldRepository_GetByID(t *testing.T) { + db := setupCustomFieldTestDB(t) + repo := NewCustomFieldRepository(db) + ctx := context.Background() + + field := &domain.CustomField{ + Name: "getbyid-field", + FieldKey: "getbyid_key", + Type: domain.CustomFieldTypeNumber, + } + repo.Create(ctx, field) + + found, err := repo.GetByID(ctx, field.ID) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if found.Name != "getbyid-field" { + t.Errorf("Name = %v, want getbyid-field", found.Name) + } + + _, err = repo.GetByID(ctx, 9999) + if err == nil { + t.Error("GetByID() should return error for non-existent ID") + } +} + +// TestCustomFieldRepository_GetByFieldKey 测试根据FieldKey获取字段 +func TestCustomFieldRepository_GetByFieldKey(t *testing.T) { + db := setupCustomFieldTestDB(t) + repo := NewCustomFieldRepository(db) + ctx := context.Background() + + field := &domain.CustomField{ + Name: "field-by-key", + FieldKey: "unique_field_key", + Type: domain.CustomFieldTypeBoolean, + } + repo.Create(ctx, field) + + found, err := repo.GetByFieldKey(ctx, "unique_field_key") + if err != nil { + t.Fatalf("GetByFieldKey() error = %v", err) + } + if found.Name != "field-by-key" { + t.Errorf("Name = %v, want field-by-key", found.Name) + } + + _, err = repo.GetByFieldKey(ctx, "not_exist_key") + if err == nil { + t.Error("GetByFieldKey() should return error for non-existent key") + } +} + +// TestCustomFieldRepository_Update 测试更新字段 +func TestCustomFieldRepository_Update(t *testing.T) { + db := setupCustomFieldTestDB(t) + repo := NewCustomFieldRepository(db) + ctx := context.Background() + + field := &domain.CustomField{ + Name: "before-update", + FieldKey: "update_key", + Type: domain.CustomFieldTypeString, + } + repo.Create(ctx, field) + + field.Name = "after-update" + field.Required = true + if err := repo.Update(ctx, field); err != nil { + t.Fatalf("Update() error = %v", err) + } + + found, _ := repo.GetByID(ctx, field.ID) + if found.Name != "after-update" { + t.Errorf("Name = %v, want after-update", found.Name) + } + if !found.Required { + t.Error("Required should be true after update") + } +} + +// TestCustomFieldRepository_Delete 测试删除字段 +func TestCustomFieldRepository_Delete(t *testing.T) { + db := setupCustomFieldTestDB(t) + repo := NewCustomFieldRepository(db) + ctx := context.Background() + + field := &domain.CustomField{ + Name: "to-delete", + FieldKey: "delete_key", + Type: domain.CustomFieldTypeDate, + } + repo.Create(ctx, field) + + if err := repo.Delete(ctx, field.ID); err != nil { + t.Fatalf("Delete() error = %v", err) + } + + _, err := repo.GetByID(ctx, field.ID) + if err == nil { + t.Error("删除后查询应返回错误") + } +} + +// TestCustomFieldRepository_List 测试获取启用字段列表 +func TestCustomFieldRepository_List(t *testing.T) { + db := setupCustomFieldTestDB(t) + repo := NewCustomFieldRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.CustomField{Name: "enabled1", FieldKey: "enabled1_key", Type: domain.CustomFieldTypeString}) + repo.Create(ctx, &domain.CustomField{Name: "enabled2", FieldKey: "enabled2_key", Type: domain.CustomFieldTypeNumber}) + repo.Create(ctx, &domain.CustomField{Name: "enabled3", FieldKey: "enabled3_key", Type: domain.CustomFieldTypeBoolean}) + + fields, err := repo.List(ctx) + if err != nil { + t.Fatalf("List() error = %v", err) + } + // List filters by status=1, all 3 have status=1 (default) + if len(fields) != 3 { + t.Errorf("len(fields) = %d, want 3", len(fields)) + } +} + +// TestCustomFieldRepository_ListAll 测试获取所有字段列表 +func TestCustomFieldRepository_ListAll(t *testing.T) { + db := setupCustomFieldTestDB(t) + repo := NewCustomFieldRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.CustomField{Name: "all1", FieldKey: "all1_key", Type: domain.CustomFieldTypeString}) + repo.Create(ctx, &domain.CustomField{Name: "all2", FieldKey: "all2_key", Type: domain.CustomFieldTypeNumber}) + + fields, err := repo.ListAll(ctx) + if err != nil { + t.Fatalf("ListAll() error = %v", err) + } + if len(fields) != 2 { + t.Errorf("len(fields) = %d, want 2", len(fields)) + } +} + +// TestUserCustomFieldValueRepository_GetByUserID 测试获取用户所有字段值 +func TestUserCustomFieldValueRepository_GetByUserID(t *testing.T) { + db := setupCustomFieldTestDB(t) + valueRepo := NewUserCustomFieldValueRepository(db) + ctx := context.Background() + + // 直接使用 GORM Create 测试,因为 Set 使用 NOW() 不兼容 SQLite + db.WithContext(ctx).Create(&domain.UserCustomFieldValue{ + UserID: 1, + FieldID: 1, + FieldKey: "field1_key", + Value: "value1", + }) + db.WithContext(ctx).Create(&domain.UserCustomFieldValue{ + UserID: 1, + FieldID: 2, + FieldKey: "field2_key", + Value: "value2", + }) + + values, err := valueRepo.GetByUserID(ctx, 1) + if err != nil { + t.Fatalf("GetByUserID() error = %v", err) + } + if len(values) != 2 { + t.Errorf("len(values) = %d, want 2", len(values)) + } +} + +// TestUserCustomFieldValueRepository_GetByUserIDAndFieldKey 测试获取用户指定字段值 +func TestUserCustomFieldValueRepository_GetByUserIDAndFieldKey(t *testing.T) { + db := setupCustomFieldTestDB(t) + valueRepo := NewUserCustomFieldValueRepository(db) + ctx := context.Background() + + db.WithContext(ctx).Create(&domain.UserCustomFieldValue{ + UserID: 1, + FieldID: 1, + FieldKey: "specific_key", + Value: "specific_value", + }) + + found, err := valueRepo.GetByUserIDAndFieldKey(ctx, 1, "specific_key") + if err != nil { + t.Fatalf("GetByUserIDAndFieldKey() error = %v", err) + } + if found.Value != "specific_value" { + t.Errorf("Value = %v, want specific_value", found.Value) + } + + _, err = valueRepo.GetByUserIDAndFieldKey(ctx, 1, "non_existent_key") + if err == nil { + t.Error("GetByUserIDAndFieldKey() should return error for non-existent key") + } +} + +// TestUserCustomFieldValueRepository_Delete 测试删除用户字段值 +func TestUserCustomFieldValueRepository_Delete(t *testing.T) { + db := setupCustomFieldTestDB(t) + valueRepo := NewUserCustomFieldValueRepository(db) + ctx := context.Background() + + db.WithContext(ctx).Create(&domain.UserCustomFieldValue{ + UserID: 1, + FieldID: 1, + FieldKey: "delete_key", + Value: "to_delete", + }) + + err := valueRepo.Delete(ctx, 1, 1) + if err != nil { + t.Fatalf("Delete() error = %v", err) + } + + _, err = valueRepo.GetByUserIDAndFieldKey(ctx, 1, "delete_key") + if err == nil { + t.Error("删除后查询应返回错误") + } +} + +// TestUserCustomFieldValueRepository_DeleteByUserID 测试删除用户所有字段值 +func TestUserCustomFieldValueRepository_DeleteByUserID(t *testing.T) { + db := setupCustomFieldTestDB(t) + valueRepo := NewUserCustomFieldValueRepository(db) + ctx := context.Background() + + db.WithContext(ctx).Create(&domain.UserCustomFieldValue{ + UserID: 1, + FieldID: 1, + FieldKey: "multi1_key", + Value: "v1", + }) + db.WithContext(ctx).Create(&domain.UserCustomFieldValue{ + UserID: 1, + FieldID: 2, + FieldKey: "multi2_key", + Value: "v2", + }) + db.WithContext(ctx).Create(&domain.UserCustomFieldValue{ + UserID: 2, + FieldID: 1, + FieldKey: "multi1_key", + Value: "v3", + }) + + err := valueRepo.DeleteByUserID(ctx, 1) + if err != nil { + t.Fatalf("DeleteByUserID() error = %v", err) + } + + values, _ := valueRepo.GetByUserID(ctx, 1) + if len(values) != 0 { + t.Errorf("len(values) = %d, want 0", len(values)) + } + + // 用户2的值应该还在 + values2, _ := valueRepo.GetByUserID(ctx, 2) + if len(values2) != 1 { + t.Errorf("用户2的字段值应该保留, got %d", len(values2)) + } +} diff --git a/internal/repository/device_repository_test.go b/internal/repository/device_repository_test.go index 843742b..75f8d7c 100644 --- a/internal/repository/device_repository_test.go +++ b/internal/repository/device_repository_test.go @@ -13,6 +13,7 @@ import ( "gorm.io/gorm/logger" "github.com/user-management-system/internal/domain" + "github.com/user-management-system/internal/pagination" ) var deviceTestCounter int64 @@ -484,3 +485,91 @@ func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, use } return d, nil } + +// TestDeviceRepository_ListAllCursor 测试设备游标分页查询 +func TestDeviceRepository_ListAllCursor(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + // 创建设备(需要设置LastActiveTime以支持游标分页) + now := time.Now() + for i := 0; i < 5; i++ { + repo.Create(ctx, &domain.Device{ + UserID: int64(i + 1), + DeviceID: "cursor-device-" + string(rune('a'+i)), + DeviceName: "设备" + string(rune('0'+i)), + Status: domain.DeviceStatusActive, + LastActiveTime: now.Add(-time.Duration(i) * time.Minute), + }) + } + + // 第一次查询,获取前3个 + devices, hasMore, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, nil) + if err != nil { + t.Fatalf("ListAllCursor() error = %v", err) + } + if len(devices) != 3 { + t.Errorf("len(devices) = %d, want 3", len(devices)) + } + if !hasMore { + t.Error("hasMore should be true when more devices exist") + } + + // 使用游标继续查询 + lastDevice := devices[len(devices)-1] + cursor := &pagination.Cursor{ + LastID: lastDevice.ID, + LastValue: lastDevice.LastActiveTime, + } + devices2, hasMore2, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, cursor) + if err != nil { + t.Fatalf("ListAllCursor() error = %v", err) + } + if len(devices2) != 2 { + t.Errorf("len(devices2) = %d, want 2", len(devices2)) + } + if hasMore2 { + t.Error("hasMore2 should be false") + } +} + +// TestDeviceRepository_ListAllCursor_WithFilters 测试带筛选条件的设备游标分页 +func TestDeviceRepository_ListAllCursor_WithFilters(t *testing.T) { + db := setupDeviceTestDB(t) + repo := NewDeviceRepository(db) + ctx := context.Background() + + now := time.Now() + repo.Create(ctx, &domain.Device{ + UserID: 1, + DeviceID: "filter-dev1", + DeviceName: "用户1设备", + Status: domain.DeviceStatusActive, + LastActiveTime: now, + }) + repo.Create(ctx, &domain.Device{ + UserID: 2, + DeviceID: "filter-dev2", + DeviceName: "用户2设备", + Status: domain.DeviceStatusActive, + LastActiveTime: now, + }) + repo.Create(ctx, &domain.Device{ + UserID: 1, + DeviceID: "filter-dev3", + DeviceName: "用户1禁用设备", + Status: domain.DeviceStatusInactive, + LastActiveTime: now, + }) + + // 按用户ID筛选 + status := domain.DeviceStatusActive + devices, _, err := repo.ListAllCursor(ctx, &ListDevicesParams{UserID: 1, Status: &status, Offset: 0, Limit: 10}, 10, nil) + if err != nil { + t.Fatalf("ListAllCursor() error = %v", err) + } + if len(devices) != 1 { + t.Errorf("len(devices) = %d, want 1", len(devices)) + } +} diff --git a/internal/repository/login_log_repository_test.go b/internal/repository/login_log_repository_test.go new file mode 100644 index 0000000..4e62fa9 --- /dev/null +++ b/internal/repository/login_log_repository_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + _ "modernc.org/sqlite" + gormsqlite "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/user-management-system/internal/domain" + "github.com/user-management-system/internal/pagination" +) + +var loginLogTestCounter int64 + +func openLoginLogTestDB(t *testing.T) *gorm.DB { + t.Helper() + + id := atomic.AddInt64(&loginLogTestCounter, 1) + dsn := fmt.Sprintf("file:loginlogtestdb%d?mode=memory&cache=private", id) + + db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ + DriverName: "sqlite", + DSN: dsn, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("打开测试数据库失败: %v", err) + } + + if err := db.AutoMigrate(&domain.LoginLog{}); err != nil { + t.Fatalf("数据库迁移失败: %v", err) + } + return db +} + +func setupLoginLogTestDB(t *testing.T) *gorm.DB { + return openLoginLogTestDB(t) +} + +func TestLoginLogRepository_ListCursor(t *testing.T) { + db := setupLoginLogTestDB(t) + repo := NewLoginLogRepository(db) + ctx := context.Background() + + now := time.Now() + for i := 0; i < 5; i++ { + repo.Create(ctx, &domain.LoginLog{ + UserID: int64Ptr(int64(i + 1)), + LoginType: 1, + IP: "192.168.1." + string(rune('0'+i)), + Status: 1, + CreatedAt: now.Add(-time.Duration(i) * time.Minute), + }) + } + + // 第一次查询,获取前3个 + logs, hasMore, err := repo.ListCursor(ctx, 3, nil) + if err != nil { + t.Fatalf("ListCursor() error = %v", err) + } + if len(logs) != 3 { + t.Errorf("len(logs) = %d, want 3", len(logs)) + } + if !hasMore { + t.Error("hasMore should be true when more logs exist") + } + + // 使用游标继续查询 + lastLog := logs[len(logs)-1] + cursor := &pagination.Cursor{ + LastID: lastLog.ID, + LastValue: lastLog.CreatedAt, + } + logs2, hasMore2, err := repo.ListCursor(ctx, 3, cursor) + if err != nil { + t.Fatalf("ListCursor() error = %v", err) + } + if len(logs2) != 2 { + t.Errorf("len(logs2) = %d, want 2", len(logs2)) + } + if hasMore2 { + t.Error("hasMore2 should be false") + } +} + +func TestLoginLogRepository_ListByUserIDCursor(t *testing.T) { + db := setupLoginLogTestDB(t) + repo := NewLoginLogRepository(db) + ctx := context.Background() + + userID := int64(123) + now := time.Now() + for i := 0; i < 3; i++ { + repo.Create(ctx, &domain.LoginLog{ + UserID: int64Ptr(userID), + LoginType: 1, + IP: "192.168.1." + string(rune('0'+i)), + Status: 1, + CreatedAt: now.Add(-time.Duration(i) * time.Minute), + }) + } + // 另一个用户的日志 + repo.Create(ctx, &domain.LoginLog{ + UserID: int64Ptr(999), + LoginType: 1, + IP: "10.0.0.1", + Status: 1, + }) + + // 查询指定用户的日志 + logs, hasMore, err := repo.ListByUserIDCursor(ctx, userID, 10, nil) + if err != nil { + t.Fatalf("ListByUserIDCursor() error = %v", err) + } + if len(logs) != 3 { + t.Errorf("len(logs) = %d, want 3", len(logs)) + } + if hasMore { + t.Error("hasMore should be false") + } +} + +func TestLoginLogRepository_ListAllForExport(t *testing.T) { + db := setupLoginLogTestDB(t) + repo := NewLoginLogRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.LoginLog{ + UserID: int64Ptr(1), + LoginType: 1, + IP: "192.168.1.1", + Status: 1, + }) + repo.Create(ctx, &domain.LoginLog{ + UserID: int64Ptr(2), + LoginType: 2, + IP: "192.168.1.2", + Status: 0, + FailReason: "invalid password", + }) + + logs, err := repo.ListAllForExport(ctx, 0, -1, nil, nil) + if err != nil { + t.Fatalf("ListAllForExport() error = %v", err) + } + if len(logs) != 2 { + t.Errorf("len(logs) = %d, want 2", len(logs)) + } +} diff --git a/internal/repository/operation_log_repository_test.go b/internal/repository/operation_log_repository_test.go new file mode 100644 index 0000000..02fe112 --- /dev/null +++ b/internal/repository/operation_log_repository_test.go @@ -0,0 +1,94 @@ +package repository + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + _ "modernc.org/sqlite" + gormsqlite "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/user-management-system/internal/domain" + "github.com/user-management-system/internal/pagination" +) + +var operationLogTestCounter int64 + +func openOperationLogTestDB(t *testing.T) *gorm.DB { + t.Helper() + + id := atomic.AddInt64(&operationLogTestCounter, 1) + dsn := fmt.Sprintf("file:operationlogtestdb%d?mode=memory&cache=private", id) + + db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ + DriverName: "sqlite", + DSN: dsn, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("打开测试数据库失败: %v", err) + } + + if err := db.AutoMigrate(&domain.OperationLog{}); err != nil { + t.Fatalf("数据库迁移失败: %v", err) + } + return db +} + +func setupOperationLogTestDB(t *testing.T) *gorm.DB { + return openOperationLogTestDB(t) +} + +func TestOperationLogRepository_ListCursor(t *testing.T) { + db := setupOperationLogTestDB(t) + repo := NewOperationLogRepository(db) + ctx := context.Background() + + now := time.Now() + for i := 0; i < 5; i++ { + repo.Create(ctx, &domain.OperationLog{ + UserID: nil, + OperationType: "test", + OperationName: "测试操作" + string(rune('0'+i)), + RequestMethod: "GET", + RequestPath: "/api/test", + ResponseStatus: 200, + IP: "192.168.1." + string(rune('0'+i)), + CreatedAt: now.Add(-time.Duration(i) * time.Minute), + }) + } + + // 第一次查询,获取前3个 + logs, hasMore, err := repo.ListCursor(ctx, 3, nil) + if err != nil { + t.Fatalf("ListCursor() error = %v", err) + } + if len(logs) != 3 { + t.Errorf("len(logs) = %d, want 3", len(logs)) + } + if !hasMore { + t.Error("hasMore should be true when more logs exist") + } + + // 使用游标继续查询 + lastLog := logs[len(logs)-1] + cursor := &pagination.Cursor{ + LastID: lastLog.ID, + LastValue: lastLog.CreatedAt, + } + logs2, hasMore2, err := repo.ListCursor(ctx, 3, cursor) + if err != nil { + t.Fatalf("ListCursor() error = %v", err) + } + if len(logs2) != 2 { + t.Errorf("len(logs2) = %d, want 2", len(logs2)) + } + if hasMore2 { + t.Error("hasMore2 should be false") + } +} diff --git a/internal/repository/role_repository_test.go b/internal/repository/role_repository_test.go new file mode 100644 index 0000000..e9ca410 --- /dev/null +++ b/internal/repository/role_repository_test.go @@ -0,0 +1,90 @@ +package repository + +import ( + "context" + "testing" + + "github.com/user-management-system/internal/domain" +) + +func TestRoleRepository_GetAncestorIDs(t *testing.T) { + db := setupTestDB(t) + repo := NewRoleRepository(db) + ctx := context.Background() + + // 创建角色层级: grandchild -> child -> parent + parentID := int64(0) + parent := &domain.Role{Name: "parent", Code: "parent", ParentID: nil} + if err := repo.Create(ctx, parent); err != nil { + t.Fatalf("Create parent failed: %v", err) + } + parentID = parent.ID + + child := &domain.Role{Name: "child", Code: "child", ParentID: &parentID} + if err := repo.Create(ctx, child); err != nil { + t.Fatalf("Create child failed: %v", err) + } + childID := child.ID + + grandchild := &domain.Role{Name: "grandchild", Code: "grandchild", ParentID: &childID} + if err := repo.Create(ctx, grandchild); err != nil { + t.Fatalf("Create grandchild failed: %v", err) + } + + // 获取grandchild的祖先ID列表 + ancestorIDs, err := repo.GetAncestorIDs(ctx, grandchild.ID) + if err != nil { + t.Fatalf("GetAncestorIDs failed: %v", err) + } + if len(ancestorIDs) != 2 { + t.Errorf("len(ancestorIDs) = %d, want 2", len(ancestorIDs)) + } + if ancestorIDs[0] != childID { + t.Errorf("ancestorIDs[0] = %d, want %d", ancestorIDs[0], childID) + } + if ancestorIDs[1] != parentID { + t.Errorf("ancestorIDs[1] = %d, want %d", ancestorIDs[1], parentID) + } +} + +func TestRoleRepository_GetAncestors(t *testing.T) { + db := setupTestDB(t) + repo := NewRoleRepository(db) + ctx := context.Background() + + // 创建角色层级 + parentID := int64(0) + parent := &domain.Role{Name: "parent-role", Code: "parent-role", Status: domain.RoleStatusEnabled} + if err := repo.Create(ctx, parent); err != nil { + t.Fatalf("Create parent failed: %v", err) + } + parentID = parent.ID + + child := &domain.Role{Name: "child-role", Code: "child-role", ParentID: &parentID, Status: domain.RoleStatusEnabled} + if err := repo.Create(ctx, child); err != nil { + t.Fatalf("Create child failed: %v", err) + } + childID := child.ID + + grandchild := &domain.Role{Name: "grandchild-role", Code: "grandchild-role", ParentID: &childID, Status: domain.RoleStatusEnabled} + if err := repo.Create(ctx, grandchild); err != nil { + t.Fatalf("Create grandchild failed: %v", err) + } + + // 获取grandchild的完整继承链 + ancestors, err := repo.GetAncestors(ctx, grandchild.ID) + if err != nil { + t.Fatalf("GetAncestors failed: %v", err) + } + if len(ancestors) != 2 { + t.Errorf("len(ancestors) = %d, want 2", len(ancestors)) + } + // 第一个应该是parent + if ancestors[0].Code != "parent-role" { + t.Errorf("ancestors[0].Code = %s, want parent-role", ancestors[0].Code) + } + // 第二个应该是child + if ancestors[1].Code != "child-role" { + t.Errorf("ancestors[1].Code = %s, want child-role", ancestors[1].Code) + } +} diff --git a/internal/repository/social_account_repository_test.go b/internal/repository/social_account_repository_test.go new file mode 100644 index 0000000..2783691 --- /dev/null +++ b/internal/repository/social_account_repository_test.go @@ -0,0 +1,263 @@ +package repository + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + _ "modernc.org/sqlite" + gormsqlite "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/user-management-system/internal/domain" +) + +var socialAccountTestCounter int64 + +func openSocialAccountTestDB(t *testing.T) *gorm.DB { + t.Helper() + + id := atomic.AddInt64(&socialAccountTestCounter, 1) + dsn := fmt.Sprintf("file:socialaccounttestdb%d?mode=memory&cache=private", id) + + db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ + DriverName: "sqlite", + DSN: dsn, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("打开测试数据库失败: %v", err) + } + + if err := db.AutoMigrate(&domain.SocialAccount{}); err != nil { + t.Fatalf("数据库迁移失败: %v", err) + } + return db +} + +func setupSocialAccountTestDB(t *testing.T) *gorm.DB { + return openSocialAccountTestDB(t) +} + +func TestSocialAccountRepository_Create(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + account := &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-123", + Nickname: "testuser", + Status: domain.SocialAccountStatusActive, + } + + if err := repo.Create(ctx, account); err != nil { + t.Fatalf("Create() error = %v", err) + } + if account.ID == 0 { + t.Error("创建后账户ID不应为0") + } +} + +func TestSocialAccountRepository_GetByID(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + account := &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-getbyid", + Nickname: "getbyid-user", + Status: domain.SocialAccountStatusActive, + } + repo.Create(ctx, account) + + found, err := repo.GetByID(ctx, account.ID) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if found.Nickname != "getbyid-user" { + t.Errorf("Nickname = %v, want getbyid-user", found.Nickname) + } +} + +func TestSocialAccountRepository_GetByUserID(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + repo.Create(ctx, &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-user1-1", + Status: domain.SocialAccountStatusActive, + }) + repo.Create(ctx, &domain.SocialAccount{ + UserID: 1, + Provider: "wechat", + OpenID: "openid-user1-2", + Status: domain.SocialAccountStatusActive, + }) + repo.Create(ctx, &domain.SocialAccount{ + UserID: 2, + Provider: "github", + OpenID: "openid-user2", + Status: domain.SocialAccountStatusActive, + }) + + accounts, err := repo.GetByUserID(ctx, 1) + if err != nil { + t.Fatalf("GetByUserID() error = %v", err) + } + if len(accounts) != 2 { + t.Errorf("len(accounts) = %d, want 2", len(accounts)) + } +} + +func TestSocialAccountRepository_GetByProviderAndOpenID(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + account := &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "unique-openid-123", + Nickname: "github-user", + Status: domain.SocialAccountStatusActive, + } + repo.Create(ctx, account) + + found, err := repo.GetByProviderAndOpenID(ctx, "github", "unique-openid-123") + if err != nil { + t.Fatalf("GetByProviderAndOpenID() error = %v", err) + } + if found.UserID != 1 { + t.Errorf("UserID = %d, want 1", found.UserID) + } +} + +func TestSocialAccountRepository_Update(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + account := &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-update", + Nickname: "before-update", + Status: domain.SocialAccountStatusActive, + } + repo.Create(ctx, account) + + account.Nickname = "after-update" + if err := repo.Update(ctx, account); err != nil { + t.Fatalf("Update() error = %v", err) + } + + found, _ := repo.GetByID(ctx, account.ID) + if found.Nickname != "after-update" { + t.Errorf("Nickname = %v, want after-update", found.Nickname) + } +} + +func TestSocialAccountRepository_Delete(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + account := &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-delete", + Status: domain.SocialAccountStatusActive, + } + repo.Create(ctx, account) + + if err := repo.Delete(ctx, account.ID); err != nil { + t.Fatalf("Delete() error = %v", err) + } +} + +func TestSocialAccountRepository_DeleteByProviderAndUserID(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + repo.Create(ctx, &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-del-provider", + Status: domain.SocialAccountStatusActive, + }) + + err = repo.DeleteByProviderAndUserID(ctx, "github", 1) + if err != nil { + t.Fatalf("DeleteByProviderAndUserID() error = %v", err) + } + + accounts, _ := repo.GetByUserID(ctx, 1) + if len(accounts) != 0 { + t.Errorf("len(accounts) = %d, want 0 after delete", len(accounts)) + } +} + +func TestSocialAccountRepository_List(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + repo.Create(ctx, &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-list-1", + Status: domain.SocialAccountStatusActive, + }) + repo.Create(ctx, &domain.SocialAccount{ + UserID: 2, + Provider: "wechat", + OpenID: "openid-list-2", + Status: domain.SocialAccountStatusActive, + }) + + accounts, total, err := repo.List(ctx, 0, 10) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(accounts) != 2 { + t.Errorf("len(accounts) = %d, want 2", len(accounts)) + } + if total != 2 { + t.Errorf("total = %d, want 2", total) + } +} diff --git a/internal/repository/theme_repository_test.go b/internal/repository/theme_repository_test.go new file mode 100644 index 0000000..1552c92 --- /dev/null +++ b/internal/repository/theme_repository_test.go @@ -0,0 +1,275 @@ +package repository + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + _ "modernc.org/sqlite" + gormsqlite "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/user-management-system/internal/domain" +) + +var themeTestCounter int64 + +// openThemeTestDB 为每个测试打开独立的内存数据库 +func openThemeTestDB(t *testing.T) *gorm.DB { + t.Helper() + + id := atomic.AddInt64(&themeTestCounter, 1) + dsn := fmt.Sprintf("file:themetestdb%d?mode=memory&cache=private", id) + + db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ + DriverName: "sqlite", + DSN: dsn, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("打开测试数据库失败: %v", err) + } + + if err := db.AutoMigrate(&domain.ThemeConfig{}); err != nil { + t.Fatalf("数据库迁移失败: %v", err) + } + return db +} + +// setupThemeTestDB 兼容性别名 +func setupThemeTestDB(t *testing.T) *gorm.DB { + return openThemeTestDB(t) +} + +// TestThemeConfigRepository_Create 测试创建主题 +func TestThemeConfigRepository_Create(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + theme := &domain.ThemeConfig{ + Name: "test-theme", + PrimaryColor: "#ff0000", + SecondaryColor: "#00ff00", + Enabled: true, + } + + if err := repo.Create(ctx, theme); err != nil { + t.Fatalf("Create() error = %v", err) + } + if theme.ID == 0 { + t.Error("创建后主题ID不应为0") + } +} + +// TestThemeConfigRepository_GetByID 测试根据ID获取主题 +func TestThemeConfigRepository_GetByID(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + theme := &domain.ThemeConfig{ + Name: "getbyid-theme", + PrimaryColor: "#0000ff", + Enabled: true, + } + repo.Create(ctx, theme) + + found, err := repo.GetByID(ctx, theme.ID) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if found.Name != "getbyid-theme" { + t.Errorf("Name = %v, want getbyid-theme", found.Name) + } +} + +// TestThemeConfigRepository_GetByName 测试根据名称获取主题 +func TestThemeConfigRepository_GetByName(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + theme := &domain.ThemeConfig{ + Name: "unique-theme-name", + PrimaryColor: "#ffff00", + Enabled: true, + } + repo.Create(ctx, theme) + + found, err := repo.GetByName(ctx, "unique-theme-name") + if err != nil { + t.Fatalf("GetByName() error = %v", err) + } + if found.ID != theme.ID { + t.Errorf("ID = %v, want %v", found.ID, theme.ID) + } +} + +// TestThemeConfigRepository_GetByName_NotFound 测试名称不存在 +func TestThemeConfigRepository_GetByName_NotFound(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + _, err := repo.GetByName(ctx, "not-exist-theme") + if err == nil { + t.Error("GetByName() should return error for non-existent theme") + } +} + +// TestThemeConfigRepository_Update 测试更新主题 +func TestThemeConfigRepository_Update(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + theme := &domain.ThemeConfig{ + Name: "update-test", + PrimaryColor: "#000000", + Enabled: true, + } + repo.Create(ctx, theme) + + theme.PrimaryColor = "#ffffff" + if err := repo.Update(ctx, theme); err != nil { + t.Fatalf("Update() error = %v", err) + } + + found, _ := repo.GetByID(ctx, theme.ID) + if found.PrimaryColor != "#ffffff" { + t.Errorf("PrimaryColor = %v, want #ffffff", found.PrimaryColor) + } +} + +// TestThemeConfigRepository_Delete 测试删除主题 +func TestThemeConfigRepository_Delete(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + theme := &domain.ThemeConfig{ + Name: "delete-test", + Enabled: true, + } + repo.Create(ctx, theme) + + if err := repo.Delete(ctx, theme.ID); err != nil { + t.Fatalf("Delete() error = %v", err) + } + + _, err := repo.GetByID(ctx, theme.ID) + if err == nil { + t.Error("删除后查询应返回错误") + } +} + +// TestThemeConfigRepository_List 测试获取已启用主题列表 +func TestThemeConfigRepository_List(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.ThemeConfig{Name: "enabled1", Enabled: true}) + repo.Create(ctx, &domain.ThemeConfig{Name: "enabled2", Enabled: true}) + repo.Create(ctx, &domain.ThemeConfig{Name: "disabled1", Enabled: false}) + + themes, err := repo.List(ctx) + if err != nil { + t.Fatalf("List() error = %v", err) + } + // List filters by enabled=true + if len(themes) < 2 { + t.Errorf("len(themes) = %d, want at least 2", len(themes)) + } +} + +// TestThemeConfigRepository_ListAll 测试获取所有主题列表 +func TestThemeConfigRepository_ListAll(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.ThemeConfig{Name: "all1", Enabled: true}) + repo.Create(ctx, &domain.ThemeConfig{Name: "all2", Enabled: false}) + + themes, err := repo.ListAll(ctx) + if err != nil { + t.Fatalf("ListAll() error = %v", err) + } + if len(themes) != 2 { + t.Errorf("len(themes) = %d, want 2", len(themes)) + } +} + +// TestThemeConfigRepository_GetDefault 测试获取默认主题 +func TestThemeConfigRepository_GetDefault(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + // 创建一个默认主题 + repo.Create(ctx, &domain.ThemeConfig{ + Name: "default-theme", + IsDefault: true, + Enabled: true, + }) + + defaultTheme, err := repo.GetDefault(ctx) + if err != nil { + t.Fatalf("GetDefault() error = %v", err) + } + if defaultTheme.Name != "default-theme" { + t.Errorf("Name = %v, want default-theme", defaultTheme.Name) + } +} + +// TestThemeConfigRepository_GetDefault_NoDefault 测试无默认主题时返回默认配置 +func TestThemeConfigRepository_GetDefault_NoDefault(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + // 不创建任何主题 + defaultTheme, err := repo.GetDefault(ctx) + if err != nil { + t.Fatalf("GetDefault() error = %v", err) + } + // 应该返回内置默认配置 + if defaultTheme.Name != "default" { + t.Errorf("Name = %v, want default", defaultTheme.Name) + } +} + +// TestThemeConfigRepository_SetDefault 测试设置默认主题 +func TestThemeConfigRepository_SetDefault(t *testing.T) { + db := setupThemeTestDB(t) + repo := NewThemeConfigRepository(db) + ctx := context.Background() + + // 创建两个主题 + theme1 := &domain.ThemeConfig{Name: "theme1", IsDefault: true, Enabled: true} + theme2 := &domain.ThemeConfig{Name: "theme2", IsDefault: false, Enabled: true} + repo.Create(ctx, theme1) + repo.Create(ctx, theme2) + + // 设置 theme2 为默认 + if err := repo.SetDefault(ctx, theme2.ID); err != nil { + t.Fatalf("SetDefault() error = %v", err) + } + + // 验证 theme1 不再是默认 + t1, _ := repo.GetByID(ctx, theme1.ID) + if t1.IsDefault { + t.Error("theme1 should not be default anymore") + } + + // 验证 theme2 现在是默认 + t2, _ := repo.GetByID(ctx, theme2.ID) + if !t2.IsDefault { + t.Error("theme2 should be default") + } +} diff --git a/internal/repository/user_repository_test.go b/internal/repository/user_repository_test.go index 0ed891d..7d9b9b7 100644 --- a/internal/repository/user_repository_test.go +++ b/internal/repository/user_repository_test.go @@ -3,6 +3,7 @@ package repository import ( "context" "testing" + "time" "gorm.io/gorm" @@ -401,3 +402,259 @@ func TestUserRepository_Search_LikePattern(t *testing.T) { // Should not error and should escape properly _ = users } + +// TestUserRepository_GetByIDs 测试批量获取用户 +func TestUserRepository_GetByIDs(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + u1 := &domain.User{Username: "batchuser1", Password: "hash", Status: domain.UserStatusActive} + u2 := &domain.User{Username: "batchuser2", Password: "hash", Status: domain.UserStatusActive} + u3 := &domain.User{Username: "batchuser3", Password: "hash", Status: domain.UserStatusActive} + repo.Create(ctx, u1) + repo.Create(ctx, u2) + repo.Create(ctx, u3) + + users, err := repo.GetByIDs(ctx, []int64{u1.ID, u3.ID}) + if err != nil { + t.Fatalf("GetByIDs() error = %v", err) + } + if len(users) != 2 { + t.Errorf("len(users) = %d, want 2", len(users)) + } +} + +// TestUserRepository_GetByIDs_Empty 测试空ID列表 +func TestUserRepository_GetByIDs_Empty(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + users, err := repo.GetByIDs(ctx, []int64{}) + if err != nil { + t.Fatalf("GetByIDs() error = %v", err) + } + if len(users) != 0 { + t.Errorf("len(users) = %d, want 0", len(users)) + } +} + +// TestUserRepository_UpdatePassword 测试更新密码 +func TestUserRepository_UpdatePassword(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + user := &domain.User{ + Username: "pwduser", + Password: "oldpassword", + Status: domain.UserStatusActive, + } + repo.Create(ctx, user) + + err := repo.UpdatePassword(ctx, user.ID, "newpasswordhash") + if err != nil { + t.Fatalf("UpdatePassword() error = %v", err) + } + + found, _ := repo.GetByID(ctx, user.ID) + if found.Password != "newpasswordhash" { + t.Errorf("Password = %v, want newpasswordhash", found.Password) + } +} + +// TestUserRepository_UpdateTOTP 测试更新TOTP +func TestUserRepository_UpdateTOTP(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + user := &domain.User{ + Username: "totpuser", + Password: "hash", + Status: domain.UserStatusActive, + } + repo.Create(ctx, user) + + user.TOTPEnabled = true + user.TOTPSecret = "JBSWY3DPEHPK3PXP" + err := repo.UpdateTOTP(ctx, user) + if err != nil { + t.Fatalf("UpdateTOTP() error = %v", err) + } + + found, _ := repo.GetByID(ctx, user.ID) + if !found.TOTPEnabled { + t.Error("TOTPEnabled should be true") + } + if found.TOTPSecret != "JBSWY3DPEHPK3PXP" { + t.Errorf("TOTPSecret = %v, want JBSWY3DPEHPK3PXP", found.TOTPSecret) + } +} + +// TestUserRepository_ListCreatedAfter 测试查询创建时间之后的用户 +func TestUserRepository_ListCreatedAfter(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + user := &domain.User{ + Username: "afteruser", + Password: "hash", + Status: domain.UserStatusActive, + } + repo.Create(ctx, user) + + since := user.CreatedAt.Add(-1 * time.Hour) + users, total, err := repo.ListCreatedAfter(ctx, since, 0, 10) + if err != nil { + t.Fatalf("ListCreatedAfter() error = %v", err) + } + if total < 1 { + t.Errorf("total = %d, want at least 1", total) + } + _ = users +} + +// TestUserRepository_ListCreatedAfter_Limited 测试带limit的查询 +func TestUserRepository_ListCreatedAfter_Limited(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + for i := 0; i < 5; i++ { + repo.Create(ctx, &domain.User{ + Username: "limituser" + string(rune('0'+i)), + Password: "hash", + Status: domain.UserStatusActive, + }) + } + + since := time.Now().Add(-1 * time.Hour) + users, total, err := repo.ListCreatedAfter(ctx, since, 0, 3) + if err != nil { + t.Fatalf("ListCreatedAfter() error = %v", err) + } + if len(users) != 3 { + t.Errorf("len(users) = %d, want 3", len(users)) + } + if total < 5 { + t.Errorf("total = %d, want at least 5", total) + } +} + +// TestUserRepository_AdvancedSearch 测试高级搜索 +func TestUserRepository_AdvancedSearch(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.User{ + Username: "searchuser1", + Nickname: "张三", + Email: domain.StrPtr("zhangsan@example.com"), + Password: "hash", + Status: domain.UserStatusActive, + }) + repo.Create(ctx, &domain.User{ + Username: "searchuser2", + Nickname: "李四", + Email: domain.StrPtr("lisi@example.com"), + Password: "hash", + Status: domain.UserStatusActive, + }) + repo.Create(ctx, &domain.User{ + Username: "searchuser3", + Nickname: "王五", + Email: domain.StrPtr("wangwu@example.com"), + Password: "hash", + Status: domain.UserStatusInactive, + }) + + // 按关键字搜索(Status=-1 表示全部状态) + filter := &AdvancedFilter{Keyword: "searchuser1", Status: -1, Offset: 0, Limit: 10} + users, total, err := repo.AdvancedSearch(ctx, filter) + if err != nil { + t.Fatalf("AdvancedSearch() error = %v", err) + } + if len(users) != 1 { + t.Errorf("len(users) = %d, want 1", len(users)) + } + if total != 1 { + t.Errorf("total = %d, want 1", total) + } + + // 按状态筛选 + filter2 := &AdvancedFilter{Status: int(domain.UserStatusActive), Offset: 0, Limit: 10} + users2, total2, err := repo.AdvancedSearch(ctx, filter2) + if err != nil { + t.Fatalf("AdvancedSearch() error = %v", err) + } + if len(users2) != 2 { + t.Errorf("len(users2) = %d, want 2", len(users2)) + } + if total2 != 2 { + t.Errorf("total2 = %d, want 2", total2) + } + + // 按状态筛选 - 禁用用户 + filter3 := &AdvancedFilter{Status: int(domain.UserStatusInactive), Offset: 0, Limit: 10} + users3, total3, err := repo.AdvancedSearch(ctx, filter3) + if err != nil { + t.Fatalf("AdvancedSearch() error = %v", err) + } + if len(users3) != 1 { + t.Errorf("len(users3) = %d, want 1", len(users3)) + } + if total3 != 1 { + t.Errorf("total3 = %d, want 1", total3) + } +} + +// TestUserRepository_AdvancedSearch_AllStatus 测试状态为-1返回全部 +func TestUserRepository_AdvancedSearch_AllStatus(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.User{Username: "active", Password: "hash", Status: domain.UserStatusActive}) + repo.Create(ctx, &domain.User{Username: "inactive", Password: "hash", Status: domain.UserStatusInactive}) + + filter := &AdvancedFilter{Status: -1, Offset: 0, Limit: 10} + users, total, err := repo.AdvancedSearch(ctx, filter) + if err != nil { + t.Fatalf("AdvancedSearch() error = %v", err) + } + if len(users) != 2 { + t.Errorf("len(users) = %d, want 2", len(users)) + } + if total != 2 { + t.Errorf("total = %d, want 2", total) + } +} + +// TestUserRepository_AdvancedSearch_LikeSpecialChars 测试搜索LIKE特殊字符转义 +func TestUserRepository_AdvancedSearch_LikeSpecialChars(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRepository(db) + ctx := context.Background() + + repo.Create(ctx, &domain.User{ + Username: "user%with%percent", + Nickname: "测试用户", + Password: "hash", + Status: domain.UserStatusActive, + }) + + // 搜索%应该不匹配任何记录(被转义) + filter := &AdvancedFilter{Keyword: "%", Offset: 0, Limit: 10} + users, _, err := repo.AdvancedSearch(ctx, filter) + if err != nil { + t.Fatalf("AdvancedSearch() error = %v", err) + } + if len(users) != 0 { + t.Errorf("len(users) = %d, want 0 for escaped percent", len(users)) + } +} + diff --git a/internal/repository/user_role_repository_test.go b/internal/repository/user_role_repository_test.go new file mode 100644 index 0000000..69e2715 --- /dev/null +++ b/internal/repository/user_role_repository_test.go @@ -0,0 +1,36 @@ +package repository + +import ( + "context" + "testing" + + "github.com/user-management-system/internal/domain" +) + +func TestUserRoleRepository_DeleteByUserAndRole(t *testing.T) { + db := setupTestDB(t) + repo := NewUserRoleRepository(db) + ctx := context.Background() + + // 创建用户和角色 + user := &domain.User{Username: "roleuser", Password: "hash", Status: domain.UserStatusActive} + db.WithContext(ctx).Create(user) + + role := &domain.Role{Code: "test_role", Name: "测试角色", Status: domain.RoleStatusEnabled} + db.WithContext(ctx).Create(role) + + // 创建用户角色关联 + repo.Create(ctx, &domain.UserRole{UserID: user.ID, RoleID: role.ID}) + + // 删除特定用户-角色关联 + err := repo.DeleteByUserAndRole(ctx, user.ID, role.ID) + if err != nil { + t.Fatalf("DeleteByUserAndRole() error = %v", err) + } + + // 验证已删除 + exists, _ := repo.Exists(ctx, user.ID, role.ID) + if exists { + t.Error("DeleteByUserAndRole should have removed the association") + } +} diff --git a/internal/repository/webhook_repository_test.go b/internal/repository/webhook_repository_test.go index 08508f3..5cf1fb3 100644 --- a/internal/repository/webhook_repository_test.go +++ b/internal/repository/webhook_repository_test.go @@ -188,3 +188,40 @@ func TestWebhookRepositoryCreateAndListDeliveries(t *testing.T) { t.Fatal("expected deliveries to be returned in reverse created_at order") } } + +func TestWebhookRepositoryListByCreatorPaginated(t *testing.T) { + repo := setupWebhookRepository(t) + ctx := context.Background() + + // 创建多个webhook + for i := 0; i < 5; i++ { + if err := repo.Create(ctx, newWebhookFixture("wh-creator1-"+string(rune('a'+i)), 1, domain.WebhookStatusActive)); err != nil { + t.Fatalf("Create failed: %v", err) + } + } + // 另一个用户的webhook + if err := repo.Create(ctx, newWebhookFixture("wh-creator2", 2, domain.WebhookStatusActive)); err != nil { + t.Fatalf("Create failed: %v", err) + } + + // 测试分页查询创建者1的webhook + webhooks, total, err := repo.ListByCreatorPaginated(ctx, 1, 0, 3) + if err != nil { + t.Fatalf("ListByCreatorPaginated failed: %v", err) + } + if len(webhooks) != 3 { + t.Errorf("len(webhooks) = %d, want 3", len(webhooks)) + } + if total != 5 { + t.Errorf("total = %d, want 5", total) + } + + // 测试第二页 + webhooks2, _, err := repo.ListByCreatorPaginated(ctx, 1, 3, 3) + if err != nil { + t.Fatalf("ListByCreatorPaginated page 2 failed: %v", err) + } + if len(webhooks2) != 2 { + t.Errorf("len(webhooks2) = %d, want 2", len(webhooks2)) + } +}