Files

257 lines
6.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package providers
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
// AlipayProvider 支付宝 OAuth提供者
// 支付宝使用 RSA2 签名SHA256withRSA
type AlipayProvider struct {
AppID string
PrivateKey string // RSA2 私钥PKCS#8 PEM格式
RedirectURI string
IsSandbox bool
}
// AlipayTokenResponse 支付宝 Token响应
type AlipayTokenResponse struct {
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
}
// AlipayUserInfo 支付宝用户信息
type AlipayUserInfo struct {
UserID string `json:"user_id"`
Nickname string `json:"nick_name"`
Avatar string `json:"avatar"`
Gender string `json:"gender"`
}
// NewAlipayProvider 创建支付宝 OAuth提供者
func NewAlipayProvider(appID, privateKey, redirectURI string, isSandbox bool) *AlipayProvider {
return &AlipayProvider{
AppID: appID,
PrivateKey: privateKey,
RedirectURI: redirectURI,
IsSandbox: isSandbox,
}
}
func (a *AlipayProvider) getGateway() string {
if a.IsSandbox {
return "https://openapi-sandbox.dl.alipaydev.com/gateway.do"
}
return "https://openapi.alipay.com/gateway.do"
}
// GetAuthURL 获取支付宝授权URL
func (a *AlipayProvider) GetAuthURL(state string) (string, error) {
authURL := fmt.Sprintf(
"https://openauth.alipay.com/oauth2/publicAppAuthorize.htm?app_id=%s&scope=auth_user&redirect_uri=%s&state=%s",
a.AppID,
url.QueryEscape(a.RedirectURI),
url.QueryEscape(state),
)
return authURL, nil
}
// ExchangeCode 用授权码换取 access_token
func (a *AlipayProvider) ExchangeCode(ctx context.Context, code string) (*AlipayTokenResponse, error) {
params := map[string]string{
"app_id": a.AppID,
"method": "alipay.system.oauth.token",
"charset": "UTF-8",
"sign_type": "RSA2",
"timestamp": time.Now().Format("2006-01-02 15:04:05"),
"version": "1.0",
"grant_type": "authorization_code",
"code": code,
}
if a.PrivateKey != "" {
sign, err := a.signParams(params)
if err != nil {
return nil, fmt.Errorf("sign failed: %w", err)
}
params["sign"] = sign
}
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, "POST", a.getGateway(),
strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var rawResp map[string]json.RawMessage
if err := json.Unmarshal(body, &rawResp); err != nil {
return nil, fmt.Errorf("parse response failed: %w", err)
}
tokenData, ok := rawResp["alipay_system_oauth_token_response"]
if !ok {
return nil, fmt.Errorf("invalid alipay response structure")
}
var tokenResp AlipayTokenResponse
if err := json.Unmarshal(tokenData, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取支付宝用户信息
func (a *AlipayProvider) GetUserInfo(ctx context.Context, accessToken string) (*AlipayUserInfo, error) {
params := map[string]string{
"app_id": a.AppID,
"method": "alipay.user.info.share",
"charset": "UTF-8",
"sign_type": "RSA2",
"timestamp": time.Now().Format("2006-01-02 15:04:05"),
"version": "1.0",
"auth_token": accessToken,
}
if a.PrivateKey != "" {
sign, err := a.signParams(params)
if err != nil {
return nil, fmt.Errorf("sign failed: %w", err)
}
params["sign"] = sign
}
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, "POST", a.getGateway(),
strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var rawResp map[string]json.RawMessage
if err := json.Unmarshal(body, &rawResp); err != nil {
return nil, fmt.Errorf("parse response failed: %w", err)
}
userData, ok := rawResp["alipay_user_info_share_response"]
if !ok {
return nil, fmt.Errorf("invalid alipay user info response")
}
var userInfo AlipayUserInfo
if err := json.Unmarshal(userData, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}
// signParams 使用 RSA2SHA256withRSA对参数签名
func (a *AlipayProvider) signParams(params map[string]string) (string, error) {
// 按字典序排列参数
keys := make([]string, 0, len(params))
for k := range params {
if k != "sign" {
keys = append(keys, k)
}
}
sort.Strings(keys)
var parts []string
for _, k := range keys {
parts = append(parts, k+"="+params[k])
}
signContent := strings.Join(parts, "&")
// 解析私钥
privKey, err := parseAlipayPrivateKey(a.PrivateKey)
if err != nil {
return "", fmt.Errorf("parse private key: %w", err)
}
// SHA256withRSA 签名
hash := sha256.Sum256([]byte(signContent))
signature, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hash[:])
if err != nil {
return "", fmt.Errorf("rsa sign: %w", err)
}
return base64.StdEncoding.EncodeToString(signature), nil
}
// parseAlipayPrivateKey 解析支付宝私钥(支持 PKCS#8 和 PKCS#1
func parseAlipayPrivateKey(pemStr string) (*rsa.PrivateKey, error) {
// 如果没有 PEM 头,添加 PKCS#8 头
if !strings.Contains(pemStr, "-----BEGIN") {
pemStr = "-----BEGIN PRIVATE KEY-----\n" + pemStr + "\n-----END PRIVATE KEY-----"
}
block, _ := pem.Decode([]byte(pemStr))
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
}
// 尝试 PKCS#8
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err == nil {
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("not an RSA private key")
}
return rsaKey, nil
}
// 尝试 PKCS#1
return x509.ParsePKCS1PrivateKey(block.Bytes)
}