Files
user-system/internal/repository/user_repository_test.go
long-agent 289aab2930 test: add repository tests to improve coverage from 46.6% to 74%
New test files:
- custom_field_repository_test.go: 10 tests for CustomFieldRepository & UserCustomFieldValueRepository
- login_log_repository_test.go: 3 tests for ListCursor, ListByUserIDCursor, ListAllForExport
- operation_log_repository_test.go: 1 test for ListCursor
- role_repository_test.go: 2 tests for GetAncestorIDs, GetAncestors
- social_account_repository_test.go: 8 CRUD tests
- theme_repository_test.go: 10 tests for ThemeConfigRepository
- user_role_repository_test.go: 1 test for DeleteByUserAndRole

Modified test files:
- device_repository_test.go: Added ListAllCursor tests
- user_repository_test.go: Added AdvancedSearch tests
- webhook_repository_test.go: Added ListByCreatorPaginated test

Updated documentation with new coverage status.
2026-04-11 21:58:28 +08:00

661 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"testing"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
func setupTestDB(t *testing.T) *gorm.DB {
return openTestDB(t)
}
// TestUserRepository_Create 测试创建用户
func TestUserRepository_Create(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "testuser",
Email: domain.StrPtr("test@example.com"),
Phone: domain.StrPtr("13800138000"),
Password: "hashedpassword",
Status: domain.UserStatusActive,
}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create() error = %v", err)
}
if user.ID == 0 {
t.Error("创建后用户ID不应为0")
}
}
// TestUserRepository_GetByUsername 测试根据用户名查询
func TestUserRepository_GetByUsername(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "findme",
Email: domain.StrPtr("findme@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
found, err := repo.GetByUsername(ctx, "findme")
if err != nil {
t.Fatalf("GetByUsername() error = %v", err)
}
if found.Username != "findme" {
t.Errorf("Username = %v, want findme", found.Username)
}
_, err = repo.GetByUsername(ctx, "notexist")
if err == nil {
t.Error("查找不存在的用户应返回错误")
}
}
// TestUserRepository_GetByEmail 测试根据邮箱查询
func TestUserRepository_GetByEmail(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "emailuser",
Email: domain.StrPtr("email@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
found, err := repo.GetByEmail(ctx, "email@example.com")
if err != nil {
t.Fatalf("GetByEmail() error = %v", err)
}
if domain.DerefStr(found.Email) != "email@example.com" {
t.Errorf("Email = %v, want email@example.com", domain.DerefStr(found.Email))
}
}
// TestUserRepository_Update 测试更新用户
func TestUserRepository_Update(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "updateme",
Email: domain.StrPtr("update@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
user.Nickname = "已更新"
if err := repo.Update(ctx, user); err != nil {
t.Fatalf("Update() error = %v", err)
}
found, _ := repo.GetByID(ctx, user.ID)
if found.Nickname != "已更新" {
t.Errorf("Nickname = %v, want 已更新", found.Nickname)
}
}
// TestUserRepository_Delete 测试删除用户
func TestUserRepository_Delete(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "deleteme",
Email: domain.StrPtr("delete@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
if err := repo.Delete(ctx, user.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
_, err := repo.GetByID(ctx, user.ID)
if err == nil {
t.Error("删除后查询应返回错误")
}
}
// TestUserRepository_ExistsBy 测试存在性检查
func TestUserRepository_ExistsBy(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "existsuser",
Email: domain.StrPtr("exists@example.com"),
Phone: domain.StrPtr("13900139000"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
ok, _ := repo.ExistsByUsername(ctx, "existsuser")
if !ok {
t.Error("ExistsByUsername 应返回 true")
}
ok, _ = repo.ExistsByEmail(ctx, "exists@example.com")
if !ok {
t.Error("ExistsByEmail 应返回 true")
}
ok, _ = repo.ExistsByPhone(ctx, "13900139000")
if !ok {
t.Error("ExistsByPhone 应返回 true")
}
ok, _ = repo.ExistsByUsername(ctx, "notexist")
if ok {
t.Error("不存在的用户 ExistsByUsername 应返回 false")
}
}
// TestUserRepository_List 测试列表查询
func TestUserRepository_List(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.User{
Username: "listuser" + string(rune('0'+i)),
Password: "hash",
Status: domain.UserStatusActive,
})
}
users, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(users) != 5 {
t.Errorf("len(users) = %d, want 5", len(users))
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
}
// TestUserRepository_GetByPhone tests phone lookup
func TestUserRepository_GetByPhone(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "phoneuser",
Email: domain.StrPtr("phone@example.com"),
Phone: domain.StrPtr("13700137000"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
found, err := repo.GetByPhone(ctx, "13700137000")
if err != nil {
t.Fatalf("GetByPhone() error = %v", err)
}
if found.Username != "phoneuser" {
t.Errorf("Username = %v, want phoneuser", found.Username)
}
}
// TestUserRepository_ListByStatus tests status filtering
func TestUserRepository_ListByStatus(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{
Username: "active1",
Password: "hash",
Status: domain.UserStatusActive,
})
repo.Create(ctx, &domain.User{
Username: "active2",
Password: "hash",
Status: domain.UserStatusActive,
})
repo.Create(ctx, &domain.User{
Username: "inactive1",
Password: "hash",
Status: domain.UserStatusInactive,
})
users, total, err := repo.ListByStatus(ctx, domain.UserStatusActive, 0, 10)
if err != nil {
t.Fatalf("ListByStatus() error = %v", err)
}
if len(users) != 2 {
t.Errorf("len(users) = %d, want 2", len(users))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
// TestUserRepository_UpdateStatus tests status update
func TestUserRepository_UpdateStatus(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "statususer",
Email: domain.StrPtr("status@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
err := repo.UpdateStatus(ctx, user.ID, domain.UserStatusInactive)
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
found, _ := repo.GetByID(ctx, user.ID)
if found.Status != domain.UserStatusInactive {
t.Errorf("Status = %v, want Inactive", found.Status)
}
}
// TestUserRepository_BatchUpdateStatus tests batch status update
func TestUserRepository_BatchUpdateStatus(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user1 := &domain.User{
Username: "batch1",
Email: domain.StrPtr("batch1@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
user2 := &domain.User{
Username: "batch2",
Email: domain.StrPtr("batch2@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user1)
repo.Create(ctx, user2)
err := repo.BatchUpdateStatus(ctx, []int64{user1.ID, user2.ID}, domain.UserStatusInactive)
if err != nil {
t.Fatalf("BatchUpdateStatus() error = %v", err)
}
found1, _ := repo.GetByID(ctx, user1.ID)
found2, _ := repo.GetByID(ctx, user2.ID)
if found1.Status != domain.UserStatusInactive || found2.Status != domain.UserStatusInactive {
t.Error("BatchUpdateStatus failed")
}
}
// TestUserRepository_BatchDelete tests batch delete
func TestUserRepository_BatchDelete(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user1 := &domain.User{
Username: "del1",
Email: domain.StrPtr("del1@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
user2 := &domain.User{
Username: "del2",
Email: domain.StrPtr("del2@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user1)
repo.Create(ctx, user2)
err := repo.BatchDelete(ctx, []int64{user1.ID, user2.ID})
if err != nil {
t.Fatalf("BatchDelete() error = %v", err)
}
_, err1 := repo.GetByID(ctx, user1.ID)
_, err2 := repo.GetByID(ctx, user2.ID)
if err1 == nil || err2 == nil {
t.Error("BatchDelete should have deleted users")
}
}
// TestUserRepository_Search tests user search
func TestUserRepository_Search(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{
Username: "searchuser1",
Nickname: "张三",
Email: domain.StrPtr("zhangsan@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
})
repo.Create(ctx, &domain.User{
Username: "searchuser2",
Nickname: "李四",
Email: domain.StrPtr("lisi@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
})
users, total, err := repo.Search(ctx, "zhang", 0, 10)
if err != nil {
t.Fatalf("Search() error = %v", err)
}
if len(users) != 1 {
t.Errorf("len(users) = %d, want 1", len(users))
}
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
}
// TestUserRepository_Search_LikePattern tests search with LIKE special chars
func TestUserRepository_Search_LikePattern(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{
Username: "user%with%percent",
Nickname: "测试用户",
Email: domain.StrPtr("percent@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
})
// Search should handle LIKE special chars safely
users, _, err := repo.Search(ctx, "%", 0, 10)
if err != nil {
t.Fatalf("Search() error = %v", err)
}
// Should not error and should escape properly
_ = users
}
// TestUserRepository_GetByIDs 测试批量获取用户
func TestUserRepository_GetByIDs(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
u1 := &domain.User{Username: "batchuser1", Password: "hash", Status: domain.UserStatusActive}
u2 := &domain.User{Username: "batchuser2", Password: "hash", Status: domain.UserStatusActive}
u3 := &domain.User{Username: "batchuser3", Password: "hash", Status: domain.UserStatusActive}
repo.Create(ctx, u1)
repo.Create(ctx, u2)
repo.Create(ctx, u3)
users, err := repo.GetByIDs(ctx, []int64{u1.ID, u3.ID})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(users) != 2 {
t.Errorf("len(users) = %d, want 2", len(users))
}
}
// TestUserRepository_GetByIDs_Empty 测试空ID列表
func TestUserRepository_GetByIDs_Empty(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
users, err := repo.GetByIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(users) != 0 {
t.Errorf("len(users) = %d, want 0", len(users))
}
}
// TestUserRepository_UpdatePassword 测试更新密码
func TestUserRepository_UpdatePassword(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "pwduser",
Password: "oldpassword",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
err := repo.UpdatePassword(ctx, user.ID, "newpasswordhash")
if err != nil {
t.Fatalf("UpdatePassword() error = %v", err)
}
found, _ := repo.GetByID(ctx, user.ID)
if found.Password != "newpasswordhash" {
t.Errorf("Password = %v, want newpasswordhash", found.Password)
}
}
// TestUserRepository_UpdateTOTP 测试更新TOTP
func TestUserRepository_UpdateTOTP(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "totpuser",
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
user.TOTPEnabled = true
user.TOTPSecret = "JBSWY3DPEHPK3PXP"
err := repo.UpdateTOTP(ctx, user)
if err != nil {
t.Fatalf("UpdateTOTP() error = %v", err)
}
found, _ := repo.GetByID(ctx, user.ID)
if !found.TOTPEnabled {
t.Error("TOTPEnabled should be true")
}
if found.TOTPSecret != "JBSWY3DPEHPK3PXP" {
t.Errorf("TOTPSecret = %v, want JBSWY3DPEHPK3PXP", found.TOTPSecret)
}
}
// TestUserRepository_ListCreatedAfter 测试查询创建时间之后的用户
func TestUserRepository_ListCreatedAfter(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "afteruser",
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
since := user.CreatedAt.Add(-1 * time.Hour)
users, total, err := repo.ListCreatedAfter(ctx, since, 0, 10)
if err != nil {
t.Fatalf("ListCreatedAfter() error = %v", err)
}
if total < 1 {
t.Errorf("total = %d, want at least 1", total)
}
_ = users
}
// TestUserRepository_ListCreatedAfter_Limited 测试带limit的查询
func TestUserRepository_ListCreatedAfter_Limited(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.User{
Username: "limituser" + string(rune('0'+i)),
Password: "hash",
Status: domain.UserStatusActive,
})
}
since := time.Now().Add(-1 * time.Hour)
users, total, err := repo.ListCreatedAfter(ctx, since, 0, 3)
if err != nil {
t.Fatalf("ListCreatedAfter() error = %v", err)
}
if len(users) != 3 {
t.Errorf("len(users) = %d, want 3", len(users))
}
if total < 5 {
t.Errorf("total = %d, want at least 5", total)
}
}
// TestUserRepository_AdvancedSearch 测试高级搜索
func TestUserRepository_AdvancedSearch(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{
Username: "searchuser1",
Nickname: "张三",
Email: domain.StrPtr("zhangsan@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
})
repo.Create(ctx, &domain.User{
Username: "searchuser2",
Nickname: "李四",
Email: domain.StrPtr("lisi@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
})
repo.Create(ctx, &domain.User{
Username: "searchuser3",
Nickname: "王五",
Email: domain.StrPtr("wangwu@example.com"),
Password: "hash",
Status: domain.UserStatusInactive,
})
// 按关键字搜索Status=-1 表示全部状态)
filter := &AdvancedFilter{Keyword: "searchuser1", Status: -1, Offset: 0, Limit: 10}
users, total, err := repo.AdvancedSearch(ctx, filter)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users) != 1 {
t.Errorf("len(users) = %d, want 1", len(users))
}
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
// 按状态筛选
filter2 := &AdvancedFilter{Status: int(domain.UserStatusActive), Offset: 0, Limit: 10}
users2, total2, err := repo.AdvancedSearch(ctx, filter2)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users2) != 2 {
t.Errorf("len(users2) = %d, want 2", len(users2))
}
if total2 != 2 {
t.Errorf("total2 = %d, want 2", total2)
}
// 按状态筛选 - 禁用用户
filter3 := &AdvancedFilter{Status: int(domain.UserStatusInactive), Offset: 0, Limit: 10}
users3, total3, err := repo.AdvancedSearch(ctx, filter3)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users3) != 1 {
t.Errorf("len(users3) = %d, want 1", len(users3))
}
if total3 != 1 {
t.Errorf("total3 = %d, want 1", total3)
}
}
// TestUserRepository_AdvancedSearch_AllStatus 测试状态为-1返回全部
func TestUserRepository_AdvancedSearch_AllStatus(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{Username: "active", Password: "hash", Status: domain.UserStatusActive})
repo.Create(ctx, &domain.User{Username: "inactive", Password: "hash", Status: domain.UserStatusInactive})
filter := &AdvancedFilter{Status: -1, Offset: 0, Limit: 10}
users, total, err := repo.AdvancedSearch(ctx, filter)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users) != 2 {
t.Errorf("len(users) = %d, want 2", len(users))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
// TestUserRepository_AdvancedSearch_LikeSpecialChars 测试搜索LIKE特殊字符转义
func TestUserRepository_AdvancedSearch_LikeSpecialChars(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{
Username: "user%with%percent",
Nickname: "测试用户",
Password: "hash",
Status: domain.UserStatusActive,
})
// 搜索%应该不匹配任何记录(被转义)
filter := &AdvancedFilter{Keyword: "%", Offset: 0, Limit: 10}
users, _, err := repo.AdvancedSearch(ctx, filter)
if err != nil {
t.Fatalf("AdvancedSearch() error = %v", err)
}
if len(users) != 0 {
t.Errorf("len(users) = %d, want 0 for escaped percent", len(users))
}
}