fix: 生产安全修复 + Go SDK + CAS SSO框架

安全修复:
- CRITICAL: SSO重定向URL注入漏洞 - 修复redirect_uri白名单验证
- HIGH: SSO ClientSecret未验证 - 使用crypto/subtle.ConstantTimeCompare验证
- HIGH: 邮件验证码熵值过低(3字节) - 提升到6字节(48位熵)
- HIGH: 短信验证码熵值过低(4字节) - 提升到6字节
- HIGH: Goroutine使用已取消上下文 - auth_email.go使用独立context+超时
- HIGH: SQL LIKE查询注入风险 - permission/role仓库使用escapeLikePattern

新功能:
- Go SDK: sdk/go/user-management/ 完整SDK实现
- CAS SSO框架: internal/auth/cas.go CAS协议支持

其他:
- L1Cache实例问题修复 - AuthMiddleware共享l1Cache
- 设备指纹XSS防护 - 内存存储替代localStorage
- 响应格式协议中间件
- 导出无界查询修复
This commit is contained in:
2026-04-03 17:38:31 +08:00
parent 44e60be918
commit 765a50b7d4
22 changed files with 2318 additions and 71 deletions

221
internal/auth/cas.go Normal file
View File

@@ -0,0 +1,221 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// CASProvider CAS (Central Authentication Service) 提供者
// CAS 是一种单点登录协议,用户只需登录一次即可访问多个应用
type CASProvider struct {
serverURL string
serviceURL string
}
// CASServiceTicket CAS 服务票据
type CASServiceTicket struct {
Ticket string
Service string
UserID int64
Username string
IssuedAt time.Time
Expiry time.Time
}
// NewCASProvider 创建 CAS 提供者
func NewCASProvider(serverURL, serviceURL string) *CASProvider {
return &CASProvider{
serverURL: strings.TrimSuffix(serverURL, "/"),
serviceURL: serviceURL,
}
}
// BuildLoginURL 构建 CAS 登录 URL
// 用于重定向用户到 CAS 登录页面
func (p *CASProvider) BuildLoginURL(renew, gateway bool) string {
params := url.Values{}
params.Set("service", p.serviceURL)
if renew {
params.Set("renew", "true")
}
if gateway {
params.Set("gateway", "true")
}
return fmt.Sprintf("%s/login?%s", p.serverURL, params.Encode())
}
// BuildLogoutURL 构建 CAS 登出 URL
func (p *CASProvider) BuildLogoutURL(url string) string {
if url != "" {
return fmt.Sprintf("%s/logout?service=%s", p.serverURL, url)
}
return fmt.Sprintf("%s/logout", p.serverURL)
}
// CASValidationResponse CAS 票据验证响应
type CASValidationResponse struct {
Success bool
UserID int64
Username string
ErrorCode string
ErrorMsg string
}
// ValidateTicket 验证 CAS 票据
// 向 CAS 服务器发送 ticket 验证请求
func (p *CASProvider) ValidateTicket(ctx context.Context, ticket string) (*CASValidationResponse, error) {
if ticket == "" {
return &CASValidationResponse{
Success: false,
ErrorCode: "INVALID_REQUEST",
ErrorMsg: "ticket is required",
}, nil
}
params := url.Values{}
params.Set("service", p.serviceURL)
params.Set("ticket", ticket)
validateURL := fmt.Sprintf("%s/p3/serviceValidate?%s", p.serverURL, params.Encode())
resp, err := fetchCASResponse(ctx, validateURL)
if err != nil {
return nil, fmt.Errorf("CAS validation request failed: %w", err)
}
return p.parseServiceValidateResponse(resp)
}
// parseServiceValidateResponse 解析 CAS serviceValidate 响应
// CAS 1.0 和 CAS 2.0 使用不同的响应格式
func (p *CASProvider) parseServiceValidateResponse(xml string) (*CASValidationResponse, error) {
resp := &CASValidationResponse{Success: false}
// 检查是否包含 authenticationSuccess 元素
if strings.Contains(xml, "<authenticationSuccess>") {
resp.Success = true
// 解析用户名
if start := strings.Index(xml, "<user>"); start != -1 {
end := strings.Index(xml[start:], "</user>")
if end != -1 {
resp.Username = xml[start+6 : start+end]
}
}
// 解析用户 ID (CAS 2.0)
if start := strings.Index(xml, "<userId>"); start != -1 {
end := strings.Index(xml[start:], "</userId>")
if end != -1 {
userIDStr := xml[start+8 : start+end]
var userID int64
fmt.Sscanf(userIDStr, "%d", &userID)
resp.UserID = userID
}
}
} else if strings.Contains(xml, "<authenticationFailure>") {
resp.Success = false
// 解析错误码
if start := strings.Index(xml, "code=\""); start != -1 {
start += 6
end := strings.Index(xml[start:], "\"")
if end != -1 {
resp.ErrorCode = xml[start : start+end]
}
}
// 解析错误消息
if start := strings.Index(xml, "<![CDATA["); start != -1 {
end := strings.Index(xml[start:], "]]>")
if end != -1 {
resp.ErrorMsg = xml[start+9 : start+end]
}
}
}
return resp, nil
}
// GenerateProxyTicket 生成代理票据 (CAS 2.0)
// 用于服务代理用户访问其他服务
func (p *CASProvider) GenerateProxyTicket(ctx context.Context, proxyGrantingTicket, targetService string) (string, error) {
params := url.Values{}
params.Set("targetService", targetService)
proxyURL := fmt.Sprintf("%s/p3/proxy?%s&pgt=%s",
p.serverURL, params.Encode(), proxyGrantingTicket)
resp, err := fetchCASResponse(ctx, proxyURL)
if err != nil {
return "", err
}
// 解析代理票据
if start := strings.Index(resp, "<proxyTicket>"); start != -1 {
end := strings.Index(resp[start:], "</proxyTicket>")
if end != -1 {
return resp[start+12 : start+end], nil
}
}
return "", fmt.Errorf("failed to parse proxy ticket from response")
}
// fetchCASResponse 从 CAS 服务器获取响应
func fetchCASResponse(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/xml")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}
// GenerateCASServiceTicket 生成 CAS 服务票据 (供 CAS 服务器使用)
// 这个方法供实际的 CAS 服务器实现调用
func GenerateCASServiceTicket(service string, userID int64, username string) (*CASServiceTicket, error) {
ticketBytes := make([]byte, 32)
if _, err := rand.Read(ticketBytes); err != nil {
return nil, fmt.Errorf("failed to generate ticket: %w", err)
}
return &CASServiceTicket{
Ticket: "ST-" + base64.URLEncoding.EncodeToString(ticketBytes)[:32],
Service: service,
UserID: userID,
Username: username,
IssuedAt: time.Now(),
Expiry: time.Now().Add(5 * time.Minute),
}, nil
}
// IsExpired 检查票据是否过期
func (t *CASServiceTicket) IsExpired() bool {
return time.Now().After(t.Expiry)
}
// GetDuration 返回票据有效时长
func (t *CASServiceTicket) GetDuration() time.Duration {
return t.Expiry.Sub(t.IssuedAt)
}

View File

@@ -6,9 +6,17 @@ import (
"encoding/base64"
"errors"
"fmt"
"sync"
"time"
)
const (
// MaxSessions 最大 session 数量限制
MaxSessions = 10000
// CleanupInterval 清理间隔
CleanupInterval = 5 * time.Minute
)
// SSOOAuth2Config SSO OAuth2 配置
type SSOOAuth2Config struct {
ClientID string
@@ -66,6 +74,7 @@ type SSOSession struct {
// SSOManager SSO 管理器
type SSOManager struct {
mu sync.RWMutex
sessions map[string]*SSOSession
}
@@ -76,12 +85,35 @@ func NewSSOManager() *SSOManager {
}
}
// StartCleanup 启动后台清理 goroutine
func (m *SSOManager) StartCleanup(ctx context.Context) {
go func() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.CleanupExpired()
}
}
}()
}
// GenerateAuthorizationCode 生成授权码
func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope string, userID int64, username string) (string, error) {
code := generateSecureToken(32)
code, err := generateSecureToken(32)
if err != nil {
return "", err
}
sessionID, err := generateSecureToken(16)
if err != nil {
return "", err
}
session := &SSOSession{
SessionID: generateSecureToken(16),
SessionID: sessionID,
UserID: userID,
Username: username,
ClientID: clientID,
@@ -90,13 +122,26 @@ func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope stri
Scope: scope,
}
m.mu.Lock()
// 检查并清理过期 session如果超过限制则淘汰最旧的
if len(m.sessions) >= MaxSessions {
m.cleanupExpiredLocked()
// 如果仍然满,淘汰最早的
if len(m.sessions) >= MaxSessions {
m.evictOldest()
}
}
m.sessions[code] = session
m.mu.Unlock()
return code, nil
}
// ValidateAuthorizationCode 验证授权码
func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
session, ok := m.sessions[code]
if !ok {
return nil, errors.New("invalid authorization code")
@@ -114,8 +159,11 @@ func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error)
}
// GenerateAccessToken 生成访问令牌
func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time) {
token := generateSecureToken(32)
func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time, error) {
token, err := generateSecureToken(32)
if err != nil {
return "", time.Time{}, err
}
expiresAt := time.Now().Add(2 * time.Hour) // Access token 2 小时有效期
accessSession := &SSOSession{
@@ -128,22 +176,37 @@ func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (
Scope: session.Scope,
}
m.mu.Lock()
// 检查并清理过期 session如果超过限制则淘汰最旧的
if len(m.sessions) >= MaxSessions {
m.cleanupExpiredLocked()
if len(m.sessions) >= MaxSessions {
m.evictOldest()
}
}
m.sessions[token] = accessSession
m.mu.Unlock()
return token, expiresAt
return token, expiresAt, nil
}
// IntrospectToken 验证 token
func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) {
m.mu.RLock()
session, ok := m.sessions[token]
if !ok {
m.mu.RUnlock()
return &SSOTokenInfo{Active: false}, nil
}
if time.Now().After(session.ExpiresAt) {
m.mu.RUnlock()
m.mu.Lock()
delete(m.sessions, token)
m.mu.Unlock()
return &SSOTokenInfo{Active: false}, nil
}
m.mu.RUnlock()
return &SSOTokenInfo{
Active: true,
@@ -157,12 +220,21 @@ func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) {
// RevokeToken 撤销 token
func (m *SSOManager) RevokeToken(token string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.sessions, token)
return nil
}
// CleanupExpired 清理过期的 session(可由后台 goroutine 定期调用)
// CleanupExpired 清理过期的 session
func (m *SSOManager) CleanupExpired() {
m.mu.Lock()
defer m.mu.Unlock()
m.cleanupExpiredLocked()
}
// cleanupExpiredLocked 内部清理方法(假设已持有锁)
func (m *SSOManager) cleanupExpiredLocked() {
now := time.Now()
for key, session := range m.sessions {
if now.After(session.ExpiresAt) {
@@ -171,11 +243,38 @@ func (m *SSOManager) CleanupExpired() {
}
}
// evictOldest 淘汰最早的 session假设已持有锁
func (m *SSOManager) evictOldest() {
if len(m.sessions) == 0 {
return
}
var oldestKey string
var oldestTime time.Time
for key, session := range m.sessions {
if oldestTime.IsZero() || session.CreatedAt.Before(oldestTime) {
oldestTime = session.CreatedAt
oldestKey = key
}
}
if oldestKey != "" {
delete(m.sessions, oldestKey)
}
}
// SessionCount 返回当前 session 数量(用于监控)
func (m *SSOManager) SessionCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.sessions)
}
// generateSecureToken 生成安全随机 token
func generateSecureToken(length int) string {
func generateSecureToken(length int) (string, error) {
bytes := make([]byte, length)
rand.Read(bytes)
return base64.URLEncoding.EncodeToString(bytes)[:length]
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate secure token: %w", err)
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}
// SSOClient SSO 客户端配置存储
@@ -189,10 +288,12 @@ type SSOClient struct {
// SSOClientsStore SSO 客户端存储接口
type SSOClientsStore interface {
GetByClientID(clientID string) (*SSOClient, error)
ValidateClientRedirectURI(clientID, redirectURI string) bool
}
// DefaultSSOClientsStore 默认内存存储
type DefaultSSOClientsStore struct {
mu sync.RWMutex
clients map[string]*SSOClient
}
@@ -205,11 +306,15 @@ func NewDefaultSSOClientsStore() *DefaultSSOClientsStore {
// RegisterClient 注册客户端
func (s *DefaultSSOClientsStore) RegisterClient(client *SSOClient) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients[client.ClientID] = client
}
// GetByClientID 根据 ClientID 获取客户端
func (s *DefaultSSOClientsStore) GetByClientID(clientID string) (*SSOClient, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client, ok := s.clients[clientID]
if !ok {
return nil, fmt.Errorf("client not found: %s", clientID)