test: add repository tests to improve coverage from 46.6% to 74%

New test files:
- custom_field_repository_test.go: 10 tests for CustomFieldRepository & UserCustomFieldValueRepository
- login_log_repository_test.go: 3 tests for ListCursor, ListByUserIDCursor, ListAllForExport
- operation_log_repository_test.go: 1 test for ListCursor
- role_repository_test.go: 2 tests for GetAncestorIDs, GetAncestors
- social_account_repository_test.go: 8 CRUD tests
- theme_repository_test.go: 10 tests for ThemeConfigRepository
- user_role_repository_test.go: 1 test for DeleteByUserAndRole

Modified test files:
- device_repository_test.go: Added ListAllCursor tests
- user_repository_test.go: Added AdvancedSearch tests
- webhook_repository_test.go: Added ListByCreatorPaginated test

Updated documentation with new coverage status.
This commit is contained in:
2026-04-11 21:58:28 +08:00
parent b1311ea144
commit 289aab2930
11 changed files with 1630 additions and 0 deletions

View File

@@ -56,6 +56,7 @@ RBAC/admin 改动必须验证:
| `.gitattributes` | ✅ 已添加 | 统一行尾符为 LF消除 LF/CRLF 污染) |
| Swagger 注解 | ✅ 已添加 | 13 个 handler 共 86 处 `@Summary/@Description/@Tags/@Param/@Router` 注解 |
| Device Repository 测试 | ✅ 已添加 | 15 个测试用例覆盖 DeviceRepository CRUD |
| Repository 测试覆盖率 | ✅ 已提升 | 从 46.6% 提升至 74%(目标 80%|
## 最新验证结果

View File

@@ -0,0 +1,332 @@
package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
_ "modernc.org/sqlite"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
)
var customFieldTestCounter int64
// openCustomFieldTestDB 为每个测试打开独立的内存数据库
func openCustomFieldTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&customFieldTestCounter, 1)
dsn := fmt.Sprintf("file:customfieldtestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.CustomField{}, &domain.UserCustomFieldValue{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
// setupCustomFieldTestDB 兼容性别名
func setupCustomFieldTestDB(t *testing.T) *gorm.DB {
return openCustomFieldTestDB(t)
}
// TestCustomFieldRepository_Create 测试创建自定义字段
func TestCustomFieldRepository_Create(t *testing.T) {
db := setupCustomFieldTestDB(t)
repo := NewCustomFieldRepository(db)
ctx := context.Background()
field := &domain.CustomField{
Name: "测试字段",
FieldKey: "test_field",
Type: domain.CustomFieldTypeString,
Required: false,
Sort: 1,
}
if err := repo.Create(ctx, field); err != nil {
t.Fatalf("Create() error = %v", err)
}
if field.ID == 0 {
t.Error("创建后字段ID不应为0")
}
}
// TestCustomFieldRepository_GetByID 测试根据ID获取字段
func TestCustomFieldRepository_GetByID(t *testing.T) {
db := setupCustomFieldTestDB(t)
repo := NewCustomFieldRepository(db)
ctx := context.Background()
field := &domain.CustomField{
Name: "getbyid-field",
FieldKey: "getbyid_key",
Type: domain.CustomFieldTypeNumber,
}
repo.Create(ctx, field)
found, err := repo.GetByID(ctx, field.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if found.Name != "getbyid-field" {
t.Errorf("Name = %v, want getbyid-field", found.Name)
}
_, err = repo.GetByID(ctx, 9999)
if err == nil {
t.Error("GetByID() should return error for non-existent ID")
}
}
// TestCustomFieldRepository_GetByFieldKey 测试根据FieldKey获取字段
func TestCustomFieldRepository_GetByFieldKey(t *testing.T) {
db := setupCustomFieldTestDB(t)
repo := NewCustomFieldRepository(db)
ctx := context.Background()
field := &domain.CustomField{
Name: "field-by-key",
FieldKey: "unique_field_key",
Type: domain.CustomFieldTypeBoolean,
}
repo.Create(ctx, field)
found, err := repo.GetByFieldKey(ctx, "unique_field_key")
if err != nil {
t.Fatalf("GetByFieldKey() error = %v", err)
}
if found.Name != "field-by-key" {
t.Errorf("Name = %v, want field-by-key", found.Name)
}
_, err = repo.GetByFieldKey(ctx, "not_exist_key")
if err == nil {
t.Error("GetByFieldKey() should return error for non-existent key")
}
}
// TestCustomFieldRepository_Update 测试更新字段
func TestCustomFieldRepository_Update(t *testing.T) {
db := setupCustomFieldTestDB(t)
repo := NewCustomFieldRepository(db)
ctx := context.Background()
field := &domain.CustomField{
Name: "before-update",
FieldKey: "update_key",
Type: domain.CustomFieldTypeString,
}
repo.Create(ctx, field)
field.Name = "after-update"
field.Required = true
if err := repo.Update(ctx, field); err != nil {
t.Fatalf("Update() error = %v", err)
}
found, _ := repo.GetByID(ctx, field.ID)
if found.Name != "after-update" {
t.Errorf("Name = %v, want after-update", found.Name)
}
if !found.Required {
t.Error("Required should be true after update")
}
}
// TestCustomFieldRepository_Delete 测试删除字段
func TestCustomFieldRepository_Delete(t *testing.T) {
db := setupCustomFieldTestDB(t)
repo := NewCustomFieldRepository(db)
ctx := context.Background()
field := &domain.CustomField{
Name: "to-delete",
FieldKey: "delete_key",
Type: domain.CustomFieldTypeDate,
}
repo.Create(ctx, field)
if err := repo.Delete(ctx, field.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
_, err := repo.GetByID(ctx, field.ID)
if err == nil {
t.Error("删除后查询应返回错误")
}
}
// TestCustomFieldRepository_List 测试获取启用字段列表
func TestCustomFieldRepository_List(t *testing.T) {
db := setupCustomFieldTestDB(t)
repo := NewCustomFieldRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.CustomField{Name: "enabled1", FieldKey: "enabled1_key", Type: domain.CustomFieldTypeString})
repo.Create(ctx, &domain.CustomField{Name: "enabled2", FieldKey: "enabled2_key", Type: domain.CustomFieldTypeNumber})
repo.Create(ctx, &domain.CustomField{Name: "enabled3", FieldKey: "enabled3_key", Type: domain.CustomFieldTypeBoolean})
fields, err := repo.List(ctx)
if err != nil {
t.Fatalf("List() error = %v", err)
}
// List filters by status=1, all 3 have status=1 (default)
if len(fields) != 3 {
t.Errorf("len(fields) = %d, want 3", len(fields))
}
}
// TestCustomFieldRepository_ListAll 测试获取所有字段列表
func TestCustomFieldRepository_ListAll(t *testing.T) {
db := setupCustomFieldTestDB(t)
repo := NewCustomFieldRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.CustomField{Name: "all1", FieldKey: "all1_key", Type: domain.CustomFieldTypeString})
repo.Create(ctx, &domain.CustomField{Name: "all2", FieldKey: "all2_key", Type: domain.CustomFieldTypeNumber})
fields, err := repo.ListAll(ctx)
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if len(fields) != 2 {
t.Errorf("len(fields) = %d, want 2", len(fields))
}
}
// TestUserCustomFieldValueRepository_GetByUserID 测试获取用户所有字段值
func TestUserCustomFieldValueRepository_GetByUserID(t *testing.T) {
db := setupCustomFieldTestDB(t)
valueRepo := NewUserCustomFieldValueRepository(db)
ctx := context.Background()
// 直接使用 GORM Create 测试,因为 Set 使用 NOW() 不兼容 SQLite
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
UserID: 1,
FieldID: 1,
FieldKey: "field1_key",
Value: "value1",
})
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
UserID: 1,
FieldID: 2,
FieldKey: "field2_key",
Value: "value2",
})
values, err := valueRepo.GetByUserID(ctx, 1)
if err != nil {
t.Fatalf("GetByUserID() error = %v", err)
}
if len(values) != 2 {
t.Errorf("len(values) = %d, want 2", len(values))
}
}
// TestUserCustomFieldValueRepository_GetByUserIDAndFieldKey 测试获取用户指定字段值
func TestUserCustomFieldValueRepository_GetByUserIDAndFieldKey(t *testing.T) {
db := setupCustomFieldTestDB(t)
valueRepo := NewUserCustomFieldValueRepository(db)
ctx := context.Background()
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
UserID: 1,
FieldID: 1,
FieldKey: "specific_key",
Value: "specific_value",
})
found, err := valueRepo.GetByUserIDAndFieldKey(ctx, 1, "specific_key")
if err != nil {
t.Fatalf("GetByUserIDAndFieldKey() error = %v", err)
}
if found.Value != "specific_value" {
t.Errorf("Value = %v, want specific_value", found.Value)
}
_, err = valueRepo.GetByUserIDAndFieldKey(ctx, 1, "non_existent_key")
if err == nil {
t.Error("GetByUserIDAndFieldKey() should return error for non-existent key")
}
}
// TestUserCustomFieldValueRepository_Delete 测试删除用户字段值
func TestUserCustomFieldValueRepository_Delete(t *testing.T) {
db := setupCustomFieldTestDB(t)
valueRepo := NewUserCustomFieldValueRepository(db)
ctx := context.Background()
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
UserID: 1,
FieldID: 1,
FieldKey: "delete_key",
Value: "to_delete",
})
err := valueRepo.Delete(ctx, 1, 1)
if err != nil {
t.Fatalf("Delete() error = %v", err)
}
_, err = valueRepo.GetByUserIDAndFieldKey(ctx, 1, "delete_key")
if err == nil {
t.Error("删除后查询应返回错误")
}
}
// TestUserCustomFieldValueRepository_DeleteByUserID 测试删除用户所有字段值
func TestUserCustomFieldValueRepository_DeleteByUserID(t *testing.T) {
db := setupCustomFieldTestDB(t)
valueRepo := NewUserCustomFieldValueRepository(db)
ctx := context.Background()
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
UserID: 1,
FieldID: 1,
FieldKey: "multi1_key",
Value: "v1",
})
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
UserID: 1,
FieldID: 2,
FieldKey: "multi2_key",
Value: "v2",
})
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
UserID: 2,
FieldID: 1,
FieldKey: "multi1_key",
Value: "v3",
})
err := valueRepo.DeleteByUserID(ctx, 1)
if err != nil {
t.Fatalf("DeleteByUserID() error = %v", err)
}
values, _ := valueRepo.GetByUserID(ctx, 1)
if len(values) != 0 {
t.Errorf("len(values) = %d, want 0", len(values))
}
// 用户2的值应该还在
values2, _ := valueRepo.GetByUserID(ctx, 2)
if len(values2) != 1 {
t.Errorf("用户2的字段值应该保留, got %d", len(values2))
}
}

