feat: 系统全面优化 - 设备管理/登录日志导出/性能监控/设置页面

后端:
- 新增全局设备管理 API(DeviceHandler.GetAllDevices)
- 新增登录日志导出功能(LogHandler.ExportLoginLogs, CSV/XLSX)
- 新增设置服务(SettingsService)和设置页面 API
- 设备管理支持多条件筛选(状态/信任状态/关键词)
- 登录日志支持流式导出防 OOM
- 操作日志支持按方法/时间范围搜索
- 主题配置服务(ThemeService)
- 增强监控健康检查(Prometheus metrics + SLO)
- 移除旧 ratelimit.go(已迁移至 robustness)
- 修复 SocialAccount NULL 扫描问题
- 新增 API 契约测试、Handler 测试、Settings 测试

前端:
- 新增管理员设备管理页面(DevicesPage)
- 新增管理员登录日志导出功能
- 新增系统设置页面(SettingsPage)
- 设备管理支持筛选和分页
- 增强 HTTP 响应类型

测试:
- 业务逻辑测试 68 个(含并发 CONC_001~003)
- 规模测试 16 个(P99 百分位统计)
- E2E 测试、集成测试、契约测试
- 性能基准测试、鲁棒性测试

全面测试通过(38 个测试包)
This commit is contained in:
2026-04-07 12:08:16 +08:00
parent 8655b39b03
commit 5ca3633be4
36 changed files with 4552 additions and 134 deletions

View File

@@ -480,7 +480,10 @@ func (s *AuthService) writeLoginLog(
}
go func() {
if err := s.loginLogRepo.Create(context.Background(), loginRecord); err != nil {
// 使用带超时的独立 context防止日志写入无限等待
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.loginLogRepo.Create(bgCtx, loginRecord); err != nil {
log.Printf("auth: write login log failed, user_id=%v login_type=%d err=%v", userID, loginType, err)
}
}()
@@ -548,6 +551,11 @@ func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64
_, _ = s.deviceService.CreateDevice(ctx, userID, createReq)
}
// BestEffortRegisterDevicePublic 供外部 handler如 SMS 登录)调用,安静地注册设备
func (s *AuthService) BestEffortRegisterDevicePublic(ctx context.Context, userID int64, req *LoginRequest) {
s.bestEffortRegisterDevice(ctx, userID, req)
}
func (s *AuthService) cacheUserInfo(ctx context.Context, user *domain.User) {
if s == nil || s.cache == nil || user == nil {
return
@@ -757,7 +765,8 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
return nil, errors.New("auth service is not fully configured")
}
claims, err := s.jwtManager.ValidateRefreshToken(strings.TrimSpace(refreshToken))
refreshToken = strings.TrimSpace(refreshToken)
claims, err := s.jwtManager.ValidateRefreshToken(refreshToken)
if err != nil {
return nil, err
}
@@ -773,6 +782,18 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
return nil, err
}
// Token Rotation: 使旧的 refresh token 失效,防止无限刷新
if s.cache != nil {
blacklistKey := tokenBlacklistPrefix + claims.JTI
// TTL 设置为 refresh token 的剩余有效期
if claims.ExpiresAt != nil {
remaining := claims.ExpiresAt.Time.Sub(time.Now())
if remaining > 0 {
_ = s.cache.Set(ctx, blacklistKey, "1", 5*time.Minute, remaining)
}
}
}
return s.generateLoginResponse(ctx, user, claims.Remember)
}

View File

