Files
user-system/internal/repository/login_log_repository_test.go
long-agent 582ad7a069 test: add comprehensive test coverage and improve code quality
- Add new test files for auth, service, and handler modules
- Improve test organization and coverage
- Refactor code for better maintainability
- Add captcha, settings, stats, and theme handler tests
- Add auth module tests (CAS, OAuth, password, SSO, state)
- Add service layer tests for auth, export, permissions, roles
- All Go tests pass (exit code 0)
- All frontend tests pass (325 tests in 59 files)
2026-04-17 20:43:50 +08:00

208 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
var loginLogTestCounter int64
func openLoginLogTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&loginLogTestCounter, 1)
dsn := fmt.Sprintf("file:loginlogtestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
if err := db.AutoMigrate(&domain.LoginLog{}); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}
func setupLoginLogTestDB(t *testing.T) *gorm.DB {
return openLoginLogTestDB(t)
}
func TestLoginLogRepository_ListCursor(t *testing.T) {
db := setupLoginLogTestDB(t)
repo := NewLoginLogRepository(db)
ctx := context.Background()
now := time.Now()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(int64(i + 1)),
LoginType: 1,
IP: "192.168.1." + string(rune('0'+i)),
Status: 1,
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
})
}
// 第一次查询获取前3个
logs, hasMore, err := repo.ListCursor(ctx, 3, nil)
if err != nil {
t.Fatalf("ListCursor() error = %v", err)
}
if len(logs) != 3 {
t.Errorf("len(logs) = %d, want 3", len(logs))
}
if !hasMore {
t.Error("hasMore should be true when more logs exist")
}
// 使用游标继续查询
lastLog := logs[len(logs)-1]
cursor := &pagination.Cursor{
LastID: lastLog.ID,
LastValue: lastLog.CreatedAt,
}
logs2, hasMore2, err := repo.ListCursor(ctx, 3, cursor)
if err != nil {
t.Fatalf("ListCursor() error = %v", err)
}
if len(logs2) != 2 {
t.Errorf("len(logs2) = %d, want 2", len(logs2))
}
if hasMore2 {
t.Error("hasMore2 should be false")
}
}
func TestLoginLogRepository_ListByUserIDCursor(t *testing.T) {
db := setupLoginLogTestDB(t)
repo := NewLoginLogRepository(db)
ctx := context.Background()
userID := int64(123)
now := time.Now()
for i := 0; i < 3; i++ {
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(userID),
LoginType: 1,
IP: "192.168.1." + string(rune('0'+i)),
Status: 1,
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
})
}
// 另一个用户的日志
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(999),
LoginType: 1,
IP: "10.0.0.1",
Status: 1,
})
// 查询指定用户的日志
logs, hasMore, err := repo.ListByUserIDCursor(ctx, userID, 10, nil)
if err != nil {
t.Fatalf("ListByUserIDCursor() error = %v", err)
}
if len(logs) != 3 {
t.Errorf("len(logs) = %d, want 3", len(logs))
}
if hasMore {
t.Error("hasMore should be false")
}
}
func TestLoginLogRepository_ListAllForExport(t *testing.T) {
db := setupLoginLogTestDB(t)
repo := NewLoginLogRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(1),
LoginType: 1,
IP: "192.168.1.1",
Status: 1,
})
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(2),
LoginType: 2,
IP: "192.168.1.2",
Status: 0,
FailReason: "invalid password",
})
logs, err := repo.ListAllForExport(ctx, 0, -1, nil, nil)
if err != nil {
t.Fatalf("ListAllForExport() error = %v", err)
}
if len(logs) != 2 {
t.Errorf("len(logs) = %d, want 2", len(logs))
}
}
func TestLoginLogRepository_ListLogsForExportBatch(t *testing.T) {
db := setupLoginLogTestDB(t)
repo := NewLoginLogRepository(db)
ctx := context.Background()
// 创建多个日志
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.LoginLog{
UserID: int64Ptr(1),
LoginType: 1,
IP: "192.168.1." + string(rune('0'+i)),
Status: 1,
})
}
// 测试批量导出使用cursor分页
// 初始查询使用一个很大的cursor来获取所有记录
logs, hasMore, err := repo.ListLogsForExportBatch(ctx, 0, -1, nil, nil, 999999, 3)
if err != nil {
t.Fatalf("ListLogsForExportBatch() error = %v", err)
}
if len(logs) != 3 {
t.Errorf("len(logs) = %d, want 3", len(logs))
}
if !hasMore {
t.Error("hasMore should be true when more logs exist")
}
// 使用cursor继续查询使用最后一条记录的ID
lastID := logs[len(logs)-1].ID
logs2, hasMore2, err := repo.ListLogsForExportBatch(ctx, 0, -1, nil, nil, lastID, 3)
if err != nil {
t.Fatalf("ListLogsForExportBatch() error = %v", err)
}
if len(logs2) != 2 {
t.Errorf("len(logs2) = %d, want 2", len(logs2))
}
if hasMore2 {
t.Error("hasMore2 should be false")
}
// 测试按用户ID筛选
logs3, _, err := repo.ListLogsForExportBatch(ctx, 1, -1, nil, nil, 999999, 10)
if err != nil {
t.Fatalf("ListLogsForExportBatch() error = %v", err)
}
if len(logs3) != 5 {
t.Errorf("len(logs3) = %d, want 5", len(logs3))
}
}