package repository import ( "context" "fmt" "sync/atomic" "testing" "time" gormsqlite "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" _ "modernc.org/sqlite" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/pagination" ) var deviceTestCounter int64 // openDeviceTestDB 为每个测试打开独立的内存数据库 func openDeviceTestDB(t *testing.T) *gorm.DB { t.Helper() id := atomic.AddInt64(&deviceTestCounter, 1) dsn := fmt.Sprintf("file:devtestdb%d?mode=memory&cache=private", id) db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ DriverName: "sqlite", DSN: dsn, }), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Fatalf("打开测试数据库失败: %v", err) } if err := db.AutoMigrate(&domain.Device{}); err != nil { t.Fatalf("数据库迁移失败: %v", err) } return db } // setupDeviceTestDB 兼容性别名 func setupDeviceTestDB(t *testing.T) *gorm.DB { return openDeviceTestDB(t) } // TestDeviceRepository_Create 测试创建设备 func TestDeviceRepository_Create(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "test-device-001", DeviceName: "测试手机", DeviceType: domain.DeviceTypeMobile, Status: domain.DeviceStatusActive, } if err := repo.Create(ctx, device); err != nil { t.Fatalf("Create() error = %v", err) } if device.ID == 0 { t.Error("创建后设备ID不应为0") } } // TestDeviceRepository_GetByID 测试根据ID获取设备 func TestDeviceRepository_GetByID(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "test-device-002", DeviceName: "测试平板", Status: domain.DeviceStatusActive, } repo.Create(ctx, device) found, err := repo.GetByID(ctx, device.ID) if err != nil { t.Fatalf("GetByID() error = %v", err) } if found.DeviceID != "test-device-002" { t.Errorf("DeviceID = %v, want test-device-002", found.DeviceID) } } // TestDeviceRepository_GetByDeviceID 测试根据设备标识查询 func TestDeviceRepository_GetByDeviceID(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "unique-device-id", DeviceName: "测试设备", Status: domain.DeviceStatusActive, } repo.Create(ctx, device) found, err := repo.GetByDeviceID(ctx, 1, "unique-device-id") if err != nil { t.Fatalf("GetByDeviceID() error = %v", err) } if found.UserID != 1 { t.Errorf("UserID = %v, want 1", found.UserID) } } // TestDeviceRepository_Update 测试更新设备 func TestDeviceRepository_Update(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "update-test", DeviceName: "旧名称", Status: domain.DeviceStatusActive, } repo.Create(ctx, device) device.DeviceName = "新名称" if err := repo.Update(ctx, device); err != nil { t.Fatalf("Update() error = %v", err) } found, _ := repo.GetByID(ctx, device.ID) if found.DeviceName != "新名称" { t.Errorf("DeviceName = %v, want 新名称", found.DeviceName) } } // TestDeviceRepository_Delete 测试删除设备 func TestDeviceRepository_Delete(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "delete-test", DeviceName: "待删除", Status: domain.DeviceStatusActive, } repo.Create(ctx, device) if err := repo.Delete(ctx, device.ID); err != nil { t.Fatalf("Delete() error = %v", err) } _, err := repo.GetByID(ctx, device.ID) if err == nil { t.Error("删除后查询应返回错误") } } // TestDeviceRepository_List 测试列表查询 func TestDeviceRepository_List(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() for i := 0; i < 3; i++ { repo.Create(ctx, &domain.Device{ UserID: int64(i + 1), DeviceID: "list-device-" + string(rune('a'+i)), Status: domain.DeviceStatusActive, }) } devices, total, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List() error = %v", err) } if len(devices) != 3 { t.Errorf("len(devices) = %d, want 3", len(devices)) } if total != 3 { t.Errorf("total = %d, want 3", total) } } // TestDeviceRepository_ListByUserID 测试按用户ID查询设备列表 func TestDeviceRepository_ListByUserID(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive}) repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive}) repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive}) devices, total, err := repo.ListByUserID(ctx, 1, 0, 10) if err != nil { t.Fatalf("ListByUserID() error = %v", err) } if len(devices) != 2 { t.Errorf("len(devices) = %d, want 2", len(devices)) } if total != 2 { t.Errorf("total = %d, want 2", total) } } // TestDeviceRepository_ListByStatus 测试按状态查询设备列表 func TestDeviceRepository_ListByStatus(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active1", Status: domain.DeviceStatusActive}) repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "active2", Status: domain.DeviceStatusActive}) repo.Create(ctx, &domain.Device{UserID: 3, DeviceID: "inactive1", Status: domain.DeviceStatusInactive}) devices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10) if err != nil { t.Fatalf("ListByStatus() error = %v", err) } if len(devices) != 2 { t.Errorf("len(devices) = %d, want 2", len(devices)) } if total != 2 { t.Errorf("total = %d, want 2", total) } } // TestDeviceRepository_UpdateStatus 测试更新设备状态 func TestDeviceRepository_UpdateStatus(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "status-test", DeviceName: "状态测试", Status: domain.DeviceStatusActive, } repo.Create(ctx, device) err := repo.UpdateStatus(ctx, device.ID, domain.DeviceStatusInactive) if err != nil { t.Fatalf("UpdateStatus() error = %v", err) } found, _ := repo.GetByID(ctx, device.ID) if found.Status != domain.DeviceStatusInactive { t.Errorf("Status = %v, want Inactive", found.Status) } } // TestDeviceRepository_Exists 测试设备存在性检查 func TestDeviceRepository_Exists(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "exists-test", DeviceName: "存在性测试", Status: domain.DeviceStatusActive, } repo.Create(ctx, device) exists, err := repo.Exists(ctx, 1, "exists-test") if err != nil { t.Fatalf("Exists() error = %v", err) } if !exists { t.Error("Exists 应返回 true") } exists, _ = repo.Exists(ctx, 1, "not-exists") if exists { t.Error("不存在的设备 Exists 应返回 false") } } // TestDeviceRepository_DeleteByUserID 测试删除用户的所有设备 func TestDeviceRepository_DeleteByUserID(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive}) repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive}) repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive}) err := repo.DeleteByUserID(ctx, 1) if err != nil { t.Fatalf("DeleteByUserID() error = %v", err) } devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10) if len(devices) != 0 { t.Errorf("用户1设备数 = %d, want 0", len(devices)) } // 用户2的设备应该还在 devices, _, _ = repo.ListByUserID(ctx, 2, 0, 10) if len(devices) != 1 { t.Errorf("用户2设备数 = %d, want 1", len(devices)) } } // TestDeviceRepository_GetActiveDevices 测试获取活跃设备 func TestDeviceRepository_GetActiveDevices(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() now := time.Now() // 创建设备并设置 LastActiveTime(GetActiveDevices 不检查状态,只检查最近活跃时间) repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active-dev1", Status: domain.DeviceStatusActive, LastActiveTime: now}) repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "recent-dev", Status: domain.DeviceStatusInactive, LastActiveTime: now}) devices, err := repo.GetActiveDevices(ctx, 1) if err != nil { t.Fatalf("GetActiveDevices() error = %v", err) } // GetActiveDevices 只检查 last_active_time > 30天前,不检查 status if len(devices) != 2 { t.Errorf("len(devices) = %d, want 2", len(devices)) } } // TestDeviceRepository_TrustDevice 测试设置设备信任 func TestDeviceRepository_TrustDevice(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "trust-test", DeviceName: "信任测试", Status: domain.DeviceStatusActive, } repo.Create(ctx, device) expiresAt := time.Now().Add(30 * 24 * time.Hour) err := repo.TrustDevice(ctx, device.ID, &expiresAt) if err != nil { t.Fatalf("TrustDevice() error = %v", err) } found, _ := repo.GetByID(ctx, device.ID) if !found.IsTrusted { t.Error("IsTrusted 应为 true") } } // TestDeviceRepository_UntrustDevice 测试取消设备信任 func TestDeviceRepository_UntrustDevice(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() device := &domain.Device{ UserID: 1, DeviceID: "untrust-test", DeviceName: "取消信任测试", IsTrusted: true, Status: domain.DeviceStatusActive, } repo.Create(ctx, device) err := repo.UntrustDevice(ctx, device.ID) if err != nil { t.Fatalf("UntrustDevice() error = %v", err) } found, _ := repo.GetByID(ctx, device.ID) if found.IsTrusted { t.Error("IsTrusted 应为 false") } } // TestDeviceRepository_DeleteAllByUserIDExcept 测试删除用户设备(保留指定设备) func TestDeviceRepository_DeleteAllByUserIDExcept(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() d1, _ := createDevice(t, repo, ctx, 1, "keep-me") createDevice(t, repo, ctx, 1, "delete-me1") createDevice(t, repo, ctx, 1, "delete-me2") err := repo.DeleteAllByUserIDExcept(ctx, 1, d1.ID) if err != nil { t.Fatalf("DeleteAllByUserIDExcept() error = %v", err) } devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10) if len(devices) != 1 { t.Errorf("len(devices) = %d, want 1", len(devices)) } if devices[0].ID != d1.ID { t.Error("应保留指定设备") } } // TestDeviceRepository_GetTrustedDevices 测试获取信任设备列表 func TestDeviceRepository_GetTrustedDevices(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() trusted := &domain.Device{ UserID: 1, DeviceID: "trusted-device", IsTrusted: true, Status: domain.DeviceStatusActive, } untrusted := &domain.Device{ UserID: 1, DeviceID: "untrusted-device", IsTrusted: false, Status: domain.DeviceStatusActive, } repo.Create(ctx, trusted) repo.Create(ctx, untrusted) devices, err := repo.GetTrustedDevices(ctx, 1) if err != nil { t.Fatalf("GetTrustedDevices() error = %v", err) } if len(devices) != 1 { t.Errorf("len(devices) = %d, want 1", len(devices)) } } // TestDeviceRepository_ListAll 测试带筛选条件的列表查询 func TestDeviceRepository_ListAll(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev1", Status: domain.DeviceStatusActive}) repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev2", Status: domain.DeviceStatusInactive}) repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "dev3", Status: domain.DeviceStatusActive}) // 按用户筛选 params := &ListDevicesParams{UserID: 1, Offset: 0, Limit: 10} _, total, err := repo.ListAll(ctx, params) if err != nil { t.Fatalf("ListAll() error = %v", err) } if total != 2 { t.Errorf("total = %d, want 2", total) } // 按状态筛选 status := domain.DeviceStatusActive params2 := &ListDevicesParams{Status: &status, Offset: 0, Limit: 10} _, total2, err := repo.ListAll(ctx, params2) if err != nil { t.Fatalf("ListAll() error = %v", err) } if total2 != 2 { t.Errorf("total = %d, want 2", total2) } } // createDevice 辅助函数:创建设备 func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, userID int64, deviceID string) (*domain.Device, error) { d := &domain.Device{ UserID: userID, DeviceID: deviceID, Status: domain.DeviceStatusActive, } err := repo.Create(ctx, d) if err != nil { t.Fatalf("createDevice() error = %v", err) } return d, nil } // TestDeviceRepository_ListAllCursor 测试设备游标分页查询 func TestDeviceRepository_ListAllCursor(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() // 创建设备(需要设置LastActiveTime以支持游标分页) now := time.Now() for i := 0; i < 5; i++ { repo.Create(ctx, &domain.Device{ UserID: int64(i + 1), DeviceID: "cursor-device-" + string(rune('a'+i)), DeviceName: "设备" + string(rune('0'+i)), Status: domain.DeviceStatusActive, LastActiveTime: now.Add(-time.Duration(i) * time.Minute), }) } // 第一次查询,获取前3个 devices, hasMore, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, nil) if err != nil { t.Fatalf("ListAllCursor() error = %v", err) } if len(devices) != 3 { t.Errorf("len(devices) = %d, want 3", len(devices)) } if !hasMore { t.Error("hasMore should be true when more devices exist") } // 使用游标继续查询 lastDevice := devices[len(devices)-1] cursor := &pagination.Cursor{ LastID: lastDevice.ID, LastValue: lastDevice.LastActiveTime, } devices2, hasMore2, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, cursor) if err != nil { t.Fatalf("ListAllCursor() error = %v", err) } if len(devices2) != 2 { t.Errorf("len(devices2) = %d, want 2", len(devices2)) } if hasMore2 { t.Error("hasMore2 should be false") } } // TestDeviceRepository_ListAllCursor_WithFilters 测试带筛选条件的设备游标分页 func TestDeviceRepository_ListAllCursor_WithFilters(t *testing.T) { db := setupDeviceTestDB(t) repo := NewDeviceRepository(db) ctx := context.Background() now := time.Now() repo.Create(ctx, &domain.Device{ UserID: 1, DeviceID: "filter-dev1", DeviceName: "用户1设备", Status: domain.DeviceStatusActive, LastActiveTime: now, }) repo.Create(ctx, &domain.Device{ UserID: 2, DeviceID: "filter-dev2", DeviceName: "用户2设备", Status: domain.DeviceStatusActive, LastActiveTime: now, }) repo.Create(ctx, &domain.Device{ UserID: 1, DeviceID: "filter-dev3", DeviceName: "用户1禁用设备", Status: domain.DeviceStatusInactive, LastActiveTime: now, }) // 按用户ID筛选 status := domain.DeviceStatusActive devices, _, err := repo.ListAllCursor(ctx, &ListDevicesParams{UserID: 1, Status: &status, Offset: 0, Limit: 10}, 10, nil) if err != nil { t.Fatalf("ListAllCursor() error = %v", err) } if len(devices) != 1 { t.Errorf("len(devices) = %d, want 1", len(devices)) } }