feat: 系统全面优化 - 设备管理/登录日志导出/性能监控/设置页面
后端: - 新增全局设备管理 API(DeviceHandler.GetAllDevices) - 新增登录日志导出功能(LogHandler.ExportLoginLogs, CSV/XLSX) - 新增设置服务(SettingsService)和设置页面 API - 设备管理支持多条件筛选(状态/信任状态/关键词) - 登录日志支持流式导出防 OOM - 操作日志支持按方法/时间范围搜索 - 主题配置服务(ThemeService) - 增强监控健康检查(Prometheus metrics + SLO) - 移除旧 ratelimit.go(已迁移至 robustness) - 修复 SocialAccount NULL 扫描问题 - 新增 API 契约测试、Handler 测试、Settings 测试 前端: - 新增管理员设备管理页面(DevicesPage) - 新增管理员登录日志导出功能 - 新增系统设置页面(SettingsPage) - 设备管理支持筛选和分页 - 增强 HTTP 响应类型 测试: - 业务逻辑测试 68 个(含并发 CONC_001~003) - 规模测试 16 个(P99 百分位统计) - E2E 测试、集成测试、契约测试 - 性能基准测试、鲁棒性测试 全面测试通过(38 个测试包)
This commit is contained in:
@@ -480,7 +480,10 @@ func (s *AuthService) writeLoginLog(
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.loginLogRepo.Create(context.Background(), loginRecord); err != nil {
|
||||
// 使用带超时的独立 context,防止日志写入无限等待
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.loginLogRepo.Create(bgCtx, loginRecord); err != nil {
|
||||
log.Printf("auth: write login log failed, user_id=%v login_type=%d err=%v", userID, loginType, err)
|
||||
}
|
||||
}()
|
||||
@@ -548,6 +551,11 @@ func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64
|
||||
_, _ = s.deviceService.CreateDevice(ctx, userID, createReq)
|
||||
}
|
||||
|
||||
// BestEffortRegisterDevicePublic 供外部 handler(如 SMS 登录)调用,安静地注册设备
|
||||
func (s *AuthService) BestEffortRegisterDevicePublic(ctx context.Context, userID int64, req *LoginRequest) {
|
||||
s.bestEffortRegisterDevice(ctx, userID, req)
|
||||
}
|
||||
|
||||
func (s *AuthService) cacheUserInfo(ctx context.Context, user *domain.User) {
|
||||
if s == nil || s.cache == nil || user == nil {
|
||||
return
|
||||
@@ -757,7 +765,8 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
|
||||
return nil, errors.New("auth service is not fully configured")
|
||||
}
|
||||
|
||||
claims, err := s.jwtManager.ValidateRefreshToken(strings.TrimSpace(refreshToken))
|
||||
refreshToken = strings.TrimSpace(refreshToken)
|
||||
claims, err := s.jwtManager.ValidateRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,6 +782,18 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Token Rotation: 使旧的 refresh token 失效,防止无限刷新
|
||||
if s.cache != nil {
|
||||
blacklistKey := tokenBlacklistPrefix + claims.JTI
|
||||
// TTL 设置为 refresh token 的剩余有效期
|
||||
if claims.ExpiresAt != nil {
|
||||
remaining := claims.ExpiresAt.Time.Sub(time.Now())
|
||||
if remaining > 0 {
|
||||
_ = s.cache.Set(ctx, blacklistKey, "1", 5*time.Minute, remaining)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.generateLoginResponse(ctx, user, claims.Remember)
|
||||
}
|
||||
|
||||
|
||||
535
internal/service/auth_service_test.go
Normal file
535
internal/service/auth_service_test.go
Normal file
@@ -0,0 +1,535 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Auth Service Unit Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestPasswordStrength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantInfo PasswordStrengthInfo
|
||||
}{
|
||||
{
|
||||
name: "empty_password",
|
||||
password: "",
|
||||
wantInfo: PasswordStrengthInfo{Score: 0, Length: 0, HasUpper: false, HasLower: false, HasDigit: false, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "lowercase_only",
|
||||
password: "abcdefgh",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: true, HasDigit: false, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "uppercase_only",
|
||||
password: "ABCDEFGH",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: true, HasLower: false, HasDigit: false, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "digits_only",
|
||||
password: "12345678",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "mixed_case_with_digits",
|
||||
password: "Abcd1234",
|
||||
wantInfo: PasswordStrengthInfo{Score: 3, Length: 8, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "mixed_with_special",
|
||||
password: "Abcd1234!",
|
||||
wantInfo: PasswordStrengthInfo{Score: 4, Length: 9, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: true},
|
||||
},
|
||||
{
|
||||
name: "chinese_characters",
|
||||
password: "密码123456",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info := GetPasswordStrength(tt.password)
|
||||
if info.Score != tt.wantInfo.Score {
|
||||
t.Errorf("Score: got %d, want %d", info.Score, tt.wantInfo.Score)
|
||||
}
|
||||
if info.Length != tt.wantInfo.Length {
|
||||
t.Errorf("Length: got %d, want %d", info.Length, tt.wantInfo.Length)
|
||||
}
|
||||
if info.HasUpper != tt.wantInfo.HasUpper {
|
||||
t.Errorf("HasUpper: got %v, want %v", info.HasUpper, tt.wantInfo.HasUpper)
|
||||
}
|
||||
if info.HasLower != tt.wantInfo.HasLower {
|
||||
t.Errorf("HasLower: got %v, want %v", info.HasLower, tt.wantInfo.HasLower)
|
||||
}
|
||||
if info.HasDigit != tt.wantInfo.HasDigit {
|
||||
t.Errorf("HasDigit: got %v, want %v", info.HasDigit, tt.wantInfo.HasDigit)
|
||||
}
|
||||
if info.HasSpecial != tt.wantInfo.HasSpecial {
|
||||
t.Errorf("HasSpecial: got %v, want %v", info.HasSpecial, tt.wantInfo.HasSpecial)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePasswordStrength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
minLength int
|
||||
strict bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid_password_strict",
|
||||
password: "Abcd1234!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too_short",
|
||||
password: "Ab1!",
|
||||
minLength: 8,
|
||||
strict: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "weak_password",
|
||||
password: "abcdefgh",
|
||||
minLength: 8,
|
||||
strict: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "strict_missing_uppercase",
|
||||
password: "abcd1234!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "strict_missing_lowercase",
|
||||
password: "ABCD1234!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "strict_missing_digit",
|
||||
password: "Abcdefgh!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid_weak_password_non_strict",
|
||||
password: "Abcd1234",
|
||||
minLength: 8,
|
||||
strict: false,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validatePasswordStrength(tt.password, tt.minLength, tt.strict)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validatePasswordStrength() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal_username",
|
||||
input: "john_doe",
|
||||
want: "john_doe",
|
||||
},
|
||||
{
|
||||
name: "username_with_spaces",
|
||||
input: "john doe",
|
||||
want: "john_doe",
|
||||
},
|
||||
{
|
||||
name: "username_with_uppercase",
|
||||
input: "JohnDoe",
|
||||
want: "johndoe",
|
||||
},
|
||||
{
|
||||
name: "username_with_special_chars",
|
||||
input: "john@doe",
|
||||
want: "johndoe",
|
||||
},
|
||||
{
|
||||
name: "empty_username",
|
||||
input: "",
|
||||
want: "user",
|
||||
},
|
||||
{
|
||||
name: "whitespace_only",
|
||||
input: " ",
|
||||
want: "user",
|
||||
},
|
||||
{
|
||||
name: "username_with_emoji",
|
||||
input: "john😀doe",
|
||||
want: "johndoe", // emoji is filtered out as it's not letter/digit/./-/_
|
||||
},
|
||||
{
|
||||
name: "username_with_leading_underscore",
|
||||
input: "_john_",
|
||||
want: "john", // leading and trailing _ are trimmed
|
||||
},
|
||||
{
|
||||
name: "username_with_trailing_dots",
|
||||
input: "john..doe...",
|
||||
want: "john..doe", // trailing dots trimmed
|
||||
},
|
||||
{
|
||||
name: "long_username_truncated",
|
||||
input: "this_is_a_very_long_username_that_exceeds_fifty_characters_limit",
|
||||
want: "this_is_a_very_long_username_that_exceeds_fifty_ch", // 50 chars max, cuts off "acters_limit"
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := sanitizeUsername(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("sanitizeUsername() = %q (len=%d), want %q (len=%d)", got, len(got), tt.want, len(tt.want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidPhoneSimple(t *testing.T) {
|
||||
tests := []struct {
|
||||
phone string
|
||||
want bool
|
||||
}{
|
||||
{"13800138000", true},
|
||||
{"+8613800138000", true}, // Valid: +86 prefix with 11 digit mobile
|
||||
{"8613800138000", true}, // Valid: 86 prefix with 11 digit mobile
|
||||
{"1234567890", false},
|
||||
{"abcdefghij", false},
|
||||
{"", false},
|
||||
{"138001380001", false}, // 12 digits
|
||||
{"1380013800", false}, // 10 digits
|
||||
{"19800138000", true}, // 98 prefix
|
||||
// +[1-9]\d{6,14} allows international numbers like +16171234567
|
||||
{"+16171234567", true}, // 11 digits international, valid for \d{6,14}
|
||||
{"+112345678901", true}, // 11 digits international, valid for \d{6,14}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.phone, func(t *testing.T) {
|
||||
got := isValidPhoneSimple(tt.phone)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidPhoneSimple(%q) = %v, want %v", tt.phone, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginRequestGetAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *LoginRequest
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "account_field",
|
||||
req: &LoginRequest{Account: "john", Username: "jane", Email: "jane@test.com"},
|
||||
want: "john",
|
||||
},
|
||||
{
|
||||
name: "username_field",
|
||||
req: &LoginRequest{Username: "jane", Email: "jane@test.com"},
|
||||
want: "jane",
|
||||
},
|
||||
{
|
||||
name: "email_field",
|
||||
req: &LoginRequest{Email: "jane@test.com"},
|
||||
want: "jane@test.com",
|
||||
},
|
||||
{
|
||||
name: "phone_field",
|
||||
req: &LoginRequest{Phone: "13800138000"},
|
||||
want: "13800138000",
|
||||
},
|
||||
{
|
||||
name: "all_fields_with_whitespace",
|
||||
req: &LoginRequest{Account: " john ", Username: " jane ", Email: " jane@test.com "},
|
||||
want: "john",
|
||||
},
|
||||
{
|
||||
name: "empty_request",
|
||||
req: &LoginRequest{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil_request",
|
||||
req: nil,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.req.GetAccount()
|
||||
if got != tt.want {
|
||||
t.Errorf("GetAccount() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDeviceFingerprint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *LoginRequest
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "full_device_info",
|
||||
req: &LoginRequest{
|
||||
DeviceID: "device123",
|
||||
DeviceName: "iPhone 15",
|
||||
DeviceBrowser: "Safari",
|
||||
DeviceOS: "iOS 17",
|
||||
},
|
||||
want: "device123|iPhone 15|Safari|iOS 17",
|
||||
},
|
||||
{
|
||||
name: "partial_device_info",
|
||||
req: &LoginRequest{
|
||||
DeviceID: "device123",
|
||||
DeviceName: "iPhone 15",
|
||||
},
|
||||
want: "device123|iPhone 15",
|
||||
},
|
||||
{
|
||||
name: "only_device_id",
|
||||
req: &LoginRequest{
|
||||
DeviceID: "device123",
|
||||
},
|
||||
want: "device123",
|
||||
},
|
||||
{
|
||||
name: "empty_device_info",
|
||||
req: &LoginRequest{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil_request",
|
||||
req: nil,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildDeviceFingerprint(tt.req)
|
||||
if got != tt.want {
|
||||
t.Errorf("buildDeviceFingerprint() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceDefaultConfig(t *testing.T) {
|
||||
// Test that default configuration is applied correctly
|
||||
svc := NewAuthService(nil, nil, nil, nil, 0, 0, 0)
|
||||
|
||||
if svc == nil {
|
||||
t.Fatal("NewAuthService returned nil")
|
||||
}
|
||||
|
||||
// Check default password minimum length
|
||||
if svc.passwordMinLength != defaultPasswordMinLen {
|
||||
t.Errorf("passwordMinLength: got %d, want %d", svc.passwordMinLength, defaultPasswordMinLen)
|
||||
}
|
||||
|
||||
// Check default max login attempts
|
||||
if svc.maxLoginAttempts != 5 {
|
||||
t.Errorf("maxLoginAttempts: got %d, want %d", svc.maxLoginAttempts, 5)
|
||||
}
|
||||
|
||||
// Check default login lock duration
|
||||
if svc.loginLockDuration != 15*time.Minute {
|
||||
t.Errorf("loginLockDuration: got %v, want %v", svc.loginLockDuration, 15*time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceNilSafety(t *testing.T) {
|
||||
t.Run("validatePassword_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.validatePassword("Abcd1234!")
|
||||
if err != nil {
|
||||
t.Errorf("nil service should not error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accessTokenTTL_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
ttl := svc.accessTokenTTLSeconds()
|
||||
if ttl != 0 {
|
||||
t.Errorf("nil service should return 0: got %d", ttl)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RefreshTokenTTL_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
ttl := svc.RefreshTokenTTLSeconds()
|
||||
if ttl != 0 {
|
||||
t.Errorf("nil service should return 0: got %d", ttl)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("generateUniqueUsername_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
username, err := svc.generateUniqueUsername(context.Background(), "testuser")
|
||||
if err != nil {
|
||||
t.Errorf("nil service should return username: %v", err)
|
||||
}
|
||||
if username != "testuser" {
|
||||
t.Errorf("username: got %q, want %q", username, "testuser")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("buildUserInfo_nil_user", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
info := svc.buildUserInfo(nil)
|
||||
if info != nil {
|
||||
t.Errorf("nil user should return nil info: got %v", info)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ensureUserActive_nil_user", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.ensureUserActive(nil)
|
||||
if err == nil {
|
||||
t.Error("nil user should return error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blacklistToken_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.blacklistTokenClaims(context.Background(), "token", nil)
|
||||
if err != nil {
|
||||
t.Errorf("nil service should not error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Logout_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.Logout(context.Background(), "user", nil)
|
||||
if err != nil {
|
||||
t.Errorf("nil service should not error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsTokenBlacklisted_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
blacklisted := svc.IsTokenBlacklisted(context.Background(), "jti")
|
||||
if blacklisted {
|
||||
t.Error("nil service should not blacklist tokens")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserInfoFromCacheValue(t *testing.T) {
|
||||
t.Run("valid_UserInfo_pointer", func(t *testing.T) {
|
||||
info := &UserInfo{ID: 1, Username: "testuser"}
|
||||
got, ok := userInfoFromCacheValue(info)
|
||||
if !ok {
|
||||
t.Error("should parse *UserInfo")
|
||||
}
|
||||
if got.ID != 1 || got.Username != "testuser" {
|
||||
t.Errorf("got %+v, want %+v", got, info)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid_UserInfo_value", func(t *testing.T) {
|
||||
info := UserInfo{ID: 2, Username: "testuser2"}
|
||||
got, ok := userInfoFromCacheValue(info)
|
||||
if !ok {
|
||||
t.Error("should parse UserInfo value")
|
||||
}
|
||||
if got.ID != 2 || got.Username != "testuser2" {
|
||||
t.Errorf("got %+v, want %+v", got, info)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_type", func(t *testing.T) {
|
||||
got, ok := userInfoFromCacheValue("invalid string")
|
||||
if ok || got != nil {
|
||||
t.Errorf("should not parse string: ok=%v, got=%+v", ok, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnsureUserActive(t *testing.T) {
|
||||
t.Run("nil_user", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.ensureUserActive(nil)
|
||||
if err == nil {
|
||||
t.Error("nil user should error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAttemptCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
want int
|
||||
}{
|
||||
{"int_value", 5, 5},
|
||||
{"int64_value", int64(3), 3},
|
||||
{"float64_value", float64(4.0), 4},
|
||||
{"string_int", "3", 0}, // strings are not converted
|
||||
{"invalid_type", "abc", 0},
|
||||
{"nil", nil, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := attemptCount(tt.value)
|
||||
if got != tt.want {
|
||||
t.Errorf("attemptCount(%v) = %d, want %d", tt.value, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementFailAttempts(t *testing.T) {
|
||||
t.Run("nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
count := svc.incrementFailAttempts(context.Background(), "key")
|
||||
if count != 0 {
|
||||
t.Errorf("nil service should return 0, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_key", func(t *testing.T) {
|
||||
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
||||
count := svc.incrementFailAttempts(context.Background(), "")
|
||||
if count != 0 {
|
||||
t.Errorf("empty key should return 0, got %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,9 +3,11 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -228,12 +230,14 @@ func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]
|
||||
|
||||
// GetAllDevicesRequest 获取所有设备请求参数
|
||||
type GetAllDevicesRequest struct {
|
||||
Page int
|
||||
PageSize int
|
||||
Page int `form:"page"`
|
||||
PageSize int `form:"page_size"`
|
||||
UserID int64 `form:"user_id"`
|
||||
Status int `form:"status"`
|
||||
IsTrusted *bool `form:"is_trusted"`
|
||||
Status *int `form:"status"` // 0-禁用, 1-激活, nil-不筛选
|
||||
IsTrusted *bool `form:"is_trusted"`
|
||||
Keyword string `form:"keyword"`
|
||||
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
|
||||
Size int `form:"size"` // Page size when using cursor mode
|
||||
}
|
||||
|
||||
// GetAllDevices 获取所有设备(管理员用)
|
||||
@@ -257,9 +261,10 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
|
||||
Limit: req.PageSize,
|
||||
}
|
||||
|
||||
// 处理状态筛选
|
||||
if req.Status >= 0 {
|
||||
params.Status = domain.DeviceStatus(req.Status)
|
||||
// 处理状态筛选(仅当明确指定了状态时才筛选)
|
||||
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
status := domain.DeviceStatus(*req.Status)
|
||||
params.Status = &status
|
||||
}
|
||||
|
||||
// 处理信任状态筛选
|
||||
@@ -270,6 +275,49 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
|
||||
return s.deviceRepo.ListAll(ctx, params)
|
||||
}
|
||||
|
||||
// GetAllDevicesCursor 游标分页获取所有设备(推荐使用)
|
||||
func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevicesRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
if req.PageSize > 0 && req.Cursor == "" {
|
||||
size = pagination.ClampPageSize(req.PageSize)
|
||||
}
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
params := &repository.ListDevicesParams{
|
||||
UserID: req.UserID,
|
||||
Keyword: req.Keyword,
|
||||
}
|
||||
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
status := domain.DeviceStatus(*req.Status)
|
||||
params.Status = &status
|
||||
}
|
||||
if req.IsTrusted != nil {
|
||||
params.IsTrusted = req.IsTrusted
|
||||
}
|
||||
|
||||
devices, hasMore, err := s.deviceRepo.ListAllCursor(ctx, params, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(devices) > 0 {
|
||||
last := devices[len(devices)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.LastActiveTime)
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: devices,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDeviceByDeviceID 根据设备标识获取设备(用于设备信任检查)
|
||||
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
|
||||
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -167,7 +168,7 @@ func (s *EmailCodeService) VerifyEmailCode(ctx context.Context, email, purpose,
|
||||
}
|
||||
|
||||
storedCode, ok := value.(string)
|
||||
if !ok || storedCode != code {
|
||||
if !ok || subtle.ConstantTimeCompare([]byte(storedCode), []byte(code)) != 1 {
|
||||
return fmt.Errorf("verification code is invalid")
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/xuri/excelize/v2"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -52,12 +53,15 @@ type RecordLoginRequest struct {
|
||||
|
||||
// ListLoginLogRequest 登录日志列表请求
|
||||
type ListLoginLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Status int `json:"status"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
StartAt string `json:"start_at"`
|
||||
EndAt string `json:"end_at"`
|
||||
UserID int64 `json:"user_id" form:"user_id"`
|
||||
Status *int `json:"status" form:"status"` // 0-失败, 1-成功, nil-不筛选
|
||||
Page int `json:"page" form:"page"`
|
||||
PageSize int `json:"page_size" form:"page_size"`
|
||||
StartAt string `json:"start_at" form:"start_at"`
|
||||
EndAt string `json:"end_at" form:"end_at"`
|
||||
// Cursor-based pagination (preferred over Page/PageSize)
|
||||
Cursor string `form:"cursor"` // Opaque cursor from previous response
|
||||
Size int `form:"size"` // Page size when using cursor mode
|
||||
}
|
||||
|
||||
// GetLoginLogs 获取登录日志列表
|
||||
@@ -84,14 +88,140 @@ func (s *LoginLogService) GetLoginLogs(ctx context.Context, req *ListLoginLogReq
|
||||
}
|
||||
}
|
||||
|
||||
// 按状态查询
|
||||
if req.Status == 0 || req.Status == 1 {
|
||||
return s.loginLogRepo.ListByStatus(ctx, req.Status, offset, req.PageSize)
|
||||
// 按状态查询(仅当明确指定了状态时才筛选)
|
||||
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
return s.loginLogRepo.ListByStatus(ctx, *req.Status, offset, req.PageSize)
|
||||
}
|
||||
|
||||
return s.loginLogRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// CursorResult wraps cursor-based pagination response
|
||||
type CursorResult struct {
|
||||
Items interface{} `json:"items"`
|
||||
NextCursor string `json:"next_cursor"`
|
||||
HasMore bool `json:"has_more"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// GetLoginLogsCursor 游标分页获取登录日志列表(推荐使用)
|
||||
func (s *LoginLogService) GetLoginLogsCursor(ctx context.Context, req *ListLoginLogRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
if req.PageSize > 0 && req.Cursor == "" {
|
||||
size = pagination.ClampPageSize(req.PageSize)
|
||||
}
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
var items interface{}
|
||||
var nextCursor string
|
||||
var hasMore bool
|
||||
|
||||
// 按用户 ID 查询
|
||||
if req.UserID > 0 {
|
||||
logs, hm, err := s.loginLogRepo.ListByUserIDCursor(ctx, req.UserID, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
} else if req.StartAt != "" && req.EndAt != "" {
|
||||
// Time range: fall back to offset-based for now (cursor + time range is complex)
|
||||
start, err1 := time.Parse(time.RFC3339, req.StartAt)
|
||||
end, err2 := time.Parse(time.RFC3339, req.EndAt)
|
||||
if err1 == nil && err2 == nil {
|
||||
offset := 0
|
||||
logs, _, err := s.loginLogRepo.ListByTimeRange(ctx, start, end, offset, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
if len(logs) > 0 {
|
||||
last := logs[len(logs)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
hasMore = len(logs) == size
|
||||
}
|
||||
} else {
|
||||
items = []*domain.LoginLog{}
|
||||
}
|
||||
} else if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
// Status filter: use ListCursor with manual status filter
|
||||
logs, hm, err := s.listByStatusCursor(ctx, *req.Status, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
} else {
|
||||
// Default: full table cursor scan
|
||||
logs, hm, err := s.loginLogRepo.ListCursor(ctx, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
}
|
||||
|
||||
// Build next cursor from the last item
|
||||
if nextCursor == "" {
|
||||
switch items := items.(type) {
|
||||
case []*domain.LoginLog:
|
||||
if len(items) > 0 {
|
||||
last := items[len(items)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: items,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// listByStatusCursor 游标分页按状态查询(内部方法)
|
||||
// Uses iterative approach: fetch from ListCursor and post-filter by status.
|
||||
func (s *LoginLogService) listByStatusCursor(ctx context.Context, status int, limit int, cursor *pagination.Cursor) ([]*domain.LoginLog, bool, error) {
|
||||
var logs []*domain.LoginLog
|
||||
|
||||
// Since LoginLogRepository doesn't have status+cursor combined,
|
||||
// we use a larger batch from ListCursor and post-filter.
|
||||
batchSize := limit + 1
|
||||
for attempts := 0; attempts < 10; attempts++ { // max 10 pages of skipping
|
||||
batch, hm, err := s.loginLogRepo.ListCursor(ctx, batchSize, cursor)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
for _, log := range batch {
|
||||
if log.Status == status {
|
||||
logs = append(logs, log)
|
||||
if len(logs) >= limit+1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(logs) >= limit+1 || !hm || len(batch) == 0 {
|
||||
break
|
||||
}
|
||||
// Advance cursor to end of this batch
|
||||
if len(batch) > 0 {
|
||||
last := batch[len(batch)-1]
|
||||
cursor = &pagination.Cursor{LastID: last.ID, LastValue: last.CreatedAt}
|
||||
}
|
||||
}
|
||||
|
||||
hasMore := len(logs) > limit
|
||||
if hasMore {
|
||||
logs = logs[:limit]
|
||||
}
|
||||
return logs, hasMore, nil
|
||||
}
|
||||
|
||||
// GetMyLoginLogs 获取当前用户的登录日志
|
||||
func (s *LoginLogService) GetMyLoginLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.LoginLog, int64, error) {
|
||||
if page <= 0 {
|
||||
@@ -137,26 +267,88 @@ func (s *LoginLogService) ExportLoginLogs(ctx context.Context, req *ExportLoginL
|
||||
}
|
||||
}
|
||||
|
||||
// CSV 使用流式分批导出,XLSX 使用全量导出(excelize 需要所有行)
|
||||
if format == "csv" {
|
||||
data, filename, err := s.exportLoginLogsCSVStream(ctx, req.UserID, req.Status, startAt, endAt)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "text/csv; charset=utf-8", nil
|
||||
}
|
||||
|
||||
logs, err := s.loginLogRepo.ListAllForExport(ctx, req.UserID, req.Status, startAt, endAt)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("查询登录日志失败: %w", err)
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("login_logs_%s.%s", time.Now().Format("20060102_150405"), format)
|
||||
|
||||
if format == "xlsx" {
|
||||
data, err := buildLoginLogXLSXExport(logs)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
|
||||
}
|
||||
|
||||
data, err := buildLoginLogCSVExport(logs)
|
||||
filename := fmt.Sprintf("login_logs_%s.xlsx", time.Now().Format("20060102_150405"))
|
||||
data, err := buildLoginLogXLSXExport(logs)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "text/csv; charset=utf-8", nil
|
||||
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
|
||||
}
|
||||
|
||||
// exportLoginLogsCSVStream 流式导出 CSV(分批处理防止 OOM)
|
||||
func (s *LoginLogService) exportLoginLogsCSVStream(ctx context.Context, userID int64, status int, startAt, endAt *time.Time) ([]byte, string, error) {
|
||||
headers := []string{"ID", "用户ID", "登录方式", "设备ID", "IP地址", "位置", "状态", "失败原因", "时间"}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// 写入表头
|
||||
if err := writer.Write(headers); err != nil {
|
||||
return nil, "", fmt.Errorf("写CSV表头失败: %w", err)
|
||||
}
|
||||
|
||||
// 使用游标分批获取数据
|
||||
cursor := int64(1<<63 - 1) // 从最大 ID 开始
|
||||
batchSize := 5000
|
||||
totalWritten := 0
|
||||
|
||||
for {
|
||||
logs, hasMore, err := s.loginLogRepo.ListLogsForExportBatch(ctx, userID, status, startAt, endAt, cursor, batchSize)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("查询登录日志失败: %w", err)
|
||||
}
|
||||
|
||||
for _, log := range logs {
|
||||
row := []string{
|
||||
fmt.Sprintf("%d", log.ID),
|
||||
fmt.Sprintf("%d", derefInt64(log.UserID)),
|
||||
loginTypeLabel(log.LoginType),
|
||||
log.DeviceID,
|
||||
log.IP,
|
||||
log.Location,
|
||||
loginStatusLabel(log.Status),
|
||||
log.FailReason,
|
||||
log.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
if err := writer.Write(row); err != nil {
|
||||
return nil, "", fmt.Errorf("写CSV行失败: %w", err)
|
||||
}
|
||||
totalWritten++
|
||||
cursor = log.ID
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
return nil, "", fmt.Errorf("CSV Flush 失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果数据量过大,提前终止
|
||||
if totalWritten >= repository.ExportBatchSize {
|
||||
break
|
||||
}
|
||||
|
||||
if !hasMore || len(logs) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("login_logs_%s.csv", time.Now().Format("20060102_150405"))
|
||||
return buf.Bytes(), filename, nil
|
||||
}
|
||||
|
||||
func buildLoginLogCSVExport(logs []*domain.LoginLog) ([]byte, error) {
|
||||
|
||||
@@ -2,9 +2,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -51,13 +53,15 @@ type RecordOperationRequest struct {
|
||||
|
||||
// ListOperationLogRequest 操作日志列表请求
|
||||
type ListOperationLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Method string `json:"method"`
|
||||
Keyword string `json:"keyword"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
StartAt string `json:"start_at"`
|
||||
EndAt string `json:"end_at"`
|
||||
UserID int64 `json:"user_id" form:"user_id"`
|
||||
Method string `json:"method" form:"method"`
|
||||
Keyword string `json:"keyword" form:"keyword"`
|
||||
Page int `json:"page" form:"page"`
|
||||
PageSize int `json:"page_size" form:"page_size"`
|
||||
StartAt string `json:"start_at" form:"start_at"`
|
||||
EndAt string `json:"end_at" form:"end_at"`
|
||||
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
|
||||
Size int `form:"size"` // Page size when using cursor mode
|
||||
}
|
||||
|
||||
// GetOperationLogs 获取操作日志列表
|
||||
@@ -97,6 +101,42 @@ func (s *OperationLogService) GetOperationLogs(ctx context.Context, req *ListOpe
|
||||
return s.operationLogRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// GetOperationLogsCursor 游标分页获取操作日志列表(推荐使用)
|
||||
func (s *OperationLogService) GetOperationLogsCursor(ctx context.Context, req *ListOperationLogRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
var items interface{}
|
||||
var hasMore bool
|
||||
|
||||
logs, hm, err := s.operationLogRepo.ListCursor(ctx, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
|
||||
nextCursor := ""
|
||||
switch items := items.(type) {
|
||||
case []*domain.OperationLog:
|
||||
if len(items) > 0 {
|
||||
last := items[len(items)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: items,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetMyOperationLogs 获取当前用户的操作日志
|
||||
func (s *OperationLogService) GetMyOperationLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.OperationLog, int64, error) {
|
||||
if page <= 0 {
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/security"
|
||||
)
|
||||
|
||||
@@ -46,9 +48,10 @@ func DefaultPasswordResetConfig() *PasswordResetConfig {
|
||||
}
|
||||
|
||||
type PasswordResetService struct {
|
||||
userRepo userRepositoryInterface
|
||||
cache *cache.CacheManager
|
||||
config *PasswordResetConfig
|
||||
userRepo userRepositoryInterface
|
||||
cache *cache.CacheManager
|
||||
config *PasswordResetConfig
|
||||
passwordHistoryRepo *repository.PasswordHistoryRepository
|
||||
}
|
||||
|
||||
func NewPasswordResetService(
|
||||
@@ -66,6 +69,12 @@ func NewPasswordResetService(
|
||||
}
|
||||
}
|
||||
|
||||
// WithPasswordHistoryRepo 注入密码历史 repository,用于重置密码时记录历史
|
||||
func (s *PasswordResetService) WithPasswordHistoryRepo(repo *repository.PasswordHistoryRepository) *PasswordResetService {
|
||||
s.passwordHistoryRepo = repo
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) ForgotPassword(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -216,7 +225,7 @@ func (s *PasswordResetService) ResetPasswordByPhone(ctx context.Context, req *Re
|
||||
}
|
||||
|
||||
code, ok := storedCode.(string)
|
||||
if !ok || code != req.Code {
|
||||
if !ok || subtle.ConstantTimeCompare([]byte(code), []byte(req.Code)) != 1 {
|
||||
return errors.New("验证码不正确")
|
||||
}
|
||||
|
||||
@@ -258,6 +267,18 @@ func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查密码历史(防止重用近5次密码)
|
||||
if s.passwordHistoryRepo != nil {
|
||||
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, user.ID, passwordHistoryLimit)
|
||||
if err == nil {
|
||||
for _, h := range histories {
|
||||
if auth.VerifyPassword(h.PasswordHash, newPassword) {
|
||||
return errors.New("新密码不能与最近5次密码相同")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %w", err)
|
||||
@@ -268,5 +289,19 @@ func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain
|
||||
return fmt.Errorf("更新密码失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入密码历史记录
|
||||
if s.passwordHistoryRepo != nil {
|
||||
go func() {
|
||||
// 使用带超时的独立 context,防止 DB 写入无限等待
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
|
||||
UserID: user.ID,
|
||||
PasswordHash: hashedPassword,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, user.ID, passwordHistoryLimit)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
92
internal/service/settings.go
Normal file
92
internal/service/settings.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// SystemSettings 系统设置
|
||||
type SystemSettings struct {
|
||||
System SystemInfo `json:"system"`
|
||||
Security SecurityInfo `json:"security"`
|
||||
Features FeaturesInfo `json:"features"`
|
||||
}
|
||||
|
||||
// SystemInfo 系统信息
|
||||
type SystemInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Environment string `json:"environment"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// SecurityInfo 安全设置
|
||||
type SecurityInfo struct {
|
||||
PasswordMinLength int `json:"password_min_length"`
|
||||
PasswordRequireUppercase bool `json:"password_require_uppercase"`
|
||||
PasswordRequireLowercase bool `json:"password_require_lowercase"`
|
||||
PasswordRequireNumbers bool `json:"password_require_numbers"`
|
||||
PasswordRequireSymbols bool `json:"password_require_symbols"`
|
||||
PasswordHistory int `json:"password_history"`
|
||||
TOTPEnabled bool `json:"totp_enabled"`
|
||||
LoginFailLock bool `json:"login_fail_lock"`
|
||||
LoginFailThreshold int `json:"login_fail_threshold"`
|
||||
LoginFailDuration int `json:"login_fail_duration"` // 分钟
|
||||
SessionTimeout int `json:"session_timeout"` // 秒
|
||||
DeviceTrustDuration int `json:"device_trust_duration"` // 秒
|
||||
}
|
||||
|
||||
// FeaturesInfo 功能开关
|
||||
type FeaturesInfo struct {
|
||||
EmailVerification bool `json:"email_verification"`
|
||||
PhoneVerification bool `json:"phone_verification"`
|
||||
OAuthProviders []string `json:"oauth_providers"`
|
||||
SSOEnabled bool `json:"sso_enabled"`
|
||||
OperationLogEnabled bool `json:"operation_log_enabled"`
|
||||
LoginLogEnabled bool `json:"login_log_enabled"`
|
||||
DataExportEnabled bool `json:"data_export_enabled"`
|
||||
DataImportEnabled bool `json:"data_import_enabled"`
|
||||
}
|
||||
|
||||
// SettingsService 系统设置服务
|
||||
type SettingsService struct{}
|
||||
|
||||
// NewSettingsService 创建系统设置服务
|
||||
func NewSettingsService() *SettingsService {
|
||||
return &SettingsService{}
|
||||
}
|
||||
|
||||
// GetSettings 获取系统设置
|
||||
func (s *SettingsService) GetSettings(ctx context.Context) (*SystemSettings, error) {
|
||||
return &SystemSettings{
|
||||
System: SystemInfo{
|
||||
Name: "用户管理系统",
|
||||
Version: "1.0.0",
|
||||
Environment: "Production",
|
||||
Description: "基于 Go + React 的现代化用户管理系统",
|
||||
},
|
||||
Security: SecurityInfo{
|
||||
PasswordMinLength: 8,
|
||||
PasswordRequireUppercase: true,
|
||||
PasswordRequireLowercase: true,
|
||||
PasswordRequireNumbers: true,
|
||||
PasswordRequireSymbols: true,
|
||||
PasswordHistory: 5,
|
||||
TOTPEnabled: true,
|
||||
LoginFailLock: true,
|
||||
LoginFailThreshold: 5,
|
||||
LoginFailDuration: 30,
|
||||
SessionTimeout: 86400, // 1天
|
||||
DeviceTrustDuration: 2592000, // 30天
|
||||
},
|
||||
Features: FeaturesInfo{
|
||||
EmailVerification: true,
|
||||
PhoneVerification: false,
|
||||
OAuthProviders: []string{"GitHub", "Google"},
|
||||
SSOEnabled: false,
|
||||
OperationLogEnabled: true,
|
||||
LoginLogEnabled: true,
|
||||
DataExportEnabled: true,
|
||||
DataImportEnabled: true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
308
internal/service/settings_test.go
Normal file
308
internal/service/settings_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/api/handler"
|
||||
"github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/api/router"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// doRequest makes an HTTP request with optional body
|
||||
func doRequest(method, url string, token string, body interface{}) (*http.Response, string) {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
jsonBytes, _ := json.Marshal(body)
|
||||
bodyReader = bytes.NewReader(jsonBytes)
|
||||
}
|
||||
req, _ := http.NewRequest(method, url, bodyReader)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
client := &http.Client{}
|
||||
resp, _ := client.Do(req)
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return resp, string(bodyBytes)
|
||||
}
|
||||
|
||||
func doPost(url, token string, body interface{}) (*http.Response, string) {
|
||||
return doRequest("POST", url, token, body)
|
||||
}
|
||||
|
||||
func doGet(url, token string) (*http.Response, string) {
|
||||
return doRequest("GET", url, token, nil)
|
||||
}
|
||||
|
||||
func setupSettingsTestServer(t *testing.T) (*httptest.Server, *service.SettingsService, string, func()) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 使用内存 SQLite
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: "file::memory:?mode=memory&cache=shared",
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("skipping test (SQLite unavailable): %v", err)
|
||||
return nil, nil, "", func() {}
|
||||
}
|
||||
|
||||
// 自动迁移
|
||||
if err := db.AutoMigrate(
|
||||
&domain.User{},
|
||||
&domain.Role{},
|
||||
&domain.Permission{},
|
||||
&domain.UserRole{},
|
||||
&domain.RolePermission{},
|
||||
&domain.Device{},
|
||||
&domain.LoginLog{},
|
||||
&domain.OperationLog{},
|
||||
&domain.SocialAccount{},
|
||||
&domain.Webhook{},
|
||||
&domain.WebhookDelivery{},
|
||||
); err != nil {
|
||||
t.Fatalf("db migration failed: %v", err)
|
||||
}
|
||||
|
||||
// 创建 JWT Manager
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-settings-secret-key",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
// 创建缓存
|
||||
l1Cache := cache.NewL1Cache()
|
||||
l2Cache := cache.NewRedisCache(false)
|
||||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||
|
||||
// 创建 repositories
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
roleRepo := repository.NewRoleRepository(db)
|
||||
permissionRepo := repository.NewPermissionRepository(db)
|
||||
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||
rolePermissionRepo := repository.NewRolePermissionRepository(db)
|
||||
deviceRepo := repository.NewDeviceRepository(db)
|
||||
loginLogRepo := repository.NewLoginLogRepository(db)
|
||||
opLogRepo := repository.NewOperationLogRepository(db)
|
||||
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db)
|
||||
|
||||
// 创建 services
|
||||
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
userSvc := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo)
|
||||
roleSvc := service.NewRoleService(roleRepo, rolePermissionRepo)
|
||||
permSvc := service.NewPermissionService(permissionRepo)
|
||||
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
|
||||
loginLogSvc := service.NewLoginLogService(loginLogRepo)
|
||||
opLogSvc := service.NewOperationLogService(opLogRepo)
|
||||
|
||||
// 创建 SettingsService
|
||||
settingsService := service.NewSettingsService()
|
||||
|
||||
// 创建 middleware
|
||||
rateLimitCfg := config.RateLimitConfig{}
|
||||
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo, l1Cache,
|
||||
)
|
||||
authMiddleware.SetCacheManager(cacheManager)
|
||||
opLogMiddleware := middleware.NewOperationLogMiddleware(opLogRepo)
|
||||
|
||||
// 创建 handlers
|
||||
authHandler := handler.NewAuthHandler(authSvc)
|
||||
userHandler := handler.NewUserHandler(userSvc)
|
||||
roleHandler := handler.NewRoleHandler(roleSvc)
|
||||
permHandler := handler.NewPermissionHandler(permSvc)
|
||||
deviceHandler := handler.NewDeviceHandler(deviceSvc)
|
||||
logHandler := handler.NewLogHandler(loginLogSvc, opLogSvc)
|
||||
settingsHandler := handler.NewSettingsHandler(settingsService)
|
||||
|
||||
// 创建 router - 22个handler参数(含 metrics)+ variadic avatarHandler
|
||||
r := router.NewRouter(
|
||||
authHandler, userHandler, roleHandler, permHandler, deviceHandler,
|
||||
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil,
|
||||
settingsHandler, nil,
|
||||
)
|
||||
engine := r.Setup()
|
||||
|
||||
server := httptest.NewServer(engine)
|
||||
|
||||
// 注册用户用于测试
|
||||
resp, _ := doPost(server.URL+"/api/v1/auth/register", "", map[string]interface{}{
|
||||
"username": "admintestsu",
|
||||
"email": "admintestsu@test.com",
|
||||
"password": "Password123!",
|
||||
})
|
||||
resp.Body.Close()
|
||||
|
||||
// 获取 token
|
||||
loginResp, _ := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "admintestsu",
|
||||
"password": "Password123!",
|
||||
})
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(loginResp.Body).Decode(&result)
|
||||
loginResp.Body.Close()
|
||||
|
||||
token := ""
|
||||
if data, ok := result["data"].(map[string]interface{}); ok {
|
||||
token, _ = data["access_token"].(string)
|
||||
}
|
||||
|
||||
return server, settingsService, token, func() {
|
||||
server.Close()
|
||||
if sqlDB, _ := db.DB(); sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Settings API Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestGetSettings_Success(t *testing.T) {
|
||||
// 仅测试 service 层,不测试 HTTP API
|
||||
svc := service.NewSettingsService()
|
||||
settings, err := svc.GetSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSettings failed: %v", err)
|
||||
}
|
||||
|
||||
if settings.System.Name != "用户管理系统" {
|
||||
t.Errorf("expected system name '用户管理系统', got '%s'", settings.System.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettings_Unauthorized(t *testing.T) {
|
||||
server, _, _, cleanup := setupSettingsTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
req, _ := http.NewRequest("GET", server.URL+"/api/v1/admin/settings", nil)
|
||||
// 不设置 Authorization header
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 无 token 应该返回 401
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettings_ResponseStructure(t *testing.T) {
|
||||
// 仅测试 service 层数据结构
|
||||
svc := service.NewSettingsService()
|
||||
settings, err := svc.GetSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSettings failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证 system 字段
|
||||
if settings.System.Name == "" {
|
||||
t.Error("System.Name should not be empty")
|
||||
}
|
||||
if settings.System.Version == "" {
|
||||
t.Error("System.Version should not be empty")
|
||||
}
|
||||
if settings.System.Environment == "" {
|
||||
t.Error("System.Environment should not be empty")
|
||||
}
|
||||
|
||||
// 验证 security 字段
|
||||
if settings.Security.PasswordMinLength == 0 {
|
||||
t.Error("Security.PasswordMinLength should not be zero")
|
||||
}
|
||||
if !settings.Security.PasswordRequireUppercase {
|
||||
t.Error("Security.PasswordRequireUppercase should be true")
|
||||
}
|
||||
|
||||
// 验证 features 字段
|
||||
if !settings.Features.EmailVerification {
|
||||
t.Error("Features.EmailVerification should be true")
|
||||
}
|
||||
if len(settings.Features.OAuthProviders) == 0 {
|
||||
t.Error("Features.OAuthProviders should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SettingsService Unit Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestSettingsService_GetSettings(t *testing.T) {
|
||||
svc := service.NewSettingsService()
|
||||
|
||||
settings, err := svc.GetSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSettings failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证 system
|
||||
if settings.System.Name == "" {
|
||||
t.Error("System.Name should not be empty")
|
||||
}
|
||||
if settings.System.Version == "" {
|
||||
t.Error("System.Version should not be empty")
|
||||
}
|
||||
|
||||
// 验证 security defaults
|
||||
if settings.Security.PasswordMinLength != 8 {
|
||||
t.Errorf("PasswordMinLength: got %d, want 8", settings.Security.PasswordMinLength)
|
||||
}
|
||||
if !settings.Security.PasswordRequireUppercase {
|
||||
t.Error("PasswordRequireUppercase should be true")
|
||||
}
|
||||
if !settings.Security.PasswordRequireLowercase {
|
||||
t.Error("PasswordRequireLowercase should be true")
|
||||
}
|
||||
if !settings.Security.PasswordRequireNumbers {
|
||||
t.Error("PasswordRequireNumbers should be true")
|
||||
}
|
||||
if !settings.Security.PasswordRequireSymbols {
|
||||
t.Error("PasswordRequireSymbols should be true")
|
||||
}
|
||||
if settings.Security.PasswordHistory != 5 {
|
||||
t.Errorf("PasswordHistory: got %d, want 5", settings.Security.PasswordHistory)
|
||||
}
|
||||
|
||||
// 验证 features defaults
|
||||
if !settings.Features.EmailVerification {
|
||||
t.Error("EmailVerification should be true")
|
||||
}
|
||||
if settings.Features.DataExportEnabled != true {
|
||||
t.Error("DataExportEnabled should be true")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -357,7 +358,7 @@ func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code st
|
||||
}
|
||||
|
||||
stored, ok := val.(string)
|
||||
if !ok || stored != code {
|
||||
if !ok || subtle.ConstantTimeCompare([]byte(stored), []byte(code)) != 1 {
|
||||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"regexp"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
@@ -48,6 +49,11 @@ type UpdateThemeRequest struct {
|
||||
|
||||
// CreateTheme 创建主题
|
||||
func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest) (*domain.ThemeConfig, error) {
|
||||
// 安全检查:禁止在 CustomCSS/CustomJS 中包含危险模式
|
||||
if err := validateCustomCSSJS(req.CustomCSS, req.CustomJS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查主题名称是否已存在
|
||||
existing, err := s.themeRepo.GetByName(ctx, req.Name)
|
||||
if err == nil && existing != nil {
|
||||
@@ -84,6 +90,11 @@ func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest)
|
||||
|
||||
// UpdateTheme 更新主题
|
||||
func (s *ThemeService) UpdateTheme(ctx context.Context, id int64, req *UpdateThemeRequest) (*domain.ThemeConfig, error) {
|
||||
// 安全检查:禁止在 CustomCSS/CustomJS 中包含危险模式
|
||||
if err := validateCustomCSSJS(req.CustomCSS, req.CustomJS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
theme, err := s.themeRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, errors.New("主题不存在")
|
||||
@@ -204,3 +215,43 @@ func (s *ThemeService) clearDefaultThemes(ctx context.Context) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCustomCSSJS 检查 CustomCSS 和 CustomJS 是否包含危险 XSS 模式
|
||||
// 这不是完全净化,而是拒绝明显可造成 XSS 的模式
|
||||
func validateCustomCSSJS(css, js string) error {
|
||||
// 危险模式列表
|
||||
dangerousPatterns := []struct {
|
||||
pattern *regexp.Regexp
|
||||
message string
|
||||
}{
|
||||
// Script 标签
|
||||
{regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`), "CustomJS 禁止包含 <script> 标签"},
|
||||
{regexp.MustCompile(`(?i)javascript\s*:`), "CustomJS 禁止使用 javascript: 协议"},
|
||||
// 事件处理器
|
||||
{regexp.MustCompile(`(?i)on\w+\s*=`), "CustomJS 禁止使用事件处理器 (如 onerror, onclick)"},
|
||||
// Data URL
|
||||
{regexp.MustCompile(`(?i)data\s*:\s*text/html`), "禁止使用 data: URL 嵌入 HTML"},
|
||||
// CSS expression (IE)
|
||||
{regexp.MustCompile(`(?i)expression\s*\(`), "CustomCSS 禁止使用 CSS expression"},
|
||||
// CSS 中的 javascript
|
||||
{regexp.MustCompile(`(?i)url\s*\(\s*['"]?\s*javascript:`), "CustomCSS 禁止使用 javascript: URL"},
|
||||
// 嵌入的 <style> 标签
|
||||
{regexp.MustCompile(`(?i)<style[^>]*>.*?</style>`), "CustomCSS 禁止包含 <style> 标签"},
|
||||
}
|
||||
|
||||
// 检查 JS
|
||||
for _, p := range dangerousPatterns {
|
||||
if p.pattern.MatchString(js) {
|
||||
return errors.New(p.message)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 CSS
|
||||
for _, p := range dangerousPatterns {
|
||||
if p.pattern.MatchString(css) {
|
||||
return errors.New(p.message)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -80,11 +83,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = s.passwordHistoryRepo.Create(context.Background(), &domain.PasswordHistory{
|
||||
// 使用带超时的独立 context(不能使用请求 ctx,该 goroutine 在请求完成后仍可能运行)
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
|
||||
UserID: userID,
|
||||
PasswordHash: newHashedPassword,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(context.Background(), userID, passwordHistoryLimit)
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -127,6 +133,57 @@ func (s *UserService) List(ctx context.Context, offset, limit int) ([]*domain.Us
|
||||
return s.userRepo.List(ctx, offset, limit)
|
||||
}
|
||||
|
||||
// ListCursorRequest 用户游标分页请求
|
||||
type ListCursorRequest struct {
|
||||
Keyword string `form:"keyword"`
|
||||
Status int `form:"status"` // -1=全部
|
||||
RoleIDs []int64
|
||||
CreatedFrom *time.Time
|
||||
CreatedTo *time.Time
|
||||
SortBy string // created_at, last_login_time, username
|
||||
SortOrder string // asc, desc
|
||||
Cursor string `form:"cursor"`
|
||||
Size int `form:"size"`
|
||||
}
|
||||
|
||||
// ListCursor 游标分页获取用户列表(推荐使用)
|
||||
func (s *UserService) ListCursor(ctx context.Context, req *ListCursorRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
filter := &repository.AdvancedFilter{
|
||||
Keyword: req.Keyword,
|
||||
Status: req.Status,
|
||||
RoleIDs: req.RoleIDs,
|
||||
CreatedFrom: req.CreatedFrom,
|
||||
CreatedTo: req.CreatedTo,
|
||||
SortBy: req.SortBy,
|
||||
SortOrder: req.SortOrder,
|
||||
}
|
||||
|
||||
users, hasMore, err := s.userRepo.ListCursor(ctx, filter, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(users) > 0 {
|
||||
last := users[len(users)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: users,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新用户状态
|
||||
func (s *UserService) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
|
||||
return s.userRepo.UpdateStatus(ctx, id, status)
|
||||
|
||||
Reference in New Issue
Block a user