fix(security): /uploads 目录路径遍历防护

- 替换 Static 为受控文件服务 handler (serveUploads)
- 添加 filepath.Clean 路径清理 + .. 检测
- 使用 Abs + HasPrefix 限制访问范围在上传目录内
- 添加安全响应头(CSP default-src 'none', X-Content-Type-Options nosniff)
This commit is contained in:
2026-05-08 12:28:03 +08:00
parent e49865df11
commit 61692e4c1a
4 changed files with 60 additions and 8 deletions

View File

@@ -1,6 +1,11 @@
package router package router
import ( import (
"net/http"
"os"
"path/filepath"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
swaggerFiles "github.com/swaggo/files" swaggerFiles "github.com/swaggo/files"
@@ -122,9 +127,9 @@ func (r *Router) Setup() *gin.Engine {
) )
} }
// P0 安全修复:/uploads 目录不再公开暴露,改为需要认证后才能访问 // P0 安全修复:/uploads 目录使用受控文件服务,防止路径遍历
uploadsGroup := r.engine.Group("/uploads", r.authMiddleware.Required()) uploadsGroup := r.engine.Group("/uploads", r.authMiddleware.Required())
uploadsGroup.Static("", "./uploads") uploadsGroup.GET("/*filepath", r.serveUploads)
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
@@ -408,3 +413,37 @@ func (r *Router) Setup() *gin.Engine {
func (r *Router) GetEngine() *gin.Engine { func (r *Router) GetEngine() *gin.Engine {
return r.engine return r.engine
} }
// serveUploads 提供受控的上传文件访问,防止路径遍历攻击
func (r *Router) serveUploads(c *gin.Context) {
filePath := c.Param("filepath")
// 1. 清理路径,阻止路径遍历
filePath = filepath.Clean("/" + filePath)
if strings.Contains(filePath, "..") {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"code": 403, "message": "invalid path"})
return
}
// 2. 限制在上传目录内
fullPath := filepath.Join("./uploads", filePath)
absUploads, _ := filepath.Abs("./uploads")
absPath, _ := filepath.Abs(fullPath)
if !strings.HasPrefix(absPath, absUploads) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"code": 403, "message": "access denied"})
return
}
// 3. 检查文件存在
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
c.AbortWithStatus(http.StatusNotFound)
return
}
// 4. 设置安全响应头(禁止浏览器执行)
c.Header("Content-Security-Policy", "default-src 'none'")
c.Header("X-Content-Type-Options", "nosniff")
// 5. 提供文件
c.File(fullPath)
}

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"time"
"github.com/user-management-system/internal/auth/providers" "github.com/user-management-system/internal/auth/providers"
) )
@@ -71,6 +72,9 @@ type OAuthManager interface {
// ValidateToken 验证令牌 // ValidateToken 验证令牌
ValidateToken(token string) (bool, error) ValidateToken(token string) (bool, error)
// ValidateTokenWithProvider 通过指定 provider 验证令牌
ValidateTokenWithProvider(ctx context.Context, provider OAuthProvider, token string) (bool, error)
// GetConfig 获取OAuth配置 // GetConfig 获取OAuth配置
GetConfig(provider OAuthProvider) (*OAuthConfig, bool) GetConfig(provider OAuthProvider) (*OAuthConfig, bool)
@@ -442,9 +446,11 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
if len(providers) == 0 { if len(providers) == 0 {
return false, errors.New("no OAuth providers configured") return false, errors.New("no OAuth providers configured")
} }
// 添加 5 秒超时,防止 provider API 无响应导致阻塞
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 尝试任一 provider 的 userinfo 端点验证 // 尝试任一 provider 的 userinfo 端点验证
tokenObj := &OAuthToken{AccessToken: token} tokenObj := &OAuthToken{AccessToken: token}
ctx := context.Background()
for _, p := range providers { for _, p := range providers {
if _, err := m.GetUserInfo(ctx, p.Provider, tokenObj); err == nil { if _, err := m.GetUserInfo(ctx, p.Provider, tokenObj); err == nil {
return true, nil return true, nil
@@ -454,10 +460,13 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
} }
// ValidateTokenWithProvider 通过指定 provider 验证令牌 // ValidateTokenWithProvider 通过指定 provider 验证令牌
func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) { func (m *DefaultOAuthManager) ValidateTokenWithProvider(ctx context.Context, provider OAuthProvider, token string) (bool, error) {
if token == "" { if token == "" {
return false, nil return false, nil
} }
if ctx == nil {
ctx = context.Background()
}
cfg, ok := m.GetConfig(provider) cfg, ok := m.GetConfig(provider)
if !ok || cfg.ClientID == "" { if !ok || cfg.ClientID == "" {
@@ -466,7 +475,6 @@ func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider,
// 通过 provider 的 userinfo 端点验证 token // 通过 provider 的 userinfo 端点验证 token
tokenObj := &OAuthToken{AccessToken: token} tokenObj := &OAuthToken{AccessToken: token}
ctx := context.Background()
_, err := m.GetUserInfo(ctx, provider, tokenObj) _, err := m.GetUserInfo(ctx, provider, tokenObj)
if err != nil { if err != nil {
return false, err return false, err

View File

@@ -175,15 +175,16 @@ func TestDefaultOAuthManager_ValidateToken(t *testing.T) {
func TestDefaultOAuthManager_ValidateTokenWithProvider(t *testing.T) { func TestDefaultOAuthManager_ValidateTokenWithProvider(t *testing.T) {
m := NewOAuthManager() m := NewOAuthManager()
ctx := context.Background()
// Test empty token // Test empty token
valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "") valid, err := m.ValidateTokenWithProvider(ctx, OAuthProviderGoogle, "")
if valid || err != nil { if valid || err != nil {
t.Errorf("ValidateTokenWithProvider('') = %v, %v, want false, nil", valid, err) t.Errorf("ValidateTokenWithProvider('') = %v, %v, want false, nil", valid, err)
} }
// Test non-existent provider // Test non-existent provider
valid, err = m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token") valid, err = m.ValidateTokenWithProvider(ctx, OAuthProviderGoogle, "some-token")
if valid { if valid {
t.Error("ValidateTokenWithProvider() should return false for unconfigured provider") t.Error("ValidateTokenWithProvider() should return false for unconfigured provider")
} }
@@ -607,7 +608,7 @@ func TestOAuthManager_ValidateTokenWithProvider_WithConfig(t *testing.T) {
}) })
// ValidateTokenWithProvider will try GetUserInfo which will fail // ValidateTokenWithProvider will try GetUserInfo which will fail
valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token") valid, err := m.ValidateTokenWithProvider(context.Background(), OAuthProviderGoogle, "some-token")
// Should return false // Should return false
if valid { if valid {
t.Error("ValidateTokenWithProvider() should return false for invalid token") t.Error("ValidateTokenWithProvider() should return false for invalid token")

View File

@@ -59,6 +59,10 @@ func (m *mockOAuthManager) ValidateToken(token string) (bool, error) {
return token != "", nil return token != "", nil
} }
func (m *mockOAuthManager) ValidateTokenWithProvider(ctx context.Context, provider auth.OAuthProvider, token string) (bool, error) {
return token != "", nil
}
func (m *mockOAuthManager) GetConfig(provider auth.OAuthProvider) (*auth.OAuthConfig, bool) { func (m *mockOAuthManager) GetConfig(provider auth.OAuthProvider) (*auth.OAuthConfig, bool) {
if m.config != nil { if m.config != nil {
return m.config, true return m.config, true