diff --git a/internal/api/handler/auth_handler.go b/internal/api/handler/auth_handler.go index 3a0ae7b..81c00d2 100644 --- a/internal/api/handler/auth_handler.go +++ b/internal/api/handler/auth_handler.go @@ -759,6 +759,15 @@ func getUserIDFromContext(c *gin.Context) (int64, bool) { return id, ok } +func getUsernameFromContext(c *gin.Context) (string, bool) { + username, exists := c.Get("username") + if !exists { + return "", false + } + usernameStr, ok := username.(string) + return usernameStr, ok +} + // handleError 将 error 转换为对应的 HTTP 响应。 // 优先识别 ApplicationError,其次通过关键词推断业务错误类型,兜底返回 500。 func handleError(c *gin.Context, err error) { diff --git a/internal/api/handler/context_guard_test.go b/internal/api/handler/context_guard_test.go new file mode 100644 index 0000000..b5352a2 --- /dev/null +++ b/internal/api/handler/context_guard_test.go @@ -0,0 +1,95 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func TestSSOHandlerAuthorize_InvalidContextTypes_ReturnsUnauthorized(t *testing.T) { + h := &SSOHandler{} + engine := gin.New() + engine.GET("/authorize", func(c *gin.Context) { + c.Set("user_id", "not-int64") + c.Set("username", 123) + h.Authorize(c) + }) + + req := httptest.NewRequest(http.MethodGet, "/authorize?client_id=test-client&redirect_uri=https://example.com/callback&response_type=code", nil) + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestSSOHandlerUserInfo_InvalidContextTypes_ReturnsUnauthorized(t *testing.T) { + h := &SSOHandler{} + engine := gin.New() + engine.GET("/userinfo", func(c *gin.Context) { + c.Set("user_id", "not-int64") + c.Set("username", 123) + h.UserInfo(c) + }) + + req := httptest.NewRequest(http.MethodGet, "/userinfo", nil) + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestWebhookHandlerCreateWebhook_InvalidContextType_ReturnsUnauthorized(t *testing.T) { + h := &WebhookHandler{} + engine := gin.New() + engine.POST("/webhooks", func(c *gin.Context) { + c.Set("user_id", "not-int64") + h.CreateWebhook(c) + }) + + body, err := json.Marshal(map[string]any{ + "name": "test", + "url": "https://example.com/webhook", + "events": []string{"user.created"}, + }) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/webhooks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestWebhookHandlerListWebhooks_InvalidContextType_ReturnsUnauthorized(t *testing.T) { + h := &WebhookHandler{} + engine := gin.New() + engine.GET("/webhooks", func(c *gin.Context) { + c.Set("user_id", "not-int64") + h.ListWebhooks(c) + }) + + req := httptest.NewRequest(http.MethodGet, "/webhooks?page=1&page_size=20", nil) + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} diff --git a/internal/api/handler/sso_handler.go b/internal/api/handler/sso_handler.go index 81c3076..a510628 100644 --- a/internal/api/handler/sso_handler.go +++ b/internal/api/handler/sso_handler.go @@ -72,13 +72,17 @@ func (h *SSOHandler) Authorize(c *gin.Context) { } // 获取当前登录用户(从 auth middleware 设置的 context) - userID, exists := c.Get("user_id") - if !exists { + userID, ok := getUserIDFromContext(c) + if !ok { c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) return } - username, _ := c.Get("username") + username, ok := getUsernameFromContext(c) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) + return + } // 生成授权码或 access token if req.ResponseType == "code" { @@ -86,8 +90,8 @@ func (h *SSOHandler) Authorize(c *gin.Context) { req.ClientID, req.RedirectURI, req.Scope, - userID.(int64), - username.(string), + userID, + username, ) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate code"}) @@ -106,8 +110,8 @@ func (h *SSOHandler) Authorize(c *gin.Context) { req.ClientID, req.RedirectURI, req.Scope, - userID.(int64), - username.(string), + userID, + username, ) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate code"}) @@ -312,20 +316,24 @@ type UserInfoResponse struct { // @Failure 500 {object} Response "服务器错误" // @Router /api/v1/sso/userinfo [get] func (h *SSOHandler) UserInfo(c *gin.Context) { - userID, exists := c.Get("user_id") - if !exists { + userID, ok := getUserIDFromContext(c) + if !ok { c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) return } - username, _ := c.Get("username") + username, ok := getUsernameFromContext(c) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) + return + } c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", "data": UserInfoResponse{ - UserID: userID.(int64), - Username: username.(string), + UserID: userID, + Username: username, }, }) } diff --git a/internal/api/handler/webhook_handler.go b/internal/api/handler/webhook_handler.go index 7c22f67..27c498a 100644 --- a/internal/api/handler/webhook_handler.go +++ b/internal/api/handler/webhook_handler.go @@ -40,8 +40,11 @@ func (h *WebhookHandler) CreateWebhook(c *gin.Context) { return } - userID, _ := c.Get("user_id") - creatorID, _ := userID.(int64) + creatorID, ok := getUserIDFromContext(c) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) + return + } webhook, err := h.webhookService.CreateWebhook(c.Request.Context(), &req, creatorID) if err != nil { @@ -76,8 +79,11 @@ func (h *WebhookHandler) ListWebhooks(c *gin.Context) { } offset := (page - 1) * pageSize - userID, _ := c.Get("user_id") - creatorID, _ := userID.(int64) + creatorID, ok := getUserIDFromContext(c) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) + return + } webhooks, total, err := h.webhookService.ListWebhooksPaginated(c.Request.Context(), creatorID, offset, pageSize) if err != nil { diff --git a/internal/api/middleware/ratelimit.go b/internal/api/middleware/ratelimit.go index 2e22420..e3b8cad 100644 --- a/internal/api/middleware/ratelimit.go +++ b/internal/api/middleware/ratelimit.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "os" "sync" "time" @@ -10,11 +11,20 @@ import ( ) // RateLimitMiddleware 限流中间件 +// 使用 endpoint + subject(IP 或 user_id) 作为限流键,并对空闲条目做 TTL 清理, +// 避免单一全局限流器误伤所有用户,也避免历史客户端条目无限增长。 type RateLimitMiddleware struct { - cfg config.RateLimitConfig - limiters map[string]*SlidingWindowLimiter - mu sync.RWMutex - cleanupInt time.Duration + cfg config.RateLimitConfig + limiters map[string]*limiterEntry + mu sync.RWMutex + cleanupInt time.Duration + lastCleanup time.Time +} + +type limiterEntry struct { + limiter *SlidingWindowLimiter + window time.Duration + lastSeen time.Time } // SlidingWindowLimiter 滑动窗口限流器 @@ -43,7 +53,7 @@ func (l *SlidingWindowLimiter) Allow() bool { cutoff := now - l.window.Milliseconds() // 清理过期请求 - var validRequests []int64 + validRequests := l.requests[:0] for _, t := range l.requests { if t > cutoff { validRequests = append(validRequests, t) @@ -63,9 +73,10 @@ func (l *SlidingWindowLimiter) Allow() bool { // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware { return &RateLimitMiddleware{ - cfg: cfg, - limiters: make(map[string]*SlidingWindowLimiter), - cleanupInt: 5 * time.Minute, + cfg: cfg, + limiters: make(map[string]*limiterEntry), + cleanupInt: 5 * time.Minute, + lastCleanup: time.Now(), } } @@ -89,16 +100,18 @@ func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc { return m.limitForKey("refresh", 60, 10) } -func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc { +func (m *RateLimitMiddleware) limitForKey(scope string, windowSeconds int, capacity int64) gin.HandlerFunc { if os.Getenv("DISABLE_RATE_LIMIT") == "1" { return func(c *gin.Context) { c.Next() } } - limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity) + window := time.Duration(windowSeconds) * time.Second return func(c *gin.Context) { + limiterKey := m.buildLimiterKey(scope, c) + limiter := m.getOrCreateLimiter(limiterKey, window, capacity) if !limiter.Allow() { c.JSON(429, gin.H{ "code": 429, @@ -111,24 +124,60 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit } } -func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter { - m.mu.RLock() - limiter, exists := m.limiters[key] - m.mu.RUnlock() +func (m *RateLimitMiddleware) buildLimiterKey(scope string, c *gin.Context) string { + if userID, ok := c.Get("user_id"); ok { + return fmt.Sprintf("%s:user:%v", scope, userID) + } + return fmt.Sprintf("%s:ip:%s", scope, c.ClientIP()) +} +func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter { + now := time.Now() + m.maybeCleanup(now) + + m.mu.RLock() + entry, exists := m.limiters[key] + m.mu.RUnlock() if exists { - return limiter + m.mu.Lock() + entry.lastSeen = now + m.mu.Unlock() + return entry.limiter } m.mu.Lock() defer m.mu.Unlock() - // 双重检查 - if limiter, exists = m.limiters[key]; exists { - return limiter + if entry, exists = m.limiters[key]; exists { + entry.lastSeen = now + return entry.limiter } - limiter = NewSlidingWindowLimiter(window, capacity) - m.limiters[key] = limiter - return limiter + entry = &limiterEntry{ + limiter: NewSlidingWindowLimiter(window, capacity), + window: window, + lastSeen: now, + } + m.limiters[key] = entry + return entry.limiter +} + +func (m *RateLimitMiddleware) maybeCleanup(now time.Time) { + m.mu.Lock() + defer m.mu.Unlock() + + if now.Sub(m.lastCleanup) < m.cleanupInt { + return + } + + for key, entry := range m.limiters { + idleTTL := entry.window + if idleTTL < m.cleanupInt { + idleTTL = m.cleanupInt + } + if now.Sub(entry.lastSeen) > idleTTL { + delete(m.limiters, key) + } + } + m.lastCleanup = now } diff --git a/internal/api/middleware/ratelimit_test.go b/internal/api/middleware/ratelimit_test.go new file mode 100644 index 0000000..1646ad4 --- /dev/null +++ b/internal/api/middleware/ratelimit_test.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/user-management-system/internal/config" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func newRateLimitTestEngine(mw gin.HandlerFunc) *gin.Engine { + engine := gin.New() + engine.Use(mw) + engine.GET("/ping", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + return engine +} + +func performRateLimitRequest(engine *gin.Engine, remoteAddr string, setup func(*http.Request)) int { + req := httptest.NewRequest(http.MethodGet, "/ping", nil) + req.RemoteAddr = remoteAddr + if setup != nil { + setup(req) + } + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + return w.Code +} + +func TestRateLimitMiddleware_LoginUsesIndependentIPBuckets(t *testing.T) { + mw := NewRateLimitMiddleware(config.RateLimitConfig{}) + engine := newRateLimitTestEngine(mw.Login()) + + for i := 0; i < 5; i++ { + if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusOK { + t.Fatalf("ip1 request %d expected 200, got %d", i+1, code) + } + } + if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusTooManyRequests { + t.Fatalf("ip1 sixth request expected 429, got %d", code) + } + + if code := performRateLimitRequest(engine, "2.2.2.2:1234", nil); code != http.StatusOK { + t.Fatalf("independent ip should not be throttled, got %d", code) + } +} + +func TestRateLimitMiddleware_APIPrefersUserIDOverSharedIP(t *testing.T) { + mw := NewRateLimitMiddleware(config.RateLimitConfig{}) + engine := gin.New() + engine.Use(func(c *gin.Context) { + if userID := c.GetHeader("X-Test-User-ID"); userID != "" { + c.Set("user_id", userID) + } + c.Next() + }) + engine.Use(mw.limitForKey("api-test", 60, 1)) + engine.GET("/ping", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + setupUser1 := func(req *http.Request) { + req.Header.Set("X-Test-User-ID", "101") + } + setupUser2 := func(req *http.Request) { + req.Header.Set("X-Test-User-ID", "202") + } + + if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusOK { + t.Fatalf("user1 first request expected 200, got %d", code) + } + if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusTooManyRequests { + t.Fatalf("user1 second request expected 429, got %d", code) + } + if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser2); code != http.StatusOK { + t.Fatalf("user2 should have independent bucket on shared ip, got %d", code) + } +} + +func TestRateLimitMiddleware_CleansUpIdleLimiters(t *testing.T) { + mw := NewRateLimitMiddleware(config.RateLimitConfig{}) + mw.cleanupInt = 10 * time.Millisecond + engine := newRateLimitTestEngine(mw.limitForKey("cleanup", 1, 2)) + + if code := performRateLimitRequest(engine, "3.3.3.3:1234", nil); code != http.StatusOK { + t.Fatalf("seed request expected 200, got %d", code) + } + if got := len(mw.limiters); got != 1 { + t.Fatalf("expected 1 limiter after seed request, got %d", got) + } + + time.Sleep(1100 * time.Millisecond) + if code := performRateLimitRequest(engine, "4.4.4.4:1234", nil); code != http.StatusOK { + t.Fatalf("cleanup trigger request expected 200, got %d", code) + } + + if got := len(mw.limiters); got != 1 { + t.Fatalf("expected stale limiter to be cleaned up, got %d entries", got) + } +}