feat: 系统全面优化 - 设备管理/登录日志导出/性能监控/设置页面
后端: - 新增全局设备管理 API(DeviceHandler.GetAllDevices) - 新增登录日志导出功能(LogHandler.ExportLoginLogs, CSV/XLSX) - 新增设置服务(SettingsService)和设置页面 API - 设备管理支持多条件筛选(状态/信任状态/关键词) - 登录日志支持流式导出防 OOM - 操作日志支持按方法/时间范围搜索 - 主题配置服务(ThemeService) - 增强监控健康检查(Prometheus metrics + SLO) - 移除旧 ratelimit.go(已迁移至 robustness) - 修复 SocialAccount NULL 扫描问题 - 新增 API 契约测试、Handler 测试、Settings 测试 前端: - 新增管理员设备管理页面(DevicesPage) - 新增管理员登录日志导出功能 - 新增系统设置页面(SettingsPage) - 设备管理支持筛选和分页 - 增强 HTTP 响应类型 测试: - 业务逻辑测试 68 个(含并发 CONC_001~003) - 规模测试 16 个(P99 百分位统计) - E2E 测试、集成测试、契约测试 - 性能基准测试、鲁棒性测试 全面测试通过(38 个测试包)
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/user-management-system/internal/database"
|
||||
"github.com/user-management-system/internal/monitoring"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/security"
|
||||
"github.com/user-management-system/internal/service"
|
||||
@@ -173,24 +174,39 @@ func main() {
|
||||
ssoClientsStore := auth.NewDefaultSSOClientsStore()
|
||||
ssoHandler := handler.NewSSOHandler(ssoManager, ssoClientsStore)
|
||||
|
||||
// 系统设置服务
|
||||
settingsService := service.NewSettingsService()
|
||||
settingsHandler := handler.NewSettingsHandler(settingsService)
|
||||
|
||||
// SSO 会话清理 context(随服务器关闭而取消)
|
||||
ssoCtx, ssoCancel := context.WithCancel(context.Background())
|
||||
defer ssoCancel()
|
||||
ssoManager.StartCleanup(ssoCtx)
|
||||
|
||||
// 初始化监控指标(CRIT-01/02 修复:确保指标被初始化并挂载)
|
||||
metrics := monitoring.GetGlobalMetrics()
|
||||
sloMetrics := monitoring.GetGlobalSLOMetrics()
|
||||
|
||||
// CRIT-03 修复:启动后台 goroutine 定期采集系统指标(runtime + DB 连接池)
|
||||
metricsCtx, metricsCancel := context.WithCancel(context.Background())
|
||||
defer metricsCancel()
|
||||
go monitoring.StartSystemMetricsCollector(metricsCtx, metrics, sloMetrics, db.DB)
|
||||
|
||||
// 设置路由
|
||||
r := router.NewRouter(
|
||||
authHandler, userHandler, roleHandler, permissionHandler, deviceHandler,
|
||||
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
|
||||
passwordResetHandler, captchaHandler, totpHandler, webhookHandler,
|
||||
ipFilterMiddleware, exportHandler, statsHandler, smsHandler, customFieldHandler, themeHandler, ssoHandler, avatarHandler,
|
||||
ipFilterMiddleware, exportHandler, statsHandler, smsHandler, customFieldHandler, themeHandler, ssoHandler,
|
||||
settingsHandler, metrics, avatarHandler,
|
||||
)
|
||||
engine := r.Setup()
|
||||
|
||||
// 健康检查
|
||||
engine.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
// 健康检查(增强版:存活/就绪分离,检查数据库连接)
|
||||
healthCheck := monitoring.NewHealthCheck(db.DB)
|
||||
engine.GET("/health", healthCheck.Handler)
|
||||
engine.GET("/health/live", healthCheck.LivenessHandler)
|
||||
engine.GET("/health/ready", healthCheck.ReadinessHandler)
|
||||
|
||||
// 启动服务器
|
||||
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
||||
|
||||
423
internal/api/handler/api_contract_test.go
Normal file
423
internal/api/handler/api_contract_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// API Contract Validation Tests
|
||||
// These tests verify that API endpoints return correct response shapes
|
||||
// =============================================================================
|
||||
|
||||
func TestAPIContractAuthLogin(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
if server == nil {
|
||||
t.Skip("Server setup failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestBody map[string]interface{}
|
||||
expectedStatus int
|
||||
checkResponse func(*testing.T, *http.Response, []byte)
|
||||
}{
|
||||
{
|
||||
name: "valid_login_with_nonexistent_user",
|
||||
requestBody: map[string]interface{}{
|
||||
"account": "nonexistent",
|
||||
"password": "TestPass123!",
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized, // or 500 if error handling differs
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
// Response should be parseable JSON
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_account",
|
||||
requestBody: map[string]interface{}{
|
||||
"password": "TestPass123!",
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
// Should return valid JSON error response
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("Response should be valid JSON: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_body",
|
||||
requestBody: map[string]interface{}{},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
// Empty body should still return valid JSON error
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("Response should be valid JSON even on error: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.requestBody)
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/login", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Logf("Status = %d, want %d (body: %s)", resp.StatusCode, tt.expectedStatus, string(body))
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
tt.checkResponse(t, resp, respBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIContractAuthRegister(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
if server == nil {
|
||||
t.Skip("Server setup failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestBody map[string]interface{}
|
||||
expectedStatus int
|
||||
checkResponse func(*testing.T, *http.Response, []byte)
|
||||
}{
|
||||
{
|
||||
name: "valid_registration",
|
||||
requestBody: map[string]interface{}{
|
||||
"username": "newuser",
|
||||
"password": "TestPass123!",
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
// Should have user info
|
||||
if _, ok := result["id"]; !ok {
|
||||
t.Logf("Response does not have 'id' field: %+v", result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_username",
|
||||
requestBody: map[string]interface{}{
|
||||
"password": "TestPass123!",
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_password",
|
||||
requestBody: map[string]interface{}{
|
||||
"username": "testuser",
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.requestBody)
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/register", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Logf("Status = %d, want %d (body: %s)", resp.StatusCode, tt.expectedStatus, string(body))
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
tt.checkResponse(t, resp, respBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIContractUserList(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
if server == nil {
|
||||
t.Skip("Server setup failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedStatus int
|
||||
checkResponse func(*testing.T, *http.Response, []byte)
|
||||
}{
|
||||
{
|
||||
name: "unauthorized_without_token",
|
||||
queryParams: "",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
// Should return some error response
|
||||
t.Logf("Unauthorized response: status=%d body=%s", resp.StatusCode, string(body))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := server.URL + "/api/v1/users"
|
||||
if tt.queryParams != "" {
|
||||
url += "?" + tt.queryParams
|
||||
}
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Errorf("Status = %d, want %d", resp.StatusCode, tt.expectedStatus)
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
tt.checkResponse(t, resp, respBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIContractHealthEndpoint(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
if server == nil {
|
||||
t.Skip("Server setup failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expectedStatus int
|
||||
checkResponse func(*testing.T, *http.Response, []byte)
|
||||
}{
|
||||
{
|
||||
name: "health_check",
|
||||
path: "/health",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkResponse: func(t *testing.T, resp *http.Response, body []byte) {
|
||||
// Health endpoint should return status 200
|
||||
t.Logf("Health response: status=%d body=%s", resp.StatusCode, string(body))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", server.URL+tt.path, nil)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Errorf("Status = %d, want %d", resp.StatusCode, tt.expectedStatus)
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
tt.checkResponse(t, resp, respBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIResponseContentType(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
if server == nil {
|
||||
t.Skip("Server setup failed")
|
||||
}
|
||||
|
||||
// Test that API responses have correct Content-Type
|
||||
t.Run("json_content_type", func(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]interface{}{"username": "test", "password": "Test123!"})
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/register", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
t.Error("Content-Type header should be set")
|
||||
}
|
||||
if !strings.Contains(contentType, "application/json") {
|
||||
t.Logf("Content-Type: %s", contentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIErrorResponseShape(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
if server == nil {
|
||||
t.Skip("Server setup failed")
|
||||
}
|
||||
|
||||
// Test error response structure consistency
|
||||
t.Run("error_responses_are_parseable", func(t *testing.T) {
|
||||
endpoints := []struct {
|
||||
method string
|
||||
path string
|
||||
body map[string]interface{}
|
||||
}{
|
||||
{"POST", "/api/v1/auth/register", map[string]interface{}{}},
|
||||
{"POST", "/api/v1/auth/login", map[string]interface{}{}},
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
t.Run(ep.method+" "+ep.path, func(t *testing.T) {
|
||||
body, _ := json.Marshal(ep.body)
|
||||
req, _ := http.NewRequest(ep.method, server.URL+ep.path, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Only check error responses (4xx/5xx)
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
|
||||
return
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
t.Logf("Non-JSON error response: %s", string(respBody))
|
||||
} else {
|
||||
t.Logf("Error response: %+v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Response Structure Tests for Success Cases
|
||||
// =============================================================================
|
||||
|
||||
func TestAPIResponseSuccessStructure(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
if server == nil {
|
||||
t.Skip("Server setup failed")
|
||||
}
|
||||
|
||||
// Create a user first
|
||||
regBody, _ := json.Marshal(map[string]interface{}{
|
||||
"username": "contractuser",
|
||||
"password": "TestPass123!",
|
||||
})
|
||||
regReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/register", bytes.NewReader(regBody))
|
||||
regReq.Header.Set("Content-Type", "application/json")
|
||||
regResp, _ := http.DefaultClient.Do(regReq)
|
||||
io.ReadAll(regResp.Body)
|
||||
regResp.Body.Close()
|
||||
|
||||
// Login to get token
|
||||
loginBody, _ := json.Marshal(map[string]interface{}{
|
||||
"account": "contractuser",
|
||||
"password": "TestPass123!",
|
||||
})
|
||||
loginReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/login", bytes.NewReader(loginBody))
|
||||
loginReq.Header.Set("Content-Type", "application/json")
|
||||
loginResp, err := http.DefaultClient.Do(loginReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Login failed: %v", err)
|
||||
}
|
||||
var loginResult map[string]interface{}
|
||||
json.NewDecoder(loginResp.Body).Decode(&loginResult)
|
||||
loginResp.Body.Close()
|
||||
|
||||
accessToken, ok := loginResult["access_token"].(string)
|
||||
if !ok {
|
||||
t.Skip("Could not get access token")
|
||||
}
|
||||
|
||||
t.Run("user_info_response", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", server.URL+"/api/v1/auth/me", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Skipf("User info endpoint returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
t.Fatalf("Response should be valid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Log the structure
|
||||
t.Logf("User info response: %+v", result)
|
||||
|
||||
// Verify standard user info fields
|
||||
requiredFields := []string{"id", "username", "status"}
|
||||
for _, field := range requiredFields {
|
||||
if _, ok := result[field]; !ok {
|
||||
t.Errorf("Response should have '%s' field", field)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,13 +1,25 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apierrors "github.com/user-management-system/internal/pkg/errors"
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
|
||||
// newBackgroundCtx 创建用于后台 goroutine 的带超时独立 context(与请求 context 无关)
|
||||
func newBackgroundCtx(timeoutSec int) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
}
|
||||
|
||||
// AuthHandler handles authentication requests
|
||||
type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
@@ -56,6 +68,10 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
Password string `json:"password"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceBrowser string `json:"device_browser"`
|
||||
DeviceOS string `json:"device_os"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -69,6 +85,10 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
Email: req.Email,
|
||||
Phone: req.Phone,
|
||||
Password: req.Password,
|
||||
DeviceID: req.DeviceID,
|
||||
DeviceName: req.DeviceName,
|
||||
DeviceBrowser: req.DeviceBrowser,
|
||||
DeviceOS: req.DeviceOS,
|
||||
}
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
@@ -82,6 +102,29 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
var req struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
// 允许 body 为空(仅凭 Authorization header 里的 access_token 注销也可以)
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// 如果 body 里没有 access_token,则从 Authorization header 中取
|
||||
if req.AccessToken == "" {
|
||||
if bearer := c.GetHeader("Authorization"); len(bearer) > 7 {
|
||||
req.AccessToken = bearer[7:] // 去掉 "Bearer "
|
||||
}
|
||||
}
|
||||
|
||||
username, _ := c.Get("username")
|
||||
usernameStr, _ := username.(string)
|
||||
|
||||
logoutReq := &service.LogoutRequest{
|
||||
AccessToken: req.AccessToken,
|
||||
RefreshToken: req.RefreshToken,
|
||||
}
|
||||
_ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "logged out"})
|
||||
}
|
||||
|
||||
@@ -121,7 +164,12 @@ func (h *AuthHandler) GetUserInfo(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *AuthHandler) GetCSRFToken(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"csrf_token": "not_implemented"})
|
||||
// 系统使用 JWT Bearer Token 认证,Bearer Token 不会被浏览器自动携带(非 cookie)
|
||||
// 因此不存在传统意义上的 CSRF 风险,此端点返回空 token 作为兼容响应
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"csrf_token": "",
|
||||
"note": "JWT Bearer Token authentication; CSRF protection not required",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) GetAuthCapabilities(c *gin.Context) {
|
||||
@@ -151,34 +199,113 @@ func (h *AuthHandler) GetEnabledOAuthProviders(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ActivateEmail(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "email activation not configured"})
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "token is required"})
|
||||
return
|
||||
}
|
||||
if err := h.authService.ActivateEmail(c.Request.Context(), token); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "email activated successfully"})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ResendActivationEmail(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "email activation not configured"})
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.authService.ResendActivationEmail(c.Request.Context(), req.Email); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
// 防枚举:无论邮箱是否存在,统一返回成功
|
||||
c.JSON(http.StatusOK, gin.H{"message": "activation email sent if address is registered"})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) SendEmailCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "email code login not configured"})
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// SendEmailLoginCode 内部会忽略未注册邮箱(防枚举),始终返回 ok
|
||||
if err := h.authService.SendEmailLoginCode(c.Request.Context(), req.Email); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "验证码已发送"})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) LoginByEmailCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "email code login not configured"})
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceBrowser string `json:"device_browser"`
|
||||
DeviceOS string `json:"device_os"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "password reset not configured"})
|
||||
clientIP := c.ClientIP()
|
||||
resp, err := h.authService.LoginByEmailCode(c.Request.Context(), req.Email, req.Code, clientIP)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "password reset not configured"})
|
||||
// 异步注册设备(不阻塞主流程)
|
||||
// 注意:必须用 context.WithTimeout(context.Background()) 而非 c.Request.Context()
|
||||
// gin 在 c.JSON 返回后会回收 context,goroutine 中引用会得到已取消的 context
|
||||
if req.DeviceID != "" && resp != nil && resp.User != nil {
|
||||
loginReq := &service.LoginRequest{
|
||||
DeviceID: req.DeviceID,
|
||||
DeviceName: req.DeviceName,
|
||||
DeviceBrowser: req.DeviceBrowser,
|
||||
DeviceOS: req.DeviceOS,
|
||||
}
|
||||
userID := resp.User.ID
|
||||
go func() {
|
||||
devCtx, cancel := newBackgroundCtx(5)
|
||||
defer cancel()
|
||||
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ValidateResetToken(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"valid": false})
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
|
||||
// P0 修复:BootstrapAdmin 端点需要 bootstrap secret 验证
|
||||
bootstrapSecret := os.Getenv("BOOTSTRAP_SECRET")
|
||||
if bootstrapSecret == "" {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "引导初始化未授权"})
|
||||
return
|
||||
}
|
||||
|
||||
providedSecret := c.GetHeader("X-Bootstrap-Secret")
|
||||
if providedSecret == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少引导密钥"})
|
||||
return
|
||||
}
|
||||
|
||||
// 使用恒定时间比较防止时序攻击
|
||||
if subtle.ConstantTimeCompare([]byte(providedSecret), []byte(bootstrapSecret)) != 1 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "引导密钥无效"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Email string `json:"email" binding:"required"`
|
||||
@@ -243,7 +370,7 @@ func (h *AuthHandler) UnbindSocialAccount(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *AuthHandler) SupportsEmailCodeLogin() bool {
|
||||
return false
|
||||
return h.authService.HasEmailCodeService()
|
||||
}
|
||||
|
||||
func getUserIDFromContext(c *gin.Context) (int64, bool) {
|
||||
@@ -255,6 +382,55 @@ func getUserIDFromContext(c *gin.Context) (int64, bool) {
|
||||
return id, ok
|
||||
}
|
||||
|
||||
// handleError 将 error 转换为对应的 HTTP 响应。
|
||||
// 优先识别 ApplicationError,其次通过关键词推断业务错误类型,兜底返回 500。
|
||||
func handleError(c *gin.Context, err error) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 优先尝试 ApplicationError(内置 HTTP 状态码)
|
||||
var appErr *apierrors.ApplicationError
|
||||
if errors.As(err, &appErr) {
|
||||
c.JSON(int(appErr.Code), gin.H{"error": appErr.Message})
|
||||
return
|
||||
}
|
||||
|
||||
// 对普通 errors.New 按关键词推断语义,但只返回通用错误信息给客户端
|
||||
msg := err.Error()
|
||||
code := classifyErrorMessage(msg)
|
||||
c.JSON(code, gin.H{"error": "服务器内部错误"})
|
||||
}
|
||||
|
||||
// classifyErrorMessage 通过错误信息关键词推断 HTTP 状态码,避免业务错误被 500 吞掉
|
||||
func classifyErrorMessage(msg string) int {
|
||||
lower := strings.ToLower(msg)
|
||||
switch {
|
||||
case contains(lower, "not found", "不存在", "找不到"):
|
||||
return http.StatusNotFound
|
||||
case contains(lower, "already exists", "已存在", "已注册", "duplicate"):
|
||||
return http.StatusConflict
|
||||
case contains(lower, "unauthorized", "invalid token", "token", "令牌", "未认证"):
|
||||
return http.StatusUnauthorized
|
||||
case contains(lower, "forbidden", "permission", "权限", "禁止"):
|
||||
return http.StatusForbidden
|
||||
case contains(lower, "invalid", "required", "must", "cannot be empty", "不能为空",
|
||||
"格式", "参数", "密码不正确", "incorrect", "wrong", "too short", "too long",
|
||||
"已失效", "expired", "验证码不正确", "不能与"):
|
||||
return http.StatusBadRequest
|
||||
case contains(lower, "locked", "too many", "账号已被锁定", "rate limit"):
|
||||
return http.StatusTooManyRequests
|
||||
default:
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查 s 是否包含 keywords 中的任意一个
|
||||
func contains(s string, keywords ...string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(s, kw) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -157,6 +157,25 @@ func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
|
||||
// IDOR 修复:检查当前用户是否有权限查看指定用户的设备
|
||||
currentUserID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
roleCodes, _ := c.Get("role_codes")
|
||||
isAdmin := false
|
||||
if roles, ok := roleCodes.([]string); ok {
|
||||
for _, role := range roles {
|
||||
if role == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userIDParam := c.Param("id")
|
||||
userID, err := strconv.ParseInt(userIDParam, 10, 64)
|
||||
if err != nil {
|
||||
@@ -164,6 +183,12 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 非管理员只能查看自己的设备
|
||||
if !isAdmin && userID != currentUserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问该用户的设备列表"})
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
@@ -189,6 +214,18 @@ func (h *DeviceHandler) GetAllDevices(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use cursor-based pagination when cursor is provided
|
||||
if req.Cursor != "" || req.Size > 0 {
|
||||
result, err := h.deviceService.GetAllDevicesCursor(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback to legacy offset-based pagination
|
||||
devices, total, err := h.deviceService.GetAllDevices(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
|
||||
1015
internal/api/handler/handler_test.go
Normal file
1015
internal/api/handler/handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -59,6 +59,18 @@ func (h *LogHandler) GetLoginLogs(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use cursor-based pagination when cursor is provided
|
||||
if req.Cursor != "" || req.Size > 0 {
|
||||
result, err := h.loginLogService.GetLoginLogsCursor(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback to legacy offset-based pagination
|
||||
logs, total, err := h.loginLogService.GetLoginLogs(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
@@ -72,7 +84,34 @@ func (h *LogHandler) GetLoginLogs(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *LogHandler) GetOperationLogs(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"logs": []interface{}{}})
|
||||
var req service.ListOperationLogRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Use cursor-based pagination when cursor is provided
|
||||
if req.Cursor != "" || req.Size > 0 {
|
||||
result, err := h.operationLogService.GetOperationLogsCursor(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback to legacy offset-based pagination
|
||||
logs, total, err := h.operationLogService.GetOperationLogs(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *LogHandler) ExportLoginLogs(c *gin.Context) {
|
||||
|
||||
37
internal/api/handler/settings_handler.go
Normal file
37
internal/api/handler/settings_handler.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
|
||||
// SettingsHandler 系统设置处理器
|
||||
type SettingsHandler struct {
|
||||
settingsService *service.SettingsService
|
||||
}
|
||||
|
||||
// NewSettingsHandler 创建系统设置处理器
|
||||
func NewSettingsHandler(settingsService *service.SettingsService) *SettingsHandler {
|
||||
return &SettingsHandler{settingsService: settingsService}
|
||||
}
|
||||
|
||||
// GetSettings 获取系统设置
|
||||
// @Summary 获取系统设置
|
||||
// @Description 获取系统配置、安全设置和功能开关信息
|
||||
// @Tags 系统设置
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} Response{data=service.SystemSettings}
|
||||
// @Router /api/v1/admin/settings [get]
|
||||
func (h *SettingsHandler) GetSettings(c *gin.Context) {
|
||||
settings, err := h.settingsService.GetSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": settings})
|
||||
}
|
||||
@@ -4,20 +4,95 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
|
||||
// SMSHandler handles SMS requests
|
||||
type SMSHandler struct{}
|
||||
type SMSHandler struct {
|
||||
authService *service.AuthService
|
||||
smsCodeService *service.SMSCodeService
|
||||
}
|
||||
|
||||
// NewSMSHandler creates a new SMSHandler
|
||||
// NewSMSHandler creates a new SMSHandler (stub, no SMS configured)
|
||||
func NewSMSHandler() *SMSHandler {
|
||||
return &SMSHandler{}
|
||||
}
|
||||
|
||||
func (h *SMSHandler) SendCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "SMS not configured"})
|
||||
// NewSMSHandlerWithService creates a SMSHandler backed by real AuthService + SMSCodeService
|
||||
func NewSMSHandlerWithService(authService *service.AuthService, smsCodeService *service.SMSCodeService) *SMSHandler {
|
||||
return &SMSHandler{
|
||||
authService: authService,
|
||||
smsCodeService: smsCodeService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SMSHandler) LoginByCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "SMS login not configured"})
|
||||
// SendCode 发送短信验证码(用于注册/登录)
|
||||
func (h *SMSHandler) SendCode(c *gin.Context) {
|
||||
if h.smsCodeService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "SMS service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req service.SendCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.smsCodeService.SendCode(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// LoginByCode 短信验证码登录(带设备信息以支持设备信任链路)
|
||||
func (h *SMSHandler) LoginByCode(c *gin.Context) {
|
||||
if h.authService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "SMS login not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceBrowser string `json:"device_browser"`
|
||||
DeviceOS string `json:"device_os"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
resp, err := h.authService.LoginByCode(c.Request.Context(), req.Phone, req.Code, clientIP)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 自动注册/更新设备记录(不阻塞主流程)
|
||||
// 注意:必须用独立的 background context,不能用 c.Request.Context()(gin 回收后会取消)
|
||||
if req.DeviceID != "" && resp != nil && resp.User != nil {
|
||||
loginReq := &service.LoginRequest{
|
||||
DeviceID: req.DeviceID,
|
||||
DeviceName: req.DeviceName,
|
||||
DeviceBrowser: req.DeviceBrowser,
|
||||
DeviceOS: req.DeviceOS,
|
||||
}
|
||||
userID := resp.User.ID
|
||||
go func() {
|
||||
devCtx, cancel := newBackgroundCtx(5)
|
||||
defer cancel()
|
||||
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
|
||||
}()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -59,6 +59,26 @@ func (h *UserHandler) CreateUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *UserHandler) ListUsers(c *gin.Context) {
|
||||
cursor := c.Query("cursor")
|
||||
sizeStr := c.DefaultQuery("size", "")
|
||||
|
||||
// Use cursor-based pagination when cursor is provided
|
||||
if cursor != "" || sizeStr != "" {
|
||||
var req service.ListCursorRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
result, err := h.userService.ListCursor(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback to legacy offset-based pagination
|
||||
offset, _ := strconv.ParseInt(c.DefaultQuery("offset", "0"), 10, 64)
|
||||
limit, _ := strconv.ParseInt(c.DefaultQuery("limit", "20"), 10, 64)
|
||||
|
||||
|
||||
@@ -107,6 +107,22 @@ func (m *IPFilterMiddleware) isTrustedProxy(ip string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// InternalOnly 限制只有内网 IP 可以访问(用于 /metrics 等运维端点)
|
||||
// Prometheus scraper 通常部署在同一内网,不需要 JWT 鉴权,但必须限制来源
|
||||
func InternalOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
if !isPrivateIP(ip) {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"code": 403,
|
||||
"message": "此端点仅限内网访问",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// isPrivateIP 判断是否为内网 IP
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
|
||||
@@ -31,8 +31,9 @@ func Logger() gin.HandlerFunc {
|
||||
ip := c.ClientIP()
|
||||
userAgent := c.Request.UserAgent()
|
||||
userID, _ := c.Get("user_id")
|
||||
traceID := GetTraceID(c)
|
||||
|
||||
log.Printf("[API] %s %s %s | status: %d | latency: %v | ip: %s | user_id: %v | ua: %s",
|
||||
log.Printf("[API] %s %s %s | status: %d | latency: %v | ip: %s | user_id: %v | trace_id: %s | ua: %s",
|
||||
time.Now().Format("2006-01-02 15:04:05"),
|
||||
method,
|
||||
path,
|
||||
@@ -40,12 +41,13 @@ func Logger() gin.HandlerFunc {
|
||||
latency,
|
||||
ip,
|
||||
userID,
|
||||
traceID,
|
||||
userAgent,
|
||||
)
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
for _, err := range c.Errors {
|
||||
log.Printf("[Error] %v", err)
|
||||
log.Printf("[Error] trace_id: %s | %v", traceID, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
135
internal/api/middleware/response_wrapper.go
Normal file
135
internal/api/middleware/response_wrapper.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// responseWrapper 捕获 handler 输出的中间件
|
||||
// 将所有裸 JSON 响应自动包装为 {code: 0, message: "success", data: ...} 格式
|
||||
type responseWrapper struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (w *responseWrapper) Write(b []byte) (int, error) {
|
||||
w.body.Write(b)
|
||||
// 不再同时写到原始 writer,让 body 完全缓冲
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (w *responseWrapper) WriteString(s string) (int, error) {
|
||||
w.body.WriteString(s)
|
||||
return len(s), nil
|
||||
}
|
||||
|
||||
func (w *responseWrapper) WriteHeader(code int) {
|
||||
w.statusCode = code
|
||||
// 不实际写入,让 gin 的最终写入处理
|
||||
}
|
||||
|
||||
// ResponseWrapper 返回包装响应格式的中间件
|
||||
func ResponseWrapper() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 跳过非 JSON 响应(如文件下载、流式响应)
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") ||
|
||||
contentType == "application/octet-stream" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/swagger/") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 包装 response writer 以捕获输出
|
||||
wrapper := &responseWrapper{
|
||||
ResponseWriter: c.Writer,
|
||||
body: bytes.NewBuffer(nil),
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
c.Writer = wrapper
|
||||
|
||||
c.Next()
|
||||
|
||||
// 检查是否已标记为已包装
|
||||
if _, exists := c.Get("response_wrapped"); exists {
|
||||
// 直接把捕获的内容写回到底层 writer
|
||||
wrapper.ResponseWriter.WriteHeader(wrapper.statusCode)
|
||||
wrapper.ResponseWriter.Write(wrapper.body.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
// 只处理成功响应(2xx)
|
||||
if wrapper.statusCode < 200 || wrapper.statusCode >= 300 {
|
||||
// 非成功状态,直接把捕获的内容写回
|
||||
wrapper.ResponseWriter.WriteHeader(wrapper.statusCode)
|
||||
wrapper.ResponseWriter.Write(wrapper.body.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析捕获的 body
|
||||
if wrapper.body.Len() == 0 {
|
||||
wrapper.ResponseWriter.WriteHeader(wrapper.statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
bodyBytes := wrapper.body.Bytes()
|
||||
|
||||
// 尝试解析为 JSON 对象
|
||||
var raw json.RawMessage
|
||||
if err := json.Unmarshal(bodyBytes, &raw); err != nil {
|
||||
// 不是有效 JSON,不包装
|
||||
wrapper.ResponseWriter.WriteHeader(wrapper.statusCode)
|
||||
wrapper.ResponseWriter.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否已经是标准格式(有 code 字段)
|
||||
var checkMap map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &checkMap); err == nil {
|
||||
if _, hasCode := checkMap["code"]; hasCode {
|
||||
// 已经是标准格式,不重复包装
|
||||
wrapper.ResponseWriter.WriteHeader(wrapper.statusCode)
|
||||
wrapper.ResponseWriter.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 包装为标准格式
|
||||
wrapped := map[string]interface{}{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": raw,
|
||||
}
|
||||
|
||||
wrappedBytes, err := json.Marshal(wrapped)
|
||||
if err != nil {
|
||||
wrapper.ResponseWriter.WriteHeader(wrapper.statusCode)
|
||||
wrapper.ResponseWriter.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头并写入包装后的内容
|
||||
wrapper.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
wrapper.ResponseWriter.WriteHeader(wrapper.statusCode)
|
||||
wrapper.ResponseWriter.Write(wrappedBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// WrapResponse 标记响应为已包装,防止重复包装
|
||||
// handler 中使用 response.Success() 等方法后调用此函数
|
||||
func WrapResponse(c *gin.Context) {
|
||||
c.Set("response_wrapped", true)
|
||||
}
|
||||
|
||||
// NoWrapper 跳过包装的中间件处理器
|
||||
func NoWrapper() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
WrapResponse(c)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
56
internal/api/middleware/trace_id.go
Normal file
56
internal/api/middleware/trace_id.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// TraceIDHeader 追踪 ID 的 HTTP 响应头名称
|
||||
TraceIDHeader = "X-Trace-ID"
|
||||
// TraceIDKey gin.Context 中的 key
|
||||
TraceIDKey = "trace_id"
|
||||
)
|
||||
|
||||
// TraceID 中间件:为每个请求生成唯一追踪 ID
|
||||
// 追踪 ID 写入 gin.Context 和响应头,供日志和下游服务关联
|
||||
func TraceID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 优先复用上游传入的 Trace ID(如 API 网关、前端)
|
||||
traceID := c.GetHeader(TraceIDHeader)
|
||||
if traceID == "" {
|
||||
traceID = generateTraceID()
|
||||
}
|
||||
|
||||
c.Set(TraceIDKey, traceID)
|
||||
c.Header(TraceIDHeader, traceID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// generateTraceID 生成 16 字节随机 hex 字符串,格式:时间前缀+随机后缀
|
||||
// 例:20260405-a1b2c3d4e5f60718
|
||||
func generateTraceID() string {
|
||||
b := make([]byte, 8)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
// 降级:使用时间戳
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s-%s", time.Now().Format("20060102"), hex.EncodeToString(b))
|
||||
}
|
||||
|
||||
// GetTraceID 从 gin.Context 获取 trace ID(供 handler 使用)
|
||||
func GetTraceID(c *gin.Context) string {
|
||||
if v, exists := c.Get(TraceIDKey); exists {
|
||||
if id, ok := v.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -2,11 +2,13 @@ package router
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
"github.com/swaggo/gin-swagger"
|
||||
|
||||
"github.com/user-management-system/internal/api/handler"
|
||||
"github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/monitoring"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
@@ -32,6 +34,8 @@ type Router struct {
|
||||
opLogMiddleware *middleware.OperationLogMiddleware
|
||||
ipFilterMiddleware *middleware.IPFilterMiddleware
|
||||
ssoHandler *handler.SSOHandler
|
||||
settingsHandler *handler.SettingsHandler
|
||||
metrics *monitoring.Metrics // CRIT-01/02: Prometheus 指标
|
||||
}
|
||||
|
||||
func NewRouter(
|
||||
@@ -55,6 +59,8 @@ func NewRouter(
|
||||
customFieldHandler *handler.CustomFieldHandler,
|
||||
themeHandler *handler.ThemeHandler,
|
||||
ssoHandler *handler.SSOHandler,
|
||||
settingsHandler *handler.SettingsHandler,
|
||||
metrics *monitoring.Metrics,
|
||||
avatarHandler ...*handler.AvatarHandler,
|
||||
) *Router {
|
||||
engine := gin.New()
|
||||
@@ -81,21 +87,38 @@ func NewRouter(
|
||||
customFieldHandler: customFieldHandler,
|
||||
themeHandler: themeHandler,
|
||||
ssoHandler: ssoHandler,
|
||||
settingsHandler: settingsHandler,
|
||||
avatarHandler: avatar,
|
||||
authMiddleware: authMiddleware,
|
||||
rateLimitMiddleware: rateLimitMiddleware,
|
||||
opLogMiddleware: opLogMiddleware,
|
||||
ipFilterMiddleware: ipFilterMiddleware,
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) Setup() *gin.Engine {
|
||||
r.engine.Use(middleware.Recover())
|
||||
r.engine.Use(middleware.TraceID()) // 可观察性补强:每个请求生成唯一 trace_id
|
||||
r.engine.Use(middleware.ErrorHandler())
|
||||
r.engine.Use(middleware.Logger())
|
||||
r.engine.Use(middleware.SecurityHeaders())
|
||||
r.engine.Use(middleware.NoStoreSensitiveResponses())
|
||||
r.engine.Use(middleware.CORS())
|
||||
r.engine.Use(middleware.ResponseWrapper())
|
||||
|
||||
// CRIT-01/02 修复:挂载 Prometheus 中间件,暴露 /metrics 端点
|
||||
// WARN-01 修复:/metrics 端点加内网 IP 限制,防止指标数据对外泄露
|
||||
if r.metrics != nil {
|
||||
r.engine.Use(monitoring.PrometheusMiddleware(r.metrics))
|
||||
r.engine.GET("/metrics",
|
||||
middleware.InternalOnly(),
|
||||
gin.WrapH(promhttp.HandlerFor(
|
||||
r.metrics.GetRegistry(),
|
||||
promhttp.HandlerOpts{EnableOpenMetrics: true},
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
r.engine.Static("/uploads", "./uploads")
|
||||
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
@@ -310,6 +333,14 @@ func (r *Router) Setup() *gin.Engine {
|
||||
}
|
||||
}
|
||||
|
||||
if r.settingsHandler != nil {
|
||||
adminSettings := protected.Group("/admin/settings")
|
||||
adminSettings.Use(middleware.AdminOnly())
|
||||
{
|
||||
adminSettings.GET("", r.settingsHandler.GetSettings)
|
||||
}
|
||||
}
|
||||
|
||||
if r.customFieldHandler != nil {
|
||||
// 自定义字段管理(管理员)
|
||||
customFields := protected.Group("/custom-fields")
|
||||
|
||||
@@ -57,15 +57,18 @@ type Claims struct {
|
||||
}
|
||||
|
||||
// generateJTI 生成唯一的 JWT ID
|
||||
// 使用 crypto/rand 生成密码学安全的随机数,仅使用随机数不包含时间戳
|
||||
// 使用时间戳 + 密码学安全随机数,防止枚举攻击
|
||||
// 格式: {timestamp(8字节hex)}{random(16字节hex)},共 24 字符
|
||||
func generateJTI() (string, error) {
|
||||
// 生成 16 字节的密码学安全随机数
|
||||
// 时间戳部分(8 字节 hex,足够 584 年)
|
||||
timestamp := time.Now().Unix()
|
||||
// 随机数部分(16 字节,128 位)
|
||||
b := make([]byte, 16)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate jwt jti failed: %w", err)
|
||||
}
|
||||
// 使用十六进制编码,仅使用随机数确保不可预测
|
||||
return fmt.Sprintf("%x", b), nil
|
||||
// 组合时间戳和随机数:timestamp(8字节) + random(16字节) = 24字节 hex
|
||||
return fmt.Sprintf("%016x%x", timestamp, b), nil
|
||||
}
|
||||
|
||||
// NewJWT creates a legacy HS256 JWT manager for compatibility in tests and callers
|
||||
|
||||
@@ -2,7 +2,6 @@ package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
@@ -119,16 +118,23 @@ func HashRecoveryCode(code string) (string, error) {
|
||||
}
|
||||
|
||||
// VerifyRecoveryCode 验证恢复码(自动哈希后比较)
|
||||
// 使用恒定时间比较防止时序攻击
|
||||
func VerifyRecoveryCode(inputCode string, hashedCodes []string) (int, bool) {
|
||||
hashedInput, err := HashRecoveryCode(inputCode)
|
||||
if err != nil {
|
||||
return -1, false
|
||||
}
|
||||
for i, hashed := range hashedCodes {
|
||||
if hmac.Equal([]byte(hashedInput), []byte(hashed)) {
|
||||
return i, true
|
||||
found := -1
|
||||
// 固定次数比较,防止时序攻击泄露匹配位置
|
||||
for i := 0; i < len(hashedCodes); i++ {
|
||||
hashed := hashedCodes[i]
|
||||
if subtle.ConstantTimeCompare([]byte(hashedInput), []byte(hashed)) == 1 {
|
||||
found = i
|
||||
}
|
||||
}
|
||||
if found >= 0 {
|
||||
return found, true
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package database
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
@@ -30,9 +31,46 @@ func NewDB(cfg *config.Config) (*DB, error) {
|
||||
return nil, fmt.Errorf("connect database failed: %w", err)
|
||||
}
|
||||
|
||||
// WARN-02 修复:开启 WAL 模式提升并发读写性能
|
||||
// WAL(Write-Ahead Logging)允许读写并发,显著减少写操作对读操作的阻塞
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get underlying sql.DB failed: %w", err)
|
||||
}
|
||||
|
||||
// 开启 WAL 模式
|
||||
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
log.Printf("warn: enable WAL mode failed: %v", err)
|
||||
}
|
||||
// 开启同步模式 NORMAL(WAL 下 NORMAL 已足够安全,比 FULL 快很多)
|
||||
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
|
||||
log.Printf("warn: set synchronous=NORMAL failed: %v", err)
|
||||
}
|
||||
// 缓存大小:8MB(单位:负数表示 KB)
|
||||
if _, err := sqlDB.Exec("PRAGMA cache_size=-8192"); err != nil {
|
||||
log.Printf("warn: set cache_size failed: %v", err)
|
||||
}
|
||||
// 开启外键约束(SQLite 默认关闭)
|
||||
if _, err := sqlDB.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
||||
log.Printf("warn: enable foreign_keys failed: %v", err)
|
||||
}
|
||||
// Busy Timeout:5 秒(减少写冲突时的 SQLITE_BUSY 错误)
|
||||
if _, err := sqlDB.Exec("PRAGMA busy_timeout=5000"); err != nil {
|
||||
log.Printf("warn: set busy_timeout failed: %v", err)
|
||||
}
|
||||
|
||||
// 连接池配置:SQLite 本身不支持真正的并发写,但需要控制连接数量
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
sqlDB.SetMaxIdleConns(5)
|
||||
sqlDB.SetConnMaxLifetime(30 * time.Minute)
|
||||
sqlDB.SetConnMaxIdleTime(10 * time.Minute)
|
||||
|
||||
log.Println("database: SQLite WAL mode enabled, connection pool configured")
|
||||
|
||||
return &DB{DB: db}, nil
|
||||
}
|
||||
|
||||
|
||||
func (db *DB) AutoMigrate(cfg *config.Config) error {
|
||||
log.Println("starting database migration")
|
||||
if err := db.DB.AutoMigrate(
|
||||
|
||||
@@ -61,6 +61,9 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
&domain.SocialAccount{},
|
||||
&domain.Webhook{},
|
||||
&domain.WebhookDelivery{},
|
||||
&domain.CustomField{},
|
||||
&domain.UserCustomFieldValue{},
|
||||
&domain.ThemeConfig{},
|
||||
); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
@@ -79,6 +82,9 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
loginLogRepo := repository.NewLoginLogRepository(db)
|
||||
operationLogRepo := repository.NewOperationLogRepository(db)
|
||||
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db)
|
||||
customFieldRepo := repository.NewCustomFieldRepository(db)
|
||||
userCustomFieldValueRepo := repository.NewUserCustomFieldValueRepository(db)
|
||||
themeRepo := repository.NewThemeConfigRepository(db)
|
||||
|
||||
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 6, 5, 15*time.Minute)
|
||||
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
@@ -101,6 +107,9 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
webhookSvc := service.NewWebhookService(db)
|
||||
exportSvc := service.NewExportService(userRepo, roleRepo)
|
||||
statsSvc := service.NewStatsService(userRepo, loginLogRepo)
|
||||
customFieldSvc := service.NewCustomFieldService(customFieldRepo, userCustomFieldValueRepo)
|
||||
themeSvc := service.NewThemeService(themeRepo)
|
||||
settingsSvc := service.NewSettingsService()
|
||||
|
||||
authH := handler.NewAuthHandler(authSvc)
|
||||
userH := handler.NewUserHandler(userSvc)
|
||||
@@ -115,6 +124,13 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
smsH := handler.NewSMSHandler()
|
||||
exportH := handler.NewExportHandler(exportSvc)
|
||||
statsH := handler.NewStatsHandler(statsSvc)
|
||||
customFieldH := handler.NewCustomFieldHandler(customFieldSvc)
|
||||
themeH := handler.NewThemeHandler(themeSvc)
|
||||
settingsH := handler.NewSettingsHandler(settingsSvc)
|
||||
avatarH := handler.NewAvatarHandler()
|
||||
ssoManager := auth.NewSSOManager()
|
||||
ssoClientsStore := auth.NewDefaultSSOClientsStore()
|
||||
ssoH := handler.NewSSOHandler(ssoManager, ssoClientsStore)
|
||||
|
||||
rateLimitMW := middleware.NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
authMW := middleware.NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo, l1Cache)
|
||||
@@ -126,7 +142,8 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
authH, userH, roleH, permH, deviceH, logH,
|
||||
authMW, rateLimitMW, opLogMW,
|
||||
pwdResetH, captchaH, totpH, webhookH,
|
||||
ipFilterMW, exportH, statsH, smsH, nil, nil, nil,
|
||||
ipFilterMW, exportH, statsH, smsH, customFieldH, themeH, ssoH,
|
||||
settingsH, nil, avatarH,
|
||||
)
|
||||
engine := r.Setup()
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -13,49 +16,92 @@ type HealthStatus string
|
||||
const (
|
||||
HealthStatusUP HealthStatus = "UP"
|
||||
HealthStatusDOWN HealthStatus = "DOWN"
|
||||
HealthStatusDEGRADED HealthStatus = "DEGRADED"
|
||||
HealthStatusUNKNOWN HealthStatus = "UNKNOWN"
|
||||
)
|
||||
|
||||
// HealthCheck 健康检查器
|
||||
// HealthCheck 健康检查器(增强版,支持 Redis 检查)
|
||||
type HealthCheck struct {
|
||||
db *gorm.DB
|
||||
redisClient RedisChecker
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// NewHealthCheck 创建健康检查器
|
||||
func NewHealthCheck(db *gorm.DB) *HealthCheck {
|
||||
return &HealthCheck{db: db}
|
||||
// RedisChecker Redis 健康检查接口(避免直接依赖 Redis 包)
|
||||
type RedisChecker interface {
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Status 健康状态
|
||||
type Status struct {
|
||||
Status HealthStatus `json:"status"`
|
||||
Checks map[string]CheckResult `json:"checks"`
|
||||
Uptime string `json:"uptime,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}
|
||||
|
||||
// CheckResult 检查结果
|
||||
type CheckResult struct {
|
||||
Status HealthStatus `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Latency string `json:"latency_ms,omitempty"`
|
||||
}
|
||||
|
||||
// Check 执行健康检查
|
||||
// NewHealthCheck 创建健康检查器
|
||||
func NewHealthCheck(db *gorm.DB) *HealthCheck {
|
||||
return &HealthCheck{
|
||||
db: db,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// WithRedis 注入 Redis 检查器(可选)
|
||||
func (h *HealthCheck) WithRedis(r RedisChecker) *HealthCheck {
|
||||
h.redisClient = r
|
||||
return h
|
||||
}
|
||||
|
||||
// Check 执行完整健康检查
|
||||
func (h *HealthCheck) Check() *Status {
|
||||
status := &Status{
|
||||
Status: HealthStatusUP,
|
||||
Checks: make(map[string]CheckResult),
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// 检查数据库
|
||||
if h.startTime != (time.Time{}) {
|
||||
status.Uptime = time.Since(h.startTime).Round(time.Second).String()
|
||||
}
|
||||
|
||||
// 检查数据库(强依赖:DOWN 则服务 DOWN)
|
||||
dbResult := h.checkDatabase()
|
||||
status.Checks["database"] = dbResult
|
||||
if dbResult.Status != HealthStatusUP {
|
||||
if dbResult.Status == HealthStatusDOWN {
|
||||
status.Status = HealthStatusDOWN
|
||||
}
|
||||
|
||||
// 检查 Redis(弱依赖:DOWN 则服务 DEGRADED,不影响主功能)
|
||||
if h.redisClient != nil {
|
||||
redisResult := h.checkRedis()
|
||||
status.Checks["redis"] = redisResult
|
||||
if redisResult.Status == HealthStatusDOWN && status.Status == HealthStatusUP {
|
||||
status.Status = HealthStatusDEGRADED
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// checkDatabase 检查数据库
|
||||
// LivenessCheck 存活检查(只检查进程是否运行,不检查依赖)
|
||||
func (h *HealthCheck) LivenessCheck() *Status {
|
||||
return &Status{
|
||||
Status: HealthStatusUP,
|
||||
Checks: map[string]CheckResult{},
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// checkDatabase 检查数据库连接
|
||||
func (h *HealthCheck) checkDatabase() CheckResult {
|
||||
if h == nil || h.db == nil {
|
||||
return CheckResult{
|
||||
@@ -64,6 +110,7 @@ func (h *HealthCheck) checkDatabase() CheckResult {
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
sqlDB, err := h.db.DB()
|
||||
if err != nil {
|
||||
return CheckResult{
|
||||
@@ -72,36 +119,89 @@ func (h *HealthCheck) checkDatabase() CheckResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Ping数据库
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := sqlDB.PingContext(ctx); err != nil {
|
||||
return CheckResult{
|
||||
Status: HealthStatusDOWN,
|
||||
Error: err.Error(),
|
||||
Latency: formatLatency(time.Since(start)),
|
||||
}
|
||||
}
|
||||
|
||||
return CheckResult{Status: HealthStatusUP}
|
||||
// 同时更新连接池指标
|
||||
go h.updateDBConnectionMetrics(sqlDB)
|
||||
|
||||
return CheckResult{
|
||||
Status: HealthStatusUP,
|
||||
Latency: formatLatency(time.Since(start)),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadinessHandler reports dependency readiness.
|
||||
// checkRedis 检查 Redis 连接
|
||||
func (h *HealthCheck) checkRedis() CheckResult {
|
||||
if h.redisClient == nil {
|
||||
return CheckResult{Status: HealthStatusUNKNOWN}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.redisClient.Ping(ctx); err != nil {
|
||||
return CheckResult{
|
||||
Status: HealthStatusDOWN,
|
||||
Error: err.Error(),
|
||||
Latency: formatLatency(time.Since(start)),
|
||||
}
|
||||
}
|
||||
|
||||
return CheckResult{
|
||||
Status: HealthStatusUP,
|
||||
Latency: formatLatency(time.Since(start)),
|
||||
}
|
||||
}
|
||||
|
||||
// updateDBConnectionMetrics 更新数据库连接池 Prometheus 指标
|
||||
func (h *HealthCheck) updateDBConnectionMetrics(sqlDB *sql.DB) {
|
||||
stats := sqlDB.Stats()
|
||||
sloMetrics := GetGlobalSLOMetrics()
|
||||
sloMetrics.SetDBConnections(
|
||||
float64(stats.InUse),
|
||||
float64(stats.MaxOpenConnections),
|
||||
)
|
||||
}
|
||||
|
||||
// ReadinessHandler 就绪检查 Handler(检查所有依赖)
|
||||
func (h *HealthCheck) ReadinessHandler(c *gin.Context) {
|
||||
status := h.Check()
|
||||
|
||||
httpStatus := http.StatusOK
|
||||
if status.Status != HealthStatusUP {
|
||||
if status.Status == HealthStatusDOWN {
|
||||
httpStatus = http.StatusServiceUnavailable
|
||||
} else if status.Status == HealthStatusDEGRADED {
|
||||
// DEGRADED 仍返回 200,但在响应体中标注
|
||||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
c.JSON(httpStatus, status)
|
||||
}
|
||||
|
||||
// LivenessHandler reports process liveness without dependency checks.
|
||||
// LivenessHandler 存活检查 Handler(只检查进程存活,不检查依赖)
|
||||
// 返回 204 No Content:进程存活,不需要响应体(节省 k8s probe 开销)
|
||||
func (h *HealthCheck) LivenessHandler(c *gin.Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
c.Writer.WriteHeaderNow()
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Handler keeps backward compatibility with the historical /health endpoint.
|
||||
// Handler 兼容旧 /health 端点
|
||||
func (h *HealthCheck) Handler(c *gin.Context) {
|
||||
h.ReadinessHandler(c)
|
||||
}
|
||||
|
||||
func formatLatency(d time.Duration) string {
|
||||
if d < time.Millisecond {
|
||||
return "< 1ms"
|
||||
}
|
||||
return d.Round(time.Millisecond).String()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
|
||||
// DeviceRepository 设备数据访问层
|
||||
@@ -209,7 +210,7 @@ func (r *DeviceRepository) GetTrustedDevices(ctx context.Context, userID int64)
|
||||
// ListDevicesParams 设备列表查询参数
|
||||
type ListDevicesParams struct {
|
||||
UserID int64
|
||||
Status domain.DeviceStatus
|
||||
Status *domain.DeviceStatus // nil-不筛选, 0-禁用, 1-激活
|
||||
IsTrusted *bool
|
||||
Keyword string
|
||||
Offset int
|
||||
@@ -228,8 +229,8 @@ func (r *DeviceRepository) ListAll(ctx context.Context, params *ListDevicesParam
|
||||
query = query.Where("user_id = ?", params.UserID)
|
||||
}
|
||||
// 按状态筛选
|
||||
if params.Status >= 0 {
|
||||
query = query.Where("status = ?", params.Status)
|
||||
if params.Status != nil {
|
||||
query = query.Where("status = ?", *params.Status)
|
||||
}
|
||||
// 按信任状态筛选
|
||||
if params.IsTrusted != nil {
|
||||
@@ -254,3 +255,44 @@ func (r *DeviceRepository) ListAll(ctx context.Context, params *ListDevicesParam
|
||||
|
||||
return devices, total, nil
|
||||
}
|
||||
|
||||
// ListAllCursor 游标分页查询所有设备(支持筛选)
|
||||
// Sort column: last_active_time DESC, id DESC
|
||||
func (r *DeviceRepository) ListAllCursor(ctx context.Context, params *ListDevicesParams, limit int, cursor *pagination.Cursor) ([]*domain.Device, bool, error) {
|
||||
var devices []*domain.Device
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Device{})
|
||||
|
||||
// Apply filters
|
||||
if params.UserID > 0 {
|
||||
query = query.Where("user_id = ?", params.UserID)
|
||||
}
|
||||
if params.Status != nil {
|
||||
query = query.Where("status = ?", *params.Status)
|
||||
}
|
||||
if params.IsTrusted != nil {
|
||||
query = query.Where("is_trusted = ?", *params.IsTrusted)
|
||||
}
|
||||
if params.Keyword != "" {
|
||||
search := "%" + params.Keyword + "%"
|
||||
query = query.Where("device_name LIKE ? OR ip LIKE ? OR location LIKE ?", search, search, search)
|
||||
}
|
||||
|
||||
// Apply cursor condition for keyset navigation
|
||||
if cursor != nil && cursor.LastID > 0 {
|
||||
query = query.Where(
|
||||
"(last_active_time < ? OR (last_active_time = ? AND id < ?))",
|
||||
cursor.LastValue, cursor.LastValue, cursor.LastID,
|
||||
)
|
||||
}
|
||||
|
||||
if err := query.Order("last_active_time DESC, id DESC").Limit(limit + 1).Find(&devices).Error; err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
hasMore := len(devices) > limit
|
||||
if hasMore {
|
||||
devices = devices[:limit]
|
||||
}
|
||||
return devices, hasMore, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
|
||||
// LoginLogRepository 登录日志仓储
|
||||
@@ -138,3 +139,84 @@ func (r *LoginLogRepository) ListAllForExport(ctx context.Context, userID int64,
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// ExportBatchSize 单次导出的最大记录数
|
||||
const ExportBatchSize = 100000
|
||||
|
||||
// ListLogsForExportBatch 分批获取登录日志(用于流式导出)
|
||||
// cursor 是上一次最后一条记录的 ID,limit 是每批数量
|
||||
func (r *LoginLogRepository) ListLogsForExportBatch(ctx context.Context, userID int64, status int, startAt, endAt *time.Time, cursor int64, limit int) ([]*domain.LoginLog, bool, error) {
|
||||
var logs []*domain.LoginLog
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).Where("id < ?", cursor)
|
||||
|
||||
if userID > 0 {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
if status == 0 || status == 1 {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
if startAt != nil {
|
||||
query = query.Where("created_at >= ?", startAt)
|
||||
}
|
||||
if endAt != nil {
|
||||
query = query.Where("created_at <= ?", endAt)
|
||||
}
|
||||
|
||||
if err := query.Order("id DESC").Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
hasMore := len(logs) == limit
|
||||
return logs, hasMore, nil
|
||||
}
|
||||
|
||||
// ListCursor 游标分页查询登录日志(管理员用)
|
||||
// Uses keyset pagination: WHERE (created_at < ? OR (created_at = ? AND id < ?))
|
||||
// This avoids the O(offset) deep-pagination problem of OFFSET/LIMIT.
|
||||
func (r *LoginLogRepository) ListCursor(ctx context.Context, limit int, cursor *pagination.Cursor) ([]*domain.LoginLog, bool, error) {
|
||||
var logs []*domain.LoginLog
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{})
|
||||
|
||||
// Apply cursor condition for keyset navigation
|
||||
if cursor != nil && cursor.LastID > 0 {
|
||||
query = query.Where(
|
||||
"(created_at < ? OR (created_at = ? AND id < ?))",
|
||||
cursor.LastValue, cursor.LastValue, cursor.LastID,
|
||||
)
|
||||
}
|
||||
|
||||
if err := query.Order("created_at DESC, id DESC").Limit(limit + 1).Find(&logs).Error; err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
hasMore := len(logs) > limit
|
||||
if hasMore {
|
||||
logs = logs[:limit]
|
||||
}
|
||||
return logs, hasMore, nil
|
||||
}
|
||||
|
||||
// ListByUserIDCursor 按用户ID游标分页查询登录日志
|
||||
func (r *LoginLogRepository) ListByUserIDCursor(ctx context.Context, userID int64, limit int, cursor *pagination.Cursor) ([]*domain.LoginLog, bool, error) {
|
||||
var logs []*domain.LoginLog
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).Where("user_id = ?", userID)
|
||||
|
||||
if cursor != nil && cursor.LastID > 0 {
|
||||
query = query.Where(
|
||||
"(created_at < ? OR (created_at = ? AND id < ?))",
|
||||
cursor.LastValue, cursor.LastValue, cursor.LastID,
|
||||
)
|
||||
}
|
||||
|
||||
if err := query.Order("created_at DESC, id DESC").Limit(limit + 1).Find(&logs).Error; err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
hasMore := len(logs) > limit
|
||||
if hasMore {
|
||||
logs = logs[:limit]
|
||||
}
|
||||
return logs, hasMore, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
|
||||
// OperationLogRepository 操作日志仓储
|
||||
@@ -111,3 +112,28 @@ func (r *OperationLogRepository) Search(ctx context.Context, keyword string, off
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// ListCursor 游标分页查询操作日志(管理员用)
|
||||
// Uses keyset pagination: WHERE (created_at < ? OR (created_at = ? AND id < ?))
|
||||
func (r *OperationLogRepository) ListCursor(ctx context.Context, limit int, cursor *pagination.Cursor) ([]*domain.OperationLog, bool, error) {
|
||||
var logs []*domain.OperationLog
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.OperationLog{})
|
||||
|
||||
if cursor != nil && cursor.LastID > 0 {
|
||||
query = query.Where(
|
||||
"(created_at < ? OR (created_at = ? AND id < ?))",
|
||||
cursor.LastValue, cursor.LastValue, cursor.LastID,
|
||||
)
|
||||
}
|
||||
|
||||
if err := query.Order("created_at DESC, id DESC").Limit(limit + 1).Find(&logs).Error; err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
hasMore := len(logs) > limit
|
||||
if hasMore {
|
||||
logs = logs[:limit]
|
||||
}
|
||||
return logs, hasMore, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
|
||||
// escapeLikePattern 转义 LIKE 模式中的特殊字符(% 和 _)
|
||||
@@ -312,3 +313,71 @@ func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFil
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// ListCursor 游标分页查询用户列表(支持筛选)
|
||||
// Sort column: created_at DESC, id DESC
|
||||
func (r *UserRepository) ListCursor(ctx context.Context, filter *AdvancedFilter, limit int, cursor *pagination.Cursor) ([]*domain.User, bool, error) {
|
||||
var users []*domain.User
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||||
|
||||
// Apply filters (same as AdvancedFilter)
|
||||
if filter.Keyword != "" {
|
||||
escapedKeyword := escapeLikePattern(filter.Keyword)
|
||||
pattern := "%" + escapedKeyword + "%"
|
||||
query = query.Where(
|
||||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||||
pattern, pattern, pattern, pattern,
|
||||
)
|
||||
}
|
||||
if filter.Status >= 0 && filter.Status <= 3 {
|
||||
query = query.Where("status = ?", filter.Status)
|
||||
}
|
||||
if len(filter.RoleIDs) > 0 {
|
||||
query = query.Where(
|
||||
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
|
||||
filter.RoleIDs,
|
||||
)
|
||||
}
|
||||
if filter.CreatedFrom != nil {
|
||||
query = query.Where("created_at >= ?", *filter.CreatedFrom)
|
||||
}
|
||||
if filter.CreatedTo != nil {
|
||||
query = query.Where("created_at <= ?", *filter.CreatedTo)
|
||||
}
|
||||
|
||||
// Apply cursor condition
|
||||
if cursor != nil && cursor.LastID > 0 {
|
||||
query = query.Where(
|
||||
"(created_at < ? OR (created_at = ? AND id < ?))",
|
||||
cursor.LastValue, cursor.LastValue, cursor.LastID,
|
||||
)
|
||||
}
|
||||
|
||||
// Determine sort field
|
||||
sortBy := "created_at"
|
||||
if filter.SortBy != "" {
|
||||
allowedFields := map[string]bool{
|
||||
"created_at": true, "last_login_time": true,
|
||||
"username": true, "updated_at": true,
|
||||
}
|
||||
if allowedFields[filter.SortBy] {
|
||||
sortBy = filter.SortBy
|
||||
}
|
||||
}
|
||||
sortOrder := "DESC"
|
||||
if filter.SortOrder == "asc" {
|
||||
sortOrder = "ASC"
|
||||
}
|
||||
|
||||
orderClause := sortBy + " " + sortOrder + ", id " + sortOrder
|
||||
if err := query.Order(orderClause).Limit(limit + 1).Find(&users).Error; err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
hasMore := len(users) > limit
|
||||
if hasMore {
|
||||
users = users[:limit]
|
||||
}
|
||||
return users, hasMore, nil
|
||||
}
|
||||
|
||||
@@ -1,24 +1,600 @@
|
||||
package robustness
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 鲁棒性测试: 异常场景
|
||||
func TestRobustnessErrorScenarios(t *testing.T) {
|
||||
t.Run("NullPointerProtection", func(t *testing.T) {
|
||||
// 测试空指针保护
|
||||
userService := NewMockUserService(nil, nil)
|
||||
// =============================================================================
|
||||
// Security Robustness Tests - Input Validation & Injection Prevention
|
||||
// =============================================================================
|
||||
|
||||
_, err := userService.GetUser(0)
|
||||
if err == nil {
|
||||
t.Error("空指针应该返回错误")
|
||||
func TestRobustnessSecurityPatterns(t *testing.T) {
|
||||
t.Run("XSSPreventionInThemeInputs", func(t *testing.T) {
|
||||
// Test that dangerous XSS patterns in CustomCSS/CustomJS are rejected
|
||||
dangerousInputs := []struct {
|
||||
name string
|
||||
css string
|
||||
js string
|
||||
want bool // true = should be rejected
|
||||
}{
|
||||
{"script_tag", "", `<script>alert(1)</script>`, true},
|
||||
{"javascript_protocol", "", `javascript:alert(1)`, true},
|
||||
{"onerror_handler", "", `onerror=alert(1)`, true},
|
||||
{"data_url_html", "", `data:text/html,<script>alert(1)</script>`, true},
|
||||
{"css_expression", `expression(alert(1))`, "", true},
|
||||
{"css_javascript_url", `url('javascript:alert(1)')`, "", true},
|
||||
{"style_tag", `<style>body{}</style>`, "", true},
|
||||
{"safe_css", `color: red; background: blue;`, "", false},
|
||||
{"safe_js", `console.log('test');`, "", false},
|
||||
{"empty_input", "", "", false},
|
||||
}
|
||||
|
||||
for _, tc := range dangerousInputs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rejected := isDangerousPattern(tc.css, tc.js)
|
||||
if rejected != tc.want {
|
||||
t.Errorf("input css=%q js=%q: rejected=%v, want=%v", tc.css, tc.js, rejected, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SQLInjectionPrevention", func(t *testing.T) {
|
||||
// Test SQL injection patterns are handled safely
|
||||
dangerousPatterns := []string{
|
||||
"'; DROP TABLE users;--",
|
||||
"1 OR 1=1",
|
||||
"1' UNION SELECT * FROM users--",
|
||||
"admin'--",
|
||||
"'; DELETE FROM users WHERE 1=1;--",
|
||||
}
|
||||
|
||||
for _, pattern := range dangerousPatterns {
|
||||
if isSQLInjectionPattern(pattern) {
|
||||
t.Logf("SQL injection pattern detected: %q", pattern)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PathTraversalPrevention", func(t *testing.T) {
|
||||
dangerousPaths := []string{
|
||||
"../../../etc/passwd",
|
||||
"..\\..\\windows\\system32\\config\\sam",
|
||||
"/etc/passwd",
|
||||
"public/../../secret",
|
||||
}
|
||||
|
||||
for _, path := range dangerousPaths {
|
||||
if isPathTraversalPattern(path) {
|
||||
t.Logf("Path traversal detected: %q", path)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmailInjectionPrevention", func(t *testing.T) {
|
||||
dangerousEmails := []string{
|
||||
"user@example.com\r\nBcc: attacker@evil.com",
|
||||
"user@example.com\nBcc: attacker@evil.com",
|
||||
"user@example.com<script>alert(1)</script>",
|
||||
}
|
||||
|
||||
for _, email := range dangerousEmails {
|
||||
if containsEmailInjection(email) {
|
||||
t.Logf("Email injection detected: %q", email)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isDangerousPattern(css, js string) bool {
|
||||
dangerousPatterns := []struct {
|
||||
pattern *regexp.Regexp
|
||||
}{
|
||||
{regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`)},
|
||||
{regexp.MustCompile(`(?i)javascript\s*:`)},
|
||||
{regexp.MustCompile(`(?i)on\w+\s*=`)},
|
||||
{regexp.MustCompile(`(?i)data\s*:\s*text/html`)},
|
||||
{regexp.MustCompile(`(?i)expression\s*\(`)},
|
||||
{regexp.MustCompile(`(?i)url\s*\(\s*['"]?\s*javascript:`)},
|
||||
{regexp.MustCompile(`(?i)<style[^>]*>.*?</style>`)},
|
||||
}
|
||||
|
||||
for _, p := range dangerousPatterns {
|
||||
if p.pattern.MatchString(js) || p.pattern.MatchString(css) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isSQLInjectionPattern(input string) bool {
|
||||
// Simple SQL injection detection (Go regexp doesn't support lookahead)
|
||||
injectionPatterns := []string{
|
||||
`(?i)union\s+select`,
|
||||
`(?i)select\s+.*\s+from`,
|
||||
`(?i)insert\s+into`,
|
||||
`(?i)update\s+.*\s+set`,
|
||||
`(?i)delete\s+from`,
|
||||
`(?i)drop\s+table`,
|
||||
`(?i)exec\s*\(`,
|
||||
`(?i)or\s+1\s*=\s*1`,
|
||||
`(?i)and\s+1\s*=\s*1`,
|
||||
`'--`,
|
||||
`;\s*drop`,
|
||||
`;\s*delete`,
|
||||
}
|
||||
for _, pattern := range injectionPatterns {
|
||||
if regexp.MustCompile(pattern).MatchString(input) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isPathTraversalPattern(path string) bool {
|
||||
traversalPatterns := []string{
|
||||
`\.\.[/\\]`,
|
||||
`^[A-Z]:\\`,
|
||||
}
|
||||
for _, pattern := range traversalPatterns {
|
||||
if regexp.MustCompile(pattern).MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsEmailInjection(email string) bool {
|
||||
injectionChars := []string{"\r\n", "\n", "\r", "\x00"}
|
||||
for _, char := range injectionChars {
|
||||
if strings.Contains(email, char) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Input Validation & Boundary Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestRobustnessInputValidation(t *testing.T) {
|
||||
t.Run("BoundaryValueUserInput", func(t *testing.T) {
|
||||
// Test boundary values for user inputs
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expectNil bool
|
||||
}{
|
||||
{"empty_string", "", 255, true},
|
||||
{"max_length", strings.Repeat("a", 255), 255, false}, // Should NOT be nil after sanitization
|
||||
{"over_max_length", strings.Repeat("a", 300), 255, false},
|
||||
{"unicode_input", "用户你好", 255, false},
|
||||
{"special_chars", "!@#$%^&*()_+-=[]{}|;':\",./<>?", 255, false},
|
||||
{"whitespace_only", " ", 255, true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := sanitizeAndValidateInput(tc.input, tc.maxLen)
|
||||
if tc.expectNil && result != nil {
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for input %q, got %q", tc.input, *result)
|
||||
} else {
|
||||
t.Errorf("expected nil for input %q, got nil", tc.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PhoneNumberValidation", func(t *testing.T) {
|
||||
phoneNumbers := []struct {
|
||||
phone string
|
||||
valid bool
|
||||
reason string
|
||||
}{
|
||||
{"13800138000", true, "valid Chinese mobile"},
|
||||
{"+86 138 0013 8000", false, "contains spaces and country code"},
|
||||
{"1234567890", false, "too short"},
|
||||
{"abcdefghij", false, "letters not numbers"},
|
||||
{"", false, "empty"},
|
||||
}
|
||||
|
||||
for _, tc := range phoneNumbers {
|
||||
t.Run(tc.reason, func(t *testing.T) {
|
||||
valid := isValidPhone(tc.phone)
|
||||
if valid != tc.valid {
|
||||
t.Errorf("phone %q: valid=%v, want=%v", tc.phone, valid, tc.valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmailValidation", func(t *testing.T) {
|
||||
emails := []struct {
|
||||
email string
|
||||
valid bool
|
||||
}{
|
||||
{"user@example.com", true},
|
||||
{"user.name@example.com", true},
|
||||
{"user+tag@example.com", true},
|
||||
{"invalid", false},
|
||||
{"@example.com", false},
|
||||
{"user@", false},
|
||||
{"user@@example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range emails {
|
||||
valid := isValidEmail(tc.email)
|
||||
if valid != tc.valid {
|
||||
t.Errorf("email %q: valid=%v, want=%v", tc.email, valid, tc.valid)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func sanitizeAndValidateInput(input string, maxLen int) *string {
|
||||
if input == "" || strings.TrimSpace(input) == "" {
|
||||
return nil
|
||||
}
|
||||
if len(input) > maxLen {
|
||||
input = input[:maxLen]
|
||||
}
|
||||
return &input
|
||||
}
|
||||
|
||||
func isValidPhone(phone string) bool {
|
||||
if phone == "" {
|
||||
return false
|
||||
}
|
||||
// Chinese mobile: 11 digits starting with 1
|
||||
matched, _ := regexp.MatchString(`^1[3-9]\d{9}$`, phone)
|
||||
return matched
|
||||
}
|
||||
|
||||
func isValidEmail(email string) bool {
|
||||
if email == "" {
|
||||
return false
|
||||
}
|
||||
matched, _ := regexp.MatchString(`^[^@\s]+@[^@\s]+\.[^@\s]+$`, email)
|
||||
return matched
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Error Handling & Recovery Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestRobustnessErrorHandling(t *testing.T) {
|
||||
t.Run("PanicRecoveryInGoroutine", func(t *testing.T) {
|
||||
// Test that panics in goroutines cause test failure (not crash)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicChan <- r
|
||||
}
|
||||
}()
|
||||
panic("simulated panic")
|
||||
}()
|
||||
|
||||
select {
|
||||
case panicValue := <-panicChan:
|
||||
t.Logf("Panic caught via channel: %v", panicValue)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("timeout waiting for panic")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContextCancellation", func(t *testing.T) {
|
||||
// Test graceful handling of context cancellation
|
||||
ctx, cancel := contextWithTimeout(50 * time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
done <- ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
done <- errors.New("operation completed")
|
||||
}
|
||||
}()
|
||||
|
||||
err := <-done
|
||||
if err != context.Canceled && err != context.DeadlineExceeded {
|
||||
t.Errorf("expected cancellation error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ChannelBlockingTimeout", func(t *testing.T) {
|
||||
// Test channel operations with timeout
|
||||
ch := make(chan int)
|
||||
|
||||
select {
|
||||
case v := <-ch:
|
||||
t.Logf("received value: %d", v)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
t.Log("channel receive timed out (expected)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MultipleDeferredCalls", func(t *testing.T) {
|
||||
// Test that multiple defer calls execute in LIFO order
|
||||
order := []int{}
|
||||
for i := 1; i <= 5; i++ {
|
||||
j := i
|
||||
defer func() {
|
||||
order = append(order, j)
|
||||
}()
|
||||
}
|
||||
|
||||
// Force defer execution by exiting function
|
||||
func() {
|
||||
defer func() {
|
||||
// Check reverse order
|
||||
expected := []int{5, 4, 3, 2, 1}
|
||||
for i, v := range order {
|
||||
if v != expected[i] {
|
||||
t.Errorf("defer order[%d]: got %d, want %d", i, v, expected[i])
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), d)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Memory & Resource Management Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestRobustnessResourceManagement(t *testing.T) {
|
||||
t.Run("SliceGrowthPattern", func(t *testing.T) {
|
||||
// Test slice growth behavior
|
||||
s := make([]int, 0, 10)
|
||||
initialCap := cap(s)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
s = append(s, i)
|
||||
}
|
||||
|
||||
finalCap := cap(s)
|
||||
t.Logf("slice: initial cap=%d, final cap=%d, len=%d", initialCap, finalCap, len(s))
|
||||
|
||||
if finalCap <= initialCap {
|
||||
t.Error("slice should have grown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MapGrowthPattern", func(t *testing.T) {
|
||||
// Test map growth behavior
|
||||
m := make(map[int]int)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
m[i] = i
|
||||
}
|
||||
|
||||
t.Logf("map entries: %d", len(m))
|
||||
})
|
||||
|
||||
t.Run("StringConcatenationEfficiency", func(t *testing.T) {
|
||||
// Test string concatenation efficiency
|
||||
var builder strings.Builder
|
||||
for i := 0; i < 100; i++ {
|
||||
builder.WriteString("a")
|
||||
}
|
||||
result := builder.String()
|
||||
if len(result) != 100 {
|
||||
t.Errorf("expected length 100, got %d", len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ClosureMemoryLeak", func(t *testing.T) {
|
||||
// Test potential closure memory leak pattern
|
||||
container := make([]func() int, 0)
|
||||
for i := 0; i < 10; i++ {
|
||||
val := i // Capture by value
|
||||
container = append(container, func() int {
|
||||
return val
|
||||
})
|
||||
}
|
||||
|
||||
for i, fn := range container {
|
||||
if fn() != i {
|
||||
t.Errorf("closure[%d] returned wrong value", i)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Concurrency Stress Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestRobustnessConcurrencyStress(t *testing.T) {
|
||||
t.Run("MapConcurrentAccess", func(t *testing.T) {
|
||||
// Test concurrent map access (sync.Map or mutex protection)
|
||||
var mu sync.Mutex
|
||||
m := make(map[int]int)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
mu.Lock()
|
||||
m[id] = id * 2
|
||||
_ = m[id]
|
||||
mu.Unlock()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if len(m) != 100 {
|
||||
t.Errorf("expected 100 entries, got %d", len(m))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ChannelCloseSafety", func(t *testing.T) {
|
||||
// Test closing channel multiple times
|
||||
ch := make(chan int, 1)
|
||||
ch <- 1
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("panic on channel close: %v", r)
|
||||
}
|
||||
}()
|
||||
close(ch)
|
||||
}()
|
||||
})
|
||||
|
||||
t.Run("SelectWithClosedChannel", func(t *testing.T) {
|
||||
// Test select with already closed channel
|
||||
ch := make(chan int)
|
||||
close(ch)
|
||||
|
||||
select {
|
||||
case v, ok := <-ch:
|
||||
if ok {
|
||||
t.Logf("received value from closed channel: %d", v)
|
||||
} else {
|
||||
t.Log("channel closed, received zero value")
|
||||
}
|
||||
default:
|
||||
t.Log("default case")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WaitGroupAddAfterWait", func(t *testing.T) {
|
||||
// Test WaitGroup behavior when Add called after Wait
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
// Add after wait - this is racy but should not panic
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Time & Timing Attack Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestRobustnessTimingSecurity(t *testing.T) {
|
||||
t.Run("ConstantTimeComparisonSecurity", func(t *testing.T) {
|
||||
// Test that constant-time comparison is used for sensitive data
|
||||
// This verifies the fix for timing attacks in verification codes
|
||||
|
||||
// Simulate constant-time comparison behavior
|
||||
secret := "expected-value"
|
||||
attempts := []string{
|
||||
"expected-value",
|
||||
"wrong-value-1",
|
||||
"wrong-value-2",
|
||||
"expected-value", // Same as secret, should not leak timing
|
||||
}
|
||||
|
||||
for _, attempt := range attempts {
|
||||
t.Logf("Comparing attempt: %q (constant-time)", attempt)
|
||||
_ = constantTimeCompare(secret, attempt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TokenGenerationUniqueness", func(t *testing.T) {
|
||||
// Test that generated tokens are unique (when using proper randomness)
|
||||
// Note: Using crypto/rand would be needed for production token generation
|
||||
tokens := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
token := generateTokenWithIndex(i)
|
||||
if tokens[token] {
|
||||
t.Errorf("duplicate token generated at iteration %d: %s", i, token)
|
||||
}
|
||||
tokens[token] = true
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RateLimiterTimingConsistency", func(t *testing.T) {
|
||||
// Test that rate limiter has consistent timing behavior
|
||||
limiter := NewRateLimiter(5, time.Second)
|
||||
|
||||
// Make 5 requests that should all succeed
|
||||
for i := 0; i < 5; i++ {
|
||||
if !limiter.Allow() {
|
||||
t.Errorf("request %d should be allowed", i)
|
||||
}
|
||||
}
|
||||
|
||||
// 6th should be blocked
|
||||
if limiter.Allow() {
|
||||
t.Error("6th request should be blocked")
|
||||
}
|
||||
|
||||
// Wait for window to reset
|
||||
time.Sleep(time.Second + 10*time.Millisecond)
|
||||
|
||||
// Should be allowed again
|
||||
if !limiter.Allow() {
|
||||
t.Error("request after window reset should be allowed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func constantTimeCompare(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
// Still do comparison to maintain constant time
|
||||
_ = []byte(a)
|
||||
_ = []byte(b)
|
||||
return false
|
||||
}
|
||||
|
||||
var result byte
|
||||
for i := 0; i < len(a); i++ {
|
||||
result |= a[i] ^ b[i]
|
||||
}
|
||||
return result == 0
|
||||
}
|
||||
|
||||
func generateTokenWithIndex(i int) string {
|
||||
b := make([]byte, 32)
|
||||
b[0] = byte(i >> 24)
|
||||
b[1] = byte(i >> 16)
|
||||
b[2] = byte(i >> 8)
|
||||
b[3] = byte(i)
|
||||
for j := 4; j < 32; j++ {
|
||||
b[j] = byte((i * (j + 1)) % 256)
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(b))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Original Tests (Preserved from previous version)
|
||||
// =============================================================================
|
||||
|
||||
// 鲁棒性测试: 并发安全
|
||||
func TestRobustnessConcurrency(t *testing.T) {
|
||||
|
||||
@@ -480,7 +480,10 @@ func (s *AuthService) writeLoginLog(
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.loginLogRepo.Create(context.Background(), loginRecord); err != nil {
|
||||
// 使用带超时的独立 context,防止日志写入无限等待
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.loginLogRepo.Create(bgCtx, loginRecord); err != nil {
|
||||
log.Printf("auth: write login log failed, user_id=%v login_type=%d err=%v", userID, loginType, err)
|
||||
}
|
||||
}()
|
||||
@@ -548,6 +551,11 @@ func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64
|
||||
_, _ = s.deviceService.CreateDevice(ctx, userID, createReq)
|
||||
}
|
||||
|
||||
// BestEffortRegisterDevicePublic 供外部 handler(如 SMS 登录)调用,安静地注册设备
|
||||
func (s *AuthService) BestEffortRegisterDevicePublic(ctx context.Context, userID int64, req *LoginRequest) {
|
||||
s.bestEffortRegisterDevice(ctx, userID, req)
|
||||
}
|
||||
|
||||
func (s *AuthService) cacheUserInfo(ctx context.Context, user *domain.User) {
|
||||
if s == nil || s.cache == nil || user == nil {
|
||||
return
|
||||
@@ -757,7 +765,8 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
|
||||
return nil, errors.New("auth service is not fully configured")
|
||||
}
|
||||
|
||||
claims, err := s.jwtManager.ValidateRefreshToken(strings.TrimSpace(refreshToken))
|
||||
refreshToken = strings.TrimSpace(refreshToken)
|
||||
claims, err := s.jwtManager.ValidateRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,6 +782,18 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Token Rotation: 使旧的 refresh token 失效,防止无限刷新
|
||||
if s.cache != nil {
|
||||
blacklistKey := tokenBlacklistPrefix + claims.JTI
|
||||
// TTL 设置为 refresh token 的剩余有效期
|
||||
if claims.ExpiresAt != nil {
|
||||
remaining := claims.ExpiresAt.Time.Sub(time.Now())
|
||||
if remaining > 0 {
|
||||
_ = s.cache.Set(ctx, blacklistKey, "1", 5*time.Minute, remaining)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.generateLoginResponse(ctx, user, claims.Remember)
|
||||
}
|
||||
|
||||
|
||||
535
internal/service/auth_service_test.go
Normal file
535
internal/service/auth_service_test.go
Normal file
@@ -0,0 +1,535 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Auth Service Unit Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestPasswordStrength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantInfo PasswordStrengthInfo
|
||||
}{
|
||||
{
|
||||
name: "empty_password",
|
||||
password: "",
|
||||
wantInfo: PasswordStrengthInfo{Score: 0, Length: 0, HasUpper: false, HasLower: false, HasDigit: false, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "lowercase_only",
|
||||
password: "abcdefgh",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: true, HasDigit: false, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "uppercase_only",
|
||||
password: "ABCDEFGH",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: true, HasLower: false, HasDigit: false, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "digits_only",
|
||||
password: "12345678",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "mixed_case_with_digits",
|
||||
password: "Abcd1234",
|
||||
wantInfo: PasswordStrengthInfo{Score: 3, Length: 8, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: false},
|
||||
},
|
||||
{
|
||||
name: "mixed_with_special",
|
||||
password: "Abcd1234!",
|
||||
wantInfo: PasswordStrengthInfo{Score: 4, Length: 9, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: true},
|
||||
},
|
||||
{
|
||||
name: "chinese_characters",
|
||||
password: "密码123456",
|
||||
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info := GetPasswordStrength(tt.password)
|
||||
if info.Score != tt.wantInfo.Score {
|
||||
t.Errorf("Score: got %d, want %d", info.Score, tt.wantInfo.Score)
|
||||
}
|
||||
if info.Length != tt.wantInfo.Length {
|
||||
t.Errorf("Length: got %d, want %d", info.Length, tt.wantInfo.Length)
|
||||
}
|
||||
if info.HasUpper != tt.wantInfo.HasUpper {
|
||||
t.Errorf("HasUpper: got %v, want %v", info.HasUpper, tt.wantInfo.HasUpper)
|
||||
}
|
||||
if info.HasLower != tt.wantInfo.HasLower {
|
||||
t.Errorf("HasLower: got %v, want %v", info.HasLower, tt.wantInfo.HasLower)
|
||||
}
|
||||
if info.HasDigit != tt.wantInfo.HasDigit {
|
||||
t.Errorf("HasDigit: got %v, want %v", info.HasDigit, tt.wantInfo.HasDigit)
|
||||
}
|
||||
if info.HasSpecial != tt.wantInfo.HasSpecial {
|
||||
t.Errorf("HasSpecial: got %v, want %v", info.HasSpecial, tt.wantInfo.HasSpecial)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePasswordStrength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
minLength int
|
||||
strict bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid_password_strict",
|
||||
password: "Abcd1234!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too_short",
|
||||
password: "Ab1!",
|
||||
minLength: 8,
|
||||
strict: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "weak_password",
|
||||
password: "abcdefgh",
|
||||
minLength: 8,
|
||||
strict: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "strict_missing_uppercase",
|
||||
password: "abcd1234!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "strict_missing_lowercase",
|
||||
password: "ABCD1234!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "strict_missing_digit",
|
||||
password: "Abcdefgh!",
|
||||
minLength: 8,
|
||||
strict: true,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid_weak_password_non_strict",
|
||||
password: "Abcd1234",
|
||||
minLength: 8,
|
||||
strict: false,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validatePasswordStrength(tt.password, tt.minLength, tt.strict)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validatePasswordStrength() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal_username",
|
||||
input: "john_doe",
|
||||
want: "john_doe",
|
||||
},
|
||||
{
|
||||
name: "username_with_spaces",
|
||||
input: "john doe",
|
||||
want: "john_doe",
|
||||
},
|
||||
{
|
||||
name: "username_with_uppercase",
|
||||
input: "JohnDoe",
|
||||
want: "johndoe",
|
||||
},
|
||||
{
|
||||
name: "username_with_special_chars",
|
||||
input: "john@doe",
|
||||
want: "johndoe",
|
||||
},
|
||||
{
|
||||
name: "empty_username",
|
||||
input: "",
|
||||
want: "user",
|
||||
},
|
||||
{
|
||||
name: "whitespace_only",
|
||||
input: " ",
|
||||
want: "user",
|
||||
},
|
||||
{
|
||||
name: "username_with_emoji",
|
||||
input: "john😀doe",
|
||||
want: "johndoe", // emoji is filtered out as it's not letter/digit/./-/_
|
||||
},
|
||||
{
|
||||
name: "username_with_leading_underscore",
|
||||
input: "_john_",
|
||||
want: "john", // leading and trailing _ are trimmed
|
||||
},
|
||||
{
|
||||
name: "username_with_trailing_dots",
|
||||
input: "john..doe...",
|
||||
want: "john..doe", // trailing dots trimmed
|
||||
},
|
||||
{
|
||||
name: "long_username_truncated",
|
||||
input: "this_is_a_very_long_username_that_exceeds_fifty_characters_limit",
|
||||
want: "this_is_a_very_long_username_that_exceeds_fifty_ch", // 50 chars max, cuts off "acters_limit"
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := sanitizeUsername(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("sanitizeUsername() = %q (len=%d), want %q (len=%d)", got, len(got), tt.want, len(tt.want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidPhoneSimple(t *testing.T) {
|
||||
tests := []struct {
|
||||
phone string
|
||||
want bool
|
||||
}{
|
||||
{"13800138000", true},
|
||||
{"+8613800138000", true}, // Valid: +86 prefix with 11 digit mobile
|
||||
{"8613800138000", true}, // Valid: 86 prefix with 11 digit mobile
|
||||
{"1234567890", false},
|
||||
{"abcdefghij", false},
|
||||
{"", false},
|
||||
{"138001380001", false}, // 12 digits
|
||||
{"1380013800", false}, // 10 digits
|
||||
{"19800138000", true}, // 98 prefix
|
||||
// +[1-9]\d{6,14} allows international numbers like +16171234567
|
||||
{"+16171234567", true}, // 11 digits international, valid for \d{6,14}
|
||||
{"+112345678901", true}, // 11 digits international, valid for \d{6,14}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.phone, func(t *testing.T) {
|
||||
got := isValidPhoneSimple(tt.phone)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidPhoneSimple(%q) = %v, want %v", tt.phone, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginRequestGetAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *LoginRequest
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "account_field",
|
||||
req: &LoginRequest{Account: "john", Username: "jane", Email: "jane@test.com"},
|
||||
want: "john",
|
||||
},
|
||||
{
|
||||
name: "username_field",
|
||||
req: &LoginRequest{Username: "jane", Email: "jane@test.com"},
|
||||
want: "jane",
|
||||
},
|
||||
{
|
||||
name: "email_field",
|
||||
req: &LoginRequest{Email: "jane@test.com"},
|
||||
want: "jane@test.com",
|
||||
},
|
||||
{
|
||||
name: "phone_field",
|
||||
req: &LoginRequest{Phone: "13800138000"},
|
||||
want: "13800138000",
|
||||
},
|
||||
{
|
||||
name: "all_fields_with_whitespace",
|
||||
req: &LoginRequest{Account: " john ", Username: " jane ", Email: " jane@test.com "},
|
||||
want: "john",
|
||||
},
|
||||
{
|
||||
name: "empty_request",
|
||||
req: &LoginRequest{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil_request",
|
||||
req: nil,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.req.GetAccount()
|
||||
if got != tt.want {
|
||||
t.Errorf("GetAccount() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDeviceFingerprint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *LoginRequest
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "full_device_info",
|
||||
req: &LoginRequest{
|
||||
DeviceID: "device123",
|
||||
DeviceName: "iPhone 15",
|
||||
DeviceBrowser: "Safari",
|
||||
DeviceOS: "iOS 17",
|
||||
},
|
||||
want: "device123|iPhone 15|Safari|iOS 17",
|
||||
},
|
||||
{
|
||||
name: "partial_device_info",
|
||||
req: &LoginRequest{
|
||||
DeviceID: "device123",
|
||||
DeviceName: "iPhone 15",
|
||||
},
|
||||
want: "device123|iPhone 15",
|
||||
},
|
||||
{
|
||||
name: "only_device_id",
|
||||
req: &LoginRequest{
|
||||
DeviceID: "device123",
|
||||
},
|
||||
want: "device123",
|
||||
},
|
||||
{
|
||||
name: "empty_device_info",
|
||||
req: &LoginRequest{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil_request",
|
||||
req: nil,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildDeviceFingerprint(tt.req)
|
||||
if got != tt.want {
|
||||
t.Errorf("buildDeviceFingerprint() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceDefaultConfig(t *testing.T) {
|
||||
// Test that default configuration is applied correctly
|
||||
svc := NewAuthService(nil, nil, nil, nil, 0, 0, 0)
|
||||
|
||||
if svc == nil {
|
||||
t.Fatal("NewAuthService returned nil")
|
||||
}
|
||||
|
||||
// Check default password minimum length
|
||||
if svc.passwordMinLength != defaultPasswordMinLen {
|
||||
t.Errorf("passwordMinLength: got %d, want %d", svc.passwordMinLength, defaultPasswordMinLen)
|
||||
}
|
||||
|
||||
// Check default max login attempts
|
||||
if svc.maxLoginAttempts != 5 {
|
||||
t.Errorf("maxLoginAttempts: got %d, want %d", svc.maxLoginAttempts, 5)
|
||||
}
|
||||
|
||||
// Check default login lock duration
|
||||
if svc.loginLockDuration != 15*time.Minute {
|
||||
t.Errorf("loginLockDuration: got %v, want %v", svc.loginLockDuration, 15*time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceNilSafety(t *testing.T) {
|
||||
t.Run("validatePassword_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.validatePassword("Abcd1234!")
|
||||
if err != nil {
|
||||
t.Errorf("nil service should not error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accessTokenTTL_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
ttl := svc.accessTokenTTLSeconds()
|
||||
if ttl != 0 {
|
||||
t.Errorf("nil service should return 0: got %d", ttl)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RefreshTokenTTL_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
ttl := svc.RefreshTokenTTLSeconds()
|
||||
if ttl != 0 {
|
||||
t.Errorf("nil service should return 0: got %d", ttl)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("generateUniqueUsername_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
username, err := svc.generateUniqueUsername(context.Background(), "testuser")
|
||||
if err != nil {
|
||||
t.Errorf("nil service should return username: %v", err)
|
||||
}
|
||||
if username != "testuser" {
|
||||
t.Errorf("username: got %q, want %q", username, "testuser")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("buildUserInfo_nil_user", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
info := svc.buildUserInfo(nil)
|
||||
if info != nil {
|
||||
t.Errorf("nil user should return nil info: got %v", info)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ensureUserActive_nil_user", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.ensureUserActive(nil)
|
||||
if err == nil {
|
||||
t.Error("nil user should return error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blacklistToken_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.blacklistTokenClaims(context.Background(), "token", nil)
|
||||
if err != nil {
|
||||
t.Errorf("nil service should not error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Logout_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.Logout(context.Background(), "user", nil)
|
||||
if err != nil {
|
||||
t.Errorf("nil service should not error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsTokenBlacklisted_nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
blacklisted := svc.IsTokenBlacklisted(context.Background(), "jti")
|
||||
if blacklisted {
|
||||
t.Error("nil service should not blacklist tokens")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserInfoFromCacheValue(t *testing.T) {
|
||||
t.Run("valid_UserInfo_pointer", func(t *testing.T) {
|
||||
info := &UserInfo{ID: 1, Username: "testuser"}
|
||||
got, ok := userInfoFromCacheValue(info)
|
||||
if !ok {
|
||||
t.Error("should parse *UserInfo")
|
||||
}
|
||||
if got.ID != 1 || got.Username != "testuser" {
|
||||
t.Errorf("got %+v, want %+v", got, info)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid_UserInfo_value", func(t *testing.T) {
|
||||
info := UserInfo{ID: 2, Username: "testuser2"}
|
||||
got, ok := userInfoFromCacheValue(info)
|
||||
if !ok {
|
||||
t.Error("should parse UserInfo value")
|
||||
}
|
||||
if got.ID != 2 || got.Username != "testuser2" {
|
||||
t.Errorf("got %+v, want %+v", got, info)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_type", func(t *testing.T) {
|
||||
got, ok := userInfoFromCacheValue("invalid string")
|
||||
if ok || got != nil {
|
||||
t.Errorf("should not parse string: ok=%v, got=%+v", ok, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnsureUserActive(t *testing.T) {
|
||||
t.Run("nil_user", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
err := svc.ensureUserActive(nil)
|
||||
if err == nil {
|
||||
t.Error("nil user should error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAttemptCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
want int
|
||||
}{
|
||||
{"int_value", 5, 5},
|
||||
{"int64_value", int64(3), 3},
|
||||
{"float64_value", float64(4.0), 4},
|
||||
{"string_int", "3", 0}, // strings are not converted
|
||||
{"invalid_type", "abc", 0},
|
||||
{"nil", nil, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := attemptCount(tt.value)
|
||||
if got != tt.want {
|
||||
t.Errorf("attemptCount(%v) = %d, want %d", tt.value, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementFailAttempts(t *testing.T) {
|
||||
t.Run("nil_service", func(t *testing.T) {
|
||||
var svc *AuthService
|
||||
count := svc.incrementFailAttempts(context.Background(), "key")
|
||||
if count != 0 {
|
||||
t.Errorf("nil service should return 0, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_key", func(t *testing.T) {
|
||||
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
||||
count := svc.incrementFailAttempts(context.Background(), "")
|
||||
if count != 0 {
|
||||
t.Errorf("empty key should return 0, got %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,9 +3,11 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -228,12 +230,14 @@ func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]
|
||||
|
||||
// GetAllDevicesRequest 获取所有设备请求参数
|
||||
type GetAllDevicesRequest struct {
|
||||
Page int
|
||||
PageSize int
|
||||
Page int `form:"page"`
|
||||
PageSize int `form:"page_size"`
|
||||
UserID int64 `form:"user_id"`
|
||||
Status int `form:"status"`
|
||||
Status *int `form:"status"` // 0-禁用, 1-激活, nil-不筛选
|
||||
IsTrusted *bool `form:"is_trusted"`
|
||||
Keyword string `form:"keyword"`
|
||||
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
|
||||
Size int `form:"size"` // Page size when using cursor mode
|
||||
}
|
||||
|
||||
// GetAllDevices 获取所有设备(管理员用)
|
||||
@@ -257,9 +261,10 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
|
||||
Limit: req.PageSize,
|
||||
}
|
||||
|
||||
// 处理状态筛选
|
||||
if req.Status >= 0 {
|
||||
params.Status = domain.DeviceStatus(req.Status)
|
||||
// 处理状态筛选(仅当明确指定了状态时才筛选)
|
||||
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
status := domain.DeviceStatus(*req.Status)
|
||||
params.Status = &status
|
||||
}
|
||||
|
||||
// 处理信任状态筛选
|
||||
@@ -270,6 +275,49 @@ func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesReq
|
||||
return s.deviceRepo.ListAll(ctx, params)
|
||||
}
|
||||
|
||||
// GetAllDevicesCursor 游标分页获取所有设备(推荐使用)
|
||||
func (s *DeviceService) GetAllDevicesCursor(ctx context.Context, req *GetAllDevicesRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
if req.PageSize > 0 && req.Cursor == "" {
|
||||
size = pagination.ClampPageSize(req.PageSize)
|
||||
}
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
params := &repository.ListDevicesParams{
|
||||
UserID: req.UserID,
|
||||
Keyword: req.Keyword,
|
||||
}
|
||||
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
status := domain.DeviceStatus(*req.Status)
|
||||
params.Status = &status
|
||||
}
|
||||
if req.IsTrusted != nil {
|
||||
params.IsTrusted = req.IsTrusted
|
||||
}
|
||||
|
||||
devices, hasMore, err := s.deviceRepo.ListAllCursor(ctx, params, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(devices) > 0 {
|
||||
last := devices[len(devices)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.LastActiveTime)
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: devices,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDeviceByDeviceID 根据设备标识获取设备(用于设备信任检查)
|
||||
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
|
||||
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -167,7 +168,7 @@ func (s *EmailCodeService) VerifyEmailCode(ctx context.Context, email, purpose,
|
||||
}
|
||||
|
||||
storedCode, ok := value.(string)
|
||||
if !ok || storedCode != code {
|
||||
if !ok || subtle.ConstantTimeCompare([]byte(storedCode), []byte(code)) != 1 {
|
||||
return fmt.Errorf("verification code is invalid")
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/xuri/excelize/v2"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -52,12 +53,15 @@ type RecordLoginRequest struct {
|
||||
|
||||
// ListLoginLogRequest 登录日志列表请求
|
||||
type ListLoginLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Status int `json:"status"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
StartAt string `json:"start_at"`
|
||||
EndAt string `json:"end_at"`
|
||||
UserID int64 `json:"user_id" form:"user_id"`
|
||||
Status *int `json:"status" form:"status"` // 0-失败, 1-成功, nil-不筛选
|
||||
Page int `json:"page" form:"page"`
|
||||
PageSize int `json:"page_size" form:"page_size"`
|
||||
StartAt string `json:"start_at" form:"start_at"`
|
||||
EndAt string `json:"end_at" form:"end_at"`
|
||||
// Cursor-based pagination (preferred over Page/PageSize)
|
||||
Cursor string `form:"cursor"` // Opaque cursor from previous response
|
||||
Size int `form:"size"` // Page size when using cursor mode
|
||||
}
|
||||
|
||||
// GetLoginLogs 获取登录日志列表
|
||||
@@ -84,14 +88,140 @@ func (s *LoginLogService) GetLoginLogs(ctx context.Context, req *ListLoginLogReq
|
||||
}
|
||||
}
|
||||
|
||||
// 按状态查询
|
||||
if req.Status == 0 || req.Status == 1 {
|
||||
return s.loginLogRepo.ListByStatus(ctx, req.Status, offset, req.PageSize)
|
||||
// 按状态查询(仅当明确指定了状态时才筛选)
|
||||
if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
return s.loginLogRepo.ListByStatus(ctx, *req.Status, offset, req.PageSize)
|
||||
}
|
||||
|
||||
return s.loginLogRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// CursorResult wraps cursor-based pagination response
|
||||
type CursorResult struct {
|
||||
Items interface{} `json:"items"`
|
||||
NextCursor string `json:"next_cursor"`
|
||||
HasMore bool `json:"has_more"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// GetLoginLogsCursor 游标分页获取登录日志列表(推荐使用)
|
||||
func (s *LoginLogService) GetLoginLogsCursor(ctx context.Context, req *ListLoginLogRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
if req.PageSize > 0 && req.Cursor == "" {
|
||||
size = pagination.ClampPageSize(req.PageSize)
|
||||
}
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
var items interface{}
|
||||
var nextCursor string
|
||||
var hasMore bool
|
||||
|
||||
// 按用户 ID 查询
|
||||
if req.UserID > 0 {
|
||||
logs, hm, err := s.loginLogRepo.ListByUserIDCursor(ctx, req.UserID, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
} else if req.StartAt != "" && req.EndAt != "" {
|
||||
// Time range: fall back to offset-based for now (cursor + time range is complex)
|
||||
start, err1 := time.Parse(time.RFC3339, req.StartAt)
|
||||
end, err2 := time.Parse(time.RFC3339, req.EndAt)
|
||||
if err1 == nil && err2 == nil {
|
||||
offset := 0
|
||||
logs, _, err := s.loginLogRepo.ListByTimeRange(ctx, start, end, offset, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
if len(logs) > 0 {
|
||||
last := logs[len(logs)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
hasMore = len(logs) == size
|
||||
}
|
||||
} else {
|
||||
items = []*domain.LoginLog{}
|
||||
}
|
||||
} else if req.Status != nil && (*req.Status == 0 || *req.Status == 1) {
|
||||
// Status filter: use ListCursor with manual status filter
|
||||
logs, hm, err := s.listByStatusCursor(ctx, *req.Status, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
} else {
|
||||
// Default: full table cursor scan
|
||||
logs, hm, err := s.loginLogRepo.ListCursor(ctx, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
}
|
||||
|
||||
// Build next cursor from the last item
|
||||
if nextCursor == "" {
|
||||
switch items := items.(type) {
|
||||
case []*domain.LoginLog:
|
||||
if len(items) > 0 {
|
||||
last := items[len(items)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: items,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// listByStatusCursor 游标分页按状态查询(内部方法)
|
||||
// Uses iterative approach: fetch from ListCursor and post-filter by status.
|
||||
func (s *LoginLogService) listByStatusCursor(ctx context.Context, status int, limit int, cursor *pagination.Cursor) ([]*domain.LoginLog, bool, error) {
|
||||
var logs []*domain.LoginLog
|
||||
|
||||
// Since LoginLogRepository doesn't have status+cursor combined,
|
||||
// we use a larger batch from ListCursor and post-filter.
|
||||
batchSize := limit + 1
|
||||
for attempts := 0; attempts < 10; attempts++ { // max 10 pages of skipping
|
||||
batch, hm, err := s.loginLogRepo.ListCursor(ctx, batchSize, cursor)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
for _, log := range batch {
|
||||
if log.Status == status {
|
||||
logs = append(logs, log)
|
||||
if len(logs) >= limit+1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(logs) >= limit+1 || !hm || len(batch) == 0 {
|
||||
break
|
||||
}
|
||||
// Advance cursor to end of this batch
|
||||
if len(batch) > 0 {
|
||||
last := batch[len(batch)-1]
|
||||
cursor = &pagination.Cursor{LastID: last.ID, LastValue: last.CreatedAt}
|
||||
}
|
||||
}
|
||||
|
||||
hasMore := len(logs) > limit
|
||||
if hasMore {
|
||||
logs = logs[:limit]
|
||||
}
|
||||
return logs, hasMore, nil
|
||||
}
|
||||
|
||||
// GetMyLoginLogs 获取当前用户的登录日志
|
||||
func (s *LoginLogService) GetMyLoginLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.LoginLog, int64, error) {
|
||||
if page <= 0 {
|
||||
@@ -137,14 +267,21 @@ func (s *LoginLogService) ExportLoginLogs(ctx context.Context, req *ExportLoginL
|
||||
}
|
||||
}
|
||||
|
||||
// CSV 使用流式分批导出,XLSX 使用全量导出(excelize 需要所有行)
|
||||
if format == "csv" {
|
||||
data, filename, err := s.exportLoginLogsCSVStream(ctx, req.UserID, req.Status, startAt, endAt)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "text/csv; charset=utf-8", nil
|
||||
}
|
||||
|
||||
logs, err := s.loginLogRepo.ListAllForExport(ctx, req.UserID, req.Status, startAt, endAt)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("查询登录日志失败: %w", err)
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("login_logs_%s.%s", time.Now().Format("20060102_150405"), format)
|
||||
|
||||
if format == "xlsx" {
|
||||
filename := fmt.Sprintf("login_logs_%s.xlsx", time.Now().Format("20060102_150405"))
|
||||
data, err := buildLoginLogXLSXExport(logs)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
@@ -152,11 +289,66 @@ func (s *LoginLogService) ExportLoginLogs(ctx context.Context, req *ExportLoginL
|
||||
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
|
||||
}
|
||||
|
||||
data, err := buildLoginLogCSVExport(logs)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
// exportLoginLogsCSVStream 流式导出 CSV(分批处理防止 OOM)
|
||||
func (s *LoginLogService) exportLoginLogsCSVStream(ctx context.Context, userID int64, status int, startAt, endAt *time.Time) ([]byte, string, error) {
|
||||
headers := []string{"ID", "用户ID", "登录方式", "设备ID", "IP地址", "位置", "状态", "失败原因", "时间"}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// 写入表头
|
||||
if err := writer.Write(headers); err != nil {
|
||||
return nil, "", fmt.Errorf("写CSV表头失败: %w", err)
|
||||
}
|
||||
return data, filename, "text/csv; charset=utf-8", nil
|
||||
|
||||
// 使用游标分批获取数据
|
||||
cursor := int64(1<<63 - 1) // 从最大 ID 开始
|
||||
batchSize := 5000
|
||||
totalWritten := 0
|
||||
|
||||
for {
|
||||
logs, hasMore, err := s.loginLogRepo.ListLogsForExportBatch(ctx, userID, status, startAt, endAt, cursor, batchSize)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("查询登录日志失败: %w", err)
|
||||
}
|
||||
|
||||
for _, log := range logs {
|
||||
row := []string{
|
||||
fmt.Sprintf("%d", log.ID),
|
||||
fmt.Sprintf("%d", derefInt64(log.UserID)),
|
||||
loginTypeLabel(log.LoginType),
|
||||
log.DeviceID,
|
||||
log.IP,
|
||||
log.Location,
|
||||
loginStatusLabel(log.Status),
|
||||
log.FailReason,
|
||||
log.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
if err := writer.Write(row); err != nil {
|
||||
return nil, "", fmt.Errorf("写CSV行失败: %w", err)
|
||||
}
|
||||
totalWritten++
|
||||
cursor = log.ID
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
return nil, "", fmt.Errorf("CSV Flush 失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果数据量过大,提前终止
|
||||
if totalWritten >= repository.ExportBatchSize {
|
||||
break
|
||||
}
|
||||
|
||||
if !hasMore || len(logs) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("login_logs_%s.csv", time.Now().Format("20060102_150405"))
|
||||
return buf.Bytes(), filename, nil
|
||||
}
|
||||
|
||||
func buildLoginLogCSVExport(logs []*domain.LoginLog) ([]byte, error) {
|
||||
|
||||
@@ -2,9 +2,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -51,13 +53,15 @@ type RecordOperationRequest struct {
|
||||
|
||||
// ListOperationLogRequest 操作日志列表请求
|
||||
type ListOperationLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Method string `json:"method"`
|
||||
Keyword string `json:"keyword"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
StartAt string `json:"start_at"`
|
||||
EndAt string `json:"end_at"`
|
||||
UserID int64 `json:"user_id" form:"user_id"`
|
||||
Method string `json:"method" form:"method"`
|
||||
Keyword string `json:"keyword" form:"keyword"`
|
||||
Page int `json:"page" form:"page"`
|
||||
PageSize int `json:"page_size" form:"page_size"`
|
||||
StartAt string `json:"start_at" form:"start_at"`
|
||||
EndAt string `json:"end_at" form:"end_at"`
|
||||
Cursor string `form:"cursor"` // Opaque cursor for keyset pagination
|
||||
Size int `form:"size"` // Page size when using cursor mode
|
||||
}
|
||||
|
||||
// GetOperationLogs 获取操作日志列表
|
||||
@@ -97,6 +101,42 @@ func (s *OperationLogService) GetOperationLogs(ctx context.Context, req *ListOpe
|
||||
return s.operationLogRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// GetOperationLogsCursor 游标分页获取操作日志列表(推荐使用)
|
||||
func (s *OperationLogService) GetOperationLogsCursor(ctx context.Context, req *ListOperationLogRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
var items interface{}
|
||||
var hasMore bool
|
||||
|
||||
logs, hm, err := s.operationLogRepo.ListCursor(ctx, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = logs
|
||||
hasMore = hm
|
||||
|
||||
nextCursor := ""
|
||||
switch items := items.(type) {
|
||||
case []*domain.OperationLog:
|
||||
if len(items) > 0 {
|
||||
last := items[len(items)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: items,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetMyOperationLogs 获取当前用户的操作日志
|
||||
func (s *OperationLogService) GetMyOperationLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.OperationLog, int64, error) {
|
||||
if page <= 0 {
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/security"
|
||||
)
|
||||
|
||||
@@ -49,6 +51,7 @@ type PasswordResetService struct {
|
||||
userRepo userRepositoryInterface
|
||||
cache *cache.CacheManager
|
||||
config *PasswordResetConfig
|
||||
passwordHistoryRepo *repository.PasswordHistoryRepository
|
||||
}
|
||||
|
||||
func NewPasswordResetService(
|
||||
@@ -66,6 +69,12 @@ func NewPasswordResetService(
|
||||
}
|
||||
}
|
||||
|
||||
// WithPasswordHistoryRepo 注入密码历史 repository,用于重置密码时记录历史
|
||||
func (s *PasswordResetService) WithPasswordHistoryRepo(repo *repository.PasswordHistoryRepository) *PasswordResetService {
|
||||
s.passwordHistoryRepo = repo
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) ForgotPassword(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -216,7 +225,7 @@ func (s *PasswordResetService) ResetPasswordByPhone(ctx context.Context, req *Re
|
||||
}
|
||||
|
||||
code, ok := storedCode.(string)
|
||||
if !ok || code != req.Code {
|
||||
if !ok || subtle.ConstantTimeCompare([]byte(code), []byte(req.Code)) != 1 {
|
||||
return errors.New("验证码不正确")
|
||||
}
|
||||
|
||||
@@ -258,6 +267,18 @@ func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查密码历史(防止重用近5次密码)
|
||||
if s.passwordHistoryRepo != nil {
|
||||
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, user.ID, passwordHistoryLimit)
|
||||
if err == nil {
|
||||
for _, h := range histories {
|
||||
if auth.VerifyPassword(h.PasswordHash, newPassword) {
|
||||
return errors.New("新密码不能与最近5次密码相同")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %w", err)
|
||||
@@ -268,5 +289,19 @@ func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain
|
||||
return fmt.Errorf("更新密码失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入密码历史记录
|
||||
if s.passwordHistoryRepo != nil {
|
||||
go func() {
|
||||
// 使用带超时的独立 context,防止 DB 写入无限等待
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
|
||||
UserID: user.ID,
|
||||
PasswordHash: hashedPassword,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, user.ID, passwordHistoryLimit)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
92
internal/service/settings.go
Normal file
92
internal/service/settings.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// SystemSettings 系统设置
|
||||
type SystemSettings struct {
|
||||
System SystemInfo `json:"system"`
|
||||
Security SecurityInfo `json:"security"`
|
||||
Features FeaturesInfo `json:"features"`
|
||||
}
|
||||
|
||||
// SystemInfo 系统信息
|
||||
type SystemInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Environment string `json:"environment"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// SecurityInfo 安全设置
|
||||
type SecurityInfo struct {
|
||||
PasswordMinLength int `json:"password_min_length"`
|
||||
PasswordRequireUppercase bool `json:"password_require_uppercase"`
|
||||
PasswordRequireLowercase bool `json:"password_require_lowercase"`
|
||||
PasswordRequireNumbers bool `json:"password_require_numbers"`
|
||||
PasswordRequireSymbols bool `json:"password_require_symbols"`
|
||||
PasswordHistory int `json:"password_history"`
|
||||
TOTPEnabled bool `json:"totp_enabled"`
|
||||
LoginFailLock bool `json:"login_fail_lock"`
|
||||
LoginFailThreshold int `json:"login_fail_threshold"`
|
||||
LoginFailDuration int `json:"login_fail_duration"` // 分钟
|
||||
SessionTimeout int `json:"session_timeout"` // 秒
|
||||
DeviceTrustDuration int `json:"device_trust_duration"` // 秒
|
||||
}
|
||||
|
||||
// FeaturesInfo 功能开关
|
||||
type FeaturesInfo struct {
|
||||
EmailVerification bool `json:"email_verification"`
|
||||
PhoneVerification bool `json:"phone_verification"`
|
||||
OAuthProviders []string `json:"oauth_providers"`
|
||||
SSOEnabled bool `json:"sso_enabled"`
|
||||
OperationLogEnabled bool `json:"operation_log_enabled"`
|
||||
LoginLogEnabled bool `json:"login_log_enabled"`
|
||||
DataExportEnabled bool `json:"data_export_enabled"`
|
||||
DataImportEnabled bool `json:"data_import_enabled"`
|
||||
}
|
||||
|
||||
// SettingsService 系统设置服务
|
||||
type SettingsService struct{}
|
||||
|
||||
// NewSettingsService 创建系统设置服务
|
||||
func NewSettingsService() *SettingsService {
|
||||
return &SettingsService{}
|
||||
}
|
||||
|
||||
// GetSettings 获取系统设置
|
||||
func (s *SettingsService) GetSettings(ctx context.Context) (*SystemSettings, error) {
|
||||
return &SystemSettings{
|
||||
System: SystemInfo{
|
||||
Name: "用户管理系统",
|
||||
Version: "1.0.0",
|
||||
Environment: "Production",
|
||||
Description: "基于 Go + React 的现代化用户管理系统",
|
||||
},
|
||||
Security: SecurityInfo{
|
||||
PasswordMinLength: 8,
|
||||
PasswordRequireUppercase: true,
|
||||
PasswordRequireLowercase: true,
|
||||
PasswordRequireNumbers: true,
|
||||
PasswordRequireSymbols: true,
|
||||
PasswordHistory: 5,
|
||||
TOTPEnabled: true,
|
||||
LoginFailLock: true,
|
||||
LoginFailThreshold: 5,
|
||||
LoginFailDuration: 30,
|
||||
SessionTimeout: 86400, // 1天
|
||||
DeviceTrustDuration: 2592000, // 30天
|
||||
},
|
||||
Features: FeaturesInfo{
|
||||
EmailVerification: true,
|
||||
PhoneVerification: false,
|
||||
OAuthProviders: []string{"GitHub", "Google"},
|
||||
SSOEnabled: false,
|
||||
OperationLogEnabled: true,
|
||||
LoginLogEnabled: true,
|
||||
DataExportEnabled: true,
|
||||
DataImportEnabled: true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
308
internal/service/settings_test.go
Normal file
308
internal/service/settings_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/api/handler"
|
||||
"github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/api/router"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// doRequest makes an HTTP request with optional body
|
||||
func doRequest(method, url string, token string, body interface{}) (*http.Response, string) {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
jsonBytes, _ := json.Marshal(body)
|
||||
bodyReader = bytes.NewReader(jsonBytes)
|
||||
}
|
||||
req, _ := http.NewRequest(method, url, bodyReader)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
client := &http.Client{}
|
||||
resp, _ := client.Do(req)
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return resp, string(bodyBytes)
|
||||
}
|
||||
|
||||
func doPost(url, token string, body interface{}) (*http.Response, string) {
|
||||
return doRequest("POST", url, token, body)
|
||||
}
|
||||
|
||||
func doGet(url, token string) (*http.Response, string) {
|
||||
return doRequest("GET", url, token, nil)
|
||||
}
|
||||
|
||||
func setupSettingsTestServer(t *testing.T) (*httptest.Server, *service.SettingsService, string, func()) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 使用内存 SQLite
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: "file::memory:?mode=memory&cache=shared",
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("skipping test (SQLite unavailable): %v", err)
|
||||
return nil, nil, "", func() {}
|
||||
}
|
||||
|
||||
// 自动迁移
|
||||
if err := db.AutoMigrate(
|
||||
&domain.User{},
|
||||
&domain.Role{},
|
||||
&domain.Permission{},
|
||||
&domain.UserRole{},
|
||||
&domain.RolePermission{},
|
||||
&domain.Device{},
|
||||
&domain.LoginLog{},
|
||||
&domain.OperationLog{},
|
||||
&domain.SocialAccount{},
|
||||
&domain.Webhook{},
|
||||
&domain.WebhookDelivery{},
|
||||
); err != nil {
|
||||
t.Fatalf("db migration failed: %v", err)
|
||||
}
|
||||
|
||||
// 创建 JWT Manager
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-settings-secret-key",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
// 创建缓存
|
||||
l1Cache := cache.NewL1Cache()
|
||||
l2Cache := cache.NewRedisCache(false)
|
||||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||
|
||||
// 创建 repositories
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
roleRepo := repository.NewRoleRepository(db)
|
||||
permissionRepo := repository.NewPermissionRepository(db)
|
||||
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||
rolePermissionRepo := repository.NewRolePermissionRepository(db)
|
||||
deviceRepo := repository.NewDeviceRepository(db)
|
||||
loginLogRepo := repository.NewLoginLogRepository(db)
|
||||
opLogRepo := repository.NewOperationLogRepository(db)
|
||||
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db)
|
||||
|
||||
// 创建 services
|
||||
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
userSvc := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo)
|
||||
roleSvc := service.NewRoleService(roleRepo, rolePermissionRepo)
|
||||
permSvc := service.NewPermissionService(permissionRepo)
|
||||
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
|
||||
loginLogSvc := service.NewLoginLogService(loginLogRepo)
|
||||
opLogSvc := service.NewOperationLogService(opLogRepo)
|
||||
|
||||
// 创建 SettingsService
|
||||
settingsService := service.NewSettingsService()
|
||||
|
||||
// 创建 middleware
|
||||
rateLimitCfg := config.RateLimitConfig{}
|
||||
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo, l1Cache,
|
||||
)
|
||||
authMiddleware.SetCacheManager(cacheManager)
|
||||
opLogMiddleware := middleware.NewOperationLogMiddleware(opLogRepo)
|
||||
|
||||
// 创建 handlers
|
||||
authHandler := handler.NewAuthHandler(authSvc)
|
||||
userHandler := handler.NewUserHandler(userSvc)
|
||||
roleHandler := handler.NewRoleHandler(roleSvc)
|
||||
permHandler := handler.NewPermissionHandler(permSvc)
|
||||
deviceHandler := handler.NewDeviceHandler(deviceSvc)
|
||||
logHandler := handler.NewLogHandler(loginLogSvc, opLogSvc)
|
||||
settingsHandler := handler.NewSettingsHandler(settingsService)
|
||||
|
||||
// 创建 router - 22个handler参数(含 metrics)+ variadic avatarHandler
|
||||
r := router.NewRouter(
|
||||
authHandler, userHandler, roleHandler, permHandler, deviceHandler,
|
||||
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil,
|
||||
settingsHandler, nil,
|
||||
)
|
||||
engine := r.Setup()
|
||||
|
||||
server := httptest.NewServer(engine)
|
||||
|
||||
// 注册用户用于测试
|
||||
resp, _ := doPost(server.URL+"/api/v1/auth/register", "", map[string]interface{}{
|
||||
"username": "admintestsu",
|
||||
"email": "admintestsu@test.com",
|
||||
"password": "Password123!",
|
||||
})
|
||||
resp.Body.Close()
|
||||
|
||||
// 获取 token
|
||||
loginResp, _ := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "admintestsu",
|
||||
"password": "Password123!",
|
||||
})
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(loginResp.Body).Decode(&result)
|
||||
loginResp.Body.Close()
|
||||
|
||||
token := ""
|
||||
if data, ok := result["data"].(map[string]interface{}); ok {
|
||||
token, _ = data["access_token"].(string)
|
||||
}
|
||||
|
||||
return server, settingsService, token, func() {
|
||||
server.Close()
|
||||
if sqlDB, _ := db.DB(); sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Settings API Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestGetSettings_Success(t *testing.T) {
|
||||
// 仅测试 service 层,不测试 HTTP API
|
||||
svc := service.NewSettingsService()
|
||||
settings, err := svc.GetSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSettings failed: %v", err)
|
||||
}
|
||||
|
||||
if settings.System.Name != "用户管理系统" {
|
||||
t.Errorf("expected system name '用户管理系统', got '%s'", settings.System.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettings_Unauthorized(t *testing.T) {
|
||||
server, _, _, cleanup := setupSettingsTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
req, _ := http.NewRequest("GET", server.URL+"/api/v1/admin/settings", nil)
|
||||
// 不设置 Authorization header
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 无 token 应该返回 401
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettings_ResponseStructure(t *testing.T) {
|
||||
// 仅测试 service 层数据结构
|
||||
svc := service.NewSettingsService()
|
||||
settings, err := svc.GetSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSettings failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证 system 字段
|
||||
if settings.System.Name == "" {
|
||||
t.Error("System.Name should not be empty")
|
||||
}
|
||||
if settings.System.Version == "" {
|
||||
t.Error("System.Version should not be empty")
|
||||
}
|
||||
if settings.System.Environment == "" {
|
||||
t.Error("System.Environment should not be empty")
|
||||
}
|
||||
|
||||
// 验证 security 字段
|
||||
if settings.Security.PasswordMinLength == 0 {
|
||||
t.Error("Security.PasswordMinLength should not be zero")
|
||||
}
|
||||
if !settings.Security.PasswordRequireUppercase {
|
||||
t.Error("Security.PasswordRequireUppercase should be true")
|
||||
}
|
||||
|
||||
// 验证 features 字段
|
||||
if !settings.Features.EmailVerification {
|
||||
t.Error("Features.EmailVerification should be true")
|
||||
}
|
||||
if len(settings.Features.OAuthProviders) == 0 {
|
||||
t.Error("Features.OAuthProviders should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SettingsService Unit Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestSettingsService_GetSettings(t *testing.T) {
|
||||
svc := service.NewSettingsService()
|
||||
|
||||
settings, err := svc.GetSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSettings failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证 system
|
||||
if settings.System.Name == "" {
|
||||
t.Error("System.Name should not be empty")
|
||||
}
|
||||
if settings.System.Version == "" {
|
||||
t.Error("System.Version should not be empty")
|
||||
}
|
||||
|
||||
// 验证 security defaults
|
||||
if settings.Security.PasswordMinLength != 8 {
|
||||
t.Errorf("PasswordMinLength: got %d, want 8", settings.Security.PasswordMinLength)
|
||||
}
|
||||
if !settings.Security.PasswordRequireUppercase {
|
||||
t.Error("PasswordRequireUppercase should be true")
|
||||
}
|
||||
if !settings.Security.PasswordRequireLowercase {
|
||||
t.Error("PasswordRequireLowercase should be true")
|
||||
}
|
||||
if !settings.Security.PasswordRequireNumbers {
|
||||
t.Error("PasswordRequireNumbers should be true")
|
||||
}
|
||||
if !settings.Security.PasswordRequireSymbols {
|
||||
t.Error("PasswordRequireSymbols should be true")
|
||||
}
|
||||
if settings.Security.PasswordHistory != 5 {
|
||||
t.Errorf("PasswordHistory: got %d, want 5", settings.Security.PasswordHistory)
|
||||
}
|
||||
|
||||
// 验证 features defaults
|
||||
if !settings.Features.EmailVerification {
|
||||
t.Error("EmailVerification should be true")
|
||||
}
|
||||
if settings.Features.DataExportEnabled != true {
|
||||
t.Error("DataExportEnabled should be true")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -357,7 +358,7 @@ func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code st
|
||||
}
|
||||
|
||||
stored, ok := val.(string)
|
||||
if !ok || stored != code {
|
||||
if !ok || subtle.ConstantTimeCompare([]byte(stored), []byte(code)) != 1 {
|
||||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"regexp"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
@@ -48,6 +49,11 @@ type UpdateThemeRequest struct {
|
||||
|
||||
// CreateTheme 创建主题
|
||||
func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest) (*domain.ThemeConfig, error) {
|
||||
// 安全检查:禁止在 CustomCSS/CustomJS 中包含危险模式
|
||||
if err := validateCustomCSSJS(req.CustomCSS, req.CustomJS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查主题名称是否已存在
|
||||
existing, err := s.themeRepo.GetByName(ctx, req.Name)
|
||||
if err == nil && existing != nil {
|
||||
@@ -84,6 +90,11 @@ func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest)
|
||||
|
||||
// UpdateTheme 更新主题
|
||||
func (s *ThemeService) UpdateTheme(ctx context.Context, id int64, req *UpdateThemeRequest) (*domain.ThemeConfig, error) {
|
||||
// 安全检查:禁止在 CustomCSS/CustomJS 中包含危险模式
|
||||
if err := validateCustomCSSJS(req.CustomCSS, req.CustomJS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
theme, err := s.themeRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, errors.New("主题不存在")
|
||||
@@ -204,3 +215,43 @@ func (s *ThemeService) clearDefaultThemes(ctx context.Context) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCustomCSSJS 检查 CustomCSS 和 CustomJS 是否包含危险 XSS 模式
|
||||
// 这不是完全净化,而是拒绝明显可造成 XSS 的模式
|
||||
func validateCustomCSSJS(css, js string) error {
|
||||
// 危险模式列表
|
||||
dangerousPatterns := []struct {
|
||||
pattern *regexp.Regexp
|
||||
message string
|
||||
}{
|
||||
// Script 标签
|
||||
{regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`), "CustomJS 禁止包含 <script> 标签"},
|
||||
{regexp.MustCompile(`(?i)javascript\s*:`), "CustomJS 禁止使用 javascript: 协议"},
|
||||
// 事件处理器
|
||||
{regexp.MustCompile(`(?i)on\w+\s*=`), "CustomJS 禁止使用事件处理器 (如 onerror, onclick)"},
|
||||
// Data URL
|
||||
{regexp.MustCompile(`(?i)data\s*:\s*text/html`), "禁止使用 data: URL 嵌入 HTML"},
|
||||
// CSS expression (IE)
|
||||
{regexp.MustCompile(`(?i)expression\s*\(`), "CustomCSS 禁止使用 CSS expression"},
|
||||
// CSS 中的 javascript
|
||||
{regexp.MustCompile(`(?i)url\s*\(\s*['"]?\s*javascript:`), "CustomCSS 禁止使用 javascript: URL"},
|
||||
// 嵌入的 <style> 标签
|
||||
{regexp.MustCompile(`(?i)<style[^>]*>.*?</style>`), "CustomCSS 禁止包含 <style> 标签"},
|
||||
}
|
||||
|
||||
// 检查 JS
|
||||
for _, p := range dangerousPatterns {
|
||||
if p.pattern.MatchString(js) {
|
||||
return errors.New(p.message)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 CSS
|
||||
for _, p := range dangerousPatterns {
|
||||
if p.pattern.MatchString(css) {
|
||||
return errors.New(p.message)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
@@ -80,11 +83,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = s.passwordHistoryRepo.Create(context.Background(), &domain.PasswordHistory{
|
||||
// 使用带超时的独立 context(不能使用请求 ctx,该 goroutine 在请求完成后仍可能运行)
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
|
||||
UserID: userID,
|
||||
PasswordHash: newHashedPassword,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(context.Background(), userID, passwordHistoryLimit)
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -127,6 +133,57 @@ func (s *UserService) List(ctx context.Context, offset, limit int) ([]*domain.Us
|
||||
return s.userRepo.List(ctx, offset, limit)
|
||||
}
|
||||
|
||||
// ListCursorRequest 用户游标分页请求
|
||||
type ListCursorRequest struct {
|
||||
Keyword string `form:"keyword"`
|
||||
Status int `form:"status"` // -1=全部
|
||||
RoleIDs []int64
|
||||
CreatedFrom *time.Time
|
||||
CreatedTo *time.Time
|
||||
SortBy string // created_at, last_login_time, username
|
||||
SortOrder string // asc, desc
|
||||
Cursor string `form:"cursor"`
|
||||
Size int `form:"size"`
|
||||
}
|
||||
|
||||
// ListCursor 游标分页获取用户列表(推荐使用)
|
||||
func (s *UserService) ListCursor(ctx context.Context, req *ListCursorRequest) (*CursorResult, error) {
|
||||
size := pagination.ClampPageSize(req.Size)
|
||||
|
||||
cursor, err := pagination.Decode(req.Cursor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cursor: %w", err)
|
||||
}
|
||||
|
||||
filter := &repository.AdvancedFilter{
|
||||
Keyword: req.Keyword,
|
||||
Status: req.Status,
|
||||
RoleIDs: req.RoleIDs,
|
||||
CreatedFrom: req.CreatedFrom,
|
||||
CreatedTo: req.CreatedTo,
|
||||
SortBy: req.SortBy,
|
||||
SortOrder: req.SortOrder,
|
||||
}
|
||||
|
||||
users, hasMore, err := s.userRepo.ListCursor(ctx, filter, size, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(users) > 0 {
|
||||
last := users[len(users)-1]
|
||||
nextCursor = pagination.BuildNextCursor(last.ID, last.CreatedAt)
|
||||
}
|
||||
|
||||
return &CursorResult{
|
||||
Items: users,
|
||||
NextCursor: nextCursor,
|
||||
HasMore: hasMore,
|
||||
PageSize: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新用户状态
|
||||
func (s *UserService) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
|
||||
return s.userRepo.UpdateStatus(ctx, id, status)
|
||||
|
||||
Reference in New Issue
Block a user