View File

@@ -13,6 +13,7 @@ import (
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
var deviceTestCounter int64
@@ -484,3 +485,91 @@ func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, use
}
return d, nil
}
// TestDeviceRepository_ListAllCursor 测试设备游标分页查询
func TestDeviceRepository_ListAllCursor(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
// 创建设备需要设置LastActiveTime以支持游标分页
now := time.Now()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.Device{
UserID: int64(i + 1),
DeviceID: "cursor-device-" + string(rune('a'+i)),
DeviceName: "设备" + string(rune('0'+i)),
Status: domain.DeviceStatusActive,
LastActiveTime: now.Add(-time.Duration(i) * time.Minute),
})
}
// 第一次查询获取前3个
devices, hasMore, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, nil)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices) != 3 {
t.Errorf("len(devices) = %d, want 3", len(devices))
}
if !hasMore {
t.Error("hasMore should be true when more devices exist")
}
// 使用游标继续查询
lastDevice := devices[len(devices)-1]
cursor := &pagination.Cursor{
LastID: lastDevice.ID,
LastValue: lastDevice.LastActiveTime,
}
devices2, hasMore2, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, cursor)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices2) != 2 {
t.Errorf("len(devices2) = %d, want 2", len(devices2))
}
if hasMore2 {
t.Error("hasMore2 should be false")
}
}
// TestDeviceRepository_ListAllCursor_WithFilters 测试带筛选条件的设备游标分页
func TestDeviceRepository_ListAllCursor_WithFilters(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
now := time.Now()
repo.Create(ctx, &domain.Device{
UserID: 1,
DeviceID: "filter-dev1",
DeviceName: "用户1设备",
Status: domain.DeviceStatusActive,
LastActiveTime: now,
})
repo.Create(ctx, &domain.Device{
UserID: 2,
DeviceID: "filter-dev2",
DeviceName: "用户2设备",
Status: domain.DeviceStatusActive,
LastActiveTime: now,
})
repo.Create(ctx, &domain.Device{
UserID: 1,
DeviceID: "filter-dev3",
DeviceName: "用户1禁用设备",
Status: domain.DeviceStatusInactive,
LastActiveTime: now,
})
// 按用户ID筛选
status := domain.DeviceStatusActive
devices, _, err := repo.ListAllCursor(ctx, &ListDevicesParams{UserID: 1, Status: &status, Offset: 0, Limit: 10}, 10, nil)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices) != 1 {
t.Errorf("len(devices) = %d, want 1", len(devices))
}
}

