package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/user-management-system/internal/config" ) func TestValidateCORSConfig(t *testing.T) { tests := []struct { name string cfg config.CORSConfig wantErr bool }{ { name: "valid config with specific origins", cfg: config.CORSConfig{ AllowedOrigins: []string{"https://example.com"}, AllowCredentials: true, }, wantErr: false, }, { name: "valid config with wildcard no credentials", cfg: config.CORSConfig{ AllowedOrigins: []string{"*"}, AllowCredentials: false, }, wantErr: false, }, { name: "invalid config with wildcard and credentials", cfg: config.CORSConfig{ AllowedOrigins: []string{"*"}, AllowCredentials: true, }, wantErr: true, }, { name: "empty origins", cfg: config.CORSConfig{ AllowedOrigins: []string{}, AllowCredentials: false, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateCORSConfig(tt.cfg) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) } }) } } func TestSetCORSConfig(t *testing.T) { // Save original config originalConfig := corsConfig defer func() { corsConfig = originalConfig }() t.Run("valid config", func(t *testing.T) { cfg := config.CORSConfig{ AllowedOrigins: []string{"https://example.com"}, AllowCredentials: true, } err := SetCORSConfig(cfg) assert.NoError(t, err) assert.Equal(t, cfg, corsConfig) }) t.Run("invalid config", func(t *testing.T) { cfg := config.CORSConfig{ AllowedOrigins: []string{"*"}, AllowCredentials: true, } err := SetCORSConfig(cfg) assert.Error(t, err) }) } func TestResolveAllowedOrigin(t *testing.T) { tests := []struct { name string origin string allowedOrigins []string allowCredentials bool wantOrigin string wantAllowed bool }{ { name: "exact match", origin: "https://example.com", allowedOrigins: []string{"https://example.com"}, allowCredentials: true, wantOrigin: "https://example.com", wantAllowed: true, }, { name: "wildcard without credentials", origin: "https://any.com", allowedOrigins: []string{"*"}, allowCredentials: false, wantOrigin: "*", wantAllowed: true, }, { name: "wildcard with credentials returns origin", origin: "https://any.com", allowedOrigins: []string{"*"}, allowCredentials: true, wantOrigin: "https://any.com", wantAllowed: true, }, { name: "no match", origin: "https://evil.com", allowedOrigins: []string{"https://example.com"}, allowCredentials: false, wantOrigin: "", wantAllowed: false, }, { name: "case insensitive match", origin: "HTTPS://EXAMPLE.COM", allowedOrigins: []string{"https://example.com"}, allowCredentials: false, wantOrigin: "HTTPS://EXAMPLE.COM", wantAllowed: true, }, { name: "empty origins list", origin: "https://example.com", allowedOrigins: []string{}, allowCredentials: false, wantOrigin: "", wantAllowed: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotOrigin, gotAllowed := resolveAllowedOrigin(tt.origin, tt.allowedOrigins, tt.allowCredentials) assert.Equal(t, tt.wantOrigin, gotOrigin) assert.Equal(t, tt.wantAllowed, gotAllowed) }) } } func TestCORS(t *testing.T) { gin.SetMode(gin.TestMode) // Save and restore original config originalConfig := corsConfig defer func() { corsConfig = originalConfig }() // Set test config corsConfig = config.CORSConfig{ AllowedOrigins: []string{"https://example.com"}, AllowCredentials: true, } router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.String(200, "OK") }) t.Run("allow valid origin", func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "https://example.com") router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.Equal(t, "https://example.com", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) }) t.Run("forbid invalid origin", func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "https://evil.com") router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) }) t.Run("handle OPTIONS request", func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/test", nil) req.Header.Set("Origin", "https://example.com") router.ServeHTTP(w, req) assert.Equal(t, 204, w.Code) assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) }) t.Run("no origin header", func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) }) }