197 lines
4.5 KiB
Go
197 lines
4.5 KiB
Go
|
|
package auth
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"crypto/rand"
|
|||
|
|
"encoding/base64"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"net/http"
|
|||
|
|
"net/url"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"golang.org/x/oauth2"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// StateStore OAuth状态存储
|
|||
|
|
type StateStore struct {
|
|||
|
|
states map[string]time.Time
|
|||
|
|
mu sync.RWMutex
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var stateStore = &StateStore{
|
|||
|
|
states: make(map[string]time.Time),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateState 生成OAuth状态参数
|
|||
|
|
func GenerateState() (string, error) {
|
|||
|
|
b := make([]byte, 32)
|
|||
|
|
if _, err := rand.Read(b); err != nil {
|
|||
|
|
return "", fmt.Errorf("generate state failed: %w", err)
|
|||
|
|
}
|
|||
|
|
state := base64.URLEncoding.EncodeToString(b)
|
|||
|
|
|
|||
|
|
// 存储状态,10分钟过期
|
|||
|
|
stateStore.mu.Lock()
|
|||
|
|
stateStore.states[state] = time.Now().Add(10 * time.Minute)
|
|||
|
|
stateStore.mu.Unlock()
|
|||
|
|
|
|||
|
|
return state, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateState 验证OAuth状态参数
|
|||
|
|
func ValidateState(state string) bool {
|
|||
|
|
stateStore.mu.Lock()
|
|||
|
|
defer stateStore.mu.Unlock()
|
|||
|
|
|
|||
|
|
expireTime, ok := stateStore.states[state]
|
|||
|
|
if !ok {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查是否过期
|
|||
|
|
if time.Now().After(expireTime) {
|
|||
|
|
delete(stateStore.states, state)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 使用后删除
|
|||
|
|
delete(stateStore.states, state)
|
|||
|
|
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// CleanupStates 清理过期的状态
|
|||
|
|
func CleanupStates() {
|
|||
|
|
stateStore.mu.Lock()
|
|||
|
|
defer stateStore.mu.Unlock()
|
|||
|
|
|
|||
|
|
now := time.Now()
|
|||
|
|
for state, expireTime := range stateStore.states {
|
|||
|
|
if now.After(expireTime) {
|
|||
|
|
delete(stateStore.states, state)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// HTTPClient OAuth HTTP客户端
|
|||
|
|
var HTTPClient = &http.Client{
|
|||
|
|
Timeout: 30 * time.Second,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Get 发送GET请求
|
|||
|
|
func Get(url string) (*http.Response, error) {
|
|||
|
|
return HTTPClient.Get(url)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// PostForm 发送POST表单请求
|
|||
|
|
func PostForm(url string, data url.Values) (*http.Response, error) {
|
|||
|
|
return HTTPClient.PostForm(url, data)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetJSON 发送GET请求并解析JSON响应
|
|||
|
|
func GetJSON(url string, result interface{}) error {
|
|||
|
|
resp, err := Get(url)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return json.NewDecoder(resp.Body).Decode(result)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// PostFormJSON 发送POST表单请求并解析JSON响应
|
|||
|
|
func PostFormJSON(url string, data url.Values, result interface{}) error {
|
|||
|
|
resp, err := PostForm(url, data)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return json.NewDecoder(resp.Body).Decode(result)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BuildAuthURL 构建标准OAuth授权URL
|
|||
|
|
func BuildAuthURL(baseURL, clientID, redirectURI, scope, state string) string {
|
|||
|
|
u, _ := url.Parse(baseURL)
|
|||
|
|
q := u.Query()
|
|||
|
|
q.Set("client_id", clientID)
|
|||
|
|
q.Set("redirect_uri", redirectURI)
|
|||
|
|
q.Set("scope", scope)
|
|||
|
|
q.Set("state", state)
|
|||
|
|
q.Set("response_type", "code")
|
|||
|
|
u.RawQuery = q.Encode()
|
|||
|
|
return u.String()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ParseAccessTokenResponse 解析访问令牌响应
|
|||
|
|
func ParseAccessTokenResponse(resp []byte) (*OAuthToken, error) {
|
|||
|
|
var result struct {
|
|||
|
|
AccessToken string `json:"access_token"`
|
|||
|
|
RefreshToken string `json:"refresh_token"`
|
|||
|
|
ExpiresIn int64 `json:"expires_in"`
|
|||
|
|
TokenType string `json:"token_type"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := json.Unmarshal(resp, &result); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &OAuthToken{
|
|||
|
|
AccessToken: result.AccessToken,
|
|||
|
|
RefreshToken: result.RefreshToken,
|
|||
|
|
ExpiresIn: result.ExpiresIn,
|
|||
|
|
TokenType: result.TokenType,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ParseQueryAccessToken 解析查询字符串形式的访问令牌(用于某些返回text/plain的API)
|
|||
|
|
func ParseQueryAccessToken(body string) (accessToken string, err error) {
|
|||
|
|
values, err := url.ParseQuery(body)
|
|||
|
|
if err != nil {
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
return values.Get("access_token"), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ParseJSONPResponse 解析JSONP响应(用于QQ等平台)
|
|||
|
|
func ParseJSONPResponse(jsonp string) (map[string]interface{}, error) {
|
|||
|
|
// 移除callback包装
|
|||
|
|
start := strings.Index(jsonp, "(")
|
|||
|
|
end := strings.LastIndex(jsonp, ")")
|
|||
|
|
if start == -1 || end == -1 {
|
|||
|
|
return nil, fmt.Errorf("invalid JSONP format")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
jsonStr := jsonp[start+1 : end]
|
|||
|
|
var result map[string]interface{}
|
|||
|
|
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ToOAuth2Config 转换为oauth2.Config
|
|||
|
|
func ToOAuth2Config(config *OAuthConfig) *oauth2.Config {
|
|||
|
|
return &oauth2.Config{
|
|||
|
|
ClientID: config.ClientID,
|
|||
|
|
ClientSecret: config.ClientSecret,
|
|||
|
|
RedirectURL: config.RedirectURI,
|
|||
|
|
Scopes: strings.Split(config.Scope, ","),
|
|||
|
|
Endpoint: oauth2.Endpoint{
|
|||
|
|
AuthURL: config.AuthURL,
|
|||
|
|
TokenURL: config.TokenURL,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
}
|