View File

@@ -0,0 +1,156 @@
package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
_ "modernc.org/sqlite"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
var loginLogTestCounter int64
func openLoginLogTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&loginLogTestCounter, 1)
dsn := fmt.Sprintf("file:loginlogtestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.LoginLog{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
func setupLoginLogTestDB(t *testing.T) *gorm.DB {
return openLoginLogTestDB(t)
}
func TestLoginLogRepository_ListCursor(t *testing.T) {
db := setupLoginLogTestDB(t)
repo := NewLoginLogRepository(db)
ctx := context.Background()
now := time.Now()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(int64(i + 1)),
LoginType: 1,
IP: "192.168.1." + string(rune('0'+i)),
Status: 1,
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
})
}
// 第一次查询获取前3个
logs, hasMore, err := repo.ListCursor(ctx, 3, nil)
if err != nil {
t.Fatalf("ListCursor() error = %v", err)
}
if len(logs) != 3 {
t.Errorf("len(logs) = %d, want 3", len(logs))
}
if !hasMore {
t.Error("hasMore should be true when more logs exist")
}
// 使用游标继续查询
lastLog := logs[len(logs)-1]
cursor := &pagination.Cursor{
LastID: lastLog.ID,
LastValue: lastLog.CreatedAt,
}
logs2, hasMore2, err := repo.ListCursor(ctx, 3, cursor)
if err != nil {
t.Fatalf("ListCursor() error = %v", err)
}
if len(logs2) != 2 {
t.Errorf("len(logs2) = %d, want 2", len(logs2))
}
if hasMore2 {
t.Error("hasMore2 should be false")
}
}
func TestLoginLogRepository_ListByUserIDCursor(t *testing.T) {
db := setupLoginLogTestDB(t)
repo := NewLoginLogRepository(db)
ctx := context.Background()
userID := int64(123)
now := time.Now()
for i := 0; i < 3; i++ {
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(userID),
LoginType: 1,
IP: "192.168.1." + string(rune('0'+i)),
Status: 1,
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
})
}
// 另一个用户的日志
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(999),
LoginType: 1,
IP: "10.0.0.1",
Status: 1,
})
// 查询指定用户的日志
logs, hasMore, err := repo.ListByUserIDCursor(ctx, userID, 10, nil)
if err != nil {
t.Fatalf("ListByUserIDCursor() error = %v", err)
}
if len(logs) != 3 {
t.Errorf("len(logs) = %d, want 3", len(logs))
}
if hasMore {
t.Error("hasMore should be false")
}
}
func TestLoginLogRepository_ListAllForExport(t *testing.T) {
db := setupLoginLogTestDB(t)
repo := NewLoginLogRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(1),
LoginType: 1,
IP: "192.168.1.1",
Status: 1,
})
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(2),
LoginType: 2,
IP: "192.168.1.2",
Status: 0,
FailReason: "invalid password",
})
logs, err := repo.ListAllForExport(ctx, 0, -1, nil, nil)
if err != nil {
t.Fatalf("ListAllForExport() error = %v", err)
}
if len(logs) != 2 {
t.Errorf("len(logs) = %d, want 2", len(logs))
}
}

