Files
user-system/internal/repository/device_repository_test.go
long-agent 582ad7a069 test: add comprehensive test coverage and improve code quality
- Add new test files for auth, service, and handler modules
- Improve test organization and coverage
- Refactor code for better maintainability
- Add captcha, settings, stats, and theme handler tests
- Add auth module tests (CAS, OAuth, password, SSO, state)
- Add service layer tests for auth, export, permissions, roles
- All Go tests pass (exit code 0)
- All frontend tests pass (325 tests in 59 files)
2026-04-17 20:43:50 +08:00

576 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
var deviceTestCounter int64
// openDeviceTestDB 为每个测试打开独立的内存数据库
func openDeviceTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&deviceTestCounter, 1)
dsn := fmt.Sprintf("file:devtestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.Device{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
// setupDeviceTestDB 兼容性别名
func setupDeviceTestDB(t *testing.T) *gorm.DB {
return openDeviceTestDB(t)
}
// TestDeviceRepository_Create 测试创建设备
func TestDeviceRepository_Create(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "test-device-001",
DeviceName: "测试手机",
DeviceType: domain.DeviceTypeMobile,
Status: domain.DeviceStatusActive,
}
if err := repo.Create(ctx, device); err != nil {
t.Fatalf("Create() error = %v", err)
}
if device.ID == 0 {
t.Error("创建后设备ID不应为0")
}
}
// TestDeviceRepository_GetByID 测试根据ID获取设备
func TestDeviceRepository_GetByID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "test-device-002",
DeviceName: "测试平板",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
found, err := repo.GetByID(ctx, device.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if found.DeviceID != "test-device-002" {
t.Errorf("DeviceID = %v, want test-device-002", found.DeviceID)
}
}
// TestDeviceRepository_GetByDeviceID 测试根据设备标识查询
func TestDeviceRepository_GetByDeviceID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "unique-device-id",
DeviceName: "测试设备",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
found, err := repo.GetByDeviceID(ctx, 1, "unique-device-id")
if err != nil {
t.Fatalf("GetByDeviceID() error = %v", err)
}
if found.UserID != 1 {
t.Errorf("UserID = %v, want 1", found.UserID)
}
}
// TestDeviceRepository_Update 测试更新设备
func TestDeviceRepository_Update(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "update-test",
DeviceName: "旧名称",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
device.DeviceName = "新名称"
if err := repo.Update(ctx, device); err != nil {
t.Fatalf("Update() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if found.DeviceName != "新名称" {
t.Errorf("DeviceName = %v, want 新名称", found.DeviceName)
}
}
// TestDeviceRepository_Delete 测试删除设备
func TestDeviceRepository_Delete(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "delete-test",
DeviceName: "待删除",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
if err := repo.Delete(ctx, device.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
_, err := repo.GetByID(ctx, device.ID)
if err == nil {
t.Error("删除后查询应返回错误")
}
}
// TestDeviceRepository_List 测试列表查询
func TestDeviceRepository_List(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
for i := 0; i < 3; i++ {
repo.Create(ctx, &domain.Device{
UserID: int64(i + 1),
DeviceID: "list-device-" + string(rune('a'+i)),
Status: domain.DeviceStatusActive,
})
}
devices, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(devices) != 3 {
t.Errorf("len(devices) = %d, want 3", len(devices))
}
if total != 3 {
t.Errorf("total = %d, want 3", total)
}
}
// TestDeviceRepository_ListByUserID 测试按用户ID查询设备列表
func TestDeviceRepository_ListByUserID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive})
devices, total, err := repo.ListByUserID(ctx, 1, 0, 10)
if err != nil {
t.Fatalf("ListByUserID() error = %v", err)
}
if len(devices) != 2 {
t.Errorf("len(devices) = %d, want 2", len(devices))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
// TestDeviceRepository_ListByStatus 测试按状态查询设备列表
func TestDeviceRepository_ListByStatus(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "active2", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 3, DeviceID: "inactive1", Status: domain.DeviceStatusInactive})
devices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10)
if err != nil {
t.Fatalf("ListByStatus() error = %v", err)
}
if len(devices) != 2 {
t.Errorf("len(devices) = %d, want 2", len(devices))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
// TestDeviceRepository_UpdateStatus 测试更新设备状态
func TestDeviceRepository_UpdateStatus(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "status-test",
DeviceName: "状态测试",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
err := repo.UpdateStatus(ctx, device.ID, domain.DeviceStatusInactive)
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if found.Status != domain.DeviceStatusInactive {
t.Errorf("Status = %v, want Inactive", found.Status)
}
}
// TestDeviceRepository_Exists 测试设备存在性检查
func TestDeviceRepository_Exists(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "exists-test",
DeviceName: "存在性测试",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
exists, err := repo.Exists(ctx, 1, "exists-test")
if err != nil {
t.Fatalf("Exists() error = %v", err)
}
if !exists {
t.Error("Exists 应返回 true")
}
exists, _ = repo.Exists(ctx, 1, "not-exists")
if exists {
t.Error("不存在的设备 Exists 应返回 false")
}
}
// TestDeviceRepository_DeleteByUserID 测试删除用户的所有设备
func TestDeviceRepository_DeleteByUserID(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive})
err := repo.DeleteByUserID(ctx, 1)
if err != nil {
t.Fatalf("DeleteByUserID() error = %v", err)
}
devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10)
if len(devices) != 0 {
t.Errorf("用户1设备数 = %d, want 0", len(devices))
}
// 用户2的设备应该还在
devices, _, _ = repo.ListByUserID(ctx, 2, 0, 10)
if len(devices) != 1 {
t.Errorf("用户2设备数 = %d, want 1", len(devices))
}
}
// TestDeviceRepository_GetActiveDevices 测试获取活跃设备
func TestDeviceRepository_GetActiveDevices(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
now := time.Now()
// 创建设备并设置 LastActiveTimeGetActiveDevices 不检查状态,只检查最近活跃时间)
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active-dev1", Status: domain.DeviceStatusActive, LastActiveTime: now})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "recent-dev", Status: domain.DeviceStatusInactive, LastActiveTime: now})
devices, err := repo.GetActiveDevices(ctx, 1)
if err != nil {
t.Fatalf("GetActiveDevices() error = %v", err)
}
// GetActiveDevices 只检查 last_active_time > 30天前不检查 status
if len(devices) != 2 {
t.Errorf("len(devices) = %d, want 2", len(devices))
}
}
// TestDeviceRepository_TrustDevice 测试设置设备信任
func TestDeviceRepository_TrustDevice(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "trust-test",
DeviceName: "信任测试",
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
expiresAt := time.Now().Add(30 * 24 * time.Hour)
err := repo.TrustDevice(ctx, device.ID, &expiresAt)
if err != nil {
t.Fatalf("TrustDevice() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if !found.IsTrusted {
t.Error("IsTrusted 应为 true")
}
}
// TestDeviceRepository_UntrustDevice 测试取消设备信任
func TestDeviceRepository_UntrustDevice(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
device := &domain.Device{
UserID: 1,
DeviceID: "untrust-test",
DeviceName: "取消信任测试",
IsTrusted: true,
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, device)
err := repo.UntrustDevice(ctx, device.ID)
if err != nil {
t.Fatalf("UntrustDevice() error = %v", err)
}
found, _ := repo.GetByID(ctx, device.ID)
if found.IsTrusted {
t.Error("IsTrusted 应为 false")
}
}
// TestDeviceRepository_DeleteAllByUserIDExcept 测试删除用户设备(保留指定设备)
func TestDeviceRepository_DeleteAllByUserIDExcept(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
d1, _ := createDevice(t, repo, ctx, 1, "keep-me")
createDevice(t, repo, ctx, 1, "delete-me1")
createDevice(t, repo, ctx, 1, "delete-me2")
err := repo.DeleteAllByUserIDExcept(ctx, 1, d1.ID)
if err != nil {
t.Fatalf("DeleteAllByUserIDExcept() error = %v", err)
}
devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10)
if len(devices) != 1 {
t.Errorf("len(devices) = %d, want 1", len(devices))
}
if devices[0].ID != d1.ID {
t.Error("应保留指定设备")
}
}
// TestDeviceRepository_GetTrustedDevices 测试获取信任设备列表
func TestDeviceRepository_GetTrustedDevices(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
trusted := &domain.Device{
UserID: 1,
DeviceID: "trusted-device",
IsTrusted: true,
Status: domain.DeviceStatusActive,
}
untrusted := &domain.Device{
UserID: 1,
DeviceID: "untrusted-device",
IsTrusted: false,
Status: domain.DeviceStatusActive,
}
repo.Create(ctx, trusted)
repo.Create(ctx, untrusted)
devices, err := repo.GetTrustedDevices(ctx, 1)
if err != nil {
t.Fatalf("GetTrustedDevices() error = %v", err)
}
if len(devices) != 1 {
t.Errorf("len(devices) = %d, want 1", len(devices))
}
}
// TestDeviceRepository_ListAll 测试带筛选条件的列表查询
func TestDeviceRepository_ListAll(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev1", Status: domain.DeviceStatusActive})
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev2", Status: domain.DeviceStatusInactive})
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "dev3", Status: domain.DeviceStatusActive})
// 按用户筛选
params := &ListDevicesParams{UserID: 1, Offset: 0, Limit: 10}
_, total, err := repo.ListAll(ctx, params)
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
// 按状态筛选
status := domain.DeviceStatusActive
params2 := &ListDevicesParams{Status: &status, Offset: 0, Limit: 10}
_, total2, err := repo.ListAll(ctx, params2)
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if total2 != 2 {
t.Errorf("total = %d, want 2", total2)
}
}
// createDevice 辅助函数:创建设备
func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
d := &domain.Device{
UserID: userID,
DeviceID: deviceID,
Status: domain.DeviceStatusActive,
}
err := repo.Create(ctx, d)
if err != nil {
t.Fatalf("createDevice() error = %v", err)
}
return d, nil
}
// TestDeviceRepository_ListAllCursor 测试设备游标分页查询
func TestDeviceRepository_ListAllCursor(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
// 创建设备需要设置LastActiveTime以支持游标分页
now := time.Now()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.Device{
UserID: int64(i + 1),
DeviceID: "cursor-device-" + string(rune('a'+i)),
DeviceName: "设备" + string(rune('0'+i)),
Status: domain.DeviceStatusActive,
LastActiveTime: now.Add(-time.Duration(i) * time.Minute),
})
}
// 第一次查询获取前3个
devices, hasMore, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, nil)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices) != 3 {
t.Errorf("len(devices) = %d, want 3", len(devices))
}
if !hasMore {
t.Error("hasMore should be true when more devices exist")
}
// 使用游标继续查询
lastDevice := devices[len(devices)-1]
cursor := &pagination.Cursor{
LastID: lastDevice.ID,
LastValue: lastDevice.LastActiveTime,
}
devices2, hasMore2, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, cursor)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices2) != 2 {
t.Errorf("len(devices2) = %d, want 2", len(devices2))
}
if hasMore2 {
t.Error("hasMore2 should be false")
}
}
// TestDeviceRepository_ListAllCursor_WithFilters 测试带筛选条件的设备游标分页
func TestDeviceRepository_ListAllCursor_WithFilters(t *testing.T) {
db := setupDeviceTestDB(t)
repo := NewDeviceRepository(db)
ctx := context.Background()
now := time.Now()
repo.Create(ctx, &domain.Device{
UserID: 1,
DeviceID: "filter-dev1",
DeviceName: "用户1设备",
Status: domain.DeviceStatusActive,
LastActiveTime: now,
})
repo.Create(ctx, &domain.Device{
UserID: 2,
DeviceID: "filter-dev2",
DeviceName: "用户2设备",
Status: domain.DeviceStatusActive,
LastActiveTime: now,
})
repo.Create(ctx, &domain.Device{
UserID: 1,
DeviceID: "filter-dev3",
DeviceName: "用户1禁用设备",
Status: domain.DeviceStatusInactive,
LastActiveTime: now,
})
// 按用户ID筛选
status := domain.DeviceStatusActive
devices, _, err := repo.ListAllCursor(ctx, &ListDevicesParams{UserID: 1, Status: &status, Offset: 0, Limit: 10}, 10, nil)
if err != nil {
t.Fatalf("ListAllCursor() error = %v", err)
}
if len(devices) != 1 {
t.Errorf("len(devices) = %d, want 1", len(devices))
}
}