Files
lijiaoqiao/supply-api/internal/middleware/auth.go
Your Name d5b5a8ece0 fix: 系统性修复安全问题、性能问题和错误处理
安全问题修复:
- X-Forwarded-For越界检查(auth.go)
- checkTokenStatus Context参数传递(auth.go)
- Type Assertion安全检查(auth.go)

性能问题修复:
- TokenCache过期清理机制
- BruteForceProtection过期清理
- InMemoryIdempotencyStore过期清理

错误处理修复:
- AuditStore.Emit返回error
- domain层emitAudit辅助方法
- List方法返回空slice而非nil
- 金额/价格负数验证

架构一致性:
- 统一使用model.RoleHierarchyLevels

新增功能:
- Alert API完整实现(CRUD+Resolve)
- pkg/error错误码集中管理
2026-04-07 07:41:25 +08:00

711 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"lijiaoqiao/supply-api/internal/iam/model"
)
// TokenClaims JWT token claims
type TokenClaims struct {
jwt.RegisteredClaims
SubjectID string `json:"subject_id"`
Role string `json:"role"`
Scope []string `json:"scope"`
TenantID int64 `json:"tenant_id"`
}
// AuthConfig 鉴权中间件配置
type AuthConfig struct {
SecretKey string
Issuer string
CacheTTL time.Duration // token状态缓存TTL
Enabled bool // 是否启用鉴权
}
// AuthMiddleware 鉴权中间件
type AuthMiddleware struct {
config AuthConfig
tokenCache *TokenCache
tokenBackend TokenStatusBackend
auditEmitter AuditEmitter
bruteForce *BruteForceProtection // 暴力破解保护
}
// TokenStatusBackend Token状态后端查询接口
type TokenStatusBackend interface {
CheckTokenStatus(ctx context.Context, tokenID string) (string, error)
}
// AuditEmitter 审计事件发射器
type AuditEmitter interface {
Emit(ctx context.Context, event AuditEvent) error
}
// AuditEvent 审计事件
type AuditEvent struct {
EventName string
RequestID string
TokenID string
SubjectID string
Route string
ResultCode string
ClientIP string
CreatedAt time.Time
}
// NewAuthMiddleware 创建鉴权中间件
func NewAuthMiddleware(config AuthConfig, tokenCache *TokenCache, tokenBackend TokenStatusBackend, auditEmitter AuditEmitter) *AuthMiddleware {
if config.CacheTTL == 0 {
config.CacheTTL = 30 * time.Second
}
return &AuthMiddleware{
config: config,
tokenCache: tokenCache,
tokenBackend: tokenBackend,
auditEmitter: auditEmitter,
}
}
// BruteForceProtection 暴力破解保护
// MED-12: 防止暴力破解攻击,限制登录尝试次数
type BruteForceProtection struct {
maxAttempts int
lockoutDuration time.Duration
attempts map[string]*attemptRecord
mu sync.Mutex
cleanupCounter int64 // 清理触发计数器
}
type attemptRecord struct {
count int
lockedUntil time.Time
lastAttempt time.Time // 最后尝试时间,用于过期清理
}
// NewBruteForceProtection 创建暴力破解保护
// maxAttempts: 最大失败尝试次数
// lockoutDuration: 锁定时长
func NewBruteForceProtection(maxAttempts int, lockoutDuration time.Duration) *BruteForceProtection {
return &BruteForceProtection{
maxAttempts: maxAttempts,
lockoutDuration: lockoutDuration,
attempts: make(map[string]*attemptRecord),
}
}
// RecordFailedAttempt 记录失败尝试
func (b *BruteForceProtection) RecordFailedAttempt(ip string) {
b.mu.Lock()
defer b.mu.Unlock()
record, exists := b.attempts[ip]
if !exists {
record = &attemptRecord{}
b.attempts[ip] = record
}
record.count++
record.lastAttempt = time.Now()
if record.count >= b.maxAttempts {
record.lockedUntil = time.Now().Add(b.lockoutDuration)
}
b.triggerCleanup()
}
// IsLocked 检查IP是否被锁定
func (b *BruteForceProtection) IsLocked(ip string) (bool, time.Duration) {
b.mu.Lock()
defer b.mu.Unlock()
record, exists := b.attempts[ip]
if !exists {
return false, 0
}
if record.count >= b.maxAttempts && record.lockedUntil.After(time.Now()) {
remaining := time.Until(record.lockedUntil)
return true, remaining
}
// 如果锁定已过期,重置计数
if record.lockedUntil.Before(time.Now()) {
record.count = 0
record.lockedUntil = time.Time{}
}
return false, 0
}
// Reset 重置IP的尝试记录
func (b *BruteForceProtection) Reset(ip string) {
b.mu.Lock()
defer b.mu.Unlock()
delete(b.attempts, ip)
}
// triggerCleanup 触发清理每100次操作清理一次过期记录
func (b *BruteForceProtection) triggerCleanup() {
b.cleanupCounter++
if b.cleanupCounter >= 100 {
b.cleanupCounter = 0
b.cleanupExpiredLocked()
}
}
// cleanupExpiredLocked 清理过期记录(需要持有锁)
// 清理条件锁定已过期且最后尝试时间超过lockoutDuration
func (b *BruteForceProtection) cleanupExpiredLocked() {
now := time.Now()
threshold := now.Add(-b.lockoutDuration * 2) // 超过两倍锁定时长未活动的记录清理
for ip, record := range b.attempts {
// 清理:锁定已过期且长时间无活动
if record.lockedUntil.Before(now) && record.lastAttempt.Before(threshold) {
delete(b.attempts, ip)
}
}
}
// CleanExpired 主动清理过期记录(可由外部定期调用)
func (b *BruteForceProtection) CleanExpired() {
b.mu.Lock()
defer b.mu.Unlock()
b.cleanupExpiredLocked()
}
// Len 返回当前记录数量(用于监控)
func (b *BruteForceProtection) Len() int {
b.mu.Lock()
defer b.mu.Unlock()
return len(b.attempts)
}
// QueryKeyRejectMiddleware 拒绝外部query key入站
// 对应M-016指标
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 检查query string中的可疑参数
queryParams := r.URL.Query()
// 禁止的query参数名
blockedParams := []string{"key", "api_key", "token", "secret", "password", "credential"}
for _, param := range blockedParams {
if _, exists := queryParams[param]; exists {
// 触发M-016指标事件
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.query_key.rejected",
RequestID: getRequestID(r),
Route: sanitizeRoute(r.URL.Path),
ResultCode: "QUERY_KEY_NOT_ALLOWED",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED",
"external query key is not allowed, use Authorization header")
return
}
}
// 检查是否有API Key在query中即使参数名不同
for param := range queryParams {
lowerParam := strings.ToLower(param)
if strings.Contains(lowerParam, "key") || strings.Contains(lowerParam, "token") || strings.Contains(lowerParam, "secret") {
// 可能是编码的API Key
if len(queryParams.Get(param)) > 20 {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.query_key.rejected",
RequestID: getRequestID(r),
Route: sanitizeRoute(r.URL.Path),
ResultCode: "QUERY_KEY_NOT_ALLOWED",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED",
"suspicious query parameter detected")
return
}
}
}
next.ServeHTTP(w, r)
})
}
// BearerExtractMiddleware 提取Bearer Token
func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.fail",
RequestID: getRequestID(r),
Route: sanitizeRoute(r.URL.Path),
ResultCode: "AUTH_MISSING_BEARER",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER",
"Authorization header with Bearer token is required")
return
}
if !strings.HasPrefix(authHeader, "Bearer ") {
writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_FORMAT",
"Authorization header must be in format: Bearer <token>")
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == "" {
writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER",
"Bearer token is empty")
return
}
// 将token存入context供后续使用
ctx := context.WithValue(r.Context(), bearerTokenKey, tokenString)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// TokenVerifyMiddleware 校验JWT Token
// MED-12: 添加暴力破解保护
func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// MED-12: 检查暴力破解保护
if m.bruteForce != nil {
clientIP := getClientIP(r)
if locked, remaining := m.bruteForce.IsLocked(clientIP); locked {
writeAuthError(w, http.StatusTooManyRequests, "AUTH_ACCOUNT_LOCKED",
fmt.Sprintf("too many failed attempts, try again in %v", remaining))
return
}
}
// 安全检查确保BearerExtractMiddleware已执行
tokenValue := r.Context().Value(bearerTokenKey)
if tokenValue == nil {
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_MISSING",
"bearer token is missing")
return
}
tokenString, ok := tokenValue.(string)
if !ok || tokenString == "" {
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INVALID",
"bearer token is invalid")
return
}
claims, err := m.verifyToken(tokenString)
if err != nil {
// MED-12: 记录失败尝试
if m.bruteForce != nil {
m.bruteForce.RecordFailedAttempt(getClientIP(r))
}
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.fail",
RequestID: getRequestID(r),
Route: sanitizeRoute(r.URL.Path),
ResultCode: "AUTH_INVALID_TOKEN",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_TOKEN",
"token verification failed: "+err.Error())
return
}
// 检查token状态是否被吊销
status, err := m.checkTokenStatus(r.Context(), claims.ID)
if err == nil && status != "active" {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.fail",
RequestID: getRequestID(r),
TokenID: claims.ID,
SubjectID: claims.SubjectID,
Route: sanitizeRoute(r.URL.Path),
ResultCode: "AUTH_TOKEN_INACTIVE",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INACTIVE",
"token is revoked or expired")
return
}
// 将claims存入context
ctx := context.WithValue(r.Context(), tokenClaimsKey, claims)
ctx = WithTenantID(ctx, claims.TenantID)
ctx = WithOperatorID(ctx, parseSubjectID(claims.SubjectID))
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.success",
RequestID: getRequestID(r),
TokenID: claims.ID,
SubjectID: claims.SubjectID,
Route: sanitizeRoute(r.URL.Path),
ResultCode: "OK",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// ScopeRoleAuthzMiddleware 权限校验中间件
func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(tokenClaimsKey).(*TokenClaims)
if !ok {
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
"authentication context is missing")
return
}
// 检查scope
if requiredScope != "" && !containsScope(claims.Scope, requiredScope) {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authz.denied",
RequestID: getRequestID(r),
TokenID: claims.ID,
SubjectID: claims.SubjectID,
Route: sanitizeRoute(r.URL.Path),
ResultCode: "AUTH_SCOPE_DENIED",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusForbidden, "AUTH_SCOPE_DENIED",
fmt.Sprintf("required scope '%s' is not granted", requiredScope))
return
}
// 检查role权限
// 使用model.GetRoleLevelByCode获取统一角色层级定义
// 路由权限要求(使用详细角色代码)
// viewer: level 10, operator: level 30, org_admin: level 50
routeRoles := map[string]string{
"/api/v1/supply/accounts": "org_admin",
"/api/v1/supply/packages": "org_admin",
"/api/v1/supply/settlements": "org_admin",
"/api/v1/supply/billing": "viewer",
"/api/v1/supplier/billing": "viewer",
}
for path, requiredRole := range routeRoles {
if strings.HasPrefix(r.URL.Path, path) {
if model.GetRoleLevelByCode(claims.Role) < model.GetRoleLevelByCode(requiredRole) {
writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED",
fmt.Sprintf("required role '%s' is not granted, current role: '%s'", requiredRole, claims.Role))
return
}
}
}
next.ServeHTTP(w, r)
})
}
}
// verifyToken 校验JWT token
func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
// 严格验证算法只接受HS256
if token.Method.Alg() != jwt.SigningMethodHS256.Alg() {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(m.config.SecretKey), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
// 验证issuer
if claims.Issuer != m.config.Issuer {
return nil, errors.New("invalid token issuer")
}
// 验证expiration
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(time.Now()) {
return nil, errors.New("token has expired")
}
// 验证not before
if claims.NotBefore != nil && claims.NotBefore.Time.After(time.Now()) {
return nil, errors.New("token is not yet valid")
}
return claims, nil
}
return nil, errors.New("invalid token")
}
// checkTokenStatus 检查token状态从缓存或数据库
func (m *AuthMiddleware) checkTokenStatus(ctx context.Context, tokenID string) (string, error) {
if m.tokenCache != nil {
// 先从缓存检查
if status, found := m.tokenCache.Get(tokenID); found {
return status, nil
}
}
// 缓存未命中查询后端验证token状态
if m.tokenBackend != nil {
return m.tokenBackend.CheckTokenStatus(ctx, tokenID)
}
// 没有后端实现时应该拒绝访问而不是默认active
return "", errors.New("token status unknown: backend not configured")
}
// GetTokenClaims 从context获取token claims
func GetTokenClaims(ctx context.Context) *TokenClaims {
if claims, ok := ctx.Value(tokenClaimsKey).(*TokenClaims); ok {
return claims
}
return nil
}
// context keys
const (
bearerTokenKey contextKey = "bearer_token"
tokenClaimsKey contextKey = "token_claims"
)
// writeAuthError 写入鉴权错误
func writeAuthError(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
resp := map[string]interface{}{
"request_id": "",
"error": map[string]string{
"code": code,
"message": message,
},
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
// 记录编码错误(响应已经开始发送,无法回退)
log.Printf("[AUTH_ERROR] failed to encode error response: %v, code=%s", err, code)
}
}
// getRequestID 获取请求ID
func getRequestID(r *http.Request) string {
if id := r.Header.Get("X-Request-Id"); id != "" {
return id
}
return r.Header.Get("X-Request-ID")
}
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
// 优先从X-Forwarded-For获取
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
// 安全检查:空字符串已在上层判断,但防御性编程
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
// X-Real-IP
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// RemoteAddr
addr := r.RemoteAddr
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
}
return addr
}
// sanitizeRoute 清理路由字符串,防止路径遍历和其他安全问题
// MED-04: 审计日志Route字段需要验证以防止路径遍历攻击
func sanitizeRoute(route string) string {
if route == "" {
return route
}
// 检查是否包含路径遍历模式
// 路径遍历通常包含 .. 或 . 后面跟着 / 或 \
for i := 0; i < len(route)-1; i++ {
if route[i] == '.' {
next := route[i+1]
if next == '.' || next == '/' || next == '\\' {
// 检测到路径遍历模式,返回安全的替代值
return "/sanitized"
}
}
// 检查反斜杠Windows路径遍历
if route[i] == '\\' {
return "/sanitized"
}
}
// 检查null字节
if strings.Contains(route, "\x00") {
return "/sanitized"
}
// 检查换行符
if strings.Contains(route, "\n") || strings.Contains(route, "\r") {
return "/sanitized"
}
return route
}
// containsScope 检查scope列表是否包含目标scope
func containsScope(scopes []string, target string) bool {
for _, scope := range scopes {
if scope == target || scope == "*" {
return true
}
}
return false
}
// parseSubjectID 解析subject ID
func parseSubjectID(subject string) int64 {
parts := strings.Split(subject, ":")
if len(parts) >= 2 {
id, _ := strconv.ParseInt(parts[1], 10, 64)
return id
}
return 0
}
// TokenCache Token状态缓存
type TokenCache struct {
data map[string]cacheEntry
mu sync.RWMutex
cleanup int64 // 清理触发计数器
}
type cacheEntry struct {
status string
expires time.Time
}
// NewTokenCache 创建token缓存
func NewTokenCache() *TokenCache {
return &TokenCache{
data: make(map[string]cacheEntry),
cleanup: 0,
}
}
// Get 获取token状态
func (c *TokenCache) Get(tokenID string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if entry, ok := c.data[tokenID]; ok {
if time.Now().Before(entry.expires) {
return entry.status, true
}
}
return "", false
}
// Set 设置token状态
func (c *TokenCache) Set(tokenID, status string, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.data[tokenID] = cacheEntry{
status: status,
expires: time.Now().Add(ttl),
}
c.triggerCleanup()
}
// Invalidate 使token失效
func (c *TokenCache) Invalidate(tokenID string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.data, tokenID)
}
// triggerCleanup 触发清理每100次操作清理一次过期条目
func (c *TokenCache) triggerCleanup() {
c.cleanup++
if c.cleanup >= 100 {
c.cleanup = 0
c.cleanupExpiredLocked()
}
}
// cleanupExpiredLocked 清理过期条目(需要持有锁)
func (c *TokenCache) cleanupExpiredLocked() {
now := time.Now()
for tokenID, entry := range c.data {
if now.After(entry.expires) {
delete(c.data, tokenID)
}
}
}
// CleanExpired 主动清理过期条目(可由外部定期调用)
func (c *TokenCache) CleanExpired() {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupExpiredLocked()
}
// Len 返回缓存条目数量(用于监控)
func (c *TokenCache) Len() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.data)
}
// ComputeFingerprint 计算凭证指纹(用于审计)
func ComputeFingerprint(credential string) string {
hash := sha256.Sum256([]byte(credential))
return hex.EncodeToString(hash[:])
}