View File

@@ -0,0 +1,94 @@
package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
_ "modernc.org/sqlite"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
var operationLogTestCounter int64
func openOperationLogTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&operationLogTestCounter, 1)
dsn := fmt.Sprintf("file:operationlogtestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
func setupOperationLogTestDB(t *testing.T) *gorm.DB {
return openOperationLogTestDB(t)
}
func TestOperationLogRepository_ListCursor(t *testing.T) {
db := setupOperationLogTestDB(t)
repo := NewOperationLogRepository(db)
ctx := context.Background()
now := time.Now()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.OperationLog{
UserID: nil,
OperationType: "test",
OperationName: "测试操作" + string(rune('0'+i)),
RequestMethod: "GET",
RequestPath: "/api/test",
ResponseStatus: 200,
IP: "192.168.1." + string(rune('0'+i)),
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
})
}
// 第一次查询获取前3个
logs, hasMore, err := repo.ListCursor(ctx, 3, nil)
if err != nil {
t.Fatalf("ListCursor() error = %v", err)
}
if len(logs) != 3 {
t.Errorf("len(logs) = %d, want 3", len(logs))
}
if !hasMore {
t.Error("hasMore should be true when more logs exist")
}
// 使用游标继续查询
lastLog := logs[len(logs)-1]
cursor := &pagination.Cursor{
LastID: lastLog.ID,
LastValue: lastLog.CreatedAt,
}
logs2, hasMore2, err := repo.ListCursor(ctx, 3, cursor)
if err != nil {
t.Fatalf("ListCursor() error = %v", err)
}
if len(logs2) != 2 {
t.Errorf("len(logs2) = %d, want 2", len(logs2))
}
if hasMore2 {
t.Error("hasMore2 should be false")
}
}

View File

@@ -0,0 +1,90 @@
package repository
import (
"context"
"testing"
"github.com/user-management-system/internal/domain"
)
func TestRoleRepository_GetAncestorIDs(t *testing.T) {
db := setupTestDB(t)
repo := NewRoleRepository(db)
ctx := context.Background()
// 创建角色层级: grandchild -> child -> parent
parentID := int64(0)
parent := &domain.Role{Name: "parent", Code: "parent", ParentID: nil}
if err := repo.Create(ctx, parent); err != nil {
t.Fatalf("Create parent failed: %v", err)
}
parentID = parent.ID
child := &domain.Role{Name: "child", Code: "child", ParentID: &parentID}
if err := repo.Create(ctx, child); err != nil {
t.Fatalf("Create child failed: %v", err)
}
childID := child.ID
grandchild := &domain.Role{Name: "grandchild", Code: "grandchild", ParentID: &childID}
if err := repo.Create(ctx, grandchild); err != nil {
t.Fatalf("Create grandchild failed: %v", err)
}
// 获取grandchild的祖先ID列表
ancestorIDs, err := repo.GetAncestorIDs(ctx, grandchild.ID)
if err != nil {
t.Fatalf("GetAncestorIDs failed: %v", err)
}
if len(ancestorIDs) != 2 {
t.Errorf("len(ancestorIDs) = %d, want 2", len(ancestorIDs))
}
if ancestorIDs[0] != childID {
t.Errorf("ancestorIDs[0] = %d, want %d", ancestorIDs[0], childID)
}
if ancestorIDs[1] != parentID {
t.Errorf("ancestorIDs[1] = %d, want %d", ancestorIDs[1], parentID)
}
}
func TestRoleRepository_GetAncestors(t *testing.T) {
db := setupTestDB(t)
repo := NewRoleRepository(db)
ctx := context.Background()
// 创建角色层级
parentID := int64(0)
parent := &domain.Role{Name: "parent-role", Code: "parent-role", Status: domain.RoleStatusEnabled}
if err := repo.Create(ctx, parent); err != nil {
t.Fatalf("Create parent failed: %v", err)
}
parentID = parent.ID
child := &domain.Role{Name: "child-role", Code: "child-role", ParentID: &parentID, Status: domain.RoleStatusEnabled}
if err := repo.Create(ctx, child); err != nil {
t.Fatalf("Create child failed: %v", err)
}
childID := child.ID
grandchild := &domain.Role{Name: "grandchild-role", Code: "grandchild-role", ParentID: &childID, Status: domain.RoleStatusEnabled}
if err := repo.Create(ctx, grandchild); err != nil {
t.Fatalf("Create grandchild failed: %v", err)
}
// 获取grandchild的完整继承链
ancestors, err := repo.GetAncestors(ctx, grandchild.ID)
if err != nil {
t.Fatalf("GetAncestors failed: %v", err)
}
if len(ancestors) != 2 {
t.Errorf("len(ancestors) = %d, want 2", len(ancestors))
}
// 第一个应该是parent
if ancestors[0].Code != "parent-role" {
t.Errorf("ancestors[0].Code = %s, want parent-role", ancestors[0].Code)
}
// 第二个应该是child
if ancestors[1].Code != "child-role" {
t.Errorf("ancestors[1].Code = %s, want child-role", ancestors[1].Code)
}
}

