package auth import ( "context" "errors" "fmt" "net/url" "github.com/user-management-system/internal/auth/providers" ) // OAuthProvider OAuth提供商类型 type OAuthProvider string const ( OAuthProviderWeChat OAuthProvider = "wechat" OAuthProviderQQ OAuthProvider = "qq" OAuthProviderWeibo OAuthProvider = "weibo" OAuthProviderGoogle OAuthProvider = "google" OAuthProviderFacebook OAuthProvider = "facebook" OAuthProviderTwitter OAuthProvider = "twitter" OAuthProviderGitHub OAuthProvider = "github" OAuthProviderAlipay OAuthProvider = "alipay" OAuthProviderDouyin OAuthProvider = "douyin" ) // OAuthUser OAuth用户信息 type OAuthUser struct { Provider OAuthProvider `json:"provider"` OpenID string `json:"open_id"` UnionID string `json:"union_id,omitempty"` Nickname string `json:"nickname"` Avatar string `json:"avatar"` Gender string `json:"gender,omitempty"` Email string `json:"email,omitempty"` Phone string `json:"phone,omitempty"` Extra map[string]interface{} `json:"extra,omitempty"` } // OAuthToken OAuth令牌 type OAuthToken struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresIn int64 `json:"expires_in"` TokenType string `json:"token_type"` OpenID string `json:"open_id,omitempty"` // 微信等需要 openid } // OAuthConfig OAuth配置 type OAuthConfig struct { ClientID string `json:"client_id"` ClientSecret string `json:"client_secret"` RedirectURI string `json:"redirect_uri"` Scope string `json:"scope"` AuthURL string `json:"auth_url"` TokenURL string `json:"token_url"` UserInfoURL string `json:"user_info_url"` } // OAuthManager OAuth管理器接口 type OAuthManager interface { // GetAuthURL 获取授权URL GetAuthURL(provider OAuthProvider, state string) (string, error) // ExchangeCode 换取访问令牌 ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) // GetUserInfo 获取用户信息 GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) // ValidateToken 验证令牌 ValidateToken(token string) (bool, error) // GetConfig 获取OAuth配置 GetConfig(provider OAuthProvider) (*OAuthConfig, bool) // GetEnabledProviders 获取已启用的OAuth提供商 GetEnabledProviders() []OAuthProviderInfo } // OAuthProviderInfo OAuth提供商信息 type OAuthProviderInfo struct { Provider OAuthProvider `json:"provider"` Enabled bool `json:"enabled"` Name string `json:"name"` } // providerEntry 内部 provider 条目 type providerEntry struct { config *OAuthConfig google *providers.GoogleProvider wechat *providers.WeChatProvider wechatRedir string qq *providers.QQProvider github *providers.GitHubProvider alipay *providers.AlipayProvider douyin *providers.DouyinProvider } // DefaultOAuthManager 默认OAuth管理器(集成真实 provider HTTP 调用) type DefaultOAuthManager struct { entries map[OAuthProvider]*providerEntry } // NewOAuthManager 创建OAuth管理器 func NewOAuthManager() *DefaultOAuthManager { return &DefaultOAuthManager{ entries: make(map[OAuthProvider]*providerEntry), } } // RegisterProvider 注册OAuth提供商(保留旧接口,仅存储配置) func (m *DefaultOAuthManager) RegisterProvider(provider OAuthProvider, config *OAuthConfig) { entry := &providerEntry{config: config} switch provider { case OAuthProviderGoogle: entry.google = providers.NewGoogleProvider(config.ClientID, config.ClientSecret, config.RedirectURI) case OAuthProviderWeChat: entry.wechat = providers.NewWeChatProvider(config.ClientID, config.ClientSecret, "web") entry.wechatRedir = config.RedirectURI case OAuthProviderQQ: entry.qq = providers.NewQQProvider(config.ClientID, config.ClientSecret, config.RedirectURI) case OAuthProviderGitHub: entry.github = providers.NewGitHubProvider(config.ClientID, config.ClientSecret, config.RedirectURI) case OAuthProviderAlipay: // 支付宝使用 ClientID 存储 AppID,ClientSecret 存储 RSA 私钥 entry.alipay = providers.NewAlipayProvider(config.ClientID, config.ClientSecret, config.RedirectURI, false) case OAuthProviderDouyin: entry.douyin = providers.NewDouyinProvider(config.ClientID, config.ClientSecret, config.RedirectURI) } m.entries[provider] = entry } // GetConfig 获取OAuth配置 func (m *DefaultOAuthManager) GetConfig(provider OAuthProvider) (*OAuthConfig, bool) { entry, ok := m.entries[provider] if !ok { return nil, false } return entry.config, true } // GetAuthURL 获取授权URL(使用真实 provider 实现) func (m *DefaultOAuthManager) GetAuthURL(provider OAuthProvider, state string) (string, error) { entry, ok := m.entries[provider] if !ok { return "", ErrOAuthProviderNotSupported } switch provider { case OAuthProviderGoogle: if entry.google != nil { resp, err := entry.google.GetAuthURL(state) if err != nil { return "", err } return resp.URL, nil } case OAuthProviderWeChat: if entry.wechat != nil { resp, err := entry.wechat.GetAuthURL(entry.wechatRedir, state) if err != nil { return "", err } return resp.URL, nil } case OAuthProviderQQ: if entry.qq != nil { resp, err := entry.qq.GetAuthURL(state) if err != nil { return "", err } return resp.URL, nil } case OAuthProviderGitHub: if entry.github != nil { return entry.github.GetAuthURL(state) } case OAuthProviderAlipay: if entry.alipay != nil { return entry.alipay.GetAuthURL(state) } case OAuthProviderDouyin: if entry.douyin != nil { return entry.douyin.GetAuthURL(state) } } // 通用 fallback:按标准 OAuth2 拼接 URL(对 QQ/微博/Twitter/Facebook) config := entry.config if config == nil { return "", ErrOAuthProviderNotSupported } return fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s", config.AuthURL, url.QueryEscape(config.ClientID), url.QueryEscape(config.RedirectURI), url.QueryEscape(config.Scope), url.QueryEscape(state), ), nil } // ExchangeCode 换取访问令牌(使用真实 provider 实现) func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) { entry, ok := m.entries[provider] if !ok { return nil, ErrOAuthProviderNotSupported } ctx := context.Background() switch provider { case OAuthProviderGoogle: if entry.google != nil { resp, err := entry.google.ExchangeCode(ctx, code) if err != nil { return nil, err } return &OAuthToken{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ExpiresIn: int64(resp.ExpiresIn), TokenType: resp.TokenType, }, nil } case OAuthProviderWeChat: if entry.wechat != nil { resp, err := entry.wechat.ExchangeCode(ctx, code) if err != nil { return nil, err } return &OAuthToken{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ExpiresIn: int64(resp.ExpiresIn), TokenType: "Bearer", OpenID: resp.OpenID, }, nil } case OAuthProviderQQ: if entry.qq != nil { resp, err := entry.qq.ExchangeCode(ctx, code) if err != nil { return nil, err } openIDResp, err := entry.qq.GetOpenID(ctx, resp.AccessToken) if err != nil { return nil, err } return &OAuthToken{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ExpiresIn: int64(resp.ExpiresIn), TokenType: "Bearer", OpenID: openIDResp.OpenID, }, nil } case OAuthProviderGitHub: if entry.github != nil { resp, err := entry.github.ExchangeCode(ctx, code) if err != nil { return nil, err } return &OAuthToken{ AccessToken: resp.AccessToken, TokenType: resp.TokenType, }, nil } case OAuthProviderAlipay: if entry.alipay != nil { resp, err := entry.alipay.ExchangeCode(ctx, code) if err != nil { return nil, err } return &OAuthToken{ AccessToken: resp.AccessToken, RefreshToken: resp.RefreshToken, ExpiresIn: int64(resp.ExpiresIn), TokenType: "Bearer", OpenID: resp.UserID, }, nil } case OAuthProviderDouyin: if entry.douyin != nil { resp, err := entry.douyin.ExchangeCode(ctx, code) if err != nil { return nil, err } return &OAuthToken{ AccessToken: resp.Data.AccessToken, RefreshToken: resp.Data.RefreshToken, ExpiresIn: int64(resp.Data.ExpiresIn), TokenType: "Bearer", OpenID: resp.Data.OpenID, }, nil } } return nil, fmt.Errorf("provider %s: real HTTP exchange not implemented yet", provider) } // GetUserInfo 获取用户信息(使用真实 provider 实现) func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) { entry, ok := m.entries[provider] if !ok { return nil, ErrOAuthProviderNotSupported } ctx := context.Background() switch provider { case OAuthProviderGoogle: if entry.google != nil { info, err := entry.google.GetUserInfo(ctx, token.AccessToken) if err != nil { return nil, err } return &OAuthUser{ Provider: provider, OpenID: info.ID, Nickname: info.Name, Avatar: info.Picture, Email: info.Email, }, nil } case OAuthProviderWeChat: if entry.wechat != nil { openID := token.OpenID info, err := entry.wechat.GetUserInfo(ctx, token.AccessToken, openID) if err != nil { return nil, err } gender := "" switch info.Sex { case 1: gender = "male" case 2: gender = "female" } return &OAuthUser{ Provider: provider, OpenID: info.OpenID, UnionID: info.UnionID, Nickname: info.Nickname, Avatar: info.HeadImgURL, Gender: gender, }, nil } case OAuthProviderQQ: if entry.qq != nil { info, err := entry.qq.GetUserInfo(ctx, token.AccessToken, token.OpenID) if err != nil { return nil, err } avatar := info.FigureURL2 if avatar == "" { avatar = info.FigureURL1 } if avatar == "" { avatar = info.FigureURL } return &OAuthUser{ Provider: provider, OpenID: token.OpenID, Nickname: info.Nickname, Avatar: avatar, Gender: info.Gender, Extra: map[string]interface{}{ "province": info.Province, "city": info.City, "year": info.Year, }, }, nil } case OAuthProviderGitHub: if entry.github != nil { info, err := entry.github.GetUserInfo(ctx, token.AccessToken) if err != nil { return nil, err } nickname := info.Name if nickname == "" { nickname = info.Login } return &OAuthUser{ Provider: provider, OpenID: fmt.Sprintf("%d", info.ID), Nickname: nickname, Email: info.Email, }, nil } case OAuthProviderAlipay: if entry.alipay != nil { info, err := entry.alipay.GetUserInfo(ctx, token.AccessToken) if err != nil { return nil, err } return &OAuthUser{ Provider: provider, OpenID: info.UserID, Nickname: info.Nickname, Avatar: info.Avatar, }, nil } case OAuthProviderDouyin: if entry.douyin != nil { info, err := entry.douyin.GetUserInfo(ctx, token.AccessToken, token.OpenID) if err != nil { return nil, err } gender := "" switch info.Data.Gender { case 1: gender = "male" case 2: gender = "female" } return &OAuthUser{ Provider: provider, OpenID: info.Data.OpenID, UnionID: info.Data.UnionID, Nickname: info.Data.Nickname, Avatar: info.Data.Avatar, Gender: gender, }, nil } } return nil, fmt.Errorf("provider %s: real HTTP user info not implemented yet", provider) } // ValidateToken 验证令牌 // 注意:由于 ValidateToken 不持有 provider 上下文,无法进行真正的 token 验证 // 对于需要验证 token 的场景,应使用 GetUserInfo 通过 provider 的 userinfo 端点验证 // 如果没有可用的 provider,返回错误 func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) { if len(token) == 0 { return false, nil } // 由于缺乏 provider 上下文,无法进行有意义的验证 // 遍历所有已启用的 provider,尝试通过 GetUserInfo 验证 // 如果没有任何 provider 可用,返回错误而不是默认通过 providers := m.GetEnabledProviders() if len(providers) == 0 { return false, errors.New("no OAuth providers configured") } // 尝试任一 provider 的 userinfo 端点验证 tokenObj := &OAuthToken{AccessToken: token} for _, p := range providers { if _, err := m.GetUserInfo(p.Provider, tokenObj); err == nil { return true, nil } } return false, nil } // ValidateTokenWithProvider 通过指定 provider 验证令牌 func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) { if token == "" { return false, nil } cfg, ok := m.GetConfig(provider) if !ok || cfg.ClientID == "" { return false, fmt.Errorf("provider %s not configured", provider) } // 通过 provider 的 userinfo 端点验证 token tokenObj := &OAuthToken{AccessToken: token} _, err := m.GetUserInfo(provider, tokenObj) if err != nil { return false, err } return true, nil } // GetEnabledProviders 获取已启用的OAuth提供商 func (m *DefaultOAuthManager) GetEnabledProviders() []OAuthProviderInfo { providerNames := map[OAuthProvider]string{ OAuthProviderGoogle: "Google", OAuthProviderWeChat: "微信", OAuthProviderQQ: "QQ", OAuthProviderWeibo: "微博", OAuthProviderFacebook: "Facebook", OAuthProviderTwitter: "Twitter", OAuthProviderGitHub: "GitHub", OAuthProviderAlipay: "支付宝", OAuthProviderDouyin: "抖音", } var result []OAuthProviderInfo for provider, entry := range m.entries { name := providerNames[provider] if name == "" { name = string(provider) } result = append(result, OAuthProviderInfo{ Provider: provider, Enabled: entry.config != nil, Name: name, }) } return result }