diff --git a/internal/api/handler/api_contract_integration_test.go b/internal/api/handler/api_contract_integration_test.go new file mode 100644 index 0000000..d03449d --- /dev/null +++ b/internal/api/handler/api_contract_integration_test.go @@ -0,0 +1,467 @@ +//go:build integration +// +build integration + +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + + "github.com/user-management-system/internal/api/middleware" +) + +// TestResponseWrapper_Contract 验证响应包装中间件符合 API 契约 +func TestResponseWrapper_Contract(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + handler gin.HandlerFunc + expectedCode int + checkWrapped bool // 是否检查包装后的格式 + }{ + { + name: "simple data gets wrapped", + handler: func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"id": "123", "name": "test"}) + }, + expectedCode: 0, // 包装后的 code + checkWrapped: true, + }, + { + name: "error response passes through without wrapping", + handler: func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "bad request"}) + }, + expectedCode: 400, + checkWrapped: false, // 非 2xx 响应不会被包装 + }, + { + name: "already wrapped response passes through", + handler: func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"code": 0, "message": "success", "data": gin.H{"id": "1"}}) + }, + expectedCode: 0, + checkWrapped: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 创建带有 ResponseWrapper 的路由 + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.GET("/test", tt.handler) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + engine.ServeHTTP(w, req) + + if tt.checkWrapped { + assert.Equal(t, http.StatusOK, w.Code) + } + + if tt.checkWrapped { + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + // 验证响应包含 code 字段 + code, exists := response["code"] + assert.True(t, exists, "response should have 'code' field") + assert.Equal(t, float64(tt.expectedCode), code) + + // 验证响应包含 message 字段 + _, exists = response["message"] + assert.True(t, exists, "response should have 'message' field") + } + }) + } +} + +// TestResponseWrapper_ListContract 验证列表响应包装 +func TestResponseWrapper_ListContract(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.GET("/users", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "items": []gin.H{ + {"id": "1", "name": "user1"}, + {"id": "2", "name": "user2"}, + }, + "total": 100, + "page": 1, + "page_size": 20, + }) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/users", nil) + engine.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + // 验证包装后的结构 + assert.Equal(t, float64(0), response["code"]) + assert.Equal(t, "success", response["message"]) + + // 验证 data 中包含列表数据 + data := response["data"].(map[string]interface{}) + assert.NotNil(t, data["items"]) + assert.Equal(t, float64(100), data["total"]) + assert.Equal(t, float64(1), data["page"]) + assert.Equal(t, float64(20), data["page_size"]) +} + +// TestResponseWrapper_PaginationParameters 验证分页参数处理 +func TestResponseWrapper_PaginationParameters(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.GET("/items", func(c *gin.Context) { + page := c.DefaultQuery("page", "1") + pageSize := c.DefaultQuery("page_size", "20") + + c.JSON(http.StatusOK, gin.H{ + "items": []gin.H{}, + "total": 0, + "page": page, + "page_size": pageSize, + }) + }) + + tests := []struct { + name string + query string + expectedPage string + expectedSize string + }{ + {"default pagination", "", "1", "20"}, + {"custom page", "?page=5", "5", "20"}, + {"custom page size", "?page_size=50", "1", "50"}, + {"both custom", "?page=3&page_size=30", "3", "30"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/items"+tt.query, nil) + engine.ServeHTTP(w, req) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + data := response["data"].(map[string]interface{}) + assert.Equal(t, tt.expectedPage, data["page"]) + assert.Equal(t, tt.expectedSize, data["page_size"]) + }) + } +} + +// TestResponseWrapper_ContentType 验证 Content-Type 头 +func TestResponseWrapper_ContentType(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"test": "data"}) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + engine.ServeHTTP(w, req) + + // 验证 Content-Type + contentType := w.Header().Get("Content-Type") + assert.Contains(t, contentType, "application/json") +} + +// TestResponseWrapper_NonJSON 验证非 JSON 响应不被包装 +func TestResponseWrapper_NonJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.GET("/file", func(c *gin.Context) { + c.Data(http.StatusOK, "application/octet-stream", []byte("binary data")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/file", nil) + engine.ServeHTTP(w, req) + + // 验证二进制响应直接通过 + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "binary data", w.Body.String()) +} + +// TestResponseWrapper_EmptyBody 验证空响应处理 +func TestResponseWrapper_EmptyBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.GET("/empty", func(c *gin.Context) { + c.Status(http.StatusNoContent) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/empty", nil) + engine.ServeHTTP(w, req) + + // NoContent 应该返回 204 + assert.Equal(t, http.StatusNoContent, w.Code) +} + +// TestAPIContract_StructuredError 验证结构化错误响应 +func TestAPIContract_StructuredError(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.POST("/validate", func(c *gin.Context) { + // 模拟验证错误 + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "validation failed", + "data": gin.H{ + "errors": []gin.H{ + {"field": "email", "message": "invalid format"}, + {"field": "password", "message": "too short"}, + }, + }, + }) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/validate", bytes.NewBufferString("{}")) + req.Header.Set("Content-Type", "application/json") + engine.ServeHTTP(w, req) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + assert.Equal(t, float64(400), response["code"]) + assert.Equal(t, "validation failed", response["message"]) + + data := response["data"].(map[string]interface{}) + errors := data["errors"].([]interface{}) + assert.Len(t, errors, 2) +} + +// TestAPIContract_SuccessFields 验证成功响应必需字段 +func TestAPIContract_SuccessFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.ResponseWrapper()) + engine.GET("/success", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"id": "123", "name": "test"}) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/success", nil) + engine.ServeHTTP(w, req) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + // 验证标准格式 + assert.Equal(t, float64(0), response["code"], "success response should have code 0") + assert.Equal(t, "success", response["message"], "success response should have message 'success'") + assert.NotNil(t, response["data"], "success response should have data field") +} + +// TestAuthEndpoints_Contract 验证认证端点契约 +func TestAuthEndpoints_Contract(t *testing.T) { + // 这个测试验证 API.md 中定义的端点存在 + // 实际的路由测试需要在完整的服务器环境中进行 + gin.SetMode(gin.TestMode) + + // 定义 API.md 中描述的公开端点 + publicEndpoints := []struct { + method string + path string + }{ + {"POST", "/api/v1/auth/register"}, + {"POST", "/api/v1/auth/bootstrap-admin"}, + {"POST", "/api/v1/auth/login"}, + {"POST", "/api/v1/auth/refresh"}, + {"GET", "/api/v1/auth/capabilities"}, + {"GET", "/api/v1/auth/csrf-token"}, + {"GET", "/api/v1/auth/captcha"}, + {"GET", "/api/v1/auth/captcha/image"}, + {"POST", "/api/v1/auth/captcha/verify"}, + {"GET", "/api/v1/auth/oauth/providers"}, + {"POST", "/api/v1/auth/forgot-password"}, + {"POST", "/api/v1/auth/reset-password"}, + } + + // 验证端点定义存在(这里只是契约验证,不是运行时测试) + for _, ep := range publicEndpoints { + assert.NotEmpty(t, ep.method) + assert.NotEmpty(t, ep.path) + assert.True(t, len(ep.path) > 0) + } +} + +// TestProtectedEndpoints_Contract 验证受保护端点契约 +func TestProtectedEndpoints_Contract(t *testing.T) { + protectedEndpoints := []struct { + method string + path string + permission string + }{ + {"GET", "/api/v1/auth/userinfo", ""}, + {"POST", "/api/v1/auth/logout", ""}, + {"GET", "/api/v1/users", "user:manage"}, + {"POST", "/api/v1/users", "user:manage"}, + {"GET", "/api/v1/users/:id", ""}, + {"PUT", "/api/v1/users/:id", ""}, + {"DELETE", "/api/v1/users/:id", "user:delete"}, + {"GET", "/api/v1/users/:id/roles", ""}, + {"PUT", "/api/v1/users/:id/roles", "user:manage"}, + {"GET", "/api/v1/roles", ""}, + {"POST", "/api/v1/roles", ""}, + {"PUT", "/api/v1/roles/:id/permissions", ""}, + {"GET", "/api/v1/permissions", ""}, + {"GET", "/api/v1/permissions/tree", ""}, + {"GET", "/api/v1/devices", ""}, + {"POST", "/api/v1/devices", ""}, + {"POST", "/api/v1/devices/:id/trust", ""}, + {"GET", "/api/v1/logs/login", ""}, + {"GET", "/api/v1/logs/operation", ""}, + {"GET", "/api/v1/webhooks", ""}, + {"POST", "/api/v1/webhooks", ""}, + {"GET", "/api/v1/auth/2fa/status", ""}, + {"GET", "/api/v1/auth/2fa/setup", ""}, + {"POST", "/api/v1/auth/2fa/enable", ""}, + {"POST", "/api/v1/auth/2fa/disable", ""}, + } + + for _, ep := range protectedEndpoints { + assert.NotEmpty(t, ep.method) + assert.NotEmpty(t, ep.path) + if ep.permission != "" { + assert.True(t, len(ep.permission) > 0) + } + } +} + +// TestHTTPStatusCodes_Contract 验证 HTTP 状态码使用规范 +func TestHTTPStatusCodes_Contract(t *testing.T) { + statusCodes := map[int]string{ + http.StatusOK: "成功响应", + http.StatusCreated: "资源创建成功", + http.StatusBadRequest: "请求参数错误", + http.StatusUnauthorized: "未认证", + http.StatusForbidden: "无权限", + http.StatusNotFound: "资源不存在", + http.StatusConflict: "资源冲突", + http.StatusTooManyRequests: "请求过于频繁", + http.StatusInternalServerError: "服务器内部错误", + } + + for code, desc := range statusCodes { + assert.NotEmpty(t, desc) + assert.Greater(t, code, 0) + } +} + +// TestHeaderContract_SecurityHeaders 验证安全响应头 +func TestHeaderContract_SecurityHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(middleware.SecurityHeaders()) + engine.Use(middleware.ResponseWrapper()) + engine.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"test": "data"}) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + engine.ServeHTTP(w, req) + + // 验证关键安全头 + assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options")) + assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy")) + assert.Equal(t, "camera=(), microphone=(), geolocation=()", w.Header().Get("Permissions-Policy")) + assert.Equal(t, "same-origin", w.Header().Get("Cross-Origin-Opener-Policy")) + assert.Equal(t, "none", w.Header().Get("X-Permitted-Cross-Domain-Policies")) +} + +// TestAPIContract_ResponseTime 验证响应时间格式 +func TestAPIContract_ResponseTime(t *testing.T) { + // API 应该返回 ISO 8601 格式的时间字符串 + timeFormats := []string{ + "2024-01-15T10:30:00Z", + "2024-01-15T10:30:00+08:00", + "2024-01-15T10:30:00.123456Z", + } + + for _, format := range timeFormats { + assert.NotEmpty(t, format) + // 验证格式符合 ISO 8601 + assert.Contains(t, format, "T") + } +} + +// TestPagination_DefaultValues 验证分页默认值 +func TestPagination_DefaultValues(t *testing.T) { + defaults := struct { + Page int + PageSize int + MaxSize int + }{ + Page: 1, + PageSize: 20, + MaxSize: 100, + } + + assert.Equal(t, 1, defaults.Page) + assert.Equal(t, 20, defaults.PageSize) + assert.Equal(t, 100, defaults.MaxSize) + + // 验证 page_size 限制 + assert.LessOrEqual(t, defaults.PageSize, defaults.MaxSize) +} + +// TestSorting_Contract 验证排序参数 +func TestSorting_Contract(t *testing.T) { + sortFields := []string{ + "created_at", + "updated_at", + "id", + "username", + "email", + } + + sortOrders := []string{"asc", "desc"} + + for _, field := range sortFields { + assert.NotEmpty(t, field) + } + + for _, order := range sortOrders { + assert.Contains(t, []string{"asc", "desc"}, order) + } +}