View File

@@ -0,0 +1,263 @@
package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
_ "modernc.org/sqlite"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
)
var socialAccountTestCounter int64
func openSocialAccountTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&socialAccountTestCounter, 1)
dsn := fmt.Sprintf("file:socialaccounttestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.SocialAccount{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
func setupSocialAccountTestDB(t *testing.T) *gorm.DB {
return openSocialAccountTestDB(t)
}
func TestSocialAccountRepository_Create(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
account := &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "openid-123",
Nickname: "testuser",
Status: domain.SocialAccountStatusActive,
}
if err := repo.Create(ctx, account); err != nil {
t.Fatalf("Create() error = %v", err)
}
if account.ID == 0 {
t.Error("创建后账户ID不应为0")
}
}
func TestSocialAccountRepository_GetByID(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
account := &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "openid-getbyid",
Nickname: "getbyid-user",
Status: domain.SocialAccountStatusActive,
}
repo.Create(ctx, account)
found, err := repo.GetByID(ctx, account.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if found.Nickname != "getbyid-user" {
t.Errorf("Nickname = %v, want getbyid-user", found.Nickname)
}
}
func TestSocialAccountRepository_GetByUserID(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
repo.Create(ctx, &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "openid-user1-1",
Status: domain.SocialAccountStatusActive,
})
repo.Create(ctx, &domain.SocialAccount{
UserID: 1,
Provider: "wechat",
OpenID: "openid-user1-2",
Status: domain.SocialAccountStatusActive,
})
repo.Create(ctx, &domain.SocialAccount{
UserID: 2,
Provider: "github",
OpenID: "openid-user2",
Status: domain.SocialAccountStatusActive,
})
accounts, err := repo.GetByUserID(ctx, 1)
if err != nil {
t.Fatalf("GetByUserID() error = %v", err)
}
if len(accounts) != 2 {
t.Errorf("len(accounts) = %d, want 2", len(accounts))
}
}
func TestSocialAccountRepository_GetByProviderAndOpenID(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
account := &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "unique-openid-123",
Nickname: "github-user",
Status: domain.SocialAccountStatusActive,
}
repo.Create(ctx, account)
found, err := repo.GetByProviderAndOpenID(ctx, "github", "unique-openid-123")
if err != nil {
t.Fatalf("GetByProviderAndOpenID() error = %v", err)
}
if found.UserID != 1 {
t.Errorf("UserID = %d, want 1", found.UserID)
}
}
func TestSocialAccountRepository_Update(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
account := &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "openid-update",
Nickname: "before-update",
Status: domain.SocialAccountStatusActive,
}
repo.Create(ctx, account)
account.Nickname = "after-update"
if err := repo.Update(ctx, account); err != nil {
t.Fatalf("Update() error = %v", err)
}
found, _ := repo.GetByID(ctx, account.ID)
if found.Nickname != "after-update" {
t.Errorf("Nickname = %v, want after-update", found.Nickname)
}
}
func TestSocialAccountRepository_Delete(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
account := &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "openid-delete",
Status: domain.SocialAccountStatusActive,
}
repo.Create(ctx, account)
if err := repo.Delete(ctx, account.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
}
func TestSocialAccountRepository_DeleteByProviderAndUserID(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
repo.Create(ctx, &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "openid-del-provider",
Status: domain.SocialAccountStatusActive,
})
err = repo.DeleteByProviderAndUserID(ctx, "github", 1)
if err != nil {
t.Fatalf("DeleteByProviderAndUserID() error = %v", err)
}
accounts, _ := repo.GetByUserID(ctx, 1)
if len(accounts) != 0 {
t.Errorf("len(accounts) = %d, want 0 after delete", len(accounts))
}
}
func TestSocialAccountRepository_List(t *testing.T) {
db := setupSocialAccountTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("NewSocialAccountRepository() error = %v", err)
}
ctx := context.Background()
repo.Create(ctx, &domain.SocialAccount{
UserID: 1,
Provider: "github",
OpenID: "openid-list-1",
Status: domain.SocialAccountStatusActive,
})
repo.Create(ctx, &domain.SocialAccount{
UserID: 2,
Provider: "wechat",
OpenID: "openid-list-2",
Status: domain.SocialAccountStatusActive,
})
accounts, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(accounts) != 2 {
t.Errorf("len(accounts) = %d, want 2", len(accounts))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}

