fix: 系统性修复安全问题、性能问题和错误处理

安全问题修复:
- X-Forwarded-For越界检查(auth.go)
- checkTokenStatus Context参数传递(auth.go)
- Type Assertion安全检查(auth.go)

性能问题修复:
- TokenCache过期清理机制
- BruteForceProtection过期清理
- InMemoryIdempotencyStore过期清理

错误处理修复:
- AuditStore.Emit返回error
- domain层emitAudit辅助方法
- List方法返回空slice而非nil
- 金额/价格负数验证

架构一致性:
- 统一使用model.RoleHierarchyLevels

新增功能:
- Alert API完整实现(CRUD+Resolve)
- pkg/error错误码集中管理
This commit is contained in:
Your Name
2026-04-07 07:41:25 +08:00
parent 12ce4913cd
commit d5b5a8ece0
21 changed files with 2321 additions and 83 deletions

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"strings"
@@ -14,6 +15,8 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"lijiaoqiao/supply-api/internal/iam/model"
)
// TokenClaims JWT token claims
@@ -84,11 +87,13 @@ type BruteForceProtection struct {
lockoutDuration time.Duration
attempts map[string]*attemptRecord
mu sync.Mutex
cleanupCounter int64 // 清理触发计数器
}
type attemptRecord struct {
count int
lockedUntil time.Time
lastAttempt time.Time // 最后尝试时间,用于过期清理
}
// NewBruteForceProtection 创建暴力破解保护
@@ -114,9 +119,11 @@ func (b *BruteForceProtection) RecordFailedAttempt(ip string) {
}
record.count++
record.lastAttempt = time.Now()
if record.count >= b.maxAttempts {
record.lockedUntil = time.Now().Add(b.lockoutDuration)
}
b.triggerCleanup()
}
// IsLocked 检查IP是否被锁定
@@ -150,6 +157,42 @@ func (b *BruteForceProtection) Reset(ip string) {
delete(b.attempts, ip)
}
// triggerCleanup 触发清理每100次操作清理一次过期记录
func (b *BruteForceProtection) triggerCleanup() {
b.cleanupCounter++
if b.cleanupCounter >= 100 {
b.cleanupCounter = 0
b.cleanupExpiredLocked()
}
}
// cleanupExpiredLocked 清理过期记录(需要持有锁)
// 清理条件锁定已过期且最后尝试时间超过lockoutDuration
func (b *BruteForceProtection) cleanupExpiredLocked() {
now := time.Now()
threshold := now.Add(-b.lockoutDuration * 2) // 超过两倍锁定时长未活动的记录清理
for ip, record := range b.attempts {
// 清理:锁定已过期且长时间无活动
if record.lockedUntil.Before(now) && record.lastAttempt.Before(threshold) {
delete(b.attempts, ip)
}
}
}
// CleanExpired 主动清理过期记录(可由外部定期调用)
func (b *BruteForceProtection) CleanExpired() {
b.mu.Lock()
defer b.mu.Unlock()
b.cleanupExpiredLocked()
}
// Len 返回当前记录数量(用于监控)
func (b *BruteForceProtection) Len() int {
b.mu.Lock()
defer b.mu.Unlock()
return len(b.attempts)
}
// QueryKeyRejectMiddleware 拒绝外部query key入站
// 对应M-016指标
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
@@ -263,7 +306,19 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
}
}
tokenString := r.Context().Value(bearerTokenKey).(string)
// 安全检查确保BearerExtractMiddleware已执行
tokenValue := r.Context().Value(bearerTokenKey)
if tokenValue == nil {
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_MISSING",
"bearer token is missing")
return
}
tokenString, ok := tokenValue.(string)
if !ok || tokenString == "" {
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INVALID",
"bearer token is invalid")
return
}
claims, err := m.verifyToken(tokenString)
if err != nil {
@@ -289,7 +344,7 @@ func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
}
// 检查token状态是否被吊销
status, err := m.checkTokenStatus(claims.ID)
status, err := m.checkTokenStatus(r.Context(), claims.ID)
if err == nil && status != "active" {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
@@ -363,24 +418,21 @@ func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(htt
}
// 检查role权限
roleHierarchy := map[string]int{
"admin": 3,
"owner": 2,
"viewer": 1,
}
// 使用model.GetRoleLevelByCode获取统一角色层级定义
// 路由权限要求
// 路由权限要求(使用详细角色代码)
// viewer: level 10, operator: level 30, org_admin: level 50
routeRoles := map[string]string{
"/api/v1/supply/accounts": "owner",
"/api/v1/supply/packages": "owner",
"/api/v1/supply/settlements": "owner",
"/api/v1/supply/billing": "viewer",
"/api/v1/supplier/billing": "viewer",
"/api/v1/supply/accounts": "org_admin",
"/api/v1/supply/packages": "org_admin",
"/api/v1/supply/settlements": "org_admin",
"/api/v1/supply/billing": "viewer",
"/api/v1/supplier/billing": "viewer",
}
for path, requiredRole := range routeRoles {
if strings.HasPrefix(r.URL.Path, path) {
if roleLevel(claims.Role, roleHierarchy) < roleLevel(requiredRole, roleHierarchy) {
if model.GetRoleLevelByCode(claims.Role) < model.GetRoleLevelByCode(requiredRole) {
writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED",
fmt.Sprintf("required role '%s' is not granted, current role: '%s'", requiredRole, claims.Role))
return
@@ -430,7 +482,7 @@ func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) {
}
// checkTokenStatus 检查token状态从缓存或数据库
func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) {
func (m *AuthMiddleware) checkTokenStatus(ctx context.Context, tokenID string) (string, error) {
if m.tokenCache != nil {
// 先从缓存检查
if status, found := m.tokenCache.Get(tokenID); found {
@@ -440,7 +492,7 @@ func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) {
// 缓存未命中查询后端验证token状态
if m.tokenBackend != nil {
return m.tokenBackend.CheckTokenStatus(context.Background(), tokenID)
return m.tokenBackend.CheckTokenStatus(ctx, tokenID)
}
// 没有后端实现时应该拒绝访问而不是默认active
@@ -472,7 +524,10 @@ func writeAuthError(w http.ResponseWriter, status int, code, message string) {
"message": message,
},
}
json.NewEncoder(w).Encode(resp)
if err := json.NewEncoder(w).Encode(resp); err != nil {
// 记录编码错误(响应已经开始发送,无法回退)
log.Printf("[AUTH_ERROR] failed to encode error response: %v, code=%s", err, code)
}
}
// getRequestID 获取请求ID
@@ -488,7 +543,10 @@ func getClientIP(r *http.Request) string {
// 优先从X-Forwarded-For获取
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
// 安全检查:空字符串已在上层判断,但防御性编程
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
// X-Real-IP
@@ -550,14 +608,6 @@ func containsScope(scopes []string, target string) bool {
return false
}
// roleLevel 获取角色等级
func roleLevel(role string, hierarchy map[string]int) int {
if level, ok := hierarchy[role]; ok {
return level
}
return 0
}
// parseSubjectID 解析subject ID
func parseSubjectID(subject string) int64 {
parts := strings.Split(subject, ":")
@@ -570,7 +620,9 @@ func parseSubjectID(subject string) int64 {
// TokenCache Token状态缓存
type TokenCache struct {
data map[string]cacheEntry
data map[string]cacheEntry
mu sync.RWMutex
cleanup int64 // 清理触发计数器
}
type cacheEntry struct {
@@ -581,34 +633,76 @@ type cacheEntry struct {
// NewTokenCache 创建token缓存
func NewTokenCache() *TokenCache {
return &TokenCache{
data: make(map[string]cacheEntry),
data: make(map[string]cacheEntry),
cleanup: 0,
}
}
// Get 获取token状态
func (c *TokenCache) Get(tokenID string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if entry, ok := c.data[tokenID]; ok {
if time.Now().Before(entry.expires) {
return entry.status, true
}
delete(c.data, tokenID)
}
return "", false
}
// Set 设置token状态
func (c *TokenCache) Set(tokenID, status string, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.data[tokenID] = cacheEntry{
status: status,
expires: time.Now().Add(ttl),
}
c.triggerCleanup()
}
// Invalidate 使token失效
func (c *TokenCache) Invalidate(tokenID string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.data, tokenID)
}
// triggerCleanup 触发清理每100次操作清理一次过期条目
func (c *TokenCache) triggerCleanup() {
c.cleanup++
if c.cleanup >= 100 {
c.cleanup = 0
c.cleanupExpiredLocked()
}
}
// cleanupExpiredLocked 清理过期条目(需要持有锁)
func (c *TokenCache) cleanupExpiredLocked() {
now := time.Now()
for tokenID, entry := range c.data {
if now.After(entry.expires) {
delete(c.data, tokenID)
}
}
}
// CleanExpired 主动清理过期条目(可由外部定期调用)
func (c *TokenCache) CleanExpired() {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupExpiredLocked()
}
// Len 返回缓存条目数量(用于监控)
func (c *TokenCache) Len() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.data)
}
// ComputeFingerprint 计算凭证指纹(用于审计)
func ComputeFingerprint(credential string) string {
hash := sha256.Sum256([]byte(credential))

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"strings"
@@ -8,6 +9,8 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"lijiaoqiao/supply-api/internal/iam/model"
)
func TestTokenVerify(t *testing.T) {
@@ -248,27 +251,25 @@ func TestContainsScope(t *testing.T) {
}
func TestRoleLevel(t *testing.T) {
hierarchy := map[string]int{
"admin": 3,
"owner": 2,
"viewer": 1,
}
tests := []struct {
role string
expected int
}{
{"admin", 3},
{"owner", 2},
{"viewer", 1},
{"super_admin", 100},
{"org_admin", 50},
{"supply_admin", 40},
{"operator", 30},
{"developer", 20},
{"finops", 20},
{"viewer", 10},
{"unknown", 0},
}
for _, tt := range tests {
t.Run(tt.role, func(t *testing.T) {
result := roleLevel(tt.role, hierarchy)
result := model.GetRoleLevelByCode(tt.role)
if result != tt.expected {
t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected)
t.Errorf("GetRoleLevelByCode(%s) = %d, want %d", tt.role, result, tt.expected)
}
})
}
@@ -411,7 +412,7 @@ func TestMED02_TokenCacheMiss_ShouldNotAssumeActive(t *testing.T) {
}
// act - 查询一个不在缓存中的token
status, err := middleware.checkTokenStatus("nonexistent-token-id")
status, err := middleware.checkTokenStatus(context.Background(), "nonexistent-token-id")
// assert - 缓存未命中且没有后端时应该返回错误(安全修复)
// 修复前bug缓存未命中时默认返回"active"