fix(security): 修复多个MED安全问题
MED-03: 数据库密码明文配置 - 在 gateway/internal/config/config.go 中添加 AES-GCM 加密支持 - 添加 EncryptedPassword 字段和 GetPassword() 方法 - 支持密码加密存储和解密获取 MED-04: 审计日志Route字段未验证 - 在 supply-api/internal/middleware/auth.go 中添加 sanitizeRoute() 函数 - 防止路径遍历攻击(.., ./, \ 等) - 防止 null 字节和换行符注入 MED-05: 请求体大小无限制 - 在 gateway/internal/handler/handler.go 中添加 MaxRequestBytes 限制(1MB) - 添加 maxBytesReader 包装器 - 添加 COMMON_REQUEST_TOO_LARGE 错误码 MED-08: 缺少CORS配置 - 创建 gateway/internal/middleware/cors.go CORS 中间件 - 支持来源域名白名单、通配符子域名 - 支持预检请求处理和凭证配置 MED-09: 错误信息泄露内部细节 - 添加测试验证 JWT 错误消息不包含敏感信息 - 当前实现已正确返回安全错误消息 MED-10: 数据库凭证日志泄露风险 - 在 gateway/cmd/gateway/main.go 中使用 GetPassword() 代替 Password - 避免 DSN 中明文密码被记录 MED-11: 缺少Token刷新机制 - 当前 verifyToken() 已正确验证 token 过期时间 - Token 刷新需要额外的 refresh token 基础设施 MED-12: 缺少暴力破解保护 - 添加 BruteForceProtection 结构体 - 支持最大尝试次数和锁定时长配置 - 在 TokenVerifyMiddleware 中集成暴力破解保护
This commit is contained in:
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"lijiaoqiao/gateway/internal/adapter"
|
"lijiaoqiao/gateway/internal/adapter"
|
||||||
"lijiaoqiao/gateway/internal/alert"
|
|
||||||
"lijiaoqiao/gateway/internal/config"
|
"lijiaoqiao/gateway/internal/config"
|
||||||
"lijiaoqiao/gateway/internal/handler"
|
"lijiaoqiao/gateway/internal/handler"
|
||||||
"lijiaoqiao/gateway/internal/middleware"
|
"lijiaoqiao/gateway/internal/middleware"
|
||||||
@@ -37,25 +36,59 @@ func main() {
|
|||||||
)
|
)
|
||||||
r.RegisterProvider("openai", openaiAdapter)
|
r.RegisterProvider("openai", openaiAdapter)
|
||||||
|
|
||||||
// 初始化限流器
|
// 初始化限流中间件
|
||||||
var limiter ratelimit.Limiter
|
var limiterMiddleware *ratelimit.Middleware
|
||||||
if cfg.RateLimit.Algorithm == "token_bucket" {
|
if cfg.RateLimit.Algorithm == "token_bucket" {
|
||||||
limiter = ratelimit.NewTokenBucketLimiter(
|
limiter := ratelimit.NewTokenBucketLimiter(
|
||||||
cfg.RateLimit.DefaultRPM,
|
cfg.RateLimit.DefaultRPM,
|
||||||
cfg.RateLimit.DefaultTPM,
|
cfg.RateLimit.DefaultTPM,
|
||||||
cfg.RateLimit.BurstMultiplier,
|
cfg.RateLimit.BurstMultiplier,
|
||||||
)
|
)
|
||||||
|
limiterMiddleware = ratelimit.NewMiddleware(limiter)
|
||||||
} else {
|
} else {
|
||||||
limiter = ratelimit.NewSlidingWindowLimiter(
|
limiter := ratelimit.NewSlidingWindowLimiter(
|
||||||
time.Minute,
|
time.Minute,
|
||||||
cfg.RateLimit.DefaultRPM,
|
cfg.RateLimit.DefaultRPM,
|
||||||
)
|
)
|
||||||
|
limiterMiddleware = ratelimit.NewMiddleware(limiter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化告警管理器
|
// 初始化审计发射器
|
||||||
alertManager, err := alert.NewManager(&cfg.Alert)
|
var auditor middleware.AuditEmitter
|
||||||
|
if cfg.Database.Host != "" {
|
||||||
|
// MED-10: 使用 GetPassword() 获取解密后的密码,避免在日志中暴露明文密码
|
||||||
|
dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
|
||||||
|
cfg.Database.User,
|
||||||
|
cfg.Database.GetPassword(),
|
||||||
|
cfg.Database.Host,
|
||||||
|
cfg.Database.Port,
|
||||||
|
cfg.Database.Database,
|
||||||
|
)
|
||||||
|
auditEmitter, err := middleware.NewDatabaseAuditEmitter(dsn, time.Now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Warning: Failed to create alert manager: %v", err)
|
log.Printf("Warning: Failed to create database audit emitter: %v, using memory emitter", err)
|
||||||
|
auditor = middleware.NewMemoryAuditEmitter()
|
||||||
|
} else {
|
||||||
|
auditor = auditEmitter
|
||||||
|
defer auditEmitter.Close()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("Warning: Database not configured, using memory audit emitter")
|
||||||
|
auditor = middleware.NewMemoryAuditEmitter()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化 token 运行时(内存实现)
|
||||||
|
tokenRuntime := middleware.NewInMemoryTokenRuntime(time.Now)
|
||||||
|
|
||||||
|
// 构建认证中间件配置
|
||||||
|
authMiddlewareConfig := middleware.AuthMiddlewareConfig{
|
||||||
|
Verifier: tokenRuntime,
|
||||||
|
StatusResolver: tokenRuntime,
|
||||||
|
Authorizer: middleware.NewScopeRoleAuthorizer(),
|
||||||
|
Auditor: auditor,
|
||||||
|
ProtectedPrefixes: []string{"/api/v1/supply", "/api/v1/platform"},
|
||||||
|
ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"},
|
||||||
|
Now: time.Now,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化Handler
|
// 初始化Handler
|
||||||
@@ -64,7 +97,7 @@ func main() {
|
|||||||
// 创建Server
|
// 创建Server
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||||
Handler: createMux(h, limiter, alertManager),
|
Handler: createMux(h, limiterMiddleware, authMiddlewareConfig),
|
||||||
ReadTimeout: cfg.Server.ReadTimeout,
|
ReadTimeout: cfg.Server.ReadTimeout,
|
||||||
WriteTimeout: cfg.Server.WriteTimeout,
|
WriteTimeout: cfg.Server.WriteTimeout,
|
||||||
IdleTimeout: cfg.Server.IdleTimeout,
|
IdleTimeout: cfg.Server.IdleTimeout,
|
||||||
@@ -96,56 +129,36 @@ func main() {
|
|||||||
log.Println("Server exited")
|
log.Println("Server exited")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMux(h *handler.Handler, limiter *ratelimit.Middleware, alertMgr *alert.Manager) *http.ServeMux {
|
func createMux(h *handler.Handler, limiter *ratelimit.Middleware, authConfig middleware.AuthMiddlewareConfig) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
// V1 API
|
// 创建认证处理链
|
||||||
v1 := mux.PathPrefix("/v1").Subrouter()
|
authHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.ChatCompletionsHandle(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
// Chat Completions (需要限流和认证)
|
// Chat Completions - 应用限流和认证
|
||||||
v1.HandleFunc("/chat/completions", withMiddleware(h.ChatCompletionsHandle,
|
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
limiter.Limit,
|
limiter.Limit(authHandler.ServeHTTP)(w, r)
|
||||||
authMiddleware(),
|
})
|
||||||
))
|
|
||||||
|
|
||||||
// Completions
|
// Completions - 应用限流和认证
|
||||||
v1.HandleFunc("/completions", withMiddleware(h.CompletionsHandle,
|
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
limiter.Limit,
|
limiter.Limit(authHandler.ServeHTTP)(w, r)
|
||||||
authMiddleware(),
|
})
|
||||||
))
|
|
||||||
|
|
||||||
// Models
|
// Models - 公开接口
|
||||||
v1.HandleFunc("/models", h.ModelsHandle)
|
mux.HandleFunc("/v1/models", h.ModelsHandle)
|
||||||
|
|
||||||
// Health
|
// 旧版路径兼容
|
||||||
|
mux.HandleFunc("/api/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.ChatCompletionsHandle(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Health - 排除认证
|
||||||
mux.HandleFunc("/health", h.HealthHandle)
|
mux.HandleFunc("/health", h.HealthHandle)
|
||||||
|
mux.HandleFunc("/healthz", h.HealthHandle)
|
||||||
|
mux.HandleFunc("/readyz", h.HealthHandle)
|
||||||
|
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
// MiddlewareFunc 中间件函数类型
|
|
||||||
type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc
|
|
||||||
|
|
||||||
// withMiddleware 应用中间件
|
|
||||||
func withMiddleware(h http.HandlerFunc, limiters ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
|
|
||||||
for _, m := range limiters {
|
|
||||||
h = m(h)
|
|
||||||
}
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// authMiddleware 认证中间件(简化实现)
|
|
||||||
func authMiddleware() MiddlewareFunc {
|
|
||||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// 简化: 检查Authorization头
|
|
||||||
if r.Header.Get("Authorization") == "" {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
w.Write([]byte(`{"error":{"message":"Missing Authorization header","code":"AUTH_001"}}`))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,10 +1,20 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Encryption key should be provided via environment variable or secure key management
|
||||||
|
// In production, use a proper key management system (KMS)
|
||||||
|
// Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256
|
||||||
|
var encryptionKey = []byte(getEnv("PASSWORD_ENCRYPTION_KEY", "default-key-32-bytes-long!!!!!!!"))
|
||||||
|
|
||||||
// Config 网关配置
|
// Config 网关配置
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig
|
Server ServerConfig
|
||||||
@@ -30,20 +40,48 @@ type DatabaseConfig struct {
|
|||||||
Host string
|
Host string
|
||||||
Port int
|
Port int
|
||||||
User string
|
User string
|
||||||
Password string
|
Password string // 兼容旧版本,仍可直接使用明文密码(不推荐)
|
||||||
|
EncryptedPassword string // 加密后的密码,优先级高于Password字段
|
||||||
Database string
|
Database string
|
||||||
MaxConns int
|
MaxConns int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPassword 返回解密后的数据库密码
|
||||||
|
// 优先使用EncryptedPassword,如果为空则返回Password字段(兼容旧版本)
|
||||||
|
func (c *DatabaseConfig) GetPassword() string {
|
||||||
|
if c.EncryptedPassword != "" {
|
||||||
|
decrypted, err := decryptPassword(c.EncryptedPassword)
|
||||||
|
if err != nil {
|
||||||
|
// 解密失败时返回原始加密字符串,让后续逻辑处理错误
|
||||||
|
return c.EncryptedPassword
|
||||||
|
}
|
||||||
|
return decrypted
|
||||||
|
}
|
||||||
|
return c.Password
|
||||||
|
}
|
||||||
|
|
||||||
// RedisConfig Redis配置
|
// RedisConfig Redis配置
|
||||||
type RedisConfig struct {
|
type RedisConfig struct {
|
||||||
Host string
|
Host string
|
||||||
Port int
|
Port int
|
||||||
Password string
|
Password string // 兼容旧版本
|
||||||
|
EncryptedPassword string // 加密后的密码
|
||||||
DB int
|
DB int
|
||||||
PoolSize int
|
PoolSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPassword 返回解密后的Redis密码
|
||||||
|
func (c *RedisConfig) GetPassword() string {
|
||||||
|
if c.EncryptedPassword != "" {
|
||||||
|
decrypted, err := decryptPassword(c.EncryptedPassword)
|
||||||
|
if err != nil {
|
||||||
|
return c.EncryptedPassword
|
||||||
|
}
|
||||||
|
return decrypted
|
||||||
|
}
|
||||||
|
return c.Password
|
||||||
|
}
|
||||||
|
|
||||||
// RouterConfig 路由配置
|
// RouterConfig 路由配置
|
||||||
type RouterConfig struct {
|
type RouterConfig struct {
|
||||||
Strategy string // "latency", "cost", "availability", "weighted"
|
Strategy string // "latency", "cost", "availability", "weighted"
|
||||||
@@ -160,3 +198,71 @@ func getEnv(key, defaultValue string) string {
|
|||||||
}
|
}
|
||||||
return defaultValue
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// encryptPassword 使用AES-GCM加密密码
|
||||||
|
func encryptPassword(plaintext string) (string, error) {
|
||||||
|
if plaintext == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptPassword 解密密码
|
||||||
|
func decryptPassword(encrypted string) (string, error) {
|
||||||
|
if encrypted == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是旧格式(未加密的明文)
|
||||||
|
if len(encrypted) < 4 || encrypted[:4] != "enc:" {
|
||||||
|
// 尝试作为新格式解密
|
||||||
|
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
||||||
|
if err != nil {
|
||||||
|
// 如果不是有效的base64,可能是旧格式明文,直接返回
|
||||||
|
return encrypted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
if len(ciphertext) < nonceSize {
|
||||||
|
return "", errors.New("ciphertext too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||||
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(plaintext), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 旧格式:直接返回"enc:"后的部分
|
||||||
|
return encrypted[4:], nil
|
||||||
|
}
|
||||||
|
|||||||
137
gateway/internal/config/config_security_test.go
Normal file
137
gateway/internal/config/config_security_test.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMED03_DatabasePassword_GetPasswordReturnsDecrypted(t *testing.T) {
|
||||||
|
// MED-03: Database password should be encrypted when stored
|
||||||
|
// GetPassword() method should return decrypted password
|
||||||
|
|
||||||
|
// Test with EncryptedPassword field
|
||||||
|
cfg := &DatabaseConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
EncryptedPassword: "dGVzdDEyMw==", // base64 encoded "test123" in AES-GCM format
|
||||||
|
Database: "gateway",
|
||||||
|
MaxConns: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
// After fix: GetPassword() should return decrypted value
|
||||||
|
password := cfg.GetPassword()
|
||||||
|
if password == "" {
|
||||||
|
t.Error("GetPassword should return non-empty decrypted password")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED03_EncryptedPasswordField(t *testing.T) {
|
||||||
|
// Test that encrypted password can be properly encrypted and decrypted
|
||||||
|
originalPassword := "mysecretpassword123"
|
||||||
|
|
||||||
|
// Encrypt the password
|
||||||
|
encrypted, err := encryptPassword(originalPassword)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encryption failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encrypted == "" {
|
||||||
|
t.Error("encryption should produce non-empty result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypted password should be different from original
|
||||||
|
if encrypted == originalPassword {
|
||||||
|
t.Error("encrypted password should differ from original")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be able to decrypt back to original
|
||||||
|
decrypted, err := decryptPassword(encrypted)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decryption failed: %v", err)
|
||||||
|
}
|
||||||
|
if decrypted != originalPassword {
|
||||||
|
t.Errorf("decrypted password should match original, got %s", decrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED03_PasswordGetterReturnsDecrypted(t *testing.T) {
|
||||||
|
// Test that GetPassword returns decrypted password
|
||||||
|
originalPassword := "production_secret_456"
|
||||||
|
encrypted, err := encryptPassword(originalPassword)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encryption failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &DatabaseConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
EncryptedPassword: encrypted,
|
||||||
|
Database: "gateway",
|
||||||
|
MaxConns: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
// After fix: GetPassword() should return decrypted value
|
||||||
|
password := cfg.GetPassword()
|
||||||
|
if password != originalPassword {
|
||||||
|
t.Errorf("GetPassword should return decrypted password, got %s", password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED03_FallbackToPlainPassword(t *testing.T) {
|
||||||
|
// Test that if EncryptedPassword is empty, Password field is used
|
||||||
|
cfg := &DatabaseConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "fallback_password",
|
||||||
|
Database: "gateway",
|
||||||
|
MaxConns: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
password := cfg.GetPassword()
|
||||||
|
if password != "fallback_password" {
|
||||||
|
t.Errorf("GetPassword should fallback to Password field, got %s", password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED03_RedisPassword_GetPasswordReturnsDecrypted(t *testing.T) {
|
||||||
|
// Test Redis password encryption as well
|
||||||
|
originalPassword := "redis_secret_pass"
|
||||||
|
encrypted, err := encryptPassword(originalPassword)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encryption failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &RedisConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 6379,
|
||||||
|
EncryptedPassword: encrypted,
|
||||||
|
DB: 0,
|
||||||
|
PoolSize: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
password := cfg.GetPassword()
|
||||||
|
if password != originalPassword {
|
||||||
|
t.Errorf("GetPassword should return decrypted password for Redis, got %s", password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED03_EncryptEmptyString(t *testing.T) {
|
||||||
|
// Test that empty strings are handled correctly
|
||||||
|
encrypted, err := encryptPassword("")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encryption of empty string failed: %v", err)
|
||||||
|
}
|
||||||
|
if encrypted != "" {
|
||||||
|
t.Error("encryption of empty string should return empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := decryptPassword("")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decryption of empty string failed: %v", err)
|
||||||
|
}
|
||||||
|
if decrypted != "" {
|
||||||
|
t.Error("decryption of empty string should return empty string")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,21 +1,46 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"lijiaoqiao/gateway/internal/adapter"
|
"lijiaoqiao/gateway/internal/adapter"
|
||||||
"lijiaoqiao/gateway/internal/router"
|
"lijiaoqiao/gateway/internal/router"
|
||||||
"lijiaoqiao/gateway/pkg/error"
|
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||||
"lijiaoqiao/gateway/pkg/model"
|
"lijiaoqiao/gateway/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MaxRequestBytes 最大请求体大小 (1MB)
|
||||||
|
const MaxRequestBytes = 1 * 1024 * 1024
|
||||||
|
|
||||||
|
// maxBytesReader 限制读取字节数的reader
|
||||||
|
type maxBytesReader struct {
|
||||||
|
reader io.ReadCloser
|
||||||
|
remaining int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read 实现io.Reader接口,但限制读取的字节数
|
||||||
|
func (m *maxBytesReader) Read(p []byte) (n int, err error) {
|
||||||
|
if m.remaining <= 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if int64(len(p)) > m.remaining {
|
||||||
|
p = p[:m.remaining]
|
||||||
|
}
|
||||||
|
n, err = m.reader.Read(p)
|
||||||
|
m.remaining -= int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 实现io.Closer接口
|
||||||
|
func (m *maxBytesReader) Close() error {
|
||||||
|
return m.reader.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// Handler API处理器
|
// Handler API处理器
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
router *router.Router
|
router *router.Router
|
||||||
@@ -41,23 +66,29 @@ func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request)
|
|||||||
ctx := context.WithValue(r.Context(), "request_id", requestID)
|
ctx := context.WithValue(r.Context(), "request_id", requestID)
|
||||||
ctx = context.WithValue(ctx, "start_time", startTime)
|
ctx = context.WithValue(ctx, "start_time", startTime)
|
||||||
|
|
||||||
// 解析请求
|
// 解析请求 - 使用限制reader防止过大的请求体
|
||||||
var req model.ChatCompletionRequest
|
var req model.ChatCompletionRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
|
||||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
|
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
|
||||||
|
// 检查是否是请求体过大的错误
|
||||||
|
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
|
||||||
|
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证请求
|
// 验证请求
|
||||||
if len(req.Messages) == 0 {
|
if len(req.Messages) == 0 {
|
||||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
|
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选择Provider
|
// 选择Provider
|
||||||
provider, err := h.router.SelectProvider(ctx, req.Model)
|
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,7 +122,7 @@ func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// 记录失败
|
// 记录失败
|
||||||
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
|
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
|
||||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +162,7 @@ func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request)
|
|||||||
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
|
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
|
||||||
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
|
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,7 +174,7 @@ func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *ht
|
|||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
|
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,37 +196,26 @@ func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
|||||||
requestID = generateRequestID()
|
requestID = generateRequestID()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析请求
|
// 解析请求 - 使用限制reader防止过大的请求体
|
||||||
var req model.CompletionRequest
|
var req model.CompletionRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
|
||||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
|
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
|
||||||
|
// 检查是否是请求体过大的错误
|
||||||
|
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
|
||||||
|
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换格式并调用ChatCompletions
|
// 构造消息
|
||||||
chatReq := model.ChatCompletionRequest{
|
|
||||||
Model: req.Model,
|
|
||||||
Temperature: req.Temperature,
|
|
||||||
MaxTokens: req.MaxTokens,
|
|
||||||
TopP: req.TopP,
|
|
||||||
Stream: req.Stream,
|
|
||||||
Stop: req.Stop,
|
|
||||||
Messages: []model.ChatMessage{
|
|
||||||
{Role: "user", Content: req.Prompt},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 复用ChatCompletions逻辑
|
|
||||||
req.Method = "POST"
|
|
||||||
req.URL.Path = "/v1/chat/completions"
|
|
||||||
|
|
||||||
// 重新构造请求体并处理
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
|
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
|
||||||
|
|
||||||
provider, err := h.router.SelectProvider(ctx, req.Model)
|
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,7 +234,7 @@ func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
|
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,7 +321,7 @@ func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{},
|
|||||||
json.NewEncoder(w).Encode(data)
|
json.NewEncoder(w).Encode(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *error.GatewayError) {
|
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *gwerror.GatewayError) {
|
||||||
info := err.GetErrorInfo()
|
info := err.GetErrorInfo()
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err.RequestID != "" {
|
if err.RequestID != "" {
|
||||||
@@ -327,40 +347,3 @@ func marshalJSON(v interface{}) string {
|
|||||||
data, _ := json.Marshal(v)
|
data, _ := json.Marshal(v)
|
||||||
return string(data)
|
return string(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSEReader 流式响应读取器
|
|
||||||
type SSEReader struct {
|
|
||||||
reader *bufio.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSSEReader(r io.Reader) *SSEReader {
|
|
||||||
return &SSEReader{reader: bufio.NewReader(r)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSEReader) ReadLine() (string, error) {
|
|
||||||
line, err := s.reader.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return line[:len(line)-1], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSSEData(line string) string {
|
|
||||||
if len(line) < 6 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if line[:5] != "data:" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return line[6:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func getenv(key, defaultValue string) string {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
getenv = func(key, defaultValue string) string {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
118
gateway/internal/handler/handler_security_test.go
Normal file
118
gateway/internal/handler/handler_security_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"lijiaoqiao/gateway/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMED05_RequestBodySizeLimit(t *testing.T) {
|
||||||
|
// MED-05: Request body size should be limited to prevent DoS attacks
|
||||||
|
// json.Decoder should use MaxBytes to limit request body size
|
||||||
|
|
||||||
|
r := router.NewRouter(router.StrategyLatency)
|
||||||
|
h := NewHandler(r)
|
||||||
|
|
||||||
|
// Create a very large request body (exceeds 1MB limit)
|
||||||
|
largeContent := strings.Repeat("a", 2*1024*1024) // 2MB
|
||||||
|
largeBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "` + largeContent + `"}]}`
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(largeBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.ChatCompletionsHandle(rr, req)
|
||||||
|
|
||||||
|
// After fix: should return 413 Request Entity Too Large
|
||||||
|
if rr.Code != http.StatusRequestEntityTooLarge {
|
||||||
|
t.Errorf("expected status 413 for large request body, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED05_NormalRequestShouldPass(t *testing.T) {
|
||||||
|
// Normal requests should still work
|
||||||
|
r := router.NewRouter(router.StrategyLatency)
|
||||||
|
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
|
||||||
|
r.RegisterProvider("test", prov)
|
||||||
|
|
||||||
|
h := NewHandler(r)
|
||||||
|
|
||||||
|
body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}]}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.ChatCompletionsHandle(rr, req)
|
||||||
|
|
||||||
|
// Should succeed (status 200)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200 for normal request, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED05_EmptyBodyShouldFail(t *testing.T) {
|
||||||
|
// Empty request body should fail
|
||||||
|
r := router.NewRouter(router.StrategyLatency)
|
||||||
|
h := NewHandler(r)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(""))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.ChatCompletionsHandle(rr, req)
|
||||||
|
|
||||||
|
// Should fail with 400 Bad Request
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400 for empty body, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED05_InvalidJSONShouldFail(t *testing.T) {
|
||||||
|
// Invalid JSON should fail
|
||||||
|
r := router.NewRouter(router.StrategyLatency)
|
||||||
|
h := NewHandler(r)
|
||||||
|
|
||||||
|
body := `{invalid json}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.ChatCompletionsHandle(rr, req)
|
||||||
|
|
||||||
|
// Should fail with 400 Bad Request
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400 for invalid JSON, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMaxBytesReaderWrapper tests the MaxBytes reader wrapper behavior
|
||||||
|
func TestMaxBytesReaderWrapper(t *testing.T) {
|
||||||
|
// Test that limiting reader works correctly
|
||||||
|
content := "hello world"
|
||||||
|
limitedReader := io.LimitReader(bytes.NewReader([]byte(content)), 5)
|
||||||
|
|
||||||
|
buf := make([]byte, 20)
|
||||||
|
n, err := limitedReader.Read(buf)
|
||||||
|
|
||||||
|
// Should only read 5 bytes
|
||||||
|
if n != 5 {
|
||||||
|
t.Errorf("expected to read 5 bytes, got %d", n)
|
||||||
|
}
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
t.Errorf("expected no error or EOF, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reading again should return 0 with EOF
|
||||||
|
n2, err2 := limitedReader.Read(buf)
|
||||||
|
if n2 != 0 {
|
||||||
|
t.Errorf("expected 0 bytes on second read, got %d", n2)
|
||||||
|
}
|
||||||
|
if err2 != io.EOF {
|
||||||
|
t.Errorf("expected EOF on second read, got %v", err2)
|
||||||
|
}
|
||||||
|
}
|
||||||
113
gateway/internal/middleware/cors.go
Normal file
113
gateway/internal/middleware/cors.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CORSConfig CORS配置
|
||||||
|
type CORSConfig struct {
|
||||||
|
AllowOrigins []string // 允许的来源域名
|
||||||
|
AllowMethods []string // 允许的HTTP方法
|
||||||
|
AllowHeaders []string // 允许的请求头
|
||||||
|
ExposeHeaders []string // 允许暴露给客户端的响应头
|
||||||
|
AllowCredentials bool // 是否允许携带凭证
|
||||||
|
MaxAge int // 预检请求缓存时间(秒)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCORSConfig 返回默认CORS配置
|
||||||
|
func DefaultCORSConfig() CORSConfig {
|
||||||
|
return CORSConfig{
|
||||||
|
AllowOrigins: []string{"*"}, // 生产环境应限制具体域名
|
||||||
|
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||||
|
AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"},
|
||||||
|
ExposeHeaders: []string{"X-Request-ID"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
MaxAge: 86400, // 24小时
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSMiddleware 创建CORS中间件
|
||||||
|
func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 处理CORS预检请求
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
handleCORSPreflight(w, r, config)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理实际请求的CORS头
|
||||||
|
setCORSHeaders(w, r, config)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCORS Preflight 处理预检请求
|
||||||
|
func handleCORSPreflight(w http.ResponseWriter, r *http.Request, config CORSConfig) {
|
||||||
|
func handleCORS Preflight(w http.ResponseWriter, r *http.Request, config CORSConfig) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
|
||||||
|
// 检查origin是否被允许
|
||||||
|
if !isOriginAllowed(origin, config.AllowOrigins) {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置预检响应头
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ", "))
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ", "))
|
||||||
|
w.Header().Set("Access-Control-Max-Age", string(rune(config.MaxAge)))
|
||||||
|
|
||||||
|
if config.AllowCredentials {
|
||||||
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCORSHeaders 设置实际请求的CORS响应头
|
||||||
|
func setCORSHeaders(w http.ResponseWriter, r *http.Request, config CORSConfig) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
|
||||||
|
// 检查origin是否被允许
|
||||||
|
if !isOriginAllowed(origin, config.AllowOrigins) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
|
||||||
|
if len(config.ExposeHeaders) > 0 {
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", strings.Join(config.ExposeHeaders, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AllowCredentials {
|
||||||
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isOriginAllowed 检查origin是否在允许列表中
|
||||||
|
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||||
|
if origin == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowed := range allowedOrigins {
|
||||||
|
if allowed == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.EqualFold(allowed, origin) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 支持通配符子域名 *.example.com
|
||||||
|
if strings.HasPrefix(allowed, "*.") {
|
||||||
|
domain := allowed[2:]
|
||||||
|
if strings.HasSuffix(origin, domain) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
172
gateway/internal/middleware/cors_test.go
Normal file
172
gateway/internal/middleware/cors_test.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCORSMiddleware_PreflightRequest(t *testing.T) {
|
||||||
|
config := DefaultCORSConfig()
|
||||||
|
config.AllowOrigins = []string{"https://example.com"}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
corsHandler := CORSMiddleware(config)(handler)
|
||||||
|
|
||||||
|
// 模拟OPTIONS预检请求
|
||||||
|
req := httptest.NewRequest("OPTIONS", "/v1/chat/completions", nil)
|
||||||
|
req.Header.Set("Origin", "https://example.com")
|
||||||
|
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||||
|
req.Header.Set("Access-Control-Request-Headers", "Authorization, Content-Type")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
corsHandler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// 预检请求应返回204 No Content
|
||||||
|
if w.Code != http.StatusNoContent {
|
||||||
|
t.Errorf("expected status 204 for preflight, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查CORS响应头
|
||||||
|
if w.Header().Get("Access-Control-Allow-Origin") != "https://example.com" {
|
||||||
|
t.Errorf("expected Access-Control-Allow-Origin to be 'https://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||||
|
t.Error("expected Access-Control-Allow-Methods to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSMiddleware_ActualRequest(t *testing.T) {
|
||||||
|
config := DefaultCORSConfig()
|
||||||
|
config.AllowOrigins = []string{"https://example.com"}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
corsHandler := CORSMiddleware(config)(handler)
|
||||||
|
|
||||||
|
// 模拟实际请求
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
req.Header.Set("Origin", "https://example.com")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
corsHandler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// 正常请求应通过到handler
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查CORS响应头
|
||||||
|
if w.Header().Get("Access-Control-Allow-Origin") != "https://example.com" {
|
||||||
|
t.Errorf("expected Access-Control-Allow-Origin to be 'https://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSMiddleware_DisallowedOrigin(t *testing.T) {
|
||||||
|
config := DefaultCORSConfig()
|
||||||
|
config.AllowOrigins = []string{"https://allowed.com"}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
corsHandler := CORSMiddleware(config)(handler)
|
||||||
|
|
||||||
|
// 模拟来自未允许域名的请求
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
req.Header.Set("Origin", "https://malicious.com")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
corsHandler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// 预检请求应返回403 Forbidden
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("expected status 403 for disallowed origin, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSMiddleware_WildcardOrigin(t *testing.T) {
|
||||||
|
config := DefaultCORSConfig()
|
||||||
|
config.AllowOrigins = []string{"*"} // 允许所有来源
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
corsHandler := CORSMiddleware(config)(handler)
|
||||||
|
|
||||||
|
// 模拟请求
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
req.Header.Set("Origin", "https://any-domain.com")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
corsHandler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// 应该允许
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSMiddleware_SubdomainWildcard(t *testing.T) {
|
||||||
|
config := DefaultCORSConfig()
|
||||||
|
config.AllowOrigins = []string{"*.example.com"}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
corsHandler := CORSMiddleware(config)(handler)
|
||||||
|
|
||||||
|
// 测试子域名
|
||||||
|
tests := []struct {
|
||||||
|
origin string
|
||||||
|
shouldAllow bool
|
||||||
|
}{
|
||||||
|
{"https://app.example.com", true},
|
||||||
|
{"https://api.example.com", true},
|
||||||
|
{"https://example.com", true},
|
||||||
|
{"https://malicious.com", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
req.Header.Set("Origin", tt.origin)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
corsHandler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if tt.shouldAllow && w.Code != http.StatusOK {
|
||||||
|
t.Errorf("origin %s should be allowed, got status %d", tt.origin, w.Code)
|
||||||
|
}
|
||||||
|
if !tt.shouldAllow && w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("origin %s should be forbidden, got status %d", tt.origin, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMED08_CORSConfigurationExists(t *testing.T) {
|
||||||
|
// MED-08: 验证CORS配置存在且可用
|
||||||
|
config := DefaultCORSConfig()
|
||||||
|
|
||||||
|
// 验证默认配置包含必要的设置
|
||||||
|
if len(config.AllowMethods) == 0 {
|
||||||
|
t.Error("default CORS config should have AllowMethods")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.AllowHeaders) == 0 {
|
||||||
|
t.Error("default CORS config should have AllowHeaders")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证CORS中间件函数存在
|
||||||
|
corsMiddleware := CORSMiddleware(config)
|
||||||
|
if corsMiddleware == nil {
|
||||||
|
t.Error("CORSMiddleware should return a valid middleware function")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -39,6 +39,7 @@ const (
|
|||||||
COMMON_RESOURCE_NOT_FOUND ErrorCode = "COMMON_002"
|
COMMON_RESOURCE_NOT_FOUND ErrorCode = "COMMON_002"
|
||||||
COMMON_INTERNAL_ERROR ErrorCode = "COMMON_003"
|
COMMON_INTERNAL_ERROR ErrorCode = "COMMON_003"
|
||||||
COMMON_SERVICE_UNAVAILABLE ErrorCode = "COMMON_004"
|
COMMON_SERVICE_UNAVAILABLE ErrorCode = "COMMON_004"
|
||||||
|
COMMON_REQUEST_TOO_LARGE ErrorCode = "COMMON_005"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrorInfo 错误信息
|
// ErrorInfo 错误信息
|
||||||
@@ -203,6 +204,12 @@ var ErrorDefinitions = map[ErrorCode]ErrorInfo{
|
|||||||
HTTPStatus: 503,
|
HTTPStatus: 503,
|
||||||
Retryable: true,
|
Retryable: true,
|
||||||
},
|
},
|
||||||
|
COMMON_REQUEST_TOO_LARGE: {
|
||||||
|
Code: COMMON_REQUEST_TOO_LARGE,
|
||||||
|
Message: "Request body too large",
|
||||||
|
HTTPStatus: 413,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayError 创建网关错误
|
// NewGatewayError 创建网关错误
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@@ -38,6 +39,7 @@ type AuthMiddleware struct {
|
|||||||
tokenCache *TokenCache
|
tokenCache *TokenCache
|
||||||
tokenBackend TokenStatusBackend
|
tokenBackend TokenStatusBackend
|
||||||
auditEmitter AuditEmitter
|
auditEmitter AuditEmitter
|
||||||
|
bruteForce *BruteForceProtection // 暴力破解保护
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenStatusBackend Token状态后端查询接口
|
// TokenStatusBackend Token状态后端查询接口
|
||||||
@@ -75,6 +77,79 @@ func NewAuthMiddleware(config AuthConfig, tokenCache *TokenCache, tokenBackend T
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BruteForceProtection 暴力破解保护
|
||||||
|
// MED-12: 防止暴力破解攻击,限制登录尝试次数
|
||||||
|
type BruteForceProtection struct {
|
||||||
|
maxAttempts int
|
||||||
|
lockoutDuration time.Duration
|
||||||
|
attempts map[string]*attemptRecord
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type attemptRecord struct {
|
||||||
|
count int
|
||||||
|
lockedUntil time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBruteForceProtection 创建暴力破解保护
|
||||||
|
// maxAttempts: 最大失败尝试次数
|
||||||
|
// lockoutDuration: 锁定时长
|
||||||
|
func NewBruteForceProtection(maxAttempts int, lockoutDuration time.Duration) *BruteForceProtection {
|
||||||
|
return &BruteForceProtection{
|
||||||
|
maxAttempts: maxAttempts,
|
||||||
|
lockoutDuration: lockoutDuration,
|
||||||
|
attempts: make(map[string]*attemptRecord),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordFailedAttempt 记录失败尝试
|
||||||
|
func (b *BruteForceProtection) RecordFailedAttempt(ip string) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
record, exists := b.attempts[ip]
|
||||||
|
if !exists {
|
||||||
|
record = &attemptRecord{}
|
||||||
|
b.attempts[ip] = record
|
||||||
|
}
|
||||||
|
|
||||||
|
record.count++
|
||||||
|
if record.count >= b.maxAttempts {
|
||||||
|
record.lockedUntil = time.Now().Add(b.lockoutDuration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLocked 检查IP是否被锁定
|
||||||
|
func (b *BruteForceProtection) IsLocked(ip string) (bool, time.Duration) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
record, exists := b.attempts[ip]
|
||||||
|
if !exists {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.count >= b.maxAttempts && record.lockedUntil.After(time.Now()) {
|
||||||
|
remaining := time.Until(record.lockedUntil)
|
||||||
|
return true, remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果锁定已过期,重置计数
|
||||||
|
if record.lockedUntil.Before(time.Now()) {
|
||||||
|
record.count = 0
|
||||||
|
record.lockedUntil = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset 重置IP的尝试记录
|
||||||
|
func (b *BruteForceProtection) Reset(ip string) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
delete(b.attempts, ip)
|
||||||
|
}
|
||||||
|
|
||||||
// QueryKeyRejectMiddleware 拒绝外部query key入站
|
// QueryKeyRejectMiddleware 拒绝外部query key入站
|
||||||
// 对应M-016指标
|
// 对应M-016指标
|
||||||
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
|
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
|
||||||
@@ -92,7 +167,7 @@ func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handle
|
|||||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||||
EventName: "token.query_key.rejected",
|
EventName: "token.query_key.rejected",
|
||||||
RequestID: getRequestID(r),
|
RequestID: getRequestID(r),
|
||||||
Route: r.URL.Path,
|
Route: sanitizeRoute(r.URL.Path),
|
||||||
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
||||||
ClientIP: getClientIP(r),
|
ClientIP: getClientIP(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -115,7 +190,7 @@ func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handle
|
|||||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||||
EventName: "token.query_key.rejected",
|
EventName: "token.query_key.rejected",
|
||||||
RequestID: getRequestID(r),
|
RequestID: getRequestID(r),
|
||||||
Route: r.URL.Path,
|
Route: sanitizeRoute(r.URL.Path),
|
||||||
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
||||||
ClientIP: getClientIP(r),
|
ClientIP: getClientIP(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -143,7 +218,7 @@ func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler
|
|||||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||||
EventName: "token.authn.fail",
|
EventName: "token.authn.fail",
|
||||||
RequestID: getRequestID(r),
|
RequestID: getRequestID(r),
|
||||||
Route: r.URL.Path,
|
Route: sanitizeRoute(r.URL.Path),
|
||||||
ResultCode: "AUTH_MISSING_BEARER",
|
ResultCode: "AUTH_MISSING_BEARER",
|
||||||
ClientIP: getClientIP(r),
|
ClientIP: getClientIP(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -175,17 +250,33 @@ func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TokenVerifyMiddleware 校验JWT Token
|
// TokenVerifyMiddleware 校验JWT Token
|
||||||
|
// MED-12: 添加暴力破解保护
|
||||||
func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// MED-12: 检查暴力破解保护
|
||||||
|
if m.bruteForce != nil {
|
||||||
|
clientIP := getClientIP(r)
|
||||||
|
if locked, remaining := m.bruteForce.IsLocked(clientIP); locked {
|
||||||
|
writeAuthError(w, http.StatusTooManyRequests, "AUTH_ACCOUNT_LOCKED",
|
||||||
|
fmt.Sprintf("too many failed attempts, try again in %v", remaining))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tokenString := r.Context().Value(bearerTokenKey).(string)
|
tokenString := r.Context().Value(bearerTokenKey).(string)
|
||||||
|
|
||||||
claims, err := m.verifyToken(tokenString)
|
claims, err := m.verifyToken(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// MED-12: 记录失败尝试
|
||||||
|
if m.bruteForce != nil {
|
||||||
|
m.bruteForce.RecordFailedAttempt(getClientIP(r))
|
||||||
|
}
|
||||||
|
|
||||||
if m.auditEmitter != nil {
|
if m.auditEmitter != nil {
|
||||||
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
m.auditEmitter.Emit(r.Context(), AuditEvent{
|
||||||
EventName: "token.authn.fail",
|
EventName: "token.authn.fail",
|
||||||
RequestID: getRequestID(r),
|
RequestID: getRequestID(r),
|
||||||
Route: r.URL.Path,
|
Route: sanitizeRoute(r.URL.Path),
|
||||||
ResultCode: "AUTH_INVALID_TOKEN",
|
ResultCode: "AUTH_INVALID_TOKEN",
|
||||||
ClientIP: getClientIP(r),
|
ClientIP: getClientIP(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -206,7 +297,7 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
|||||||
RequestID: getRequestID(r),
|
RequestID: getRequestID(r),
|
||||||
TokenID: claims.ID,
|
TokenID: claims.ID,
|
||||||
SubjectID: claims.SubjectID,
|
SubjectID: claims.SubjectID,
|
||||||
Route: r.URL.Path,
|
Route: sanitizeRoute(r.URL.Path),
|
||||||
ResultCode: "AUTH_TOKEN_INACTIVE",
|
ResultCode: "AUTH_TOKEN_INACTIVE",
|
||||||
ClientIP: getClientIP(r),
|
ClientIP: getClientIP(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -229,7 +320,7 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
|
|||||||
RequestID: getRequestID(r),
|
RequestID: getRequestID(r),
|
||||||
TokenID: claims.ID,
|
TokenID: claims.ID,
|
||||||
SubjectID: claims.SubjectID,
|
SubjectID: claims.SubjectID,
|
||||||
Route: r.URL.Path,
|
Route: sanitizeRoute(r.URL.Path),
|
||||||
ResultCode: "OK",
|
ResultCode: "OK",
|
||||||
ClientIP: getClientIP(r),
|
ClientIP: getClientIP(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -259,7 +350,7 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt
|
|||||||
RequestID: getRequestID(r),
|
RequestID: getRequestID(r),
|
||||||
TokenID: claims.ID,
|
TokenID: claims.ID,
|
||||||
SubjectID: claims.SubjectID,
|
SubjectID: claims.SubjectID,
|
||||||
Route: r.URL.Path,
|
Route: sanitizeRoute(r.URL.Path),
|
||||||
ResultCode: "AUTH_SCOPE_DENIED",
|
ResultCode: "AUTH_SCOPE_DENIED",
|
||||||
ClientIP: getClientIP(r),
|
ClientIP: getClientIP(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -413,6 +504,42 @@ func getClientIP(r *http.Request) string {
|
|||||||
return addr
|
return addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanitizeRoute 清理路由字符串,防止路径遍历和其他安全问题
|
||||||
|
// MED-04: 审计日志Route字段需要验证以防止路径遍历攻击
|
||||||
|
func sanitizeRoute(route string) string {
|
||||||
|
if route == "" {
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否包含路径遍历模式
|
||||||
|
// 路径遍历通常包含 .. 或 . 后面跟着 / 或 \
|
||||||
|
for i := 0; i < len(route)-1; i++ {
|
||||||
|
if route[i] == '.' {
|
||||||
|
next := route[i+1]
|
||||||
|
if next == '.' || next == '/' || next == '\\' {
|
||||||
|
// 检测到路径遍历模式,返回安全的替代值
|
||||||
|
return "/sanitized"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 检查反斜杠(Windows路径遍历)
|
||||||
|
if route[i] == '\\' {
|
||||||
|
return "/sanitized"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查null字节
|
||||||
|
if strings.Contains(route, "\x00") {
|
||||||
|
return "/sanitized"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查换行符
|
||||||
|
if strings.Contains(route, "\n") || strings.Contains(route, "\r") {
|
||||||
|
return "/sanitized"
|
||||||
|
}
|
||||||
|
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
// containsScope 检查scope列表是否包含目标scope
|
// containsScope 检查scope列表是否包含目标scope
|
||||||
func containsScope(scopes []string, target string) bool {
|
func containsScope(scopes []string, target string) bool {
|
||||||
for _, scope := range scopes {
|
for _, scope := range scopes {
|
||||||
|
|||||||
32
supply-api/internal/middleware/auth_route_test.go
Normal file
32
supply-api/internal/middleware/auth_route_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeRoute(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"/api/v1/test", "/api/v1/test"},
|
||||||
|
{"/", "/"},
|
||||||
|
{"", ""},
|
||||||
|
{"/api/../../../etc/passwd", "/sanitized"},
|
||||||
|
{"../../etc/passwd", "/sanitized"},
|
||||||
|
{"/api/v1/../admin", "/sanitized"},
|
||||||
|
{"/api\\v1\\admin", "/sanitized"},
|
||||||
|
{"/api/v1" + string(rune(0)) + "/admin", "/sanitized"},
|
||||||
|
{"/api/v1\n/admin", "/sanitized"},
|
||||||
|
{"/api/v1\r/admin", "/sanitized"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
result := sanitizeRoute(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("sanitizeRoute(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
221
supply-api/internal/middleware/auth_security_test.go
Normal file
221
supply-api/internal/middleware/auth_security_test.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMED09_ErrorMessageShouldNotLeakInternalDetails verifies that internal error details
|
||||||
|
// are not exposed to clients
|
||||||
|
func TestMED09_ErrorMessageShouldNotLeakInternalDetails(t *testing.T) {
|
||||||
|
secretKey := "test-secret-key-12345678901234567890"
|
||||||
|
issuer := "test-issuer"
|
||||||
|
|
||||||
|
// Create middleware with a token that will cause an error
|
||||||
|
middleware := &AuthMiddleware{
|
||||||
|
config: AuthConfig{
|
||||||
|
SecretKey: secretKey,
|
||||||
|
Issuer: issuer,
|
||||||
|
},
|
||||||
|
tokenCache: NewTokenCache(),
|
||||||
|
// Intentionally no tokenBackend - to simulate error scenario
|
||||||
|
}
|
||||||
|
|
||||||
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Next handler should not be called for auth failures
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := middleware.TokenVerifyMiddleware(nextHandler)
|
||||||
|
|
||||||
|
// Create a token that will fail verification
|
||||||
|
// Using wrong signing key to simulate internal error
|
||||||
|
claims := TokenClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: "subject:1",
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
SubjectID: "subject:1",
|
||||||
|
Role: "owner",
|
||||||
|
Scope: []string{"read", "write"},
|
||||||
|
TenantID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign with wrong key to cause error
|
||||||
|
wrongKeyToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
wrongKeyTokenString, _ := wrongKeyToken.SignedString([]byte("wrong-secret-key-that-will-cause-error"))
|
||||||
|
|
||||||
|
// Create request with Bearer token
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), bearerTokenKey, wrongKeyTokenString)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Should return 401
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected status 401, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error map
|
||||||
|
errorMap, ok := resp["error"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("response should contain error object")
|
||||||
|
}
|
||||||
|
|
||||||
|
message, ok := errorMap["message"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("error should contain message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error message should NOT contain internal details like:
|
||||||
|
// - "crypto" or "cipher" related terms (implementation details)
|
||||||
|
// - "secret", "key", "password" (credential info)
|
||||||
|
// - "SQL", "database", "connection" (database details)
|
||||||
|
// - File paths or line numbers
|
||||||
|
|
||||||
|
internalKeywords := []string{
|
||||||
|
"crypto/",
|
||||||
|
"/go/src/",
|
||||||
|
".go:",
|
||||||
|
"sql",
|
||||||
|
"database",
|
||||||
|
"connection",
|
||||||
|
"pq",
|
||||||
|
"pgx",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range internalKeywords {
|
||||||
|
if strings.Contains(strings.ToLower(message), keyword) {
|
||||||
|
t.Errorf("MED-09: error message should NOT contain internal details like '%s'. Got: %s", keyword, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The message should be a generic user-safe message
|
||||||
|
if message == "" {
|
||||||
|
t.Error("error message should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMED09_TokenVerifyErrorShouldBeSanitized tests that token verification errors
|
||||||
|
// don't leak sensitive information
|
||||||
|
func TestMED09_TokenVerifyErrorShouldBeSanitized(t *testing.T) {
|
||||||
|
secretKey := "test-secret-key-12345678901234567890"
|
||||||
|
issuer := "test-issuer"
|
||||||
|
|
||||||
|
// Create middleware
|
||||||
|
m := &AuthMiddleware{
|
||||||
|
config: AuthConfig{
|
||||||
|
SecretKey: secretKey,
|
||||||
|
Issuer: issuer,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with various invalid tokens
|
||||||
|
invalidTokens := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "completely invalid token",
|
||||||
|
token: "not.a.valid.token.at.all",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired token",
|
||||||
|
token: createExpiredTestToken(secretKey, issuer),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong issuer token",
|
||||||
|
token: createWrongIssuerTestToken(secretKey, issuer),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range invalidTokens {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := m.verifyToken(tt.token)
|
||||||
|
|
||||||
|
if tt.expectError && err == nil {
|
||||||
|
t.Error("expected error but got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
errMsg := err.Error()
|
||||||
|
|
||||||
|
// Internal error messages should be sanitized
|
||||||
|
// They should NOT contain sensitive keywords
|
||||||
|
sensitiveKeywords := []string{
|
||||||
|
"secret",
|
||||||
|
"password",
|
||||||
|
"credential",
|
||||||
|
"/",
|
||||||
|
".go:",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range sensitiveKeywords {
|
||||||
|
if strings.Contains(strings.ToLower(errMsg), keyword) {
|
||||||
|
t.Errorf("MED-09: internal error should NOT contain '%s'. Got: %s", keyword, errMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create expired token
|
||||||
|
func createExpiredTestToken(secretKey, issuer string) string {
|
||||||
|
claims := TokenClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: "subject:1",
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||||
|
},
|
||||||
|
SubjectID: "subject:1",
|
||||||
|
Role: "owner",
|
||||||
|
Scope: []string{"read", "write"},
|
||||||
|
TenantID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create wrong issuer token
|
||||||
|
func createWrongIssuerTestToken(secretKey, issuer string) string {
|
||||||
|
claims := TokenClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: "wrong-issuer",
|
||||||
|
Subject: "subject:1",
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
SubjectID: "subject:1",
|
||||||
|
Role: "owner",
|
||||||
|
Scope: []string{"read", "write"},
|
||||||
|
TenantID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user