From 11232177d99d9ebeec2656c492f7cef0e4655946 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 28 May 2026 17:28:08 +0800 Subject: [PATCH] fix: enforce resource ownership checks --- internal/api/handler/device_handler.go | 64 +++++++--- internal/api/handler/handler_test.go | 127 ++++++++++++++++++- internal/api/handler/webhook_handler.go | 34 +++++ internal/api/handler/webhook_handler_test.go | 6 +- 4 files changed, 209 insertions(+), 22 deletions(-) diff --git a/internal/api/handler/device_handler.go b/internal/api/handler/device_handler.go index 321dd9e..7c80fc6 100644 --- a/internal/api/handler/device_handler.go +++ b/internal/api/handler/device_handler.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" + apimiddleware "github.com/user-management-system/internal/api/middleware" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/service" ) @@ -118,9 +119,8 @@ func (h *DeviceHandler) GetDevice(c *gin.Context) { return } - device, err := h.deviceService.GetDevice(c.Request.Context(), id) - if err != nil { - handleError(c, err) + device, ok := h.authorizeDeviceAccess(c, id) + if !ok { return } @@ -151,6 +151,10 @@ func (h *DeviceHandler) UpdateDevice(c *gin.Context) { return } + if _, ok := h.authorizeDeviceAccess(c, id); !ok { + return + } + var req service.UpdateDeviceRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) @@ -187,6 +191,10 @@ func (h *DeviceHandler) DeleteDevice(c *gin.Context) { return } + if _, ok := h.authorizeDeviceAccess(c, id); !ok { + return + } + if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil { handleError(c, err) return @@ -218,6 +226,10 @@ func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) { return } + if _, ok := h.authorizeDeviceAccess(c, id); !ok { + return + } + var req struct { Status string `json:"status" binding:"required"` } @@ -269,27 +281,14 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) { return } - // 检查是否为管理员 - roleCodes, _ := c.Get("role_codes") - isAdmin := false - if roles, ok := roleCodes.([]string); ok { - for _, role := range roles { - if role == "admin" { - isAdmin = true - break - } - } - } - - userIDParam := c.Param("id") - userID, err := strconv.ParseInt(userIDParam, 10, 64) + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid user id"}) return } // 非管理员只能查看自己的设备 - if !isAdmin && userID != currentUserID { + if !apimiddleware.IsAdmin(c) && userID != currentUserID { c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "无权访问该用户的设备列表"}) return } @@ -396,6 +395,10 @@ func (h *DeviceHandler) TrustDevice(c *gin.Context) { return } + if _, ok := h.authorizeDeviceAccess(c, id); !ok { + return + } + var req TrustDeviceRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) @@ -478,6 +481,10 @@ func (h *DeviceHandler) UntrustDevice(c *gin.Context) { return } + if _, ok := h.authorizeDeviceAccess(c, id); !ok { + return + } + if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil { handleError(c, err) return @@ -555,6 +562,27 @@ func (h *DeviceHandler) LogoutAllOtherDevices(c *gin.Context) { }) } +func (h *DeviceHandler) authorizeDeviceAccess(c *gin.Context, deviceID int64) (*domain.Device, bool) { + currentUserID, ok := getUserIDFromContext(c) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) + return nil, false + } + + device, err := h.deviceService.GetDevice(c.Request.Context(), deviceID) + if err != nil { + handleError(c, err) + return nil, false + } + + if device.UserID != currentUserID && !apimiddleware.IsAdmin(c) { + c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) + return nil, false + } + + return device, true +} + // parseDuration 解析duration字符串,如 "30d" -> 30天的time.Duration func parseDuration(s string) time.Duration { if s == "" { diff --git a/internal/api/handler/handler_test.go b/internal/api/handler/handler_test.go index b8fa711..57e872e 100644 --- a/internal/api/handler/handler_test.go +++ b/internal/api/handler/handler_test.go @@ -118,6 +118,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { deviceSvc := service.NewDeviceService(deviceRepo, userRepo) loginLogSvc := service.NewLoginLogService(loginLogRepo) opLogSvc := service.NewOperationLogService(opLogRepo) + webhookSvc := service.NewWebhookService(db) captchaSvc := service.NewCaptchaService(cacheManager) totpSvc := service.NewTOTPService(userRepo) pwdResetCfg := service.DefaultPasswordResetConfig() @@ -141,6 +142,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { permHandler := handler.NewPermissionHandler(permSvc) deviceHandler := handler.NewDeviceHandler(deviceSvc) logHandler := handler.NewLogHandler(loginLogSvc, opLogSvc) + webhookHandler := handler.NewWebhookHandler(webhookSvc) captchaHandler := handler.NewCaptchaHandler(captchaSvc) totpHandler := handler.NewTOTPHandler(authSvc, totpSvc) pwdResetHandler := handler.NewPasswordResetHandler(pwdResetSvc) @@ -149,7 +151,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { r := router.NewRouter( authHandler, userHandler, roleHandler, permHandler, deviceHandler, logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware, - pwdResetHandler, captchaHandler, totpHandler, nil, + pwdResetHandler, captchaHandler, totpHandler, webhookHandler, nil, nil, nil, nil, nil, themeHandler, nil, nil, nil, avatarH, ) engine := r.Setup() @@ -233,6 +235,62 @@ func registerUser(baseURL, username, email, password string) bool { return resp.StatusCode == http.StatusCreated } +func createDeviceAndGetID(t *testing.T, baseURL, token, deviceID string) int64 { + t.Helper() + + resp, body := doPost(baseURL+"/api/v1/devices", token, map[string]interface{}{ + "device_id": deviceID, + "device_name": "Owned Device", + "device_type": 3, + "device_os": "Linux", + "device_browser": "Chrome", + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("create device failed: status=%d body=%s", resp.StatusCode, body) + } + + var result struct { + Data struct { + ID int64 `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal([]byte(body), &result); err != nil { + t.Fatalf("decode create device response failed: %v body=%s", err, body) + } + if result.Data.ID == 0 { + t.Fatalf("expected non-zero device id, body=%s", body) + } + return result.Data.ID +} + +func createWebhookAndGetID(t *testing.T, baseURL, token, name string) int64 { + t.Helper() + + resp, body := doPost(baseURL+"/api/v1/webhooks", token, map[string]interface{}{ + "name": name, + "url": "https://example.com/webhook", + "events": []string{"user.created"}, + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("create webhook failed: status=%d body=%s", resp.StatusCode, body) + } + + var result struct { + Data struct { + ID int64 `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal([]byte(body), &result); err != nil { + t.Fatalf("decode create webhook response failed: %v body=%s", err, body) + } + if result.Data.ID == 0 { + t.Fatalf("expected non-zero webhook id, body=%s", body) + } + return result.Data.ID +} + func bootstrapAdminToken(baseURL, username, email, password string) string { payload, _ := json.Marshal(map[string]interface{}{ "username": username, @@ -876,6 +934,73 @@ func TestDeviceHandler_CreateDevice_Success(t *testing.T) { } } +func TestDeviceHandler_DeviceByIDRoutes_ForbiddenForOtherUser(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "device-owner", "device-owner@test.com", "UserPass123!") + registerUser(server.URL, "device-attacker", "device-attacker@test.com", "UserPass123!") + ownerToken := getToken(server.URL, "device-owner", "UserPass123!") + attackerToken := getToken(server.URL, "device-attacker", "UserPass123!") + deviceID := createDeviceAndGetID(t, server.URL, ownerToken, "device-owner-001") + + tests := []struct { + name string + method string + url string + body map[string]interface{} + }{ + {name: "get", method: http.MethodGet, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID)}, + {name: "update", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), body: map[string]interface{}{"device_name": "hijacked"}}, + {name: "delete", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID)}, + {name: "status", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), body: map[string]interface{}{"status": "inactive"}}, + {name: "trust", method: http.MethodPost, url: fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), body: map[string]interface{}{"trust_duration": "30d"}}, + {name: "untrust", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID)}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resp, body := doRequest(tc.method, tc.url, attackerToken, tc.body) + defer resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected 403 for %s, got %d body=%s", tc.name, resp.StatusCode, body) + } + }) + } +} + +func TestWebhookHandler_OtherUserCannotManageForeignWebhook(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "webhook-owner", "webhook-owner@test.com", "UserPass123!") + registerUser(server.URL, "webhook-attacker", "webhook-attacker@test.com", "UserPass123!") + ownerToken := getToken(server.URL, "webhook-owner", "UserPass123!") + attackerToken := getToken(server.URL, "webhook-attacker", "UserPass123!") + webhookID := createWebhookAndGetID(t, server.URL, ownerToken, "owner-webhook") + + tests := []struct { + name string + method string + url string + body map[string]interface{} + }{ + {name: "update", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/webhooks/%d", server.URL, webhookID), body: map[string]interface{}{"name": "hijacked"}}, + {name: "delete", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/webhooks/%d", server.URL, webhookID)}, + {name: "deliveries", method: http.MethodGet, url: fmt.Sprintf("%s/api/v1/webhooks/%d/deliveries", server.URL, webhookID)}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resp, body := doRequest(tc.method, tc.url, attackerToken, tc.body) + defer resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected 403 for webhook %s, got %d body=%s", tc.name, resp.StatusCode, body) + } + }) + } +} + // ============================================================================= // Role Handler Tests // ============================================================================= diff --git a/internal/api/handler/webhook_handler.go b/internal/api/handler/webhook_handler.go index c7ec067..7c22f67 100644 --- a/internal/api/handler/webhook_handler.go +++ b/internal/api/handler/webhook_handler.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" + apimiddleware "github.com/user-management-system/internal/api/middleware" "github.com/user-management-system/internal/service" ) @@ -117,6 +118,10 @@ func (h *WebhookHandler) UpdateWebhook(c *gin.Context) { return } + if _, ok := h.authorizeWebhookAccess(c, id); !ok { + return + } + var req service.UpdateWebhookRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) @@ -150,6 +155,10 @@ func (h *WebhookHandler) DeleteWebhook(c *gin.Context) { return } + if _, ok := h.authorizeWebhookAccess(c, id); !ok { + return + } + if err := h.webhookService.DeleteWebhook(c.Request.Context(), id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "删除 Webhook 失败"}) return @@ -178,6 +187,10 @@ func (h *WebhookHandler) GetWebhookDeliveries(c *gin.Context) { return } + if _, ok := h.authorizeWebhookAccess(c, id); !ok { + return + } + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) if limit < 1 || limit > 100 { limit = 20 @@ -191,3 +204,24 @@ func (h *WebhookHandler) GetWebhookDeliveries(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"code": 0, "message": "success", "data": gin.H{"deliveries": deliveries}}) } + +func (h *WebhookHandler) authorizeWebhookAccess(c *gin.Context, webhookID int64) (int64, bool) { + userID, ok := getUserIDFromContext(c) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"}) + return 0, false + } + + webhook, err := h.webhookService.GetWebhook(c.Request.Context(), webhookID) + if err != nil { + handleError(c, err) + return 0, false + } + + if webhook.CreatedBy != userID && !apimiddleware.IsAdmin(c) { + c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) + return 0, false + } + + return userID, true +} diff --git a/internal/api/handler/webhook_handler_test.go b/internal/api/handler/webhook_handler_test.go index fe29944..a00067a 100644 --- a/internal/api/handler/webhook_handler_test.go +++ b/internal/api/handler/webhook_handler_test.go @@ -359,9 +359,9 @@ func TestWebhookHandler_DeleteWebhook_NotFound(t *testing.T) { resp := doRequestWithCheck(t, "DELETE", server.URL+"/api/v1/webhooks/99999", token, nil) defer resp.Body.Close() - // Delete is idempotent - returns 200 even if not found - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.StatusCode) + // 先做归属/存在性校验,不存在的 webhook 返回 404 + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected status 404, got %d", resp.StatusCode) } }