Files
user-system/internal/auth/cas.go

222 lines
5.6 KiB
Go
Raw Normal View History

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// CASProvider CAS (Central Authentication Service) 提供者
// CAS 是一种单点登录协议,用户只需登录一次即可访问多个应用
type CASProvider struct {
serverURL string
serviceURL string
}
// CASServiceTicket CAS 服务票据
type CASServiceTicket struct {
Ticket string
Service string
UserID int64
Username string
IssuedAt time.Time
Expiry time.Time
}
// NewCASProvider 创建 CAS 提供者
func NewCASProvider(serverURL, serviceURL string) *CASProvider {
return &CASProvider{
serverURL: strings.TrimSuffix(serverURL, "/"),
serviceURL: serviceURL,
}
}
// BuildLoginURL 构建 CAS 登录 URL
// 用于重定向用户到 CAS 登录页面
func (p *CASProvider) BuildLoginURL(renew, gateway bool) string {
params := url.Values{}
params.Set("service", p.serviceURL)
if renew {
params.Set("renew", "true")
}
if gateway {
params.Set("gateway", "true")
}
return fmt.Sprintf("%s/login?%s", p.serverURL, params.Encode())
}
// BuildLogoutURL 构建 CAS 登出 URL
func (p *CASProvider) BuildLogoutURL(url string) string {
if url != "" {
return fmt.Sprintf("%s/logout?service=%s", p.serverURL, url)
}
return fmt.Sprintf("%s/logout", p.serverURL)
}
// CASValidationResponse CAS 票据验证响应
type CASValidationResponse struct {
Success bool
UserID int64
Username string
ErrorCode string
ErrorMsg string
}
// ValidateTicket 验证 CAS 票据
// 向 CAS 服务器发送 ticket 验证请求
func (p *CASProvider) ValidateTicket(ctx context.Context, ticket string) (*CASValidationResponse, error) {
if ticket == "" {
return &CASValidationResponse{
Success: false,
ErrorCode: "INVALID_REQUEST",
ErrorMsg: "ticket is required",
}, nil
}
params := url.Values{}
params.Set("service", p.serviceURL)
params.Set("ticket", ticket)
validateURL := fmt.Sprintf("%s/p3/serviceValidate?%s", p.serverURL, params.Encode())
resp, err := fetchCASResponse(ctx, validateURL)
if err != nil {
return nil, fmt.Errorf("CAS validation request failed: %w", err)
}
return p.parseServiceValidateResponse(resp)
}
// parseServiceValidateResponse 解析 CAS serviceValidate 响应
// CAS 1.0 和 CAS 2.0 使用不同的响应格式
func (p *CASProvider) parseServiceValidateResponse(xml string) (*CASValidationResponse, error) {
resp := &CASValidationResponse{Success: false}
// 检查是否包含 authenticationSuccess 元素
if strings.Contains(xml, "<authenticationSuccess>") {
resp.Success = true
// 解析用户名
if start := strings.Index(xml, "<user>"); start != -1 {
end := strings.Index(xml[start:], "</user>")
if end != -1 {
resp.Username = xml[start+6 : start+end]
}
}
// 解析用户 ID (CAS 2.0)
if start := strings.Index(xml, "<userId>"); start != -1 {
end := strings.Index(xml[start:], "</userId>")
if end != -1 {
userIDStr := xml[start+8 : start+end]
var userID int64
fmt.Sscanf(userIDStr, "%d", &userID)
resp.UserID = userID
}
}
} else if strings.Contains(xml, "<authenticationFailure>") {
resp.Success = false
// 解析错误码
if start := strings.Index(xml, "code=\""); start != -1 {
start += 6
end := strings.Index(xml[start:], "\"")
if end != -1 {
resp.ErrorCode = xml[start : start+end]
}
}
// 解析错误消息
if start := strings.Index(xml, "<![CDATA["); start != -1 {
end := strings.Index(xml[start:], "]]>")
if end != -1 {
resp.ErrorMsg = xml[start+9 : start+end]
}
}
}
return resp, nil
}
// GenerateProxyTicket 生成代理票据 (CAS 2.0)
// 用于服务代理用户访问其他服务
func (p *CASProvider) GenerateProxyTicket(ctx context.Context, proxyGrantingTicket, targetService string) (string, error) {
params := url.Values{}
params.Set("targetService", targetService)
proxyURL := fmt.Sprintf("%s/p3/proxy?%s&pgt=%s",
p.serverURL, params.Encode(), proxyGrantingTicket)
resp, err := fetchCASResponse(ctx, proxyURL)
if err != nil {
return "", err
}
// 解析代理票据
if start := strings.Index(resp, "<proxyTicket>"); start != -1 {
end := strings.Index(resp[start:], "</proxyTicket>")
if end != -1 {
return resp[start+12 : start+end], nil
}
}
return "", fmt.Errorf("failed to parse proxy ticket from response")
}
// fetchCASResponse 从 CAS 服务器获取响应
func fetchCASResponse(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/xml")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}
// GenerateCASServiceTicket 生成 CAS 服务票据 (供 CAS 服务器使用)
// 这个方法供实际的 CAS 服务器实现调用
func GenerateCASServiceTicket(service string, userID int64, username string) (*CASServiceTicket, error) {
ticketBytes := make([]byte, 32)
if _, err := rand.Read(ticketBytes); err != nil {
return nil, fmt.Errorf("failed to generate ticket: %w", err)
}
return &CASServiceTicket{
Ticket: "ST-" + base64.URLEncoding.EncodeToString(ticketBytes)[:32],
Service: service,
UserID: userID,
Username: username,
IssuedAt: time.Now(),
Expiry: time.Now().Add(5 * time.Minute),
}, nil
}
// IsExpired 检查票据是否过期
func (t *CASServiceTicket) IsExpired() bool {
return time.Now().After(t.Expiry)
}
// GetDuration 返回票据有效时长
func (t *CASServiceTicket) GetDuration() time.Duration {
return t.Expiry.Sub(t.IssuedAt)
}