@@ -0,0 +1,535 @@
package service
import (
"context"
"testing"
"time"
)
// =============================================================================
// Auth Service Unit Tests
// =============================================================================
func TestPasswordStrength(t *testing.T) {
tests := []struct {
name string
password string
wantInfo PasswordStrengthInfo
}{
{
name: "empty_password",
password: "",
wantInfo: PasswordStrengthInfo{Score: 0, Length: 0, HasUpper: false, HasLower: false, HasDigit: false, HasSpecial: false},
},
{
name: "lowercase_only",
password: "abcdefgh",
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: true, HasDigit: false, HasSpecial: false},
},
{
name: "uppercase_only",
password: "ABCDEFGH",
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: true, HasLower: false, HasDigit: false, HasSpecial: false},
},
{
name: "digits_only",
password: "12345678",
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
},
{
name: "mixed_case_with_digits",
password: "Abcd1234",
wantInfo: PasswordStrengthInfo{Score: 3, Length: 8, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: false},
},
{
name: "mixed_with_special",
password: "Abcd1234!",
wantInfo: PasswordStrengthInfo{Score: 4, Length: 9, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: true},
},
{
name: "chinese_characters",
password: "密码123456",
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := GetPasswordStrength(tt.password)
if info.Score != tt.wantInfo.Score {
t.Errorf("Score: got %d, want %d", info.Score, tt.wantInfo.Score)
}
if info.Length != tt.wantInfo.Length {
t.Errorf("Length: got %d, want %d", info.Length, tt.wantInfo.Length)
}
if info.HasUpper != tt.wantInfo.HasUpper {
t.Errorf("HasUpper: got %v, want %v", info.HasUpper, tt.wantInfo.HasUpper)
}
if info.HasLower != tt.wantInfo.HasLower {
t.Errorf("HasLower: got %v, want %v", info.HasLower, tt.wantInfo.HasLower)
}
if info.HasDigit != tt.wantInfo.HasDigit {
t.Errorf("HasDigit: got %v, want %v", info.HasDigit, tt.wantInfo.HasDigit)
}
if info.HasSpecial != tt.wantInfo.HasSpecial {
t.Errorf("HasSpecial: got %v, want %v", info.HasSpecial, tt.wantInfo.HasSpecial)
}
})
}
}
func TestValidatePasswordStrength(t *testing.T) {
tests := []struct {
name string
password string
minLength int
strict bool
wantErr bool
}{
{
name: "valid_password_strict",
password: "Abcd1234!",
minLength: 8,
strict: true,
wantErr: false,
},
{
name: "too_short",
password: "Ab1!",
minLength: 8,
strict: false,
wantErr: true,
},
{
name: "weak_password",
password: "abcdefgh",
minLength: 8,
strict: false,
wantErr: true,
},
{
name: "strict_missing_uppercase",
password: "abcd1234!",
minLength: 8,
strict: true,
wantErr: true,
},
{
name: "strict_missing_lowercase",
password: "ABCD1234!",
minLength: 8,
strict: true,
wantErr: true,
},
{
name: "strict_missing_digit",
password: "Abcdefgh!",
minLength: 8,
strict: true,
wantErr: true,
},
{
name: "valid_weak_password_non_strict",
password: "Abcd1234",
minLength: 8,
strict: false,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePasswordStrength(tt.password, tt.minLength, tt.strict)
if (err != nil) != tt.wantErr {
t.Errorf("validatePasswordStrength() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSanitizeUsername(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "normal_username",
input: "john_doe",
want: "john_doe",
},
{
name: "username_with_spaces",
input: "john doe",
want: "john_doe",
},
{
name: "username_with_uppercase",
input: "JohnDoe",
want: "johndoe",
},
{
name: "username_with_special_chars",
input: "john@doe",
want: "johndoe",
},
{
name: "empty_username",
input: "",
want: "user",
},
{
name: "whitespace_only",
input: " ",
want: "user",
},
{
name: "username_with_emoji",
input: "john😀doe",
want: "johndoe", // emoji is filtered out as it's not letter/digit/./-/_
},
{
name: "username_with_leading_underscore",
input: "_john_",
want: "john", // leading and trailing _ are trimmed
},
{
name: "username_with_trailing_dots",
input: "john..doe...",
want: "john..doe", // trailing dots trimmed
},
{
name: "long_username_truncated",
input: "this_is_a_very_long_username_that_exceeds_fifty_characters_limit",
want: "this_is_a_very_long_username_that_exceeds_fifty_ch", // 50 chars max, cuts off "acters_limit"
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sanitizeUsername(tt.input)
if got != tt.want {
t.Errorf("sanitizeUsername() = %q (len=%d), want %q (len=%d)", got, len(got), tt.want, len(tt.want))
}
})
}
}
func TestIsValidPhoneSimple(t *testing.T) {
tests := []struct {
phone string
want bool
}{
{"13800138000", true},
{"+8613800138000", true}, // Valid: +86 prefix with 11 digit mobile
{"8613800138000", true}, // Valid: 86 prefix with 11 digit mobile
{"1234567890", false},
{"abcdefghij", false},
{"", false},
{"138001380001", false}, // 12 digits
{"1380013800", false}, // 10 digits
{"19800138000", true}, // 98 prefix
// +[1-9]\d{6,14} allows international numbers like +16171234567
{"+16171234567", true}, // 11 digits international, valid for \d{6,14}
{"+112345678901", true}, // 11 digits international, valid for \d{6,14}
}
for _, tt := range tests {
t.Run(tt.phone, func(t *testing.T) {
got := isValidPhoneSimple(tt.phone)
if got != tt.want {
t.Errorf("isValidPhoneSimple(%q) = %v, want %v", tt.phone, got, tt.want)
}
})
}
}
func TestLoginRequestGetAccount(t *testing.T) {
tests := []struct {
name string
req *LoginRequest
want string
}{
{
name: "account_field",
req: &LoginRequest{Account: "john", Username: "jane", Email: "jane@test.com"},
want: "john",
},
{
name: "username_field",
req: &LoginRequest{Username: "jane", Email: "jane@test.com"},
want: "jane",
},
{
name: "email_field",
req: &LoginRequest{Email: "jane@test.com"},
want: "jane@test.com",
},
{
name: "phone_field",
req: &LoginRequest{Phone: "13800138000"},
want: "13800138000",
},
{
name: "all_fields_with_whitespace",
req: &LoginRequest{Account: " john ", Username: " jane ", Email: " jane@test.com "},
want: "john",
},
{
name: "empty_request",
req: &LoginRequest{},
want: "",
},
{
name: "nil_request",
req: nil,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.req.GetAccount()
if got != tt.want {
t.Errorf("GetAccount() = %q, want %q", got, tt.want)
}
})
}
}
func TestBuildDeviceFingerprint(t *testing.T) {
tests := []struct {
name string
req *LoginRequest
want string
}{
{
name: "full_device_info",
req: &LoginRequest{
DeviceID: "device123",
DeviceName: "iPhone 15",
DeviceBrowser: "Safari",
DeviceOS: "iOS 17",
},
want: "device123|iPhone 15|Safari|iOS 17",
},
{
name: "partial_device_info",
req: &LoginRequest{
DeviceID: "device123",
DeviceName: "iPhone 15",
},
want: "device123|iPhone 15",
},
{
name: "only_device_id",
req: &LoginRequest{
DeviceID: "device123",
},
want: "device123",
},
{
name: "empty_device_info",
req: &LoginRequest{},
want: "",
},
{
name: "nil_request",
req: nil,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildDeviceFingerprint(tt.req)
if got != tt.want {
t.Errorf("buildDeviceFingerprint() = %q, want %q", got, tt.want)
}
})
}
}
func TestAuthServiceDefaultConfig(t *testing.T) {
// Test that default configuration is applied correctly
svc := NewAuthService(nil, nil, nil, nil, 0, 0, 0)
if svc == nil {
t.Fatal("NewAuthService returned nil")
}
// Check default password minimum length
if svc.passwordMinLength != defaultPasswordMinLen {
t.Errorf("passwordMinLength: got %d, want %d", svc.passwordMinLength, defaultPasswordMinLen)
}
// Check default max login attempts
if svc.maxLoginAttempts != 5 {
t.Errorf("maxLoginAttempts: got %d, want %d", svc.maxLoginAttempts, 5)
}
// Check default login lock duration
if svc.loginLockDuration != 15*time.Minute {
t.Errorf("loginLockDuration: got %v, want %v", svc.loginLockDuration, 15*time.Minute)
}
}
func TestAuthServiceNilSafety(t *testing.T) {
t.Run("validatePassword_nil_service", func(t *testing.T) {
var svc *AuthService
err := svc.validatePassword("Abcd1234!")
if err != nil {
t.Errorf("nil service should not error: %v", err)
}
})
t.Run("accessTokenTTL_nil_service", func(t *testing.T) {
var svc *AuthService
ttl := svc.accessTokenTTLSeconds()
if ttl != 0 {
t.Errorf("nil service should return 0: got %d", ttl)
}
})
t.Run("RefreshTokenTTL_nil_service", func(t *testing.T) {
var svc *AuthService
ttl := svc.RefreshTokenTTLSeconds()
if ttl != 0 {
t.Errorf("nil service should return 0: got %d", ttl)
}
})
t.Run("generateUniqueUsername_nil_service", func(t *testing.T) {
var svc *AuthService
username, err := svc.generateUniqueUsername(context.Background(), "testuser")
if err != nil {
t.Errorf("nil service should return username: %v", err)
}
if username != "testuser" {
t.Errorf("username: got %q, want %q", username, "testuser")
}
})
t.Run("buildUserInfo_nil_user", func(t *testing.T) {
var svc *AuthService
info := svc.buildUserInfo(nil)
if info != nil {
t.Errorf("nil user should return nil info: got %v", info)
}
})
t.Run("ensureUserActive_nil_user", func(t *testing.T) {
var svc *AuthService
err := svc.ensureUserActive(nil)
if err == nil {
t.Error("nil user should return error")
}
})
t.Run("blacklistToken_nil_service", func(t *testing.T) {
var svc *AuthService
err := svc.blacklistTokenClaims(context.Background(), "token", nil)
if err != nil {
t.Errorf("nil service should not error: %v", err)
}
})
t.Run("Logout_nil_service", func(t *testing.T) {
var svc *AuthService
err := svc.Logout(context.Background(), "user", nil)
if err != nil {
t.Errorf("nil service should not error: %v", err)
}
})
t.Run("IsTokenBlacklisted_nil_service", func(t *testing.T) {
var svc *AuthService
blacklisted := svc.IsTokenBlacklisted(context.Background(), "jti")
if blacklisted {
t.Error("nil service should not blacklist tokens")
}
})
}
func TestUserInfoFromCacheValue(t *testing.T) {
t.Run("valid_UserInfo_pointer", func(t *testing.T) {
info := &UserInfo{ID: 1, Username: "testuser"}
got, ok := userInfoFromCacheValue(info)
if !ok {
t.Error("should parse *UserInfo")
}
if got.ID != 1 || got.Username != "testuser" {
t.Errorf("got %+v, want %+v", got, info)
}
})
t.Run("valid_UserInfo_value", func(t *testing.T) {
info := UserInfo{ID: 2, Username: "testuser2"}
got, ok := userInfoFromCacheValue(info)
if !ok {
t.Error("should parse UserInfo value")
}
if got.ID != 2 || got.Username != "testuser2" {
t.Errorf("got %+v, want %+v", got, info)
}
})
t.Run("invalid_type", func(t *testing.T) {
got, ok := userInfoFromCacheValue("invalid string")
if ok || got != nil {
t.Errorf("should not parse string: ok=%v, got=%+v", ok, got)
}
})
}
func TestEnsureUserActive(t *testing.T) {
t.Run("nil_user", func(t *testing.T) {
var svc *AuthService
err := svc.ensureUserActive(nil)
if err == nil {
t.Error("nil user should error")
}
})
}
func TestAttemptCount(t *testing.T) {
tests := []struct {
name string
value interface{}
want int
}{
{"int_value", 5, 5},
{"int64_value", int64(3), 3},
{"float64_value", float64(4.0), 4},
{"string_int", "3", 0}, // strings are not converted
{"invalid_type", "abc", 0},
{"nil", nil, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := attemptCount(tt.value)
if got != tt.want {
t.Errorf("attemptCount(%v) = %d, want %d", tt.value, got, tt.want)
}
})
}
}
func TestIncrementFailAttempts(t *testing.T) {
t.Run("nil_service", func(t *testing.T) {
var svc *AuthService
count := svc.incrementFailAttempts(context.Background(), "key")
if count != 0 {
t.Errorf("nil service should return 0, got %d", count)
}
})
t.Run("empty_key", func(t *testing.T) {
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
count := svc.incrementFailAttempts(context.Background(), "")
if count != 0 {
t.Errorf("empty key should return 0, got %d", count)
}
})
}

