fix/status-review-sync-20260409 #1
486
internal/repository/device_repository_test.go
Normal file
486
internal/repository/device_repository_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
var deviceTestCounter int64
|
||||
|
||||
// openDeviceTestDB 为每个测试打开独立的内存数据库
|
||||
func openDeviceTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
id := atomic.AddInt64(&deviceTestCounter, 1)
|
||||
dsn := fmt.Sprintf("file:devtestdb%d?mode=memory&cache=private", id)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("打开测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.Device{}); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// setupDeviceTestDB 兼容性别名
|
||||
func setupDeviceTestDB(t *testing.T) *gorm.DB {
|
||||
return openDeviceTestDB(t)
|
||||
}
|
||||
|
||||
// TestDeviceRepository_Create 测试创建设备
|
||||
func TestDeviceRepository_Create(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "test-device-001",
|
||||
DeviceName: "测试手机",
|
||||
DeviceType: domain.DeviceTypeMobile,
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, device); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if device.ID == 0 {
|
||||
t.Error("创建后设备ID不应为0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_GetByID 测试根据ID获取设备
|
||||
func TestDeviceRepository_GetByID(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "test-device-002",
|
||||
DeviceName: "测试平板",
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
found, err := repo.GetByID(ctx, device.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if found.DeviceID != "test-device-002" {
|
||||
t.Errorf("DeviceID = %v, want test-device-002", found.DeviceID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_GetByDeviceID 测试根据设备标识查询
|
||||
func TestDeviceRepository_GetByDeviceID(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "unique-device-id",
|
||||
DeviceName: "测试设备",
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
found, err := repo.GetByDeviceID(ctx, 1, "unique-device-id")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByDeviceID() error = %v", err)
|
||||
}
|
||||
if found.UserID != 1 {
|
||||
t.Errorf("UserID = %v, want 1", found.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_Update 测试更新设备
|
||||
func TestDeviceRepository_Update(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "update-test",
|
||||
DeviceName: "旧名称",
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
device.DeviceName = "新名称"
|
||||
if err := repo.Update(ctx, device); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, device.ID)
|
||||
if found.DeviceName != "新名称" {
|
||||
t.Errorf("DeviceName = %v, want 新名称", found.DeviceName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_Delete 测试删除设备
|
||||
func TestDeviceRepository_Delete(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "delete-test",
|
||||
DeviceName: "待删除",
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
if err := repo.Delete(ctx, device.ID); err != nil {
|
||||
t.Fatalf("Delete() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := repo.GetByID(ctx, device.ID)
|
||||
if err == nil {
|
||||
t.Error("删除后查询应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_List 测试列表查询
|
||||
func TestDeviceRepository_List(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
repo.Create(ctx, &domain.Device{
|
||||
UserID: int64(i + 1),
|
||||
DeviceID: "list-device-" + string(rune('a'+i)),
|
||||
Status: domain.DeviceStatusActive,
|
||||
})
|
||||
}
|
||||
|
||||
devices, total, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
if len(devices) != 3 {
|
||||
t.Errorf("len(devices) = %d, want 3", len(devices))
|
||||
}
|
||||
if total != 3 {
|
||||
t.Errorf("total = %d, want 3", total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_ListByUserID 测试按用户ID查询设备列表
|
||||
func TestDeviceRepository_ListByUserID(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive})
|
||||
|
||||
devices, total, err := repo.ListByUserID(ctx, 1, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByUserID() error = %v", err)
|
||||
}
|
||||
if len(devices) != 2 {
|
||||
t.Errorf("len(devices) = %d, want 2", len(devices))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("total = %d, want 2", total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_ListByStatus 测试按状态查询设备列表
|
||||
func TestDeviceRepository_ListByStatus(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active1", Status: domain.DeviceStatusActive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "active2", Status: domain.DeviceStatusActive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 3, DeviceID: "inactive1", Status: domain.DeviceStatusInactive})
|
||||
|
||||
devices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByStatus() error = %v", err)
|
||||
}
|
||||
if len(devices) != 2 {
|
||||
t.Errorf("len(devices) = %d, want 2", len(devices))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("total = %d, want 2", total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_UpdateStatus 测试更新设备状态
|
||||
func TestDeviceRepository_UpdateStatus(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "status-test",
|
||||
DeviceName: "状态测试",
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
err := repo.UpdateStatus(ctx, device.ID, domain.DeviceStatusInactive)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, device.ID)
|
||||
if found.Status != domain.DeviceStatusInactive {
|
||||
t.Errorf("Status = %v, want Inactive", found.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_Exists 测试设备存在性检查
|
||||
func TestDeviceRepository_Exists(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "exists-test",
|
||||
DeviceName: "存在性测试",
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
exists, err := repo.Exists(ctx, 1, "exists-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists() error = %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Exists 应返回 true")
|
||||
}
|
||||
|
||||
exists, _ = repo.Exists(ctx, 1, "not-exists")
|
||||
if exists {
|
||||
t.Error("不存在的设备 Exists 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_DeleteByUserID 测试删除用户的所有设备
|
||||
func TestDeviceRepository_DeleteByUserID(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev1", Status: domain.DeviceStatusActive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "user1-dev2", Status: domain.DeviceStatusActive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "user2-dev1", Status: domain.DeviceStatusActive})
|
||||
|
||||
err := repo.DeleteByUserID(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteByUserID() error = %v", err)
|
||||
}
|
||||
|
||||
devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10)
|
||||
if len(devices) != 0 {
|
||||
t.Errorf("用户1设备数 = %d, want 0", len(devices))
|
||||
}
|
||||
|
||||
// 用户2的设备应该还在
|
||||
devices, _, _ = repo.ListByUserID(ctx, 2, 0, 10)
|
||||
if len(devices) != 1 {
|
||||
t.Errorf("用户2设备数 = %d, want 1", len(devices))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_GetActiveDevices 测试获取活跃设备
|
||||
func TestDeviceRepository_GetActiveDevices(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
// 创建设备并设置 LastActiveTime(GetActiveDevices 不检查状态,只检查最近活跃时间)
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "active-dev1", Status: domain.DeviceStatusActive, LastActiveTime: now})
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "recent-dev", Status: domain.DeviceStatusInactive, LastActiveTime: now})
|
||||
|
||||
devices, err := repo.GetActiveDevices(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetActiveDevices() error = %v", err)
|
||||
}
|
||||
// GetActiveDevices 只检查 last_active_time > 30天前,不检查 status
|
||||
if len(devices) != 2 {
|
||||
t.Errorf("len(devices) = %d, want 2", len(devices))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_TrustDevice 测试设置设备信任
|
||||
func TestDeviceRepository_TrustDevice(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "trust-test",
|
||||
DeviceName: "信任测试",
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
expiresAt := time.Now().Add(30 * 24 * time.Hour)
|
||||
err := repo.TrustDevice(ctx, device.ID, &expiresAt)
|
||||
if err != nil {
|
||||
t.Fatalf("TrustDevice() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, device.ID)
|
||||
if !found.IsTrusted {
|
||||
t.Error("IsTrusted 应为 true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_UntrustDevice 测试取消设备信任
|
||||
func TestDeviceRepository_UntrustDevice(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
device := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "untrust-test",
|
||||
DeviceName: "取消信任测试",
|
||||
IsTrusted: true,
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, device)
|
||||
|
||||
err := repo.UntrustDevice(ctx, device.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("UntrustDevice() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, device.ID)
|
||||
if found.IsTrusted {
|
||||
t.Error("IsTrusted 应为 false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_DeleteAllByUserIDExcept 测试删除用户设备(保留指定设备)
|
||||
func TestDeviceRepository_DeleteAllByUserIDExcept(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
d1, _ := createDevice(t, repo, ctx, 1, "keep-me")
|
||||
createDevice(t, repo, ctx, 1, "delete-me1")
|
||||
createDevice(t, repo, ctx, 1, "delete-me2")
|
||||
|
||||
err := repo.DeleteAllByUserIDExcept(ctx, 1, d1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteAllByUserIDExcept() error = %v", err)
|
||||
}
|
||||
|
||||
devices, _, _ := repo.ListByUserID(ctx, 1, 0, 10)
|
||||
if len(devices) != 1 {
|
||||
t.Errorf("len(devices) = %d, want 1", len(devices))
|
||||
}
|
||||
if devices[0].ID != d1.ID {
|
||||
t.Error("应保留指定设备")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_GetTrustedDevices 测试获取信任设备列表
|
||||
func TestDeviceRepository_GetTrustedDevices(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
trusted := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "trusted-device",
|
||||
IsTrusted: true,
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
untrusted := &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "untrusted-device",
|
||||
IsTrusted: false,
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
repo.Create(ctx, trusted)
|
||||
repo.Create(ctx, untrusted)
|
||||
|
||||
devices, err := repo.GetTrustedDevices(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTrustedDevices() error = %v", err)
|
||||
}
|
||||
if len(devices) != 1 {
|
||||
t.Errorf("len(devices) = %d, want 1", len(devices))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_ListAll 测试带筛选条件的列表查询
|
||||
func TestDeviceRepository_ListAll(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev1", Status: domain.DeviceStatusActive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 1, DeviceID: "dev2", Status: domain.DeviceStatusInactive})
|
||||
repo.Create(ctx, &domain.Device{UserID: 2, DeviceID: "dev3", Status: domain.DeviceStatusActive})
|
||||
|
||||
// 按用户筛选
|
||||
params := &ListDevicesParams{UserID: 1, Offset: 0, Limit: 10}
|
||||
_, total, err := repo.ListAll(ctx, params)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAll() error = %v", err)
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("total = %d, want 2", total)
|
||||
}
|
||||
|
||||
// 按状态筛选
|
||||
status := domain.DeviceStatusActive
|
||||
params2 := &ListDevicesParams{Status: &status, Offset: 0, Limit: 10}
|
||||
_, total2, err := repo.ListAll(ctx, params2)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAll() error = %v", err)
|
||||
}
|
||||
if total2 != 2 {
|
||||
t.Errorf("total = %d, want 2", total2)
|
||||
}
|
||||
}
|
||||
|
||||
// createDevice 辅助函数:创建设备
|
||||
func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
|
||||
d := &domain.Device{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
err := repo.Create(ctx, d)
|
||||
if err != nil {
|
||||
t.Fatalf("createDevice() error = %v", err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
Reference in New Issue
Block a user