Files
user-system/internal/api/middleware/trace_id_test.go
Your Name 707d35fb74 test: add middleware tests for cache_control, security_headers, trace_id
Add comprehensive tests for three middleware components:
- cache_control: NoStoreSensitiveResponses, shouldDisableCaching
- security_headers: SecurityHeaders, shouldAttachCSP, isHTTPSRequest
- trace_id: TraceID, GetTraceID, generateTraceID

Coverage: middleware 35.7% → 36.4%
2026-05-29 20:11:26 +08:00

149 lines
3.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestTraceID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(TraceID())
router.GET("/test", func(c *gin.Context) {
// 返回 trace ID 供验证
traceID := GetTraceID(c)
c.String(200, traceID)
})
tests := []struct {
name string
incomingTraceID string
expectNewGenerated bool
}{
{
name: "generate new trace ID",
incomingTraceID: "",
expectNewGenerated: true,
},
{
name: "reuse incoming trace ID",
incomingTraceID: "abc123xyz",
expectNewGenerated: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
if tt.incomingTraceID != "" {
req.Header.Set(TraceIDHeader, tt.incomingTraceID)
}
router.ServeHTTP(w, req)
// 检查响应头中的 trace ID
responseTraceID := w.Header().Get(TraceIDHeader)
assert.NotEmpty(t, responseTraceID)
if tt.expectNewGenerated {
// 新生成的 trace ID 应该包含日期格式
assert.True(t, strings.Contains(responseTraceID, "-"))
// 格式: YYYYMMDD-xxxxxxxx
parts := strings.Split(responseTraceID, "-")
assert.Equal(t, 2, len(parts))
assert.Equal(t, 8, len(parts[0])) // YYYYMMDD
assert.Equal(t, 16, len(parts[1])) // hex
} else {
assert.Equal(t, tt.incomingTraceID, responseTraceID)
}
// 响应体应该包含 trace ID
body := w.Body.String()
assert.Equal(t, responseTraceID, body)
})
}
}
func TestTraceID_SetInContext(t *testing.T) {
gin.SetMode(gin.TestMode)
var capturedTraceID string
router := gin.New()
router.Use(TraceID())
router.GET("/test", func(c *gin.Context) {
capturedTraceID = GetTraceID(c)
c.String(200, "OK")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set(TraceIDHeader, "custom-trace-123")
router.ServeHTTP(w, req)
assert.Equal(t, "custom-trace-123", capturedTraceID)
}
func TestGetTraceID(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
setupContext func(*gin.Context)
expected string
}{
{
name: "trace ID exists",
setupContext: func(c *gin.Context) {
c.Set(TraceIDKey, "existing-trace")
},
expected: "existing-trace",
},
{
name: "trace ID not exists",
setupContext: func(c *gin.Context) {},
expected: "",
},
{
name: "trace ID is not string",
setupContext: func(c *gin.Context) {
c.Set(TraceIDKey, 12345)
},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
tt.setupContext(c)
result := GetTraceID(c)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGenerateTraceID(t *testing.T) {
// 生成多个 trace ID验证格式
traceIDs := make(map[string]bool)
for i := 0; i < 100; i++ {
id := generateTraceID()
traceIDs[id] = true
// 验证格式
parts := strings.Split(id, "-")
assert.Equal(t, 2, len(parts), "trace ID should have 2 parts separated by -")
assert.Equal(t, 8, len(parts[0]), "date part should be 8 characters (YYYYMMDD)")
assert.Equal(t, 16, len(parts[1]), "random part should be 16 hex characters")
}
// 验证唯一性100个应该都不同
assert.Equal(t, 100, len(traceIDs), "generated trace IDs should be unique")
}