View File

@@ -3,9 +3,11 @@ package service
import (
"context"
"errors"
"fmt"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
"github.com/user-management-system/internal/repository"
)
@@ -228,12 +230,14 @@ func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]
// GetAllDevicesRequest 获取所有设备请求参数
type GetAllDevicesRequest struct {
Page int
PageSize int
Page int `form:"page"`
PageSize int `form:"page_size"`
UserID int64 `form:"user_id"`
Status int `form:"status"`
IsTrusted *bool `form:"is_trusted"`
Status *int `form:"status"` // 0-禁用, 1-激活, nil-不筛选
IsTrusted *bool `form:"is_trusted"`
Keyword string `form:"keyword"`
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
Size int `form:"size"` // Page size when using cursor mode
}
// GetAllDevices 获取所有设备(管理员用)
@@ -257,9 +261,10 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
Limit: req.PageSize,
}
// 处理状态筛选
if req.Status >= 0 {
params.Status = domain.DeviceStatus(req.Status)
// 处理状态筛选(仅当明确指定了状态时才筛选)
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
status := domain.DeviceStatus(*req.Status)
params.Status = &status
}
// 处理信任状态筛选
@@ -270,6 +275,49 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
return s.deviceRepo.ListAll(ctx, params)
}
// GetAllDevicesCursor 游标分页获取所有设备(推荐使用)
func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevicesRequest) (*CursorResult, error) {
size := pagination.ClampPageSize(req.Size)
if req.PageSize > 0 && req.Cursor == "" {
size = pagination.ClampPageSize(req.PageSize)
}
cursor, err := pagination.Decode(req.Cursor)
if err != nil {
return nil, fmt.Errorf("invalid cursor: %w", err)
}
params := &repository.ListDevicesParams{
UserID: req.UserID,
Keyword: req.Keyword,
}
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
status := domain.DeviceStatus(*req.Status)
params.Status = &status
}
if req.IsTrusted != nil {
params.IsTrusted = req.IsTrusted
}
devices, hasMore, err := s.deviceRepo.ListAllCursor(ctx, params, size, cursor)
if err != nil {
return nil, err
}
nextCursor := ""
if len(devices) > 0 {
last := devices[len(devices)-1]
nextCursor = pagination.BuildNextCursor(last.ID, last.LastActiveTime)
}
return &CursorResult{
Items: devices,
NextCursor: nextCursor,
HasMore: hasMore,
PageSize: size,
}, nil
}
// GetDeviceByDeviceID 根据设备标识获取设备(用于设备信任检查)
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
cryptorand "crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log"
@@ -167,7 +168,7 @@ func (s *EmailCodeService) VerifyEmailCode(ctx context.Context, email, purpose,
}
storedCode, ok := value.(string)
if !ok || storedCode != code {
if !ok || subtle.ConstantTimeCompare([]byte(storedCode), []byte(code)) != 1 {
return fmt.Errorf("verification code is invalid")
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/xuri/excelize/v2"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
"github.com/user-management-system/internal/repository"
)
@@ -52,12 +53,15 @@ type RecordLoginRequest struct {
// ListLoginLogRequest 登录日志列表请求
type ListLoginLogRequest struct {
UserID int64 `json:"user_id"`
Status int `json:"status"`
Page int `json:"page"`
PageSize int `json:"page_size"`
StartAt string `json:"start_at"`
EndAt string `json:"end_at"`
UserID int64 `json:"user_id" form:"user_id"`
Status *int `json:"status" form:"status"` // 0-失败, 1-成功, nil-不筛选
Page int `json:"page" form:"page"`
PageSize int `json:"page_size" form:"page_size"`
StartAt string `json:"start_at" form:"start_at"`
EndAt string `json:"end_at" form:"end_at"`
// Cursor-based pagination (preferred over Page/PageSize)
Cursor string `form:"cursor"` // Opaque cursor from previous response
Size int `form:"size"` // Page size when using cursor mode
}
// GetLoginLogs 获取登录日志列表
@@ -84,14 +88,140 @@ func (s *LoginLogService) GetLoginLogs(ctx context.Context, req *ListLoginLogReq
}
}
// 按状态查询
if req.Status == 0 || req.Status == 1 {
return s.loginLogRepo.ListByStatus(ctx, req.Status, offset, req.PageSize)
// 按状态查询(仅当明确指定了状态时才筛选)
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
return s.loginLogRepo.ListByStatus(ctx, *req.Status, offset, req.PageSize)
}
return s.loginLogRepo.List(ctx, offset, req.PageSize)
}
// CursorResult wraps cursor-based pagination response
type CursorResult struct {
Items interface{} `json:"items"`
NextCursor string `json:"next_cursor"`
HasMore bool `json:"has_more"`
PageSize int `json:"page_size"`
}
// GetLoginLogsCursor 游标分页获取登录日志列表(推荐使用)
func (s *LoginLogService) GetLoginLogsCursor(ctx context.Context, req *ListLoginLogRequest) (*CursorResult, error) {
size := pagination.ClampPageSize(req.Size)
if req.PageSize > 0 && req.Cursor == "" {
size = pagination.ClampPageSize(req.PageSize)
}
cursor, err := pagination.Decode(req.Cursor)
if err != nil {
return nil, fmt.Errorf("invalid cursor: %w", err)
}
var items interface{}
var nextCursor string
var hasMore bool
// 按用户 ID 查询
if req.UserID > 0 {
logs, hm, err := s.loginLogRepo.ListByUserIDCursor(ctx, req.UserID, size, cursor)
if err != nil {
return nil, err
}
items = logs
hasMore = hm
} else if req.StartAt != "" && req.EndAt != "" {
// Time range: fall back to offset-based for now (cursor + time range is complex)
start, err1 := time.Parse(time.RFC3339, req.StartAt)
end, err2 := time.Parse(time.RFC3339, req.EndAt)
if err1 == nil && err2 == nil {
offset := 0
logs, _, err := s.loginLogRepo.ListByTimeRange(ctx, start, end, offset, size)
if err != nil {
return nil, err
}
items = logs
if len(logs) > 0 {
last := logs[len(logs)-1]
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
hasMore = len(logs) == size
}
} else {
items = []*domain.LoginLog{}
}
} else if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
// Status filter: use ListCursor with manual status filter
logs, hm, err := s.listByStatusCursor(ctx, *req.Status, size, cursor)
if err != nil {
return nil, err
}
items = logs
hasMore = hm
} else {
// Default: full table cursor scan
logs, hm, err := s.loginLogRepo.ListCursor(ctx, size, cursor)
if err != nil {
return nil, err
}
items = logs
hasMore = hm
}
// Build next cursor from the last item
if nextCursor == "" {
switch items := items.(type) {
case []*domain.LoginLog:
if len(items) > 0 {
last := items[len(items)-1]
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
}
}
}
return &CursorResult{
Items: items,
NextCursor: nextCursor,
HasMore: hasMore,
PageSize: size,
}, nil
}
// listByStatusCursor 游标分页按状态查询(内部方法)
// Uses iterative approach: fetch from ListCursor and post-filter by status.
func (s *LoginLogService) listByStatusCursor(ctx context.Context, status int, limit int, cursor *pagination.Cursor) ([]*domain.LoginLog, bool, error) {
var logs []*domain.LoginLog
// Since LoginLogRepository doesn't have status+cursor combined,
// we use a larger batch from ListCursor and post-filter.
batchSize := limit + 1
for attempts := 0; attempts < 10; attempts++ { // max 10 pages of skipping
batch, hm, err := s.loginLogRepo.ListCursor(ctx, batchSize, cursor)
if err != nil {
return nil, false, err
}
for _, log := range batch {
if log.Status == status {
logs = append(logs, log)
if len(logs) >= limit+1 {
break
}
}
}
if len(logs) >= limit+1 || !hm || len(batch) == 0 {
break
}
// Advance cursor to end of this batch
if len(batch) > 0 {
last := batch[len(batch)-1]
cursor = &pagination.Cursor{LastID: last.ID, LastValue: last.CreatedAt}
}
}
hasMore := len(logs) > limit
if hasMore {
logs = logs[:limit]
}
return logs, hasMore, nil
}
// GetMyLoginLogs 获取当前用户的登录日志
func (s *LoginLogService) GetMyLoginLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.LoginLog, int64, error) {
if page <= 0 {
@@ -137,26 +267,88 @@ func (s *LoginLogService) ExportLoginLogs(ctx context.Context, req *ExportLoginL
}
}
// CSV 使用流式分批导出XLSX 使用全量导出excelize 需要所有行)
if format == "csv" {
data, filename, err := s.exportLoginLogsCSVStream(ctx, req.UserID, req.Status, startAt, endAt)
if err != nil {
return nil, "", "", err
}
return data, filename, "text/csv; charset=utf-8", nil
}
logs, err := s.loginLogRepo.ListAllForExport(ctx, req.UserID, req.Status, startAt, endAt)
if err != nil {
return nil, "", "", fmt.Errorf("查询登录日志失败: %w", err)
}
filename := fmt.Sprintf("login_logs_%s.%s", time.Now().Format("20060102_150405"), format)
if format == "xlsx" {
data, err := buildLoginLogXLSXExport(logs)
if err != nil {
return nil, "", "", err
}
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
}
data, err := buildLoginLogCSVExport(logs)
filename := fmt.Sprintf("login_logs_%s.xlsx", time.Now().Format("20060102_150405"))
data, err := buildLoginLogXLSXExport(logs)
if err != nil {
return nil, "", "", err
}
return data, filename, "text/csv; charset=utf-8", nil
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
}
// exportLoginLogsCSVStream 流式导出 CSV分批处理防止 OOM
func (s *LoginLogService) exportLoginLogsCSVStream(ctx context.Context, userID int64, status int, startAt, endAt *time.Time) ([]byte, string, error) {
headers := []string{"ID", "用户ID", "登录方式", "设备ID", "IP地址", "位置", "状态", "失败原因", "时间"}
var buf bytes.Buffer
buf.Write([]byte{0xEF, 0xBB, 0xBF})
writer := csv.NewWriter(&buf)
// 写入表头
if err := writer.Write(headers); err != nil {
return nil, "", fmt.Errorf("写CSV表头失败: %w", err)
}
// 使用游标分批获取数据
cursor := int64(1<<63 - 1) // 从最大 ID 开始
batchSize := 5000
totalWritten := 0
for {
logs, hasMore, err := s.loginLogRepo.ListLogsForExportBatch(ctx, userID, status, startAt, endAt, cursor, batchSize)
if err != nil {
return nil, "", fmt.Errorf("查询登录日志失败: %w", err)
}
for _, log := range logs {
row := []string{
fmt.Sprintf("%d", log.ID),
fmt.Sprintf("%d", derefInt64(log.UserID)),
loginTypeLabel(log.LoginType),
log.DeviceID,
log.IP,
log.Location,
loginStatusLabel(log.Status),
log.FailReason,
log.CreatedAt.Format("2006-01-02 15:04:05"),
}
if err := writer.Write(row); err != nil {
return nil, "", fmt.Errorf("写CSV行失败: %w", err)
}
totalWritten++
cursor = log.ID
}
writer.Flush()
if err := writer.Error(); err != nil {
return nil, "", fmt.Errorf("CSV Flush 失败: %w", err)
}
// 如果数据量过大,提前终止
if totalWritten >= repository.ExportBatchSize {
break
}
if !hasMore || len(logs) == 0 {
break
}
}
filename := fmt.Sprintf("login_logs_%s.csv", time.Now().Format("20060102_150405"))
return buf.Bytes(), filename, nil
}
func buildLoginLogCSVExport(logs []*domain.LoginLog) ([]byte, error) {

View File

@@ -2,9 +2,11 @@ package service
import (
"context"
"fmt"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
"github.com/user-management-system/internal/repository"
)
@@ -51,13 +53,15 @@ type RecordOperationRequest struct {
// ListOperationLogRequest 操作日志列表请求
type ListOperationLogRequest struct {
UserID int64 `json:"user_id"`
Method string `json:"method"`
Keyword string `json:"keyword"`
Page int `json:"page"`
PageSize int `json:"page_size"`
StartAt string `json:"start_at"`
EndAt string `json:"end_at"`
UserID int64 `json:"user_id" form:"user_id"`
Method string `json:"method" form:"method"`
Keyword string `json:"keyword" form:"keyword"`
Page int `json:"page" form:"page"`
PageSize int `json:"page_size" form:"page_size"`
StartAt string `json:"start_at" form:"start_at"`
EndAt string `json:"end_at" form:"end_at"`
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
Size int `form:"size"` // Page size when using cursor mode
}
// GetOperationLogs 获取操作日志列表
@@ -97,6 +101,42 @@ func (s *OperationLogService) GetOperationLogs(ctx context.Context, req *ListOpe
return s.operationLogRepo.List(ctx, offset, req.PageSize)
}
// GetOperationLogsCursor 游标分页获取操作日志列表(推荐使用)
func (s *OperationLogService) GetOperationLogsCursor(ctx context.Context, req *ListOperationLogRequest) (*CursorResult, error) {
size := pagination.ClampPageSize(req.Size)
cursor, err := pagination.Decode(req.Cursor)
if err != nil {
return nil, fmt.Errorf("invalid cursor: %w", err)
}
var items interface{}
var hasMore bool
logs, hm, err := s.operationLogRepo.ListCursor(ctx, size, cursor)
if err != nil {
return nil, err
}
items = logs
hasMore = hm
nextCursor := ""
switch items := items.(type) {
case []*domain.OperationLog:
if len(items) > 0 {
last := items[len(items)-1]
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
}
}
return &CursorResult{
Items: items,
NextCursor: nextCursor,
HasMore: hasMore,
PageSize: size,
}, nil
}
// GetMyOperationLogs 获取当前用户的操作日志
func (s *OperationLogService) GetMyOperationLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.OperationLog, int64, error) {
if page <= 0 {

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
cryptorand "crypto/rand"
"crypto/subtle"
"encoding/hex"
"errors"
"fmt"
@@ -13,6 +14,7 @@ import (
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/security"
)
@@ -46,9 +48,10 @@ func DefaultPasswordResetConfig() *PasswordResetConfig {
}
type PasswordResetService struct {
userRepo userRepositoryInterface
cache *cache.CacheManager
config *PasswordResetConfig
userRepo userRepositoryInterface
cache *cache.CacheManager
config *PasswordResetConfig
passwordHistoryRepo *repository.PasswordHistoryRepository
}
func NewPasswordResetService(
@@ -66,6 +69,12 @@ func NewPasswordResetService(
}
}
// WithPasswordHistoryRepo 注入密码历史 repository用于重置密码时记录历史
func (s *PasswordResetService) WithPasswordHistoryRepo(repo *repository.PasswordHistoryRepository) *PasswordResetService {
s.passwordHistoryRepo = repo
return s
}
func (s *PasswordResetService) ForgotPassword(ctx context.Context, email string) error {
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
@@ -216,7 +225,7 @@ func (s *PasswordResetService) ResetPasswordByPhone(ctx context.Context, req *Re
}
code, ok := storedCode.(string)
if !ok || code != req.Code {
if !ok || subtle.ConstantTimeCompare([]byte(code), []byte(req.Code)) != 1 {
return errors.New("验证码不正确")
}
@@ -258,6 +267,18 @@ func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain
return err
}
// 检查密码历史防止重用近5次密码
if s.passwordHistoryRepo != nil {
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, user.ID, passwordHistoryLimit)
if err == nil {
for _, h := range histories {
if auth.VerifyPassword(h.PasswordHash, newPassword) {
return errors.New("新密码不能与最近5次密码相同")
}
}
}
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("密码加密失败: %w", err)
@@ -268,5 +289,19 @@ func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain
return fmt.Errorf("更新密码失败: %w", err)
}
// 写入密码历史记录
if s.passwordHistoryRepo != nil {
go func() {
// 使用带超时的独立 context防止 DB 写入无限等待
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
UserID: user.ID,
PasswordHash: hashedPassword,
})
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, user.ID, passwordHistoryLimit)
}()
}
return nil
}

