package auth import ( "context" "crypto/rand" "encoding/base64" "fmt" "io" "net/http" "net/url" "strings" "time" ) // CASProvider CAS (Central Authentication Service) 提供者 // CAS 是一种单点登录协议,用户只需登录一次即可访问多个应用 type CASProvider struct { serverURL string serviceURL string } // CASServiceTicket CAS 服务票据 type CASServiceTicket struct { Ticket string Service string UserID int64 Username string IssuedAt time.Time Expiry time.Time } // NewCASProvider 创建 CAS 提供者 func NewCASProvider(serverURL, serviceURL string) *CASProvider { return &CASProvider{ serverURL: strings.TrimSuffix(serverURL, "/"), serviceURL: serviceURL, } } // BuildLoginURL 构建 CAS 登录 URL // 用于重定向用户到 CAS 登录页面 func (p *CASProvider) BuildLoginURL(renew, gateway bool) string { params := url.Values{} params.Set("service", p.serviceURL) if renew { params.Set("renew", "true") } if gateway { params.Set("gateway", "true") } return fmt.Sprintf("%s/login?%s", p.serverURL, params.Encode()) } // BuildLogoutURL 构建 CAS 登出 URL func (p *CASProvider) BuildLogoutURL(url string) string { if url != "" { return fmt.Sprintf("%s/logout?service=%s", p.serverURL, url) } return fmt.Sprintf("%s/logout", p.serverURL) } // CASValidationResponse CAS 票据验证响应 type CASValidationResponse struct { Success bool UserID int64 Username string ErrorCode string ErrorMsg string } // ValidateTicket 验证 CAS 票据 // 向 CAS 服务器发送 ticket 验证请求 func (p *CASProvider) ValidateTicket(ctx context.Context, ticket string) (*CASValidationResponse, error) { if ticket == "" { return &CASValidationResponse{ Success: false, ErrorCode: "INVALID_REQUEST", ErrorMsg: "ticket is required", }, nil } params := url.Values{} params.Set("service", p.serviceURL) params.Set("ticket", ticket) validateURL := fmt.Sprintf("%s/p3/serviceValidate?%s", p.serverURL, params.Encode()) resp, err := fetchCASResponse(ctx, validateURL) if err != nil { return nil, fmt.Errorf("CAS validation request failed: %w", err) } return p.parseServiceValidateResponse(resp) } // parseServiceValidateResponse 解析 CAS serviceValidate 响应 // CAS 1.0 和 CAS 2.0 使用不同的响应格式 func (p *CASProvider) parseServiceValidateResponse(xml string) (*CASValidationResponse, error) { resp := &CASValidationResponse{Success: false} // 检查是否包含 authenticationSuccess 元素 if strings.Contains(xml, "") { resp.Success = true // 解析用户名 if start := strings.Index(xml, ""); start != -1 { end := strings.Index(xml[start:], "") if end != -1 { resp.Username = xml[start+6 : start+end] } } // 解析用户 ID (CAS 2.0) if start := strings.Index(xml, ""); start != -1 { end := strings.Index(xml[start:], "") if end != -1 { userIDStr := xml[start+8 : start+end] var userID int64 fmt.Sscanf(userIDStr, "%d", &userID) resp.UserID = userID } } } else if strings.Contains(xml, "") { resp.Success = false // 解析错误码 if start := strings.Index(xml, "code=\""); start != -1 { start += 6 end := strings.Index(xml[start:], "\"") if end != -1 { resp.ErrorCode = xml[start : start+end] } } // 解析错误消息 if start := strings.Index(xml, "") if end != -1 { resp.ErrorMsg = xml[start+9 : start+end] } } } return resp, nil } // GenerateProxyTicket 生成代理票据 (CAS 2.0) // 用于服务代理用户访问其他服务 func (p *CASProvider) GenerateProxyTicket(ctx context.Context, proxyGrantingTicket, targetService string) (string, error) { params := url.Values{} params.Set("targetService", targetService) proxyURL := fmt.Sprintf("%s/p3/proxy?%s&pgt=%s", p.serverURL, params.Encode(), proxyGrantingTicket) resp, err := fetchCASResponse(ctx, proxyURL) if err != nil { return "", err } // 解析代理票据 if start := strings.Index(resp, ""); start != -1 { end := strings.Index(resp[start:], "") if end != -1 { return resp[start+12 : start+end], nil } } return "", fmt.Errorf("failed to parse proxy ticket from response") } // fetchCASResponse 从 CAS 服务器获取响应 func fetchCASResponse(ctx context.Context, url string) (string, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return "", err } req.Header.Set("Accept", "application/xml") client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "", err } return string(body), nil } // GenerateCASServiceTicket 生成 CAS 服务票据 (供 CAS 服务器使用) // 这个方法供实际的 CAS 服务器实现调用 func GenerateCASServiceTicket(service string, userID int64, username string) (*CASServiceTicket, error) { ticketBytes := make([]byte, 32) if _, err := rand.Read(ticketBytes); err != nil { return nil, fmt.Errorf("failed to generate ticket: %w", err) } return &CASServiceTicket{ Ticket: "ST-" + base64.URLEncoding.EncodeToString(ticketBytes)[:32], Service: service, UserID: userID, Username: username, IssuedAt: time.Now(), Expiry: time.Now().Add(5 * time.Minute), }, nil } // IsExpired 检查票据是否过期 func (t *CASServiceTicket) IsExpired() bool { return time.Now().After(t.Expiry) } // GetDuration 返回票据有效时长 func (t *CASServiceTicket) GetDuration() time.Duration { return t.Expiry.Sub(t.IssuedAt) }