Files
user-system/internal/api/middleware/auth.go
long-agent 09beb173cc feat: complete production readiness improvements
- Fix DIP violations in service layer (device, stats, auth middleware)
- Add ReplaceUserRoles interface method for transaction safety
- Implement Magic Bytes validation for avatar uploads
- Standardize OAuth error handling with ErrOAuthProviderNotSupported
- Use crypto/rand for JWT secret generation instead of weak fixed key
- Apply code formatting with gofumpt and goimports
- Fix staticcheck issues (S1024, S1008, ST1005)
- Add comprehensive quality and functional test reports
- Achieve 36.3% test coverage (up from 16.3%)
- All E2E, integration, and business logic tests passing
2026-04-12 16:15:32 +08:00

216 lines
5.4 KiB
Go

package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/sync/singleflight"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
)
// Interfaces for dependency inversion (DIP) — middleware depends on these abstractions, not concrete types.
type authUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
}
type authUserRoleRepository interface {
GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error)
}
type AuthMiddleware struct {
jwt *auth.JWT
userRepo authUserRepository
userRoleRepo authUserRoleRepository
l1Cache *cache.L1Cache
cacheManager *cache.CacheManager
sfGroup singleflight.Group
}
func NewAuthMiddleware(
jwt *auth.JWT,
userRepo authUserRepository,
userRoleRepo authUserRoleRepository,
l1Cache *cache.L1Cache,
) *AuthMiddleware {
return &AuthMiddleware{
jwt: jwt,
userRepo: userRepo,
userRoleRepo: userRoleRepo,
l1Cache: l1Cache,
}
}
func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) {
m.cacheManager = cm
}
func (m *AuthMiddleware) Required() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌"))
c.Abort()
return
}
claims, err := m.jwt.ValidateAccessToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌"))
c.Abort()
return
}
if m.isJTIBlacklisted(c.Request.Context(), claims.JTI) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
c.Abort()
return
}
if !m.isUserActive(c.Request.Context(), claims.UserID) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录"))
c.Abort()
return
}
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
c.Next()
}
}
func (m *AuthMiddleware) Optional() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token != "" {
claims, err := m.jwt.ValidateAccessToken(token)
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
}
}
c.Next()
}
}
func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool {
if jti == "" {
return false
}
key := "jwt_blacklist:" + jti
// 先检查 L1 缓存
if _, ok := m.l1Cache.Get(key); ok {
return true
}
// L1 miss 时使用 singleflight 防止缓存击穿
// 多个并发请求只会触发一次 L2 查询
if m.cacheManager != nil {
val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) {
found, _ := m.cacheManager.Get(ctx, key)
return found, nil
})
if err == nil && val != nil {
// 回写 L1 缓存
m.l1Cache.Set(key, true, 5*time.Minute)
return true
}
}
return false
}
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
if m.userRoleRepo == nil {
return nil, nil
}
cacheKey := fmt.Sprintf("user_perms:%d", userID)
if cached, ok := m.l1Cache.Get(cacheKey); ok {
if entry, ok := cached.(userPermEntry); ok {
return entry.roles, entry.perms
}
}
// 使用已优化的单次 JOIN 查询获取用户角色和权限
roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID)
if err != nil || len(roles) == 0 {
return nil, nil
}
roleCodes := make([]string, 0, len(roles))
for _, role := range roles {
roleCodes = append(roleCodes, role.Code)
}
permCodes := make([]string, 0, len(permissions))
for _, perm := range permissions {
permCodes = append(permCodes, perm.Code)
}
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute)
return roleCodes, permCodes
}
func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) {
m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID))
}
func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) {
if jti != "" && ttl > 0 {
m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl)
}
}
func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool {
if m.userRepo == nil {
return true
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil {
return false
}
return user.Status == domain.UserStatusActive
}
func (m *AuthMiddleware) extractToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return ""
}
return parts[1]
}
type userPermEntry struct {
roles []string
perms []string
}