package repository import ( "context" "testing" "time" "gorm.io/gorm" "github.com/user-management-system/internal/domain" ) func migrateRepositoryTables(t *testing.T, db *gorm.DB, tables ...interface{}) { t.Helper() if err := db.AutoMigrate(tables...); err != nil { t.Fatalf("migrate repository tables failed: %v", err) } } func int64Ptr(v int64) *int64 { return &v } func TestDeviceRepositoryLifecycleAndQueries(t *testing.T) { db := openTestDB(t) migrateRepositoryTables(t, db, &domain.Device{}) repo := NewDeviceRepository(db) ctx := context.Background() now := time.Now().UTC() devices := []*domain.Device{ { UserID: 1, DeviceID: "device-alpha", DeviceName: "Alpha", DeviceType: domain.DeviceTypeDesktop, DeviceOS: "Windows", DeviceBrowser: "Chrome", IP: "10.0.0.1", Location: "Shanghai", Status: domain.DeviceStatusActive, LastActiveTime: now.Add(-1 * time.Hour), }, { UserID: 1, DeviceID: "device-beta", DeviceName: "Beta", DeviceType: domain.DeviceTypeWeb, DeviceOS: "macOS", DeviceBrowser: "Safari", IP: "10.0.0.2", Location: "Hangzhou", Status: domain.DeviceStatusInactive, LastActiveTime: now.Add(-2 * time.Hour), }, { UserID: 2, DeviceID: "device-gamma", DeviceName: "Gamma", DeviceType: domain.DeviceTypeMobile, DeviceOS: "Android", DeviceBrowser: "WebView", IP: "10.0.0.3", Location: "Beijing", Status: domain.DeviceStatusActive, LastActiveTime: now.Add(-40 * 24 * time.Hour), }, } for _, device := range devices { if err := repo.Create(ctx, device); err != nil { t.Fatalf("Create(%s) failed: %v", device.DeviceID, err) } } if allDevices, total, err := repo.List(ctx, 0, 10); err != nil { t.Fatalf("List failed: %v", err) } else if total != 3 || len(allDevices) != 3 { t.Fatalf("expected 3 devices, got total=%d len=%d", total, len(allDevices)) } loadedByDeviceID, err := repo.GetByDeviceID(ctx, 1, "device-beta") if err != nil { t.Fatalf("GetByDeviceID failed: %v", err) } if loadedByDeviceID.DeviceName != "Beta" { t.Fatalf("expected device name Beta, got %q", loadedByDeviceID.DeviceName) } exists, err := repo.Exists(ctx, 1, "device-alpha") if err != nil { t.Fatalf("Exists(device-alpha) failed: %v", err) } if !exists { t.Fatal("expected device-alpha to exist") } missing, err := repo.Exists(ctx, 1, "missing-device") if err != nil { t.Fatalf("Exists(missing-device) failed: %v", err) } if missing { t.Fatal("expected missing-device to be absent") } userDevices, total, err := repo.ListByUserID(ctx, 1, 0, 10) if err != nil { t.Fatalf("ListByUserID failed: %v", err) } if total != 2 || len(userDevices) != 2 { t.Fatalf("expected 2 devices for user 1, got total=%d len=%d", total, len(userDevices)) } if userDevices[0].DeviceID != "device-alpha" { t.Fatalf("expected latest active device first, got %q", userDevices[0].DeviceID) } activeDevices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10) if err != nil { t.Fatalf("ListByStatus failed: %v", err) } if total != 2 || len(activeDevices) != 2 { t.Fatalf("expected 2 active devices, got total=%d len=%d", total, len(activeDevices)) } if err := repo.UpdateStatus(ctx, devices[1].ID, domain.DeviceStatusActive); err != nil { t.Fatalf("UpdateStatus failed: %v", err) } beforeTouch, err := repo.GetByID(ctx, devices[1].ID) if err != nil { t.Fatalf("GetByID before UpdateLastActiveTime failed: %v", err) } time.Sleep(10 * time.Millisecond) if err := repo.UpdateLastActiveTime(ctx, devices[1].ID); err != nil { t.Fatalf("UpdateLastActiveTime failed: %v", err) } afterTouch, err := repo.GetByID(ctx, devices[1].ID) if err != nil { t.Fatalf("GetByID after UpdateLastActiveTime failed: %v", err) } if !afterTouch.LastActiveTime.After(beforeTouch.LastActiveTime) { t.Fatal("expected last_active_time to move forward") } recentDevices, err := repo.GetActiveDevices(ctx, 1) if err != nil { t.Fatalf("GetActiveDevices failed: %v", err) } if len(recentDevices) != 2 { t.Fatalf("expected 2 recent devices for user 1, got %d", len(recentDevices)) } if err := repo.DeleteByUserID(ctx, 1); err != nil { t.Fatalf("DeleteByUserID failed: %v", err) } remainingDevices, remainingTotal, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List after DeleteByUserID failed: %v", err) } if remainingTotal != 1 || len(remainingDevices) != 1 { t.Fatalf("expected 1 remaining device, got total=%d len=%d", remainingTotal, len(remainingDevices)) } if err := repo.Delete(ctx, devices[2].ID); err != nil { t.Fatalf("Delete failed: %v", err) } if _, err := repo.GetByID(ctx, devices[2].ID); err == nil { t.Fatal("expected deleted device lookup to fail") } } func TestLoginLogRepositoryQueriesAndRetention(t *testing.T) { db := openTestDB(t) migrateRepositoryTables(t, db, &domain.LoginLog{}) repo := NewLoginLogRepository(db) ctx := context.Background() now := time.Now().UTC() logs := []*domain.LoginLog{ { UserID: int64Ptr(1), LoginType: int(domain.LoginTypePassword), DeviceID: "device-alpha", IP: "10.0.0.1", Location: "Shanghai", Status: 1, CreatedAt: now.Add(-1 * time.Hour), }, { UserID: int64Ptr(1), LoginType: int(domain.LoginTypeSMSCode), DeviceID: "device-beta", IP: "10.0.0.2", Location: "Hangzhou", Status: 0, FailReason: "code expired", CreatedAt: now.Add(-30 * time.Minute), }, { UserID: int64Ptr(2), LoginType: int(domain.LoginTypeOAuth), DeviceID: "device-gamma", IP: "10.0.0.3", Location: "Beijing", Status: 1, CreatedAt: now.Add(-45 * 24 * time.Hour), }, } for _, log := range logs { if err := repo.Create(ctx, log); err != nil { t.Fatalf("Create login log failed: %v", err) } } loaded, err := repo.GetByID(ctx, logs[0].ID) if err != nil { t.Fatalf("GetByID failed: %v", err) } if loaded.DeviceID != "device-alpha" { t.Fatalf("expected device-alpha, got %q", loaded.DeviceID) } userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10) if err != nil { t.Fatalf("ListByUserID failed: %v", err) } if total != 2 || len(userLogs) != 2 { t.Fatalf("expected 2 user logs, got total=%d len=%d", total, len(userLogs)) } if userLogs[0].DeviceID != "device-beta" { t.Fatalf("expected newest login log first, got %q", userLogs[0].DeviceID) } allLogs, total, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List failed: %v", err) } if total != 3 || len(allLogs) != 3 { t.Fatalf("expected 3 total logs, got total=%d len=%d", total, len(allLogs)) } successLogs, total, err := repo.ListByStatus(ctx, 1, 0, 10) if err != nil { t.Fatalf("ListByStatus failed: %v", err) } if total != 2 || len(successLogs) != 2 { t.Fatalf("expected 2 success logs, got total=%d len=%d", total, len(successLogs)) } recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-2*time.Hour), now, 0, 10) if err != nil { t.Fatalf("ListByTimeRange failed: %v", err) } if total != 2 || len(recentLogs) != 2 { t.Fatalf("expected 2 recent logs, got total=%d len=%d", total, len(recentLogs)) } if count := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); count != 1 { t.Fatalf("expected 1 recent success login, got %d", count) } if count := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); count != 1 { t.Fatalf("expected 1 recent failed login, got %d", count) } if err := repo.DeleteOlderThan(ctx, 30); err != nil { t.Fatalf("DeleteOlderThan failed: %v", err) } retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List after DeleteOlderThan failed: %v", err) } if retainedTotal != 2 || len(retainedLogs) != 2 { t.Fatalf("expected 2 retained logs, got total=%d len=%d", retainedTotal, len(retainedLogs)) } if err := repo.DeleteByUserID(ctx, 1); err != nil { t.Fatalf("DeleteByUserID failed: %v", err) } finalLogs, finalTotal, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List after DeleteByUserID failed: %v", err) } if finalTotal != 0 || len(finalLogs) != 0 { t.Fatalf("expected all logs removed, got total=%d len=%d", finalTotal, len(finalLogs)) } } func TestPasswordHistoryRepositoryKeepsNewestRecords(t *testing.T) { db := openTestDB(t) migrateRepositoryTables(t, db, &domain.PasswordHistory{}) repo := NewPasswordHistoryRepository(db) ctx := context.Background() now := time.Now().UTC() histories := []*domain.PasswordHistory{ {UserID: 1, PasswordHash: "hash-1", CreatedAt: now.Add(-4 * time.Hour)}, {UserID: 1, PasswordHash: "hash-2", CreatedAt: now.Add(-3 * time.Hour)}, {UserID: 1, PasswordHash: "hash-3", CreatedAt: now.Add(-2 * time.Hour)}, {UserID: 1, PasswordHash: "hash-4", CreatedAt: now.Add(-1 * time.Hour)}, {UserID: 2, PasswordHash: "hash-foreign", CreatedAt: now.Add(-30 * time.Minute)}, } for _, history := range histories { if err := repo.Create(ctx, history); err != nil { t.Fatalf("Create password history failed: %v", err) } } latestTwo, err := repo.GetByUserID(ctx, 1, 2) if err != nil { t.Fatalf("GetByUserID(limit=2) failed: %v", err) } if len(latestTwo) != 2 { t.Fatalf("expected 2 latest password histories, got %d", len(latestTwo)) } if latestTwo[0].PasswordHash != "hash-4" || latestTwo[1].PasswordHash != "hash-3" { t.Fatalf("expected newest password hashes to be retained, got %q and %q", latestTwo[0].PasswordHash, latestTwo[1].PasswordHash) } if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil { t.Fatalf("DeleteOldRecords failed: %v", err) } remainingHistories, err := repo.GetByUserID(ctx, 1, 10) if err != nil { t.Fatalf("GetByUserID after DeleteOldRecords failed: %v", err) } if len(remainingHistories) != 2 { t.Fatalf("expected 2 remaining histories, got %d", len(remainingHistories)) } if remainingHistories[0].PasswordHash != "hash-4" || remainingHistories[1].PasswordHash != "hash-3" { t.Fatalf("unexpected remaining password hashes: %q and %q", remainingHistories[0].PasswordHash, remainingHistories[1].PasswordHash) } if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil { t.Fatalf("DeleteOldRecords for missing user failed: %v", err) } } func TestOperationLogRepositorySearchAndRetention(t *testing.T) { db := openTestDB(t) migrateRepositoryTables(t, db, &domain.OperationLog{}) repo := NewOperationLogRepository(db) ctx := context.Background() now := time.Now().UTC() logs := []*domain.OperationLog{ { UserID: int64Ptr(1), OperationType: "user", OperationName: "create user", RequestMethod: "POST", RequestPath: "/api/v1/users", RequestParams: `{"username":"alice"}`, ResponseStatus: 201, IP: "10.0.0.1", UserAgent: "Chrome", CreatedAt: now.Add(-20 * time.Minute), }, { UserID: int64Ptr(1), OperationType: "dashboard", OperationName: "view dashboard", RequestMethod: "GET", RequestPath: "/dashboard", RequestParams: "{}", ResponseStatus: 200, IP: "10.0.0.2", UserAgent: "Chrome", CreatedAt: now.Add(-10 * time.Minute), }, { UserID: int64Ptr(2), OperationType: "user", OperationName: "delete user", RequestMethod: "DELETE", RequestPath: "/api/v1/users/7", RequestParams: "{}", ResponseStatus: 204, IP: "10.0.0.3", UserAgent: "Firefox", CreatedAt: now.Add(-40 * 24 * time.Hour), }, } for _, log := range logs { if err := repo.Create(ctx, log); err != nil { t.Fatalf("Create operation log failed: %v", err) } } loaded, err := repo.GetByID(ctx, logs[0].ID) if err != nil { t.Fatalf("GetByID failed: %v", err) } if loaded.OperationName != "create user" { t.Fatalf("expected create user log, got %q", loaded.OperationName) } userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10) if err != nil { t.Fatalf("ListByUserID failed: %v", err) } if total != 2 || len(userLogs) != 2 { t.Fatalf("expected 2 user operation logs, got total=%d len=%d", total, len(userLogs)) } if userLogs[0].OperationName != "view dashboard" { t.Fatalf("expected newest operation log first, got %q", userLogs[0].OperationName) } allLogs, total, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List failed: %v", err) } if total != 3 || len(allLogs) != 3 { t.Fatalf("expected 3 total operation logs, got total=%d len=%d", total, len(allLogs)) } postLogs, total, err := repo.ListByMethod(ctx, "POST", 0, 10) if err != nil { t.Fatalf("ListByMethod failed: %v", err) } if total != 1 || len(postLogs) != 1 || postLogs[0].OperationName != "create user" { t.Fatalf("expected a single POST operation log, got total=%d len=%d", total, len(postLogs)) } recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-1*time.Hour), now, 0, 10) if err != nil { t.Fatalf("ListByTimeRange failed: %v", err) } if total != 2 || len(recentLogs) != 2 { t.Fatalf("expected 2 recent operation logs, got total=%d len=%d", total, len(recentLogs)) } searchResults, total, err := repo.Search(ctx, "user", 0, 10) if err != nil { t.Fatalf("Search failed: %v", err) } if total != 2 || len(searchResults) != 2 { t.Fatalf("expected 2 operation logs matching user, got total=%d len=%d", total, len(searchResults)) } if err := repo.DeleteOlderThan(ctx, 30); err != nil { t.Fatalf("DeleteOlderThan failed: %v", err) } retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10) if err != nil { t.Fatalf("List after DeleteOlderThan failed: %v", err) } if retainedTotal != 2 || len(retainedLogs) != 2 { t.Fatalf("expected 2 retained operation logs, got total=%d len=%d", retainedTotal, len(retainedLogs)) } }