View File

@@ -0,0 +1,275 @@
package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
_ "modernc.org/sqlite"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
)
var themeTestCounter int64
// openThemeTestDB 为每个测试打开独立的内存数据库
func openThemeTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&themeTestCounter, 1)
dsn := fmt.Sprintf("file:themetestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.ThemeConfig{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
// setupThemeTestDB 兼容性别名
func setupThemeTestDB(t *testing.T) *gorm.DB {
return openThemeTestDB(t)
}
// TestThemeConfigRepository_Create 测试创建主题
func TestThemeConfigRepository_Create(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
theme := &domain.ThemeConfig{
Name: "test-theme",
PrimaryColor: "#ff0000",
SecondaryColor: "#00ff00",
Enabled: true,
}
if err := repo.Create(ctx, theme); err != nil {
t.Fatalf("Create() error = %v", err)
}
if theme.ID == 0 {
t.Error("创建后主题ID不应为0")
}
}
// TestThemeConfigRepository_GetByID 测试根据ID获取主题
func TestThemeConfigRepository_GetByID(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
theme := &domain.ThemeConfig{
Name: "getbyid-theme",
PrimaryColor: "#0000ff",
Enabled: true,
}
repo.Create(ctx, theme)
found, err := repo.GetByID(ctx, theme.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if found.Name != "getbyid-theme" {
t.Errorf("Name = %v, want getbyid-theme", found.Name)
}
}
// TestThemeConfigRepository_GetByName 测试根据名称获取主题
func TestThemeConfigRepository_GetByName(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
theme := &domain.ThemeConfig{
Name: "unique-theme-name",
PrimaryColor: "#ffff00",
Enabled: true,
}
repo.Create(ctx, theme)
found, err := repo.GetByName(ctx, "unique-theme-name")
if err != nil {
t.Fatalf("GetByName() error = %v", err)
}
if found.ID != theme.ID {
t.Errorf("ID = %v, want %v", found.ID, theme.ID)
}
}
// TestThemeConfigRepository_GetByName_NotFound 测试名称不存在
func TestThemeConfigRepository_GetByName_NotFound(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
_, err := repo.GetByName(ctx, "not-exist-theme")
if err == nil {
t.Error("GetByName() should return error for non-existent theme")
}
}
// TestThemeConfigRepository_Update 测试更新主题
func TestThemeConfigRepository_Update(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
theme := &domain.ThemeConfig{
Name: "update-test",
PrimaryColor: "#000000",
Enabled: true,
}
repo.Create(ctx, theme)
theme.PrimaryColor = "#ffffff"
if err := repo.Update(ctx, theme); err != nil {
t.Fatalf("Update() error = %v", err)
}
found, _ := repo.GetByID(ctx, theme.ID)
if found.PrimaryColor != "#ffffff" {
t.Errorf("PrimaryColor = %v, want #ffffff", found.PrimaryColor)
}
}
// TestThemeConfigRepository_Delete 测试删除主题
func TestThemeConfigRepository_Delete(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
theme := &domain.ThemeConfig{
Name: "delete-test",
Enabled: true,
}
repo.Create(ctx, theme)
if err := repo.Delete(ctx, theme.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
_, err := repo.GetByID(ctx, theme.ID)
if err == nil {
t.Error("删除后查询应返回错误")
}
}
// TestThemeConfigRepository_List 测试获取已启用主题列表
func TestThemeConfigRepository_List(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.ThemeConfig{Name: "enabled1", Enabled: true})
repo.Create(ctx, &domain.ThemeConfig{Name: "enabled2", Enabled: true})
repo.Create(ctx, &domain.ThemeConfig{Name: "disabled1", Enabled: false})
themes, err := repo.List(ctx)
if err != nil {
t.Fatalf("List() error = %v", err)
}
// List filters by enabled=true
if len(themes) < 2 {
t.Errorf("len(themes) = %d, want at least 2", len(themes))
}
}
// TestThemeConfigRepository_ListAll 测试获取所有主题列表
func TestThemeConfigRepository_ListAll(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.ThemeConfig{Name: "all1", Enabled: true})
repo.Create(ctx, &domain.ThemeConfig{Name: "all2", Enabled: false})
themes, err := repo.ListAll(ctx)
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if len(themes) != 2 {
t.Errorf("len(themes) = %d, want 2", len(themes))
}
}
// TestThemeConfigRepository_GetDefault 测试获取默认主题
func TestThemeConfigRepository_GetDefault(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
// 创建一个默认主题
repo.Create(ctx, &domain.ThemeConfig{
Name: "default-theme",
IsDefault: true,
Enabled: true,
})
defaultTheme, err := repo.GetDefault(ctx)
if err != nil {
t.Fatalf("GetDefault() error = %v", err)
}
if defaultTheme.Name != "default-theme" {
t.Errorf("Name = %v, want default-theme", defaultTheme.Name)
}
}
// TestThemeConfigRepository_GetDefault_NoDefault 测试无默认主题时返回默认配置
func TestThemeConfigRepository_GetDefault_NoDefault(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
// 不创建任何主题
defaultTheme, err := repo.GetDefault(ctx)
if err != nil {
t.Fatalf("GetDefault() error = %v", err)
}
// 应该返回内置默认配置
if defaultTheme.Name != "default" {
t.Errorf("Name = %v, want default", defaultTheme.Name)
}
}
// TestThemeConfigRepository_SetDefault 测试设置默认主题
func TestThemeConfigRepository_SetDefault(t *testing.T) {
db := setupThemeTestDB(t)
repo := NewThemeConfigRepository(db)
ctx := context.Background()
// 创建两个主题
theme1 := &domain.ThemeConfig{Name: "theme1", IsDefault: true, Enabled: true}
theme2 := &domain.ThemeConfig{Name: "theme2", IsDefault: false, Enabled: true}
repo.Create(ctx, theme1)
repo.Create(ctx, theme2)
// 设置 theme2 为默认
if err := repo.SetDefault(ctx, theme2.ID); err != nil {
t.Fatalf("SetDefault() error = %v", err)
}
// 验证 theme1 不再是默认
t1, _ := repo.GetByID(ctx, theme1.ID)
if t1.IsDefault {
t.Error("theme1 should not be default anymore")
}
// 验证 theme2 现在是默认
t2, _ := repo.GetByID(ctx, theme2.ID)
if !t2.IsDefault {
t.Error("theme2 should be default")
}
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"testing"
"time"
"gorm.io/gorm"
@@ -401,3 +402,259 @@ func TestUserRepository_Search_LikePattern(t *testing.T) {
// Should not error and should escape properly
_ = users
}
// TestUserRepository_GetByIDs 测试批量获取用户
func TestUserRepository_GetByIDs(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
u1 := &domain.User{Username: "batchuser1", Password: "hash", Status: domain.UserStatusActive}
u2 := &domain.User{Username: "batchuser2", Password: "hash", Status: domain.UserStatusActive}
u3 := &domain.User{Username: "batchuser3", Password: "hash", Status: domain.UserStatusActive}
repo.Create(ctx, u1)
repo.Create(ctx, u2)
repo.Create(ctx, u3)
users, err := repo.GetByIDs(ctx, []int64{u1.ID, u3.ID})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(users) != 2 {
t.Errorf("len(users) = %d, want 2", len(users))
}
}
// TestUserRepository_GetByIDs_Empty 测试空ID列表
func TestUserRepository_GetByIDs_Empty(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
users, err := repo.GetByIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(users) != 0 {
t.Errorf("len(users) = %d, want 0", len(users))
}
}
// TestUserRepository_UpdatePassword 测试更新密码
func TestUserRepository_UpdatePassword(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "pwduser",
Password: "oldpassword",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
err := repo.UpdatePassword(ctx, user.ID, "newpasswordhash")
if err != nil {
t.Fatalf("UpdatePassword() error = %v", err)
}
found, _ := repo.GetByID(ctx, user.ID)
if found.Password != "newpasswordhash" {
t.Errorf("Password = %v, want newpasswordhash", found.Password)
}
}
// TestUserRepository_UpdateTOTP 测试更新TOTP
func TestUserRepository_UpdateTOTP(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "totpuser",
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
user.TOTPEnabled = true
user.TOTPSecret = "JBSWY3DPEHPK3PXP"
err := repo.UpdateTOTP(ctx, user)
if err != nil {
t.Fatalf("UpdateTOTP() error = %v", err)
}
found, _ := repo.GetByID(ctx, user.ID)
if !found.TOTPEnabled {
t.Error("TOTPEnabled should be true")
}
if found.TOTPSecret != "JBSWY3DPEHPK3PXP" {
t.Errorf("TOTPSecret = %v, want JBSWY3DPEHPK3PXP", found.TOTPSecret)
}
}
// TestUserRepository_ListCreatedAfter 测试查询创建时间之后的用户
func TestUserRepository_ListCreatedAfter(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "afteruser",
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
since := user.CreatedAt.Add(-1 * time.Hour)
users, total, err := repo.ListCreatedAfter(ctx, since, 0, 10)
if err != nil {
t.Fatalf("ListCreatedAfter() error = %v", err)
}
if total < 1 {
t.Errorf("total = %d, want at least 1", total)
}
_ = users
}
// TestUserRepository_ListCreatedAfter_Limited 测试带limit的查询
func TestUserRepository_ListCreatedAfter_Limited(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.User{
Username: "limituser" + string(rune('0'+i)),
Password: "hash",
Status: domain.UserStatusActive,
})
}
since := time.Now().Add(-1 * time.Hour)
users, total, err := repo.ListCreatedAfter(ctx, since, 0, 3)
if err != nil {
t.Fatalf("ListCreatedAfter() error = %v", err)
}
if len(users) != 3 {
t.Errorf("len(users) = %d, want 3", len(users))
}
if total < 5 {
t.Errorf("total = %d, want at least 5", total)
}
}
// TestUserRepository_AdvancedSearch 测试高级搜索
func TestUserRepository_AdvancedSearch(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{
Username: "searchuser1",
Nickname: "张三",
Email: domain.StrPtr("zhangsan@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
})
repo.Create(ctx, &domain.User{
Username: "searchuser2",
Nickname: "李四",
Email: domain.StrPtr("lisi@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
})
repo.Create(ctx, &domain.User{
Username: "searchuser3",
Nickname: "王五",
Email: domain.StrPtr("wangwu@example.com"),
Password: "hash",
Status: domain.UserStatusInactive,
})
// 按关键字搜索Status=-1 表示全部状态)
filter := &AdvancedFilter{Keyword: "searchuser1", Status: -1, Offset: 0, Limit: 10}
users, total, err := repo.AdvancedSearch(ctx, filter)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users) != 1 {
t.Errorf("len(users) = %d, want 1", len(users))
}
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
// 按状态筛选
filter2 := &AdvancedFilter{Status: int(domain.UserStatusActive), Offset: 0, Limit: 10}
users2, total2, err := repo.AdvancedSearch(ctx, filter2)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users2) != 2 {
t.Errorf("len(users2) = %d, want 2", len(users2))
}
if total2 != 2 {
t.Errorf("total2 = %d, want 2", total2)
}
// 按状态筛选 - 禁用用户
filter3 := &AdvancedFilter{Status: int(domain.UserStatusInactive), Offset: 0, Limit: 10}
users3, total3, err := repo.AdvancedSearch(ctx, filter3)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users3) != 1 {
t.Errorf("len(users3) = %d, want 1", len(users3))
}
if total3 != 1 {
t.Errorf("total3 = %d, want 1", total3)
}
}
// TestUserRepository_AdvancedSearch_AllStatus 测试状态为-1返回全部
func TestUserRepository_AdvancedSearch_AllStatus(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{Username: "active", Password: "hash", Status: domain.UserStatusActive})
repo.Create(ctx, &domain.User{Username: "inactive", Password: "hash", Status: domain.UserStatusInactive})
filter := &AdvancedFilter{Status: -1, Offset: 0, Limit: 10}
users, total, err := repo.AdvancedSearch(ctx, filter)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users) != 2 {
t.Errorf("len(users) = %d, want 2", len(users))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
// TestUserRepository_AdvancedSearch_LikeSpecialChars 测试搜索LIKE特殊字符转义
func TestUserRepository_AdvancedSearch_LikeSpecialChars(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{
Username: "user%with%percent",
Nickname: "测试用户",
Password: "hash",
Status: domain.UserStatusActive,
})
// 搜索%应该不匹配任何记录(被转义)
filter := &AdvancedFilter{Keyword: "%", Offset: 0, Limit: 10}
users, _, err := repo.AdvancedSearch(ctx, filter)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users) != 0 {
t.Errorf("len(users) = %d, want 0 for escaped percent", len(users))
}
}