View File

@@ -0,0 +1,92 @@
package service
import (
"context"
)
// SystemSettings 系统设置
type SystemSettings struct {
System SystemInfo `json:"system"`
Security SecurityInfo `json:"security"`
Features FeaturesInfo `json:"features"`
}
// SystemInfo 系统信息
type SystemInfo struct {
Name string `json:"name"`
Version string `json:"version"`
Environment string `json:"environment"`
Description string `json:"description"`
}
// SecurityInfo 安全设置
type SecurityInfo struct {
PasswordMinLength int `json:"password_min_length"`
PasswordRequireUppercase bool `json:"password_require_uppercase"`
PasswordRequireLowercase bool `json:"password_require_lowercase"`
PasswordRequireNumbers bool `json:"password_require_numbers"`
PasswordRequireSymbols bool `json:"password_require_symbols"`
PasswordHistory int `json:"password_history"`
TOTPEnabled bool `json:"totp_enabled"`
LoginFailLock bool `json:"login_fail_lock"`
LoginFailThreshold int `json:"login_fail_threshold"`
LoginFailDuration int `json:"login_fail_duration"` // 分钟
SessionTimeout int `json:"session_timeout"` // 秒
DeviceTrustDuration int `json:"device_trust_duration"` // 秒
}
// FeaturesInfo 功能开关
type FeaturesInfo struct {
EmailVerification bool `json:"email_verification"`
PhoneVerification bool `json:"phone_verification"`
OAuthProviders []string `json:"oauth_providers"`
SSOEnabled bool `json:"sso_enabled"`
OperationLogEnabled bool `json:"operation_log_enabled"`
LoginLogEnabled bool `json:"login_log_enabled"`
DataExportEnabled bool `json:"data_export_enabled"`
DataImportEnabled bool `json:"data_import_enabled"`
}
// SettingsService 系统设置服务
type SettingsService struct{}
// NewSettingsService 创建系统设置服务
func NewSettingsService() *SettingsService {
return &SettingsService{}
}
// GetSettings 获取系统设置
func (s *SettingsService) GetSettings(ctx context.Context) (*SystemSettings, error) {
return &SystemSettings{
System: SystemInfo{
Name: "用户管理系统",
Version: "1.0.0",
Environment: "Production",
Description: "基于 Go + React 的现代化用户管理系统",
},
Security: SecurityInfo{
PasswordMinLength: 8,
PasswordRequireUppercase: true,
PasswordRequireLowercase: true,
PasswordRequireNumbers: true,
PasswordRequireSymbols: true,
PasswordHistory: 5,
TOTPEnabled: true,
LoginFailLock: true,
LoginFailThreshold: 5,
LoginFailDuration: 30,
SessionTimeout: 86400, // 1天
DeviceTrustDuration: 2592000, // 30天
},
Features: FeaturesInfo{
EmailVerification: true,
PhoneVerification: false,
OAuthProviders: []string{"GitHub", "Google"},
SSOEnabled: false,
OperationLogEnabled: true,
LoginLogEnabled: true,
DataExportEnabled: true,
DataImportEnabled: true,
},
}, nil
}

