package auth import ( "context" "crypto/rand" "encoding/base64" "errors" "fmt" "sync" "time" ) const ( // MaxSessions 最大 session 数量限制 MaxSessions = 10000 // CleanupInterval 清理间隔 CleanupInterval = 5 * time.Minute ) // SSOOAuth2Config SSO OAuth2 配置 type SSOOAuth2Config struct { ClientID string ClientSecret string RedirectURI string Scope string } // SSOProvider SSO 提供者接口 type SSOProvider interface { // Authorize 处理授权请求 Authorize(ctx context.Context, req *SSOAuthorizeRequest) (*SSOAuthorizeResponse, error) // Introspect 验证 access token Introspect(ctx context.Context, token string) (*SSOTokenInfo, error) // Revoke 撤销 token Revoke(ctx context.Context, token string) error } // SSOAuthorizeRequest 授权请求 type SSOAuthorizeRequest struct { ClientID string RedirectURI string ResponseType string // "code" 或 "token" Scope string State string UserID int64 } // SSOAuthorizeResponse 授权响应 type SSOAuthorizeResponse struct { Code string // 授权码(authorization_code 模式) State string } // SSOTokenInfo Token 信息 type SSOTokenInfo struct { Active bool UserID int64 Username string ExpiresAt time.Time Scope string ClientID string } // SSOSession SSO Session type SSOSession struct { SessionID string UserID int64 Username string ClientID string CreatedAt time.Time ExpiresAt time.Time Scope string } // SSOManager SSO 管理器 type SSOManager struct { mu sync.RWMutex sessions map[string]*SSOSession } // NewSSOManager 创建 SSO 管理器 func NewSSOManager() *SSOManager { return &SSOManager{ sessions: make(map[string]*SSOSession), } } // StartCleanup 启动后台清理 goroutine func (m *SSOManager) StartCleanup(ctx context.Context) { go func() { ticker := time.NewTicker(CleanupInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: m.CleanupExpired() } } }() } // GenerateAuthorizationCode 生成授权码 func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope string, userID int64, username string) (string, error) { code, err := generateSecureToken(32) if err != nil { return "", err } sessionID, err := generateSecureToken(16) if err != nil { return "", err } session := &SSOSession{ SessionID: sessionID, UserID: userID, Username: username, ClientID: clientID, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(10 * time.Minute), // 授权码 10 分钟有效期 Scope: scope, } m.mu.Lock() // 检查并清理过期 session,如果超过限制则淘汰最旧的 if len(m.sessions) >= MaxSessions { m.cleanupExpiredLocked() // 如果仍然满,淘汰最早的 if len(m.sessions) >= MaxSessions { m.evictOldest() } } m.sessions[code] = session m.mu.Unlock() return code, nil } // ValidateAuthorizationCode 验证授权码 func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error) { m.mu.Lock() defer m.mu.Unlock() session, ok := m.sessions[code] if !ok { return nil, errors.New("invalid authorization code") } if time.Now().After(session.ExpiresAt) { delete(m.sessions, code) return nil, errors.New("authorization code expired") } // 使用后删除 delete(m.sessions, code) return session, nil } // GenerateAccessToken 生成访问令牌 func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time, error) { token, err := generateSecureToken(32) if err != nil { return "", time.Time{}, err } expiresAt := time.Now().Add(2 * time.Hour) // Access token 2 小时有效期 accessSession := &SSOSession{ SessionID: token, UserID: session.UserID, Username: session.Username, ClientID: clientID, CreatedAt: time.Now(), ExpiresAt: expiresAt, Scope: session.Scope, } m.mu.Lock() // 检查并清理过期 session,如果超过限制则淘汰最旧的 if len(m.sessions) >= MaxSessions { m.cleanupExpiredLocked() if len(m.sessions) >= MaxSessions { m.evictOldest() } } m.sessions[token] = accessSession m.mu.Unlock() return token, expiresAt, nil } // IntrospectToken 验证 token func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) { m.mu.RLock() session, ok := m.sessions[token] if !ok { m.mu.RUnlock() return &SSOTokenInfo{Active: false}, nil } if time.Now().After(session.ExpiresAt) { m.mu.RUnlock() m.mu.Lock() delete(m.sessions, token) m.mu.Unlock() return &SSOTokenInfo{Active: false}, nil } m.mu.RUnlock() return &SSOTokenInfo{ Active: true, UserID: session.UserID, Username: session.Username, ExpiresAt: session.ExpiresAt, Scope: session.Scope, ClientID: session.ClientID, }, nil } // RevokeToken 撤销 token func (m *SSOManager) RevokeToken(token string) error { m.mu.Lock() defer m.mu.Unlock() delete(m.sessions, token) return nil } // CleanupExpired 清理过期的 session func (m *SSOManager) CleanupExpired() { m.mu.Lock() defer m.mu.Unlock() m.cleanupExpiredLocked() } // cleanupExpiredLocked 内部清理方法(假设已持有锁) func (m *SSOManager) cleanupExpiredLocked() { now := time.Now() for key, session := range m.sessions { if now.After(session.ExpiresAt) { delete(m.sessions, key) } } } // evictOldest 淘汰最早的 session(假设已持有锁) func (m *SSOManager) evictOldest() { if len(m.sessions) == 0 { return } var oldestKey string var oldestTime time.Time for key, session := range m.sessions { if oldestTime.IsZero() || session.CreatedAt.Before(oldestTime) { oldestTime = session.CreatedAt oldestKey = key } } if oldestKey != "" { delete(m.sessions, oldestKey) } } // SessionCount 返回当前 session 数量(用于监控) func (m *SSOManager) SessionCount() int { m.mu.RLock() defer m.mu.RUnlock() return len(m.sessions) } // generateSecureToken 生成安全随机 token func generateSecureToken(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("failed to generate secure token: %w", err) } return base64.URLEncoding.EncodeToString(bytes)[:length], nil } // SSOClient SSO 客户端配置存储 type SSOClient struct { ClientID string ClientSecret string Name string RedirectURIs []string } // SSOClientsStore SSO 客户端存储接口 type SSOClientsStore interface { GetByClientID(clientID string) (*SSOClient, error) ValidateClientRedirectURI(clientID, redirectURI string) bool } // DefaultSSOClientsStore 默认内存存储 type DefaultSSOClientsStore struct { mu sync.RWMutex clients map[string]*SSOClient } // NewDefaultSSOClientsStore 创建默认客户端存储 func NewDefaultSSOClientsStore() *DefaultSSOClientsStore { return &DefaultSSOClientsStore{ clients: make(map[string]*SSOClient), } } // RegisterClient 注册客户端 func (s *DefaultSSOClientsStore) RegisterClient(client *SSOClient) { s.mu.Lock() defer s.mu.Unlock() s.clients[client.ClientID] = client } // GetByClientID 根据 ClientID 获取客户端 func (s *DefaultSSOClientsStore) GetByClientID(clientID string) (*SSOClient, error) { s.mu.RLock() defer s.mu.RUnlock() client, ok := s.clients[clientID] if !ok { return nil, fmt.Errorf("client not found: %s", clientID) } return client, nil } // ValidateClientRedirectURI 验证客户端的 RedirectURI func (s *DefaultSSOClientsStore) ValidateClientRedirectURI(clientID, redirectURI string) bool { client, err := s.GetByClientID(clientID) if err != nil { return false } for _, uri := range client.RedirectURIs { if uri == redirectURI { return true } } return false }