fix(security): /uploads 目录路径遍历防护
- 替换 Static 为受控文件服务 handler (serveUploads) - 添加 filepath.Clean 路径清理 + .. 检测 - 使用 Abs + HasPrefix 限制访问范围在上传目录内 - 添加安全响应头(CSP default-src 'none', X-Content-Type-Options nosniff)
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
@@ -122,9 +127,9 @@ func (r *Router) Setup() *gin.Engine {
|
||||
)
|
||||
}
|
||||
|
||||
// P0 安全修复:/uploads 目录不再公开暴露,改为需要认证后才能访问
|
||||
// P0 安全修复:/uploads 目录使用受控文件服务,防止路径遍历
|
||||
uploadsGroup := r.engine.Group("/uploads", r.authMiddleware.Required())
|
||||
uploadsGroup.Static("", "./uploads")
|
||||
uploadsGroup.GET("/*filepath", r.serveUploads)
|
||||
|
||||
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
|
||||
@@ -408,3 +413,37 @@ func (r *Router) Setup() *gin.Engine {
|
||||
func (r *Router) GetEngine() *gin.Engine {
|
||||
return r.engine
|
||||
}
|
||||
|
||||
// serveUploads 提供受控的上传文件访问,防止路径遍历攻击
|
||||
func (r *Router) serveUploads(c *gin.Context) {
|
||||
filePath := c.Param("filepath")
|
||||
|
||||
// 1. 清理路径,阻止路径遍历
|
||||
filePath = filepath.Clean("/" + filePath)
|
||||
if strings.Contains(filePath, "..") {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"code": 403, "message": "invalid path"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 限制在上传目录内
|
||||
fullPath := filepath.Join("./uploads", filePath)
|
||||
absUploads, _ := filepath.Abs("./uploads")
|
||||
absPath, _ := filepath.Abs(fullPath)
|
||||
if !strings.HasPrefix(absPath, absUploads) {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"code": 403, "message": "access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 检查文件存在
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 设置安全响应头(禁止浏览器执行)
|
||||
c.Header("Content-Security-Policy", "default-src 'none'")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// 5. 提供文件
|
||||
c.File(fullPath)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth/providers"
|
||||
)
|
||||
@@ -71,6 +72,9 @@ type OAuthManager interface {
|
||||
// ValidateToken 验证令牌
|
||||
ValidateToken(token string) (bool, error)
|
||||
|
||||
// ValidateTokenWithProvider 通过指定 provider 验证令牌
|
||||
ValidateTokenWithProvider(ctx context.Context, provider OAuthProvider, token string) (bool, error)
|
||||
|
||||
// GetConfig 获取OAuth配置
|
||||
GetConfig(provider OAuthProvider) (*OAuthConfig, bool)
|
||||
|
||||
@@ -442,9 +446,11 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
|
||||
if len(providers) == 0 {
|
||||
return false, errors.New("no OAuth providers configured")
|
||||
}
|
||||
// 添加 5 秒超时,防止 provider API 无响应导致阻塞
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
// 尝试任一 provider 的 userinfo 端点验证
|
||||
tokenObj := &OAuthToken{AccessToken: token}
|
||||
ctx := context.Background()
|
||||
for _, p := range providers {
|
||||
if _, err := m.GetUserInfo(ctx, p.Provider, tokenObj); err == nil {
|
||||
return true, nil
|
||||
@@ -454,10 +460,13 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
|
||||
}
|
||||
|
||||
// ValidateTokenWithProvider 通过指定 provider 验证令牌
|
||||
func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) {
|
||||
func (m *DefaultOAuthManager) ValidateTokenWithProvider(ctx context.Context, provider OAuthProvider, token string) (bool, error) {
|
||||
if token == "" {
|
||||
return false, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
cfg, ok := m.GetConfig(provider)
|
||||
if !ok || cfg.ClientID == "" {
|
||||
@@ -466,7 +475,6 @@ func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider,
|
||||
|
||||
// 通过 provider 的 userinfo 端点验证 token
|
||||
tokenObj := &OAuthToken{AccessToken: token}
|
||||
ctx := context.Background()
|
||||
_, err := m.GetUserInfo(ctx, provider, tokenObj)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -175,15 +175,16 @@ func TestDefaultOAuthManager_ValidateToken(t *testing.T) {
|
||||
|
||||
func TestDefaultOAuthManager_ValidateTokenWithProvider(t *testing.T) {
|
||||
m := NewOAuthManager()
|
||||
ctx := context.Background()
|
||||
|
||||
// Test empty token
|
||||
valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "")
|
||||
valid, err := m.ValidateTokenWithProvider(ctx, OAuthProviderGoogle, "")
|
||||
if valid || err != nil {
|
||||
t.Errorf("ValidateTokenWithProvider('') = %v, %v, want false, nil", valid, err)
|
||||
}
|
||||
|
||||
// Test non-existent provider
|
||||
valid, err = m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token")
|
||||
valid, err = m.ValidateTokenWithProvider(ctx, OAuthProviderGoogle, "some-token")
|
||||
if valid {
|
||||
t.Error("ValidateTokenWithProvider() should return false for unconfigured provider")
|
||||
}
|
||||
@@ -607,7 +608,7 @@ func TestOAuthManager_ValidateTokenWithProvider_WithConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
// ValidateTokenWithProvider will try GetUserInfo which will fail
|
||||
valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token")
|
||||
valid, err := m.ValidateTokenWithProvider(context.Background(), OAuthProviderGoogle, "some-token")
|
||||
// Should return false
|
||||
if valid {
|
||||
t.Error("ValidateTokenWithProvider() should return false for invalid token")
|
||||
|
||||
@@ -59,6 +59,10 @@ func (m *mockOAuthManager) ValidateToken(token string) (bool, error) {
|
||||
return token != "", nil
|
||||
}
|
||||
|
||||
func (m *mockOAuthManager) ValidateTokenWithProvider(ctx context.Context, provider auth.OAuthProvider, token string) (bool, error) {
|
||||
return token != "", nil
|
||||
}
|
||||
|
||||
func (m *mockOAuthManager) GetConfig(provider auth.OAuthProvider) (*auth.OAuthConfig, bool) {
|
||||
if m.config != nil {
|
||||
return m.config, true
|
||||
|
||||
Reference in New Issue
Block a user