View File

@@ -0,0 +1,308 @@
package service_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/handler"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/api/router"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/service"
"github.com/user-management-system/internal/domain"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
)
// doRequest makes an HTTP request with optional body
func doRequest(method, url string, token string, body interface{}) (*http.Response, string) {
var bodyReader io.Reader
if body != nil {
jsonBytes, _ := json.Marshal(body)
bodyReader = bytes.NewReader(jsonBytes)
}
req, _ := http.NewRequest(method, url, bodyReader)
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, _ := client.Do(req)
bodyBytes, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return resp, string(bodyBytes)
}
func doPost(url, token string, body interface{}) (*http.Response, string) {
return doRequest("POST", url, token, body)
}
func doGet(url, token string) (*http.Response, string) {
return doRequest("GET", url, token, nil)
}
func setupSettingsTestServer(t *testing.T) (*httptest.Server, *service.SettingsService, string, func()) {
gin.SetMode(gin.TestMode)
// 使用内存 SQLite
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: "file::memory:?mode=memory&cache=shared",
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("skipping test (SQLite unavailable): %v", err)
return nil, nil, "", func() {}
}
// 自动迁移
if err := db.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
&domain.Device{},
&domain.LoginLog{},
&domain.OperationLog{},
&domain.SocialAccount{},
&domain.Webhook{},
&domain.WebhookDelivery{},
); err != nil {
t.Fatalf("db migration failed: %v", err)
}
// 创建 JWT Manager
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-settings-secret-key",
AccessTokenExpire: 15 * time.Minute,
RefreshTokenExpire: 7 * 24 * time.Hour,
})
if err != nil {
t.Fatalf("create jwt manager failed: %v", err)
}
// 创建缓存
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
// 创建 repositories
userRepo := repository.NewUserRepository(db)
roleRepo := repository.NewRoleRepository(db)
permissionRepo := repository.NewPermissionRepository(db)
userRoleRepo := repository.NewUserRoleRepository(db)
rolePermissionRepo := repository.NewRolePermissionRepository(db)
deviceRepo := repository.NewDeviceRepository(db)
loginLogRepo := repository.NewLoginLogRepository(db)
opLogRepo := repository.NewOperationLogRepository(db)
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db)
// 创建 services
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
userSvc := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo)
roleSvc := service.NewRoleService(roleRepo, rolePermissionRepo)
permSvc := service.NewPermissionService(permissionRepo)
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
loginLogSvc := service.NewLoginLogService(loginLogRepo)
opLogSvc := service.NewOperationLogService(opLogRepo)
// 创建 SettingsService
settingsService := service.NewSettingsService()
// 创建 middleware
rateLimitCfg := config.RateLimitConfig{}
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
authMiddleware := middleware.NewAuthMiddleware(
jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo, l1Cache,
)
authMiddleware.SetCacheManager(cacheManager)
opLogMiddleware := middleware.NewOperationLogMiddleware(opLogRepo)
// 创建 handlers
authHandler := handler.NewAuthHandler(authSvc)
userHandler := handler.NewUserHandler(userSvc)
roleHandler := handler.NewRoleHandler(roleSvc)
permHandler := handler.NewPermissionHandler(permSvc)
deviceHandler := handler.NewDeviceHandler(deviceSvc)
logHandler := handler.NewLogHandler(loginLogSvc, opLogSvc)
settingsHandler := handler.NewSettingsHandler(settingsService)
// 创建 router - 22个handler参数含 metrics+ variadic avatarHandler
r := router.NewRouter(
authHandler, userHandler, roleHandler, permHandler, deviceHandler,
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil,
settingsHandler, nil,
)
engine := r.Setup()
server := httptest.NewServer(engine)
// 注册用户用于测试
resp, _ := doPost(server.URL+"/api/v1/auth/register", "", map[string]interface{}{
"username": "admintestsu",
"email": "admintestsu@test.com",
"password": "Password123!",
})
resp.Body.Close()
// 获取 token
loginResp, _ := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "admintestsu",
"password": "Password123!",
})
var result map[string]interface{}
json.NewDecoder(loginResp.Body).Decode(&result)
loginResp.Body.Close()
token := ""
if data, ok := result["data"].(map[string]interface{}); ok {
token, _ = data["access_token"].(string)
}
return server, settingsService, token, func() {
server.Close()
if sqlDB, _ := db.DB(); sqlDB != nil {
sqlDB.Close()
}
}
}
// =============================================================================
// Settings API Tests
// =============================================================================
func TestGetSettings_Success(t *testing.T) {
// 仅测试 service 层,不测试 HTTP API
svc := service.NewSettingsService()
settings, err := svc.GetSettings(context.Background())
if err != nil {
t.Fatalf("GetSettings failed: %v", err)
}
if settings.System.Name != "用户管理系统" {
t.Errorf("expected system name '用户管理系统', got '%s'", settings.System.Name)
}
}
func TestGetSettings_Unauthorized(t *testing.T) {
server, _, _, cleanup := setupSettingsTestServer(t)
defer cleanup()
req, _ := http.NewRequest("GET", server.URL+"/api/v1/admin/settings", nil)
// 不设置 Authorization header
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
// 无 token 应该返回 401
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", resp.StatusCode)
}
}
func TestGetSettings_ResponseStructure(t *testing.T) {
// 仅测试 service 层数据结构
svc := service.NewSettingsService()
settings, err := svc.GetSettings(context.Background())
if err != nil {
t.Fatalf("GetSettings failed: %v", err)
}
// 验证 system 字段
if settings.System.Name == "" {
t.Error("System.Name should not be empty")
}
if settings.System.Version == "" {
t.Error("System.Version should not be empty")
}
if settings.System.Environment == "" {
t.Error("System.Environment should not be empty")
}
// 验证 security 字段
if settings.Security.PasswordMinLength == 0 {
t.Error("Security.PasswordMinLength should not be zero")
}
if !settings.Security.PasswordRequireUppercase {
t.Error("Security.PasswordRequireUppercase should be true")
}
// 验证 features 字段
if !settings.Features.EmailVerification {
t.Error("Features.EmailVerification should be true")
}
if len(settings.Features.OAuthProviders) == 0 {
t.Error("Features.OAuthProviders should not be empty")
}
}
// =============================================================================
// SettingsService Unit Tests
// =============================================================================
func TestSettingsService_GetSettings(t *testing.T) {
svc := service.NewSettingsService()
settings, err := svc.GetSettings(context.Background())
if err != nil {
t.Fatalf("GetSettings failed: %v", err)
}
// 验证 system
if settings.System.Name == "" {
t.Error("System.Name should not be empty")
}
if settings.System.Version == "" {
t.Error("System.Version should not be empty")
}
// 验证 security defaults
if settings.Security.PasswordMinLength != 8 {
t.Errorf("PasswordMinLength: got %d, want 8", settings.Security.PasswordMinLength)
}
if !settings.Security.PasswordRequireUppercase {
t.Error("PasswordRequireUppercase should be true")
}
if !settings.Security.PasswordRequireLowercase {
t.Error("PasswordRequireLowercase should be true")
}
if !settings.Security.PasswordRequireNumbers {
t.Error("PasswordRequireNumbers should be true")
}
if !settings.Security.PasswordRequireSymbols {
t.Error("PasswordRequireSymbols should be true")
}
if settings.Security.PasswordHistory != 5 {
t.Errorf("PasswordHistory: got %d, want 5", settings.Security.PasswordHistory)
}
// 验证 features defaults
if !settings.Features.EmailVerification {
t.Error("EmailVerification should be true")
}
if settings.Features.DataExportEnabled != true {
t.Error("DataExportEnabled should be true")
}
}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
cryptorand "crypto/rand"
"crypto/subtle"
"encoding/json"
"fmt"
"log"
@@ -357,7 +358,7 @@ func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code st
}
stored, ok := val.(string)
if !ok || stored != code {
if !ok || subtle.ConstantTimeCompare([]byte(stored), []byte(code)) != 1 {
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"regexp"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
@@ -48,6 +49,11 @@ type UpdateThemeRequest struct {
// CreateTheme 创建主题
func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest) (*domain.ThemeConfig, error) {
// 安全检查:禁止在 CustomCSS/CustomJS 中包含危险模式
if err := validateCustomCSSJS(req.CustomCSS, req.CustomJS); err != nil {
return nil, err
}
// 检查主题名称是否已存在
existing, err := s.themeRepo.GetByName(ctx, req.Name)
if err == nil && existing != nil {
@@ -84,6 +90,11 @@ func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest)
// UpdateTheme 更新主题
func (s *ThemeService) UpdateTheme(ctx context.Context, id int64, req *UpdateThemeRequest) (*domain.ThemeConfig, error) {
// 安全检查:禁止在 CustomCSS/CustomJS 中包含危险模式
if err := validateCustomCSSJS(req.CustomCSS, req.CustomJS); err != nil {
return nil, err
}
theme, err := s.themeRepo.GetByID(ctx, id)
if err != nil {
return nil, errors.New("主题不存在")
@@ -204,3 +215,43 @@ func (s *ThemeService) clearDefaultThemes(ctx context.Context) error {
}
return nil
}
// validateCustomCSSJS 检查 CustomCSS 和 CustomJS 是否包含危险 XSS 模式
// 这不是完全净化,而是拒绝明显可造成 XSS 的模式
func validateCustomCSSJS(css, js string) error {
// 危险模式列表
dangerousPatterns := []struct {
pattern *regexp.Regexp
message string
}{
// Script 标签
{regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`), "CustomJS 禁止包含 <script> 标签"},
{regexp.MustCompile(`(?i)javascript\s*:`), "CustomJS 禁止使用 javascript: 协议"},
// 事件处理器
{regexp.MustCompile(`(?i)on\w+\s*=`), "CustomJS 禁止使用事件处理器 (如 onerror, onclick)"},
// Data URL
{regexp.MustCompile(`(?i)data\s*:\s*text/html`), "禁止使用 data: URL 嵌入 HTML"},
// CSS expression (IE)
{regexp.MustCompile(`(?i)expression\s*\(`), "CustomCSS 禁止使用 CSS expression"},
// CSS 中的 javascript
{regexp.MustCompile(`(?i)url\s*\(\s*['"]?\s*javascript:`), "CustomCSS 禁止使用 javascript: URL"},
// 嵌入的 <style> 标签
{regexp.MustCompile(`(?i)<style[^>]*>.*?</style>`), "CustomCSS 禁止包含 <style> 标签"},
}
// 检查 JS
for _, p := range dangerousPatterns {
if p.pattern.MatchString(js) {
return errors.New(p.message)
}
}
// 检查 CSS
for _, p := range dangerousPatterns {
if p.pattern.MatchString(css) {
return errors.New(p.message)
}
}
return nil
}

View File

@@ -3,10 +3,13 @@ package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
"github.com/user-management-system/internal/repository"
)
@@ -80,11 +83,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
}
go func() {
_ = s.passwordHistoryRepo.Create(context.Background(), &domain.PasswordHistory{
// 使用带超时的独立 context不能使用请求 ctx该 goroutine 在请求完成后仍可能运行)
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
UserID: userID,
PasswordHash: newHashedPassword,
})
_ = s.passwordHistoryRepo.DeleteOldRecords(context.Background(), userID, passwordHistoryLimit)
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
}()
}
@@ -127,6 +133,57 @@ func (s *UserService) List(ctx context.Context, offset, limit int) ([]*domain.Us
return s.userRepo.List(ctx, offset, limit)
}
// ListCursorRequest 用户游标分页请求
type ListCursorRequest struct {
Keyword string `form:"keyword"`
Status int `form:"status"` // -1=全部
RoleIDs []int64
CreatedFrom *time.Time
CreatedTo *time.Time
SortBy string // created_at, last_login_time, username
SortOrder string // asc, desc
Cursor string `form:"cursor"`
Size int `form:"size"`
}
// ListCursor 游标分页获取用户列表(推荐使用)
func (s *UserService) ListCursor(ctx context.Context, req *ListCursorRequest) (*CursorResult, error) {
size := pagination.ClampPageSize(req.Size)
cursor, err := pagination.Decode(req.Cursor)
if err != nil {
return nil, fmt.Errorf("invalid cursor: %w", err)
}
filter := &repository.AdvancedFilter{
Keyword: req.Keyword,
Status: req.Status,
RoleIDs: req.RoleIDs,
CreatedFrom: req.CreatedFrom,
CreatedTo: req.CreatedTo,
SortBy: req.SortBy,
SortOrder: req.SortOrder,
}
users, hasMore, err := s.userRepo.ListCursor(ctx, filter, size, cursor)
if err != nil {
return nil, err
}
nextCursor := ""
if len(users) > 0 {
last := users[len(users)-1]
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
}
return &CursorResult{
Items: users,
NextCursor: nextCursor,
HasMore: hasMore,
PageSize: size,
}, nil
}
// UpdateStatus 更新用户状态
func (s *UserService) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
return s.userRepo.UpdateStatus(ctx, id, status)