test: add middleware tests for cache_control, security_headers, trace_id
Add comprehensive tests for three middleware components: - cache_control: NoStoreSensitiveResponses, shouldDisableCaching - security_headers: SecurityHeaders, shouldAttachCSP, isHTTPSRequest - trace_id: TraceID, GetTraceID, generateTraceID Coverage: middleware 35.7% → 36.4%
This commit is contained in:
117
internal/api/middleware/cache_control_test.go
Normal file
117
internal/api/middleware/cache_control_test.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNoStoreSensitiveResponses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
fullPath string
|
||||||
|
wantNoCache bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "auth login path",
|
||||||
|
path: "/api/v1/auth/login",
|
||||||
|
fullPath: "/api/v1/auth/login",
|
||||||
|
wantNoCache: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth register path",
|
||||||
|
path: "/api/v1/auth/register",
|
||||||
|
fullPath: "/api/v1/auth/register",
|
||||||
|
wantNoCache: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-auth path",
|
||||||
|
path: "/api/v1/users",
|
||||||
|
fullPath: "/api/v1/users",
|
||||||
|
wantNoCache: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty fullPath uses request path",
|
||||||
|
path: "/api/v1/auth/refresh",
|
||||||
|
fullPath: "",
|
||||||
|
wantNoCache: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subpath of auth",
|
||||||
|
path: "/api/v1/auth/oauth/callback",
|
||||||
|
fullPath: "/api/v1/auth/oauth/callback",
|
||||||
|
wantNoCache: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(NoStoreSensitiveResponses())
|
||||||
|
router.GET(tt.path, func(c *gin.Context) {
|
||||||
|
c.String(200, "OK")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", tt.path, nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if tt.wantNoCache {
|
||||||
|
assert.Equal(t, "no-store, no-cache, must-revalidate, max-age=0", w.Header().Get("Cache-Control"))
|
||||||
|
assert.Equal(t, "no-cache", w.Header().Get("Pragma"))
|
||||||
|
assert.Equal(t, "0", w.Header().Get("Expires"))
|
||||||
|
assert.Equal(t, "no-store", w.Header().Get("Surrogate-Control"))
|
||||||
|
} else {
|
||||||
|
assert.Empty(t, w.Header().Get("Cache-Control"))
|
||||||
|
assert.Empty(t, w.Header().Get("Pragma"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldDisableCaching(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
routePath string
|
||||||
|
requestPath string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "auth prefix match",
|
||||||
|
routePath: "/api/v1/auth/login",
|
||||||
|
requestPath: "/api/v1/auth/login",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no auth prefix",
|
||||||
|
routePath: "/api/v1/users",
|
||||||
|
requestPath: "/api/v1/users",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty routePath uses requestPath",
|
||||||
|
routePath: "",
|
||||||
|
requestPath: "/api/v1/auth/logout",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trimmed spaces",
|
||||||
|
routePath: " /api/v1/auth/login ",
|
||||||
|
requestPath: "/api/v1/auth/login",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := shouldDisableCaching(tt.routePath, tt.requestPath)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
160
internal/api/middleware/security_headers_test.go
Normal file
160
internal/api/middleware/security_headers_test.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSecurityHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(SecurityHeaders())
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
c.String(200, "OK")
|
||||||
|
})
|
||||||
|
router.GET("/swagger/index.html", func(c *gin.Context) {
|
||||||
|
c.String(200, "Swagger UI")
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
wantCSP bool
|
||||||
|
wantSTS bool // Strict-Transport-Security (only for HTTPS)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "regular API endpoint",
|
||||||
|
path: "/test",
|
||||||
|
wantCSP: true,
|
||||||
|
wantSTS: false, // HTTP request
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "swagger endpoint no CSP",
|
||||||
|
path: "/swagger/index.html",
|
||||||
|
wantCSP: false,
|
||||||
|
wantSTS: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", tt.path, nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// 基础安全头
|
||||||
|
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
||||||
|
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
||||||
|
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
||||||
|
assert.Equal(t, "camera=(), microphone=(), geolocation=()", w.Header().Get("Permissions-Policy"))
|
||||||
|
assert.Equal(t, "same-origin", w.Header().Get("Cross-Origin-Opener-Policy"))
|
||||||
|
assert.Equal(t, "none", w.Header().Get("X-Permitted-Cross-Domain-Policies"))
|
||||||
|
|
||||||
|
// CSP 头
|
||||||
|
if tt.wantCSP {
|
||||||
|
assert.NotEmpty(t, w.Header().Get("Content-Security-Policy"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldAttachCSP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
routePath string
|
||||||
|
requestPath string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-swagger path",
|
||||||
|
routePath: "/api/v1/users",
|
||||||
|
requestPath: "/api/v1/users",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "swagger path",
|
||||||
|
routePath: "/swagger/index.html",
|
||||||
|
requestPath: "/swagger/index.html",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "swagger subpath",
|
||||||
|
routePath: "/swagger/api-docs",
|
||||||
|
requestPath: "/swagger/api-docs",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty routePath uses requestPath",
|
||||||
|
routePath: "",
|
||||||
|
requestPath: "/swagger/",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trimmed spaces",
|
||||||
|
routePath: " /api/v1/users ",
|
||||||
|
requestPath: "/api/v1/users",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := shouldAttachCSP(tt.routePath, tt.requestPath)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsHTTPSRequest(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*http.Request)
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "plain HTTP request",
|
||||||
|
setup: func(req *http.Request) {},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-Proto is https",
|
||||||
|
setup: func(req *http.Request) {
|
||||||
|
req.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-Proto is HTTPS (uppercase)",
|
||||||
|
setup: func(req *http.Request) {
|
||||||
|
req.Header.Set("X-Forwarded-Proto", "HTTPS")
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-Proto is http",
|
||||||
|
setup: func(req *http.Request) {
|
||||||
|
req.Header.Set("X-Forwarded-Proto", "http")
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
req, _ := http.NewRequest("GET", "/test", nil)
|
||||||
|
tt.setup(req)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
result := isHTTPSRequest(c)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
148
internal/api/middleware/trace_id_test.go
Normal file
148
internal/api/middleware/trace_id_test.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTraceID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(TraceID())
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
// 返回 trace ID 供验证
|
||||||
|
traceID := GetTraceID(c)
|
||||||
|
c.String(200, traceID)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
incomingTraceID string
|
||||||
|
expectNewGenerated bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "generate new trace ID",
|
||||||
|
incomingTraceID: "",
|
||||||
|
expectNewGenerated: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reuse incoming trace ID",
|
||||||
|
incomingTraceID: "abc123xyz",
|
||||||
|
expectNewGenerated: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/test", nil)
|
||||||
|
if tt.incomingTraceID != "" {
|
||||||
|
req.Header.Set(TraceIDHeader, tt.incomingTraceID)
|
||||||
|
}
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// 检查响应头中的 trace ID
|
||||||
|
responseTraceID := w.Header().Get(TraceIDHeader)
|
||||||
|
assert.NotEmpty(t, responseTraceID)
|
||||||
|
|
||||||
|
if tt.expectNewGenerated {
|
||||||
|
// 新生成的 trace ID 应该包含日期格式
|
||||||
|
assert.True(t, strings.Contains(responseTraceID, "-"))
|
||||||
|
// 格式: YYYYMMDD-xxxxxxxx
|
||||||
|
parts := strings.Split(responseTraceID, "-")
|
||||||
|
assert.Equal(t, 2, len(parts))
|
||||||
|
assert.Equal(t, 8, len(parts[0])) // YYYYMMDD
|
||||||
|
assert.Equal(t, 16, len(parts[1])) // hex
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.incomingTraceID, responseTraceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 响应体应该包含 trace ID
|
||||||
|
body := w.Body.String()
|
||||||
|
assert.Equal(t, responseTraceID, body)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceID_SetInContext(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var capturedTraceID string
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(TraceID())
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
capturedTraceID = GetTraceID(c)
|
||||||
|
c.String(200, "OK")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set(TraceIDHeader, "custom-trace-123")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, "custom-trace-123", capturedTraceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTraceID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupContext func(*gin.Context)
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "trace ID exists",
|
||||||
|
setupContext: func(c *gin.Context) {
|
||||||
|
c.Set(TraceIDKey, "existing-trace")
|
||||||
|
},
|
||||||
|
expected: "existing-trace",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trace ID not exists",
|
||||||
|
setupContext: func(c *gin.Context) {},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trace ID is not string",
|
||||||
|
setupContext: func(c *gin.Context) {
|
||||||
|
c.Set(TraceIDKey, 12345)
|
||||||
|
},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
tt.setupContext(c)
|
||||||
|
|
||||||
|
result := GetTraceID(c)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateTraceID(t *testing.T) {
|
||||||
|
// 生成多个 trace ID,验证格式
|
||||||
|
traceIDs := make(map[string]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
id := generateTraceID()
|
||||||
|
traceIDs[id] = true
|
||||||
|
|
||||||
|
// 验证格式
|
||||||
|
parts := strings.Split(id, "-")
|
||||||
|
assert.Equal(t, 2, len(parts), "trace ID should have 2 parts separated by -")
|
||||||
|
assert.Equal(t, 8, len(parts[0]), "date part should be 8 characters (YYYYMMDD)")
|
||||||
|
assert.Equal(t, 16, len(parts[1]), "random part should be 16 hex characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证唯一性(100个应该都不同)
|
||||||
|
assert.Equal(t, 100, len(traceIDs), "generated trace IDs should be unique")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user