222 lines
5.6 KiB
Go
222 lines
5.6 KiB
Go
|
|
package auth
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/rand"
|
||
|
|
"encoding/base64"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/url"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
// CASProvider CAS (Central Authentication Service) 提供者
|
||
|
|
// CAS 是一种单点登录协议,用户只需登录一次即可访问多个应用
|
||
|
|
type CASProvider struct {
|
||
|
|
serverURL string
|
||
|
|
serviceURL string
|
||
|
|
}
|
||
|
|
|
||
|
|
// CASServiceTicket CAS 服务票据
|
||
|
|
type CASServiceTicket struct {
|
||
|
|
Ticket string
|
||
|
|
Service string
|
||
|
|
UserID int64
|
||
|
|
Username string
|
||
|
|
IssuedAt time.Time
|
||
|
|
Expiry time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewCASProvider 创建 CAS 提供者
|
||
|
|
func NewCASProvider(serverURL, serviceURL string) *CASProvider {
|
||
|
|
return &CASProvider{
|
||
|
|
serverURL: strings.TrimSuffix(serverURL, "/"),
|
||
|
|
serviceURL: serviceURL,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// BuildLoginURL 构建 CAS 登录 URL
|
||
|
|
// 用于重定向用户到 CAS 登录页面
|
||
|
|
func (p *CASProvider) BuildLoginURL(renew, gateway bool) string {
|
||
|
|
params := url.Values{}
|
||
|
|
params.Set("service", p.serviceURL)
|
||
|
|
if renew {
|
||
|
|
params.Set("renew", "true")
|
||
|
|
}
|
||
|
|
if gateway {
|
||
|
|
params.Set("gateway", "true")
|
||
|
|
}
|
||
|
|
return fmt.Sprintf("%s/login?%s", p.serverURL, params.Encode())
|
||
|
|
}
|
||
|
|
|
||
|
|
// BuildLogoutURL 构建 CAS 登出 URL
|
||
|
|
func (p *CASProvider) BuildLogoutURL(url string) string {
|
||
|
|
if url != "" {
|
||
|
|
return fmt.Sprintf("%s/logout?service=%s", p.serverURL, url)
|
||
|
|
}
|
||
|
|
return fmt.Sprintf("%s/logout", p.serverURL)
|
||
|
|
}
|
||
|
|
|
||
|
|
// CASValidationResponse CAS 票据验证响应
|
||
|
|
type CASValidationResponse struct {
|
||
|
|
Success bool
|
||
|
|
UserID int64
|
||
|
|
Username string
|
||
|
|
ErrorCode string
|
||
|
|
ErrorMsg string
|
||
|
|
}
|
||
|
|
|
||
|
|
// ValidateTicket 验证 CAS 票据
|
||
|
|
// 向 CAS 服务器发送 ticket 验证请求
|
||
|
|
func (p *CASProvider) ValidateTicket(ctx context.Context, ticket string) (*CASValidationResponse, error) {
|
||
|
|
if ticket == "" {
|
||
|
|
return &CASValidationResponse{
|
||
|
|
Success: false,
|
||
|
|
ErrorCode: "INVALID_REQUEST",
|
||
|
|
ErrorMsg: "ticket is required",
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
params := url.Values{}
|
||
|
|
params.Set("service", p.serviceURL)
|
||
|
|
params.Set("ticket", ticket)
|
||
|
|
|
||
|
|
validateURL := fmt.Sprintf("%s/p3/serviceValidate?%s", p.serverURL, params.Encode())
|
||
|
|
|
||
|
|
resp, err := fetchCASResponse(ctx, validateURL)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("CAS validation request failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return p.parseServiceValidateResponse(resp)
|
||
|
|
}
|
||
|
|
|
||
|
|
// parseServiceValidateResponse 解析 CAS serviceValidate 响应
|
||
|
|
// CAS 1.0 和 CAS 2.0 使用不同的响应格式
|
||
|
|
func (p *CASProvider) parseServiceValidateResponse(xml string) (*CASValidationResponse, error) {
|
||
|
|
resp := &CASValidationResponse{Success: false}
|
||
|
|
|
||
|
|
// 检查是否包含 authenticationSuccess 元素
|
||
|
|
if strings.Contains(xml, "<authenticationSuccess>") {
|
||
|
|
resp.Success = true
|
||
|
|
|
||
|
|
// 解析用户名
|
||
|
|
if start := strings.Index(xml, "<user>"); start != -1 {
|
||
|
|
end := strings.Index(xml[start:], "</user>")
|
||
|
|
if end != -1 {
|
||
|
|
resp.Username = xml[start+6 : start+end]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 解析用户 ID (CAS 2.0)
|
||
|
|
if start := strings.Index(xml, "<userId>"); start != -1 {
|
||
|
|
end := strings.Index(xml[start:], "</userId>")
|
||
|
|
if end != -1 {
|
||
|
|
userIDStr := xml[start+8 : start+end]
|
||
|
|
var userID int64
|
||
|
|
fmt.Sscanf(userIDStr, "%d", &userID)
|
||
|
|
resp.UserID = userID
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} else if strings.Contains(xml, "<authenticationFailure>") {
|
||
|
|
resp.Success = false
|
||
|
|
|
||
|
|
// 解析错误码
|
||
|
|
if start := strings.Index(xml, "code=\""); start != -1 {
|
||
|
|
start += 6
|
||
|
|
end := strings.Index(xml[start:], "\"")
|
||
|
|
if end != -1 {
|
||
|
|
resp.ErrorCode = xml[start : start+end]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 解析错误消息
|
||
|
|
if start := strings.Index(xml, "<![CDATA["); start != -1 {
|
||
|
|
end := strings.Index(xml[start:], "]]>")
|
||
|
|
if end != -1 {
|
||
|
|
resp.ErrorMsg = xml[start+9 : start+end]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return resp, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GenerateProxyTicket 生成代理票据 (CAS 2.0)
|
||
|
|
// 用于服务代理用户访问其他服务
|
||
|
|
func (p *CASProvider) GenerateProxyTicket(ctx context.Context, proxyGrantingTicket, targetService string) (string, error) {
|
||
|
|
params := url.Values{}
|
||
|
|
params.Set("targetService", targetService)
|
||
|
|
|
||
|
|
proxyURL := fmt.Sprintf("%s/p3/proxy?%s&pgt=%s",
|
||
|
|
p.serverURL, params.Encode(), proxyGrantingTicket)
|
||
|
|
|
||
|
|
resp, err := fetchCASResponse(ctx, proxyURL)
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
|
||
|
|
// 解析代理票据
|
||
|
|
if start := strings.Index(resp, "<proxyTicket>"); start != -1 {
|
||
|
|
end := strings.Index(resp[start:], "</proxyTicket>")
|
||
|
|
if end != -1 {
|
||
|
|
return resp[start+12 : start+end], nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return "", fmt.Errorf("failed to parse proxy ticket from response")
|
||
|
|
}
|
||
|
|
|
||
|
|
// fetchCASResponse 从 CAS 服务器获取响应
|
||
|
|
func fetchCASResponse(ctx context.Context, url string) (string, error) {
|
||
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
req.Header.Set("Accept", "application/xml")
|
||
|
|
|
||
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
||
|
|
resp, err := client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
body, err := io.ReadAll(resp.Body)
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
|
||
|
|
return string(body), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GenerateCASServiceTicket 生成 CAS 服务票据 (供 CAS 服务器使用)
|
||
|
|
// 这个方法供实际的 CAS 服务器实现调用
|
||
|
|
func GenerateCASServiceTicket(service string, userID int64, username string) (*CASServiceTicket, error) {
|
||
|
|
ticketBytes := make([]byte, 32)
|
||
|
|
if _, err := rand.Read(ticketBytes); err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to generate ticket: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &CASServiceTicket{
|
||
|
|
Ticket: "ST-" + base64.URLEncoding.EncodeToString(ticketBytes)[:32],
|
||
|
|
Service: service,
|
||
|
|
UserID: userID,
|
||
|
|
Username: username,
|
||
|
|
IssuedAt: time.Now(),
|
||
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// IsExpired 检查票据是否过期
|
||
|
|
func (t *CASServiceTicket) IsExpired() bool {
|
||
|
|
return time.Now().After(t.Expiry)
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetDuration 返回票据有效时长
|
||
|
|
func (t *CASServiceTicket) GetDuration() time.Duration {
|
||
|
|
return t.Expiry.Sub(t.IssuedAt)
|
||
|
|
}
|