108 lines
3.3 KiB
Go
108 lines
3.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/user-management-system/internal/config"
|
|
)
|
|
|
|
func init() {
|
|
gin.SetMode(gin.TestMode)
|
|
}
|
|
|
|
func newRateLimitTestEngine(mw gin.HandlerFunc) *gin.Engine {
|
|
engine := gin.New()
|
|
engine.Use(mw)
|
|
engine.GET("/ping", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
return engine
|
|
}
|
|
|
|
func performRateLimitRequest(engine *gin.Engine, remoteAddr string, setup func(*http.Request)) int {
|
|
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
|
req.RemoteAddr = remoteAddr
|
|
if setup != nil {
|
|
setup(req)
|
|
}
|
|
w := httptest.NewRecorder()
|
|
engine.ServeHTTP(w, req)
|
|
return w.Code
|
|
}
|
|
|
|
func TestRateLimitMiddleware_LoginUsesIndependentIPBuckets(t *testing.T) {
|
|
mw := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
engine := newRateLimitTestEngine(mw.Login())
|
|
|
|
for i := 0; i < 5; i++ {
|
|
if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusOK {
|
|
t.Fatalf("ip1 request %d expected 200, got %d", i+1, code)
|
|
}
|
|
}
|
|
if code := performRateLimitRequest(engine, "1.1.1.1:1234", nil); code != http.StatusTooManyRequests {
|
|
t.Fatalf("ip1 sixth request expected 429, got %d", code)
|
|
}
|
|
|
|
if code := performRateLimitRequest(engine, "2.2.2.2:1234", nil); code != http.StatusOK {
|
|
t.Fatalf("independent ip should not be throttled, got %d", code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_APIPrefersUserIDOverSharedIP(t *testing.T) {
|
|
mw := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
engine := gin.New()
|
|
engine.Use(func(c *gin.Context) {
|
|
if userID := c.GetHeader("X-Test-User-ID"); userID != "" {
|
|
c.Set("user_id", userID)
|
|
}
|
|
c.Next()
|
|
})
|
|
engine.Use(mw.limitForKey("api-test", 60, 1))
|
|
engine.GET("/ping", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
setupUser1 := func(req *http.Request) {
|
|
req.Header.Set("X-Test-User-ID", "101")
|
|
}
|
|
setupUser2 := func(req *http.Request) {
|
|
req.Header.Set("X-Test-User-ID", "202")
|
|
}
|
|
|
|
if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusOK {
|
|
t.Fatalf("user1 first request expected 200, got %d", code)
|
|
}
|
|
if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser1); code != http.StatusTooManyRequests {
|
|
t.Fatalf("user1 second request expected 429, got %d", code)
|
|
}
|
|
if code := performRateLimitRequest(engine, "9.9.9.9:1234", setupUser2); code != http.StatusOK {
|
|
t.Fatalf("user2 should have independent bucket on shared ip, got %d", code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_CleansUpIdleLimiters(t *testing.T) {
|
|
mw := NewRateLimitMiddleware(config.RateLimitConfig{})
|
|
mw.cleanupInt = 10 * time.Millisecond
|
|
engine := newRateLimitTestEngine(mw.limitForKey("cleanup", 1, 2))
|
|
|
|
if code := performRateLimitRequest(engine, "3.3.3.3:1234", nil); code != http.StatusOK {
|
|
t.Fatalf("seed request expected 200, got %d", code)
|
|
}
|
|
if got := len(mw.limiters); got != 1 {
|
|
t.Fatalf("expected 1 limiter after seed request, got %d", got)
|
|
}
|
|
|
|
time.Sleep(1100 * time.Millisecond)
|
|
if code := performRateLimitRequest(engine, "4.4.4.4:1234", nil); code != http.StatusOK {
|
|
t.Fatalf("cleanup trigger request expected 200, got %d", code)
|
|
}
|
|
|
|
if got := len(mw.limiters); got != 1 {
|
|
t.Fatalf("expected stale limiter to be cleaned up, got %d entries", got)
|
|
}
|
|
}
|