View File

@@ -0,0 +1,36 @@
package repository
import (
"context"
"testing"
"github.com/user-management-system/internal/domain"
)
func TestUserRoleRepository_DeleteByUserAndRole(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRoleRepository(db)
ctx := context.Background()
// 创建用户和角色
user := &domain.User{Username: "roleuser", Password: "hash", Status: domain.UserStatusActive}
db.WithContext(ctx).Create(user)
role := &domain.Role{Code: "test_role", Name: "测试角色", Status: domain.RoleStatusEnabled}
db.WithContext(ctx).Create(role)
// 创建用户角色关联
repo.Create(ctx, &domain.UserRole{UserID: user.ID, RoleID: role.ID})
// 删除特定用户-角色关联
err := repo.DeleteByUserAndRole(ctx, user.ID, role.ID)
if err != nil {
t.Fatalf("DeleteByUserAndRole() error = %v", err)
}
// 验证已删除
exists, _ := repo.Exists(ctx, user.ID, role.ID)
if exists {
t.Error("DeleteByUserAndRole should have removed the association")
}
}

View File

@@ -188,3 +188,40 @@ func TestWebhookRepositoryCreateAndListDeliveries(t *testing.T) {
t.Fatal("expected deliveries to be returned in reverse created_at order")
}
}
func TestWebhookRepositoryListByCreatorPaginated(t *testing.T) {
repo := setupWebhookRepository(t)
ctx := context.Background()
// 创建多个webhook
for i := 0; i < 5; i++ {
if err := repo.Create(ctx, newWebhookFixture("wh-creator1-"+string(rune('a'+i)), 1, domain.WebhookStatusActive)); err != nil {
t.Fatalf("Create failed: %v", err)
}
}
// 另一个用户的webhook
if err := repo.Create(ctx, newWebhookFixture("wh-creator2", 2, domain.WebhookStatusActive)); err != nil {
t.Fatalf("Create failed: %v", err)
}
// 测试分页查询创建者1的webhook
webhooks, total, err := repo.ListByCreatorPaginated(ctx, 1, 0, 3)
if err != nil {
t.Fatalf("ListByCreatorPaginated failed: %v", err)
}
if len(webhooks) != 3 {
t.Errorf("len(webhooks) = %d, want 3", len(webhooks))
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
// 测试第二页
webhooks2, _, err := repo.ListByCreatorPaginated(ctx, 1, 3, 3)
if err != nil {
t.Fatalf("ListByCreatorPaginated page 2 failed: %v", err)
}
if len(webhooks2) != 2 {
t.Errorf("len(webhooks2) = %d, want 2", len(webhooks2))
}
}