feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
506
internal/auth/oauth.go
Normal file
506
internal/auth/oauth.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/user-management-system/internal/auth/providers"
|
||||
)
|
||||
|
||||
// OAuthProvider OAuth提供商类型
|
||||
type OAuthProvider string
|
||||
|
||||
const (
|
||||
OAuthProviderWeChat OAuthProvider = "wechat"
|
||||
OAuthProviderQQ OAuthProvider = "qq"
|
||||
OAuthProviderWeibo OAuthProvider = "weibo"
|
||||
OAuthProviderGoogle OAuthProvider = "google"
|
||||
OAuthProviderFacebook OAuthProvider = "facebook"
|
||||
OAuthProviderTwitter OAuthProvider = "twitter"
|
||||
OAuthProviderGitHub OAuthProvider = "github"
|
||||
OAuthProviderAlipay OAuthProvider = "alipay"
|
||||
OAuthProviderDouyin OAuthProvider = "douyin"
|
||||
)
|
||||
|
||||
// OAuthUser OAuth用户信息
|
||||
type OAuthUser struct {
|
||||
Provider OAuthProvider `json:"provider"`
|
||||
OpenID string `json:"open_id"`
|
||||
UnionID string `json:"union_id,omitempty"`
|
||||
Nickname string `json:"nickname"`
|
||||
Avatar string `json:"avatar"`
|
||||
Gender string `json:"gender,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Phone string `json:"phone,omitempty"`
|
||||
Extra map[string]interface{} `json:"extra,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthToken OAuth令牌
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
OpenID string `json:"open_id,omitempty"` // 微信等需要 openid
|
||||
}
|
||||
|
||||
// OAuthConfig OAuth配置
|
||||
type OAuthConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
Scope string `json:"scope"`
|
||||
AuthURL string `json:"auth_url"`
|
||||
TokenURL string `json:"token_url"`
|
||||
UserInfoURL string `json:"user_info_url"`
|
||||
}
|
||||
|
||||
// OAuthManager OAuth管理器接口
|
||||
type OAuthManager interface {
|
||||
// GetAuthURL 获取授权URL
|
||||
GetAuthURL(provider OAuthProvider, state string) (string, error)
|
||||
|
||||
// ExchangeCode 换取访问令牌
|
||||
ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error)
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error)
|
||||
|
||||
// ValidateToken 验证令牌
|
||||
ValidateToken(token string) (bool, error)
|
||||
|
||||
// GetConfig 获取OAuth配置
|
||||
GetConfig(provider OAuthProvider) (*OAuthConfig, bool)
|
||||
|
||||
// GetEnabledProviders 获取已启用的OAuth提供商
|
||||
GetEnabledProviders() []OAuthProviderInfo
|
||||
}
|
||||
|
||||
// OAuthProviderInfo OAuth提供商信息
|
||||
type OAuthProviderInfo struct {
|
||||
Provider OAuthProvider `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// providerEntry 内部 provider 条目
|
||||
type providerEntry struct {
|
||||
config *OAuthConfig
|
||||
google *providers.GoogleProvider
|
||||
wechat *providers.WeChatProvider
|
||||
wechatRedir string
|
||||
qq *providers.QQProvider
|
||||
github *providers.GitHubProvider
|
||||
alipay *providers.AlipayProvider
|
||||
douyin *providers.DouyinProvider
|
||||
}
|
||||
|
||||
// DefaultOAuthManager 默认OAuth管理器(集成真实 provider HTTP 调用)
|
||||
type DefaultOAuthManager struct {
|
||||
entries map[OAuthProvider]*providerEntry
|
||||
}
|
||||
|
||||
// NewOAuthManager 创建OAuth管理器
|
||||
func NewOAuthManager() *DefaultOAuthManager {
|
||||
return &DefaultOAuthManager{
|
||||
entries: make(map[OAuthProvider]*providerEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProvider 注册OAuth提供商(保留旧接口,仅存储配置)
|
||||
func (m *DefaultOAuthManager) RegisterProvider(provider OAuthProvider, config *OAuthConfig) {
|
||||
entry := &providerEntry{config: config}
|
||||
|
||||
switch provider {
|
||||
case OAuthProviderGoogle:
|
||||
entry.google = providers.NewGoogleProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||||
case OAuthProviderWeChat:
|
||||
entry.wechat = providers.NewWeChatProvider(config.ClientID, config.ClientSecret, "web")
|
||||
entry.wechatRedir = config.RedirectURI
|
||||
case OAuthProviderQQ:
|
||||
entry.qq = providers.NewQQProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||||
case OAuthProviderGitHub:
|
||||
entry.github = providers.NewGitHubProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||||
case OAuthProviderAlipay:
|
||||
// 支付宝使用 ClientID 存储 AppID,ClientSecret 存储 RSA 私钥
|
||||
entry.alipay = providers.NewAlipayProvider(config.ClientID, config.ClientSecret, config.RedirectURI, false)
|
||||
case OAuthProviderDouyin:
|
||||
entry.douyin = providers.NewDouyinProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||||
}
|
||||
|
||||
m.entries[provider] = entry
|
||||
}
|
||||
|
||||
// GetConfig 获取OAuth配置
|
||||
func (m *DefaultOAuthManager) GetConfig(provider OAuthProvider) (*OAuthConfig, bool) {
|
||||
entry, ok := m.entries[provider]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return entry.config, true
|
||||
}
|
||||
|
||||
// GetAuthURL 获取授权URL(使用真实 provider 实现)
|
||||
func (m *DefaultOAuthManager) GetAuthURL(provider OAuthProvider, state string) (string, error) {
|
||||
entry, ok := m.entries[provider]
|
||||
if !ok {
|
||||
return "", ErrOAuthProviderNotSupported
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case OAuthProviderGoogle:
|
||||
if entry.google != nil {
|
||||
resp, err := entry.google.GetAuthURL(state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.URL, nil
|
||||
}
|
||||
case OAuthProviderWeChat:
|
||||
if entry.wechat != nil {
|
||||
resp, err := entry.wechat.GetAuthURL(entry.wechatRedir, state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.URL, nil
|
||||
}
|
||||
case OAuthProviderQQ:
|
||||
if entry.qq != nil {
|
||||
resp, err := entry.qq.GetAuthURL(state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.URL, nil
|
||||
}
|
||||
case OAuthProviderGitHub:
|
||||
if entry.github != nil {
|
||||
return entry.github.GetAuthURL(state)
|
||||
}
|
||||
case OAuthProviderAlipay:
|
||||
if entry.alipay != nil {
|
||||
return entry.alipay.GetAuthURL(state)
|
||||
}
|
||||
case OAuthProviderDouyin:
|
||||
if entry.douyin != nil {
|
||||
return entry.douyin.GetAuthURL(state)
|
||||
}
|
||||
}
|
||||
|
||||
// 通用 fallback:按标准 OAuth2 拼接 URL(对 QQ/微博/Twitter/Facebook)
|
||||
config := entry.config
|
||||
if config == nil {
|
||||
return "", ErrOAuthProviderNotSupported
|
||||
}
|
||||
return fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s",
|
||||
config.AuthURL,
|
||||
url.QueryEscape(config.ClientID),
|
||||
url.QueryEscape(config.RedirectURI),
|
||||
url.QueryEscape(config.Scope),
|
||||
url.QueryEscape(state),
|
||||
), nil
|
||||
}
|
||||
|
||||
// ExchangeCode 换取访问令牌(使用真实 provider 实现)
|
||||
func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) {
|
||||
entry, ok := m.entries[provider]
|
||||
if !ok {
|
||||
return nil, ErrOAuthProviderNotSupported
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch provider {
|
||||
case OAuthProviderGoogle:
|
||||
if entry.google != nil {
|
||||
resp, err := entry.google.ExchangeCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthToken{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ExpiresIn: int64(resp.ExpiresIn),
|
||||
TokenType: resp.TokenType,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderWeChat:
|
||||
if entry.wechat != nil {
|
||||
resp, err := entry.wechat.ExchangeCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthToken{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ExpiresIn: int64(resp.ExpiresIn),
|
||||
TokenType: "Bearer",
|
||||
OpenID: resp.OpenID,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderQQ:
|
||||
if entry.qq != nil {
|
||||
resp, err := entry.qq.ExchangeCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openIDResp, err := entry.qq.GetOpenID(ctx, resp.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthToken{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ExpiresIn: int64(resp.ExpiresIn),
|
||||
TokenType: "Bearer",
|
||||
OpenID: openIDResp.OpenID,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderGitHub:
|
||||
if entry.github != nil {
|
||||
resp, err := entry.github.ExchangeCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthToken{
|
||||
AccessToken: resp.AccessToken,
|
||||
TokenType: resp.TokenType,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderAlipay:
|
||||
if entry.alipay != nil {
|
||||
resp, err := entry.alipay.ExchangeCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthToken{
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: resp.RefreshToken,
|
||||
ExpiresIn: int64(resp.ExpiresIn),
|
||||
TokenType: "Bearer",
|
||||
OpenID: resp.UserID,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderDouyin:
|
||||
if entry.douyin != nil {
|
||||
resp, err := entry.douyin.ExchangeCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthToken{
|
||||
AccessToken: resp.Data.AccessToken,
|
||||
RefreshToken: resp.Data.RefreshToken,
|
||||
ExpiresIn: int64(resp.Data.ExpiresIn),
|
||||
TokenType: "Bearer",
|
||||
OpenID: resp.Data.OpenID,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("provider %s: real HTTP exchange not implemented yet", provider)
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息(使用真实 provider 实现)
|
||||
func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) {
|
||||
entry, ok := m.entries[provider]
|
||||
if !ok {
|
||||
return nil, ErrOAuthProviderNotSupported
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch provider {
|
||||
case OAuthProviderGoogle:
|
||||
if entry.google != nil {
|
||||
info, err := entry.google.GetUserInfo(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthUser{
|
||||
Provider: provider,
|
||||
OpenID: info.ID,
|
||||
Nickname: info.Name,
|
||||
Avatar: info.Picture,
|
||||
Email: info.Email,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderWeChat:
|
||||
if entry.wechat != nil {
|
||||
openID := token.OpenID
|
||||
info, err := entry.wechat.GetUserInfo(ctx, token.AccessToken, openID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gender := ""
|
||||
switch info.Sex {
|
||||
case 1:
|
||||
gender = "male"
|
||||
case 2:
|
||||
gender = "female"
|
||||
}
|
||||
return &OAuthUser{
|
||||
Provider: provider,
|
||||
OpenID: info.OpenID,
|
||||
UnionID: info.UnionID,
|
||||
Nickname: info.Nickname,
|
||||
Avatar: info.HeadImgURL,
|
||||
Gender: gender,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderQQ:
|
||||
if entry.qq != nil {
|
||||
info, err := entry.qq.GetUserInfo(ctx, token.AccessToken, token.OpenID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
avatar := info.FigureURL2
|
||||
if avatar == "" {
|
||||
avatar = info.FigureURL1
|
||||
}
|
||||
if avatar == "" {
|
||||
avatar = info.FigureURL
|
||||
}
|
||||
return &OAuthUser{
|
||||
Provider: provider,
|
||||
OpenID: token.OpenID,
|
||||
Nickname: info.Nickname,
|
||||
Avatar: avatar,
|
||||
Gender: info.Gender,
|
||||
Extra: map[string]interface{}{
|
||||
"province": info.Province,
|
||||
"city": info.City,
|
||||
"year": info.Year,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderGitHub:
|
||||
if entry.github != nil {
|
||||
info, err := entry.github.GetUserInfo(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nickname := info.Name
|
||||
if nickname == "" {
|
||||
nickname = info.Login
|
||||
}
|
||||
return &OAuthUser{
|
||||
Provider: provider,
|
||||
OpenID: fmt.Sprintf("%d", info.ID),
|
||||
Nickname: nickname,
|
||||
Email: info.Email,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderAlipay:
|
||||
if entry.alipay != nil {
|
||||
info, err := entry.alipay.GetUserInfo(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OAuthUser{
|
||||
Provider: provider,
|
||||
OpenID: info.UserID,
|
||||
Nickname: info.Nickname,
|
||||
Avatar: info.Avatar,
|
||||
}, nil
|
||||
}
|
||||
case OAuthProviderDouyin:
|
||||
if entry.douyin != nil {
|
||||
info, err := entry.douyin.GetUserInfo(ctx, token.AccessToken, token.OpenID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gender := ""
|
||||
switch info.Data.Gender {
|
||||
case 1:
|
||||
gender = "male"
|
||||
case 2:
|
||||
gender = "female"
|
||||
}
|
||||
return &OAuthUser{
|
||||
Provider: provider,
|
||||
OpenID: info.Data.OpenID,
|
||||
UnionID: info.Data.UnionID,
|
||||
Nickname: info.Data.Nickname,
|
||||
Avatar: info.Data.Avatar,
|
||||
Gender: gender,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("provider %s: real HTTP user info not implemented yet", provider)
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌
|
||||
// 注意:由于 ValidateToken 不持有 provider 上下文,无法进行真正的 token 验证
|
||||
// 对于需要验证 token 的场景,应使用 GetUserInfo 通过 provider 的 userinfo 端点验证
|
||||
// 如果没有可用的 provider,返回错误
|
||||
func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
|
||||
if len(token) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
// 由于缺乏 provider 上下文,无法进行有意义的验证
|
||||
// 遍历所有已启用的 provider,尝试通过 GetUserInfo 验证
|
||||
// 如果没有任何 provider 可用,返回错误而不是默认通过
|
||||
providers := m.GetEnabledProviders()
|
||||
if len(providers) == 0 {
|
||||
return false, errors.New("no OAuth providers configured")
|
||||
}
|
||||
// 尝试任一 provider 的 userinfo 端点验证
|
||||
tokenObj := &OAuthToken{AccessToken: token}
|
||||
for _, p := range providers {
|
||||
if _, err := m.GetUserInfo(p.Provider, tokenObj); err == nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ValidateTokenWithProvider 通过指定 provider 验证令牌
|
||||
func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) {
|
||||
if token == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cfg, ok := m.GetConfig(provider)
|
||||
if !ok || cfg.ClientID == "" {
|
||||
return false, fmt.Errorf("provider %s not configured", provider)
|
||||
}
|
||||
|
||||
// 通过 provider 的 userinfo 端点验证 token
|
||||
tokenObj := &OAuthToken{AccessToken: token}
|
||||
_, err := m.GetUserInfo(provider, tokenObj)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetEnabledProviders 获取已启用的OAuth提供商
|
||||
func (m *DefaultOAuthManager) GetEnabledProviders() []OAuthProviderInfo {
|
||||
providerNames := map[OAuthProvider]string{
|
||||
OAuthProviderGoogle: "Google",
|
||||
OAuthProviderWeChat: "微信",
|
||||
OAuthProviderQQ: "QQ",
|
||||
OAuthProviderWeibo: "微博",
|
||||
OAuthProviderFacebook: "Facebook",
|
||||
OAuthProviderTwitter: "Twitter",
|
||||
OAuthProviderGitHub: "GitHub",
|
||||
OAuthProviderAlipay: "支付宝",
|
||||
OAuthProviderDouyin: "抖音",
|
||||
}
|
||||
|
||||
var result []OAuthProviderInfo
|
||||
for provider, entry := range m.entries {
|
||||
name := providerNames[provider]
|
||||
if name == "" {
|
||||
name = string(provider)
|
||||
}
|
||||
result = append(result, OAuthProviderInfo{
|
||||
Provider: provider,
|
||||
Enabled: entry.config != nil,
|
||||
Name: name,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user