Files
user-system/internal/service/device_service_test.go
long-agent 2a18a6fb47 fix(n+1): 批量查询替代循环单查
- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量
- AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量
- 在 userRepositoryInterface 补充 GetByIDs 方法签名
2026-05-08 08:05:26 +08:00

601 lines
16 KiB
Go

package service_test
import (
"context"
"testing"
"time"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// =============================================================================
// Device Service Tests
// =============================================================================
func setupDeviceTestEnv(t *testing.T) (*service.DeviceService, *gorm.DB) {
t.Helper()
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: "file:device_test?mode=memory&cache=shared",
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("failed to connect database: %v", err)
}
if err := db.AutoMigrate(&domain.User{}, &domain.Device{}); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
// Create test user
db.Create(&domain.User{Username: "deviceuser", Status: domain.UserStatusActive})
deviceRepo := repository.NewDeviceRepository(db)
userRepo := repository.NewUserRepository(db)
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
return deviceSvc, db
}
func TestDeviceService_CreateDevice(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
t.Run("Create device success", func(t *testing.T) {
req := &service.CreateDeviceRequest{
DeviceID: "device001",
DeviceName: "Test Device",
DeviceType: int(domain.DeviceTypeDesktop),
DeviceOS: "Windows",
DeviceBrowser: "Chrome",
IP: "192.168.1.1",
Location: "Beijing",
}
device, err := svc.CreateDevice(ctx, 1, req)
if err != nil {
t.Fatalf("CreateDevice failed: %v", err)
}
if device.DeviceID != "device001" {
t.Errorf("Expected device ID 'device001', got %s", device.DeviceID)
}
})
t.Run("Create device for non-existent user", func(t *testing.T) {
req := &service.CreateDeviceRequest{
DeviceID: "device002",
}
_, err := svc.CreateDevice(ctx, 9999, req)
if err == nil {
t.Error("Expected error for non-existent user")
}
})
t.Run("Create duplicate device updates last active time", func(t *testing.T) {
req := &service.CreateDeviceRequest{
DeviceID: "device003",
DeviceName: "First",
}
svc.CreateDevice(ctx, 1, req)
// Create again with same device ID
req2 := &service.CreateDeviceRequest{
DeviceID: "device003",
DeviceName: "Second",
}
device, err := svc.CreateDevice(ctx, 1, req2)
if err != nil {
t.Fatalf("CreateDevice failed: %v", err)
}
// Should return existing device with first name (not updated)
if device.DeviceName != "First" {
t.Logf("Device name: %s", device.DeviceName)
}
})
}
func TestDeviceService_UpdateDevice(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
// Create device first
req := &service.CreateDeviceRequest{
DeviceID: "update_device",
DeviceName: "Original",
}
device, _ := svc.CreateDevice(ctx, 1, req)
t.Run("Update device success", func(t *testing.T) {
updateReq := &service.UpdateDeviceRequest{
DeviceName: "Updated",
DeviceOS: "macOS",
}
updated, err := svc.UpdateDevice(ctx, device.ID, updateReq)
if err != nil {
t.Fatalf("UpdateDevice failed: %v", err)
}
if updated.DeviceName != "Updated" {
t.Errorf("Expected name 'Updated', got %s", updated.DeviceName)
}
})
t.Run("Update non-existent device", func(t *testing.T) {
updateReq := &service.UpdateDeviceRequest{
DeviceName: "NotExist",
}
_, err := svc.UpdateDevice(ctx, 9999, updateReq)
if err == nil {
t.Error("Expected error for non-existent device")
}
})
}
func TestDeviceService_GetDevice(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "get_device",
}
device, _ := svc.CreateDevice(ctx, 1, req)
t.Run("Get device success", func(t *testing.T) {
got, err := svc.GetDevice(ctx, device.ID)
if err != nil {
t.Fatalf("GetDevice failed: %v", err)
}
if got.DeviceID != "get_device" {
t.Errorf("Expected device ID 'get_device', got %s", got.DeviceID)
}
})
}
func TestDeviceService_DeviceOwnershipAuthorization(t *testing.T) {
svc, db := setupDeviceTestEnv(t)
ctx := context.Background()
owner := &domain.User{Username: "device_owner", Status: domain.UserStatusActive}
if err := db.Create(owner).Error; err != nil {
t.Fatalf("create owner failed: %v", err)
}
actor := &domain.User{Username: "device_actor", Status: domain.UserStatusActive}
if err := db.Create(actor).Error; err != nil {
t.Fatalf("create actor failed: %v", err)
}
device, err := svc.CreateDevice(ctx, owner.ID, &service.CreateDeviceRequest{
DeviceID: "ownership_device",
DeviceName: "Owner Device",
})
if err != nil {
t.Fatalf("CreateDevice failed: %v", err)
}
t.Run("GetDeviceForActor forbids cross-user access", func(t *testing.T) {
_, err := svc.GetDeviceForActor(ctx, actor.ID, device.ID, false)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
})
t.Run("UpdateDeviceForActor forbids cross-user access", func(t *testing.T) {
_, err := svc.UpdateDeviceForActor(ctx, actor.ID, device.ID, false, &service.UpdateDeviceRequest{
DeviceName: "Hacked Name",
})
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
current, getErr := svc.GetDevice(ctx, device.ID)
if getErr != nil {
t.Fatalf("GetDevice failed: %v", getErr)
}
if current.DeviceName != "Owner Device" {
t.Fatalf("expected device name to remain unchanged, got %q", current.DeviceName)
}
})
t.Run("DeleteDeviceForActor forbids cross-user access", func(t *testing.T) {
err := svc.DeleteDeviceForActor(ctx, actor.ID, device.ID, false)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
if _, getErr := svc.GetDevice(ctx, device.ID); getErr != nil {
t.Fatalf("expected device to remain after forbidden delete, got %v", getErr)
}
})
t.Run("TrustDeviceForActor forbids cross-user access", func(t *testing.T) {
err := svc.TrustDeviceForActor(ctx, actor.ID, device.ID, false, time.Hour)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
current, getErr := svc.GetDevice(ctx, device.ID)
if getErr != nil {
t.Fatalf("GetDevice failed: %v", getErr)
}
if current.IsTrusted {
t.Fatal("expected device to remain untrusted")
}
})
t.Run("UpdateDeviceStatusForActor forbids cross-user access", func(t *testing.T) {
err := svc.UpdateDeviceStatusForActor(ctx, actor.ID, device.ID, false, domain.DeviceStatusInactive)
if !apierrors.IsForbidden(err) {
t.Fatalf("expected forbidden error, got %v", err)
}
current, getErr := svc.GetDevice(ctx, device.ID)
if getErr != nil {
t.Fatalf("GetDevice failed: %v", getErr)
}
if current.Status != domain.DeviceStatusActive {
t.Fatalf("expected device to remain active, got %d", current.Status)
}
})
t.Run("Admin can manage another users device", func(t *testing.T) {
got, err := svc.GetDeviceForActor(ctx, actor.ID, device.ID, true)
if err != nil {
t.Fatalf("expected admin access, got %v", err)
}
if got.ID != device.ID {
t.Fatalf("expected device id %d, got %d", device.ID, got.ID)
}
})
}
func TestDeviceService_GetUserDevices(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
// Create multiple devices
for i := 0; i < 3; i++ {
req := &service.CreateDeviceRequest{
DeviceID: string(rune('a' + i)),
}
svc.CreateDevice(ctx, 1, req)
}
t.Run("Get user devices", func(t *testing.T) {
devices, total, err := svc.GetUserDevices(ctx, 1, 1, 10)
if err != nil {
t.Fatalf("GetUserDevices failed: %v", err)
}
if total < 3 {
t.Errorf("Expected total >= 3, got %d", total)
}
if len(devices) < 3 {
t.Logf("Got %d devices", len(devices))
}
})
t.Run("Get user devices with default pagination", func(t *testing.T) {
_, _, err := svc.GetUserDevices(ctx, 1, 0, 0)
if err != nil {
t.Fatalf("GetUserDevices failed: %v", err)
}
})
}
func TestDeviceService_TrustDevice(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "trust_device",
}
device, _ := svc.CreateDevice(ctx, 1, req)
t.Run("Trust device success", func(t *testing.T) {
err := svc.TrustDevice(ctx, device.ID, 24*time.Hour)
if err != nil {
t.Fatalf("TrustDevice failed: %v", err)
}
})
t.Run("Trust non-existent device", func(t *testing.T) {
err := svc.TrustDevice(ctx, 9999, time.Hour)
if err == nil {
t.Error("Expected error for non-existent device")
}
})
t.Run("Untrust device", func(t *testing.T) {
err := svc.UntrustDevice(ctx, device.ID)
if err != nil {
t.Fatalf("UntrustDevice failed: %v", err)
}
})
}
func TestDeviceService_TrustDeviceByDeviceID(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "trust_by_id",
}
svc.CreateDevice(ctx, 1, req)
t.Run("Trust device by device ID", func(t *testing.T) {
err := svc.TrustDeviceByDeviceID(ctx, 1, "trust_by_id", time.Hour)
if err != nil {
t.Fatalf("TrustDeviceByDeviceID failed: %v", err)
}
})
t.Run("Trust non-existent device by device ID", func(t *testing.T) {
err := svc.TrustDeviceByDeviceID(ctx, 1, "not_exist", time.Hour)
if err == nil {
t.Error("Expected error for non-existent device")
}
})
}
func TestDeviceService_GetActiveDevices(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "active_device",
}
svc.CreateDevice(ctx, 1, req)
t.Run("Get active devices", func(t *testing.T) {
devices, _, err := svc.GetActiveDevices(ctx, 1, 10)
if err != nil {
t.Fatalf("GetActiveDevices failed: %v", err)
}
if len(devices) == 0 {
t.Log("No active devices")
}
})
}
func TestDeviceService_GetAllDevices(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "all_device",
}
svc.CreateDevice(ctx, 1, req)
t.Run("Get all devices", func(t *testing.T) {
req := &service.GetAllDevicesRequest{
Page: 1,
PageSize: 10,
}
devices, total, err := svc.GetAllDevices(ctx, req)
if err != nil {
t.Fatalf("GetAllDevices failed: %v", err)
}
if total < 1 {
t.Error("Expected at least 1 device")
}
_ = devices
})
t.Run("Get all devices with status filter", func(t *testing.T) {
status := int(domain.DeviceStatusActive)
req := &service.GetAllDevicesRequest{
Page: 1,
PageSize: 10,
Status: &status,
}
_, _, err := svc.GetAllDevices(ctx, req)
if err != nil {
t.Fatalf("GetAllDevices failed: %v", err)
}
})
t.Run("Get all devices with trusted filter", func(t *testing.T) {
isTrusted := true
req := &service.GetAllDevicesRequest{
Page: 1,
PageSize: 10,
IsTrusted: &isTrusted,
}
_, _, err := svc.GetAllDevices(ctx, req)
if err != nil {
t.Fatalf("GetAllDevices failed: %v", err)
}
})
}
func TestDeviceService_DeleteDevice(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "delete_device",
}
device, _ := svc.CreateDevice(ctx, 1, req)
t.Run("Delete device", func(t *testing.T) {
err := svc.DeleteDevice(ctx, device.ID)
if err != nil {
t.Fatalf("DeleteDevice failed: %v", err)
}
})
}
func TestDeviceService_UpdateDeviceStatus(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "status_device",
}
device, _ := svc.CreateDevice(ctx, 1, req)
t.Run("Update device status", func(t *testing.T) {
err := svc.UpdateDeviceStatus(ctx, device.ID, domain.DeviceStatusInactive)
if err != nil {
t.Fatalf("UpdateDeviceStatus failed: %v", err)
}
})
}
func TestDeviceService_GetTrustedDevices(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "trusted_device",
}
device, _ := svc.CreateDevice(ctx, 1, req)
svc.TrustDevice(ctx, device.ID, time.Hour)
t.Run("Get trusted devices", func(t *testing.T) {
devices, err := svc.GetTrustedDevices(ctx, 1)
if err != nil {
t.Fatalf("GetTrustedDevices failed: %v", err)
}
if len(devices) == 0 {
t.Log("No trusted devices")
}
})
}
func TestDeviceService_UpdateLastActiveTime(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "last_active_device",
}
device, _ := svc.CreateDevice(ctx, 1, req)
t.Run("Update last active time", func(t *testing.T) {
err := svc.UpdateLastActiveTime(ctx, device.ID)
if err != nil {
t.Fatalf("UpdateLastActiveTime failed: %v", err)
}
})
t.Run("Update last active time for non-existent device", func(t *testing.T) {
err := svc.UpdateLastActiveTime(ctx, 9999)
// May not return error depending on implementation
_ = err
})
}
func TestDeviceService_LogoutAllOtherDevices(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
// Create multiple devices
var firstDeviceID int64
for i := 0; i < 3; i++ {
req := &service.CreateDeviceRequest{
DeviceID: "logout_device_" + string(rune('a'+i)),
}
device, _ := svc.CreateDevice(ctx, 1, req)
if i == 0 {
firstDeviceID = device.ID
}
}
t.Run("Logout all other devices", func(t *testing.T) {
err := svc.LogoutAllOtherDevices(ctx, 1, firstDeviceID)
// May not return error
_ = err
t.Logf("LogoutAllOtherDevices returned: %v", err)
})
}
func TestDeviceService_GetAllDevicesCursor(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
// Create multiple devices
for i := 0; i < 5; i++ {
req := &service.CreateDeviceRequest{
DeviceID: "cursor_device_" + string(rune('a'+i)),
}
svc.CreateDevice(ctx, 1, req)
}
t.Run("Get all devices with cursor", func(t *testing.T) {
req := &service.GetAllDevicesRequest{
Cursor: "",
Size: 3,
}
resp, err := svc.GetAllDevicesCursor(ctx, req)
if err != nil {
t.Fatalf("GetAllDevicesCursor failed: %v", err)
}
if resp == nil {
t.Error("Expected response")
}
})
}
func TestDeviceService_GetDeviceByDeviceID(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
req := &service.CreateDeviceRequest{
DeviceID: "get_by_device_id",
}
svc.CreateDevice(ctx, 1, req)
t.Run("Get device by device ID", func(t *testing.T) {
device, err := svc.GetDeviceByDeviceID(ctx, 1, "get_by_device_id")
if err != nil {
t.Fatalf("GetDeviceByDeviceID failed: %v", err)
}
if device.DeviceID != "get_by_device_id" {
t.Errorf("Expected device ID 'get_by_device_id', got %s", device.DeviceID)
}
})
t.Run("Get non-existent device by device ID", func(t *testing.T) {
_, err := svc.GetDeviceByDeviceID(ctx, 1, "not_exist")
if err == nil {
t.Error("Expected error for non-existent device")
}
})
}
// =============================================================================
// Get Active Devices Extended Tests
// =============================================================================
func TestDeviceService_GetActiveDevices_Extended(t *testing.T) {
svc, _ := setupDeviceTestEnv(t)
ctx := context.Background()
t.Run("Get active devices with pagination", func(t *testing.T) {
// Create some devices
for i := 0; i < 5; i++ {
req := &service.CreateDeviceRequest{
DeviceID: "active_device_paged_" + string(rune('0'+i)),
DeviceName: "Device " + string(rune('0'+i)),
}
svc.CreateDevice(ctx, 1, req)
}
devices, total, err := svc.GetActiveDevices(ctx, 1, 3)
if err != nil {
t.Fatalf("GetActiveDevices failed: %v", err)
}
if len(devices) > 3 {
t.Errorf("Expected at most 3 devices, got %d", len(devices))
}
_ = total
})
}