Files
user-system/internal/auth/oauth_test.go
long-agent 61c19e54ac fix: P1-02 OAuth context propagation and P1-16 AuthProvider double-check
P1-02: OAuth ExchangeCode and GetUserInfo now accept context parameter
       to properly propagate request context to HTTP calls
P1-16: AuthProvider isAuthenticated now uses single source of truth
       (effectiveUser !== null) instead of double-checking both
       React state and module-level function
2026-04-18 19:40:54 +08:00

619 lines
18 KiB
Go

package auth
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestNewOAuthManager(t *testing.T) {
m := NewOAuthManager()
if m == nil {
t.Fatal("NewOAuthManager() returned nil")
}
if m.entries == nil {
t.Error("NewOAuthManager() did not initialize entries map")
}
}
func TestDefaultOAuthManager_RegisterProvider(t *testing.T) {
m := NewOAuthManager()
config := &OAuthConfig{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURI: "https://example.com/callback",
Scope: "openid email",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
UserInfoURL: "https://example.com/userinfo",
}
m.RegisterProvider(OAuthProviderGoogle, config)
// Verify provider was registered
if len(m.entries) != 1 {
t.Errorf("Expected 1 entry, got %d", len(m.entries))
}
entry, ok := m.entries[OAuthProviderGoogle]
if !ok {
t.Fatal("Google provider not found in entries")
}
if entry.config == nil {
t.Error("Config not set for Google provider")
}
if entry.google == nil {
t.Error("Google provider instance not created")
}
}
func TestDefaultOAuthManager_GetConfig(t *testing.T) {
m := NewOAuthManager()
// Test non-existent provider
_, ok := m.GetConfig(OAuthProviderGoogle)
if ok {
t.Error("GetConfig() should return false for non-existent provider")
}
// Register and test
config := &OAuthConfig{
ClientID: "test-id",
Scope: "openid",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
UserInfoURL: "https://example.com/userinfo",
}
m.RegisterProvider(OAuthProviderGoogle, config)
retrieved, ok := m.GetConfig(OAuthProviderGoogle)
if !ok {
t.Fatal("GetConfig() should return true for registered provider")
}
if retrieved.ClientID != "test-id" {
t.Errorf("ClientID = %s, want test-id", retrieved.ClientID)
}
}
func TestDefaultOAuthManager_GetAuthURL(t *testing.T) {
m := NewOAuthManager()
// Test non-existent provider
_, err := m.GetAuthURL(OAuthProviderGoogle, "test-state")
if err != ErrOAuthProviderNotSupported {
t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err)
}
// Register Google provider
config := &OAuthConfig{
ClientID: "google-client-id",
ClientSecret: "google-secret",
RedirectURI: "https://example.com/callback",
Scope: "openid email",
}
m.RegisterProvider(OAuthProviderGoogle, config)
// GetAuthURL should work (though it may fail to make actual HTTP call)
// We just verify the method is called
_, err = m.GetAuthURL(OAuthProviderGoogle, "test-state")
// The call will attempt to use the Google provider
// We can't test the actual URL without a mock server
_ = err // Ignore error for this test
}
func TestDefaultOAuthManager_GetAuthURL_Fallback(t *testing.T) {
m := NewOAuthManager()
// Register a provider without specific implementation (e.g., Facebook)
config := &OAuthConfig{
ClientID: "facebook-id",
ClientSecret: "facebook-secret",
RedirectURI: "https://example.com/callback",
Scope: "email",
AuthURL: "https://facebook.com/dialog/oauth",
}
m.RegisterProvider(OAuthProviderFacebook, config)
url, err := m.GetAuthURL(OAuthProviderFacebook, "test-state")
if err != nil {
t.Fatalf("GetAuthURL() error = %v", err)
}
// Should use fallback URL generation
if url == "" {
t.Error("GetAuthURL() returned empty URL")
}
// URL should contain the auth endpoint
if len(url) < 10 {
t.Errorf("GetAuthURL() returned suspiciously short URL: %s", url)
}
}
func TestDefaultOAuthManager_ExchangeCode(t *testing.T) {
m := NewOAuthManager()
// Test non-existent provider
_, err := m.ExchangeCode(context.Background(), OAuthProviderGoogle, "test-code")
if err != ErrOAuthProviderNotSupported {
t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err)
}
}
func TestDefaultOAuthManager_GetUserInfo(t *testing.T) {
m := NewOAuthManager()
// Test non-existent provider
token := &OAuthToken{AccessToken: "test-token"}
_, err := m.GetUserInfo(context.Background(), OAuthProviderGoogle, token)
if err != ErrOAuthProviderNotSupported {
t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err)
}
}
func TestDefaultOAuthManager_ValidateToken(t *testing.T) {
m := NewOAuthManager()
// Test empty token
valid, err := m.ValidateToken("")
if valid || err != nil {
t.Errorf("ValidateToken('') = %v, %v, want false, nil", valid, err)
}
// Test with no providers configured
valid, err = m.ValidateToken("some-token")
if valid {
t.Error("ValidateToken() should return false with no providers")
}
if err == nil {
t.Error("ValidateToken() should return error with no providers")
}
}
func TestDefaultOAuthManager_ValidateTokenWithProvider(t *testing.T) {
m := NewOAuthManager()
// Test empty token
valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "")
if valid || err != nil {
t.Errorf("ValidateTokenWithProvider('') = %v, %v, want false, nil", valid, err)
}
// Test non-existent provider
valid, err = m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token")
if valid {
t.Error("ValidateTokenWithProvider() should return false for unconfigured provider")
}
if err == nil {
t.Error("ValidateTokenWithProvider() should return error for unconfigured provider")
}
}
func TestDefaultOAuthManager_GetEnabledProviders(t *testing.T) {
m := NewOAuthManager()
// Test empty manager
providers := m.GetEnabledProviders()
if len(providers) != 0 {
t.Errorf("GetEnabledProviders() = %d, want 0", len(providers))
}
// Register some providers
m.RegisterProvider(OAuthProviderGoogle, &OAuthConfig{ClientID: "google"})
m.RegisterProvider(OAuthProviderGitHub, &OAuthConfig{ClientID: "github"})
providers = m.GetEnabledProviders()
if len(providers) != 2 {
t.Errorf("GetEnabledProviders() = %d, want 2", len(providers))
}
// Check that providers have correct info
providerMap := make(map[OAuthProvider]OAuthProviderInfo)
for _, p := range providers {
providerMap[p.Provider] = p
}
if p, ok := providerMap[OAuthProviderGoogle]; !ok || p.Name != "Google" {
t.Error("Google provider info incorrect")
}
if p, ok := providerMap[OAuthProviderGitHub]; !ok || p.Name != "GitHub" {
t.Error("GitHub provider info incorrect")
}
}
func TestDefaultOAuthManager_RegisterAllProviders(t *testing.T) {
m := NewOAuthManager()
providers := []struct {
provider OAuthProvider
config *OAuthConfig
}{
{OAuthProviderGoogle, &OAuthConfig{ClientID: "google", ClientSecret: "secret"}},
{OAuthProviderWeChat, &OAuthConfig{ClientID: "wechat", ClientSecret: "secret"}},
{OAuthProviderQQ, &OAuthConfig{ClientID: "qq", ClientSecret: "secret"}},
{OAuthProviderGitHub, &OAuthConfig{ClientID: "github", ClientSecret: "secret"}},
{OAuthProviderAlipay, &OAuthConfig{ClientID: "alipay", ClientSecret: "secret"}},
{OAuthProviderDouyin, &OAuthConfig{ClientID: "douyin", ClientSecret: "secret"}},
}
for _, tc := range providers {
m.RegisterProvider(tc.provider, tc.config)
}
if len(m.entries) != len(providers) {
t.Errorf("Expected %d entries, got %d", len(providers), len(m.entries))
}
// Verify each provider has appropriate implementation
if m.entries[OAuthProviderGoogle].google == nil {
t.Error("Google provider instance not created")
}
if m.entries[OAuthProviderWeChat].wechat == nil {
t.Error("WeChat provider instance not created")
}
if m.entries[OAuthProviderQQ].qq == nil {
t.Error("QQ provider instance not created")
}
if m.entries[OAuthProviderGitHub].github == nil {
t.Error("GitHub provider instance not created")
}
if m.entries[OAuthProviderAlipay].alipay == nil {
t.Error("Alipay provider instance not created")
}
if m.entries[OAuthProviderDouyin].douyin == nil {
t.Error("Douyin provider instance not created")
}
}
func TestOAuthProviderConstants(t *testing.T) {
providers := []OAuthProvider{
OAuthProviderWeChat,
OAuthProviderQQ,
OAuthProviderWeibo,
OAuthProviderGoogle,
OAuthProviderFacebook,
OAuthProviderTwitter,
OAuthProviderGitHub,
OAuthProviderAlipay,
OAuthProviderDouyin,
}
for _, p := range providers {
if string(p) == "" {
t.Errorf("OAuthProvider constant %v has empty string value", p)
}
}
}
func TestOAuthUser_Struct(t *testing.T) {
user := &OAuthUser{
Provider: OAuthProviderGoogle,
OpenID: "12345",
UnionID: "union-123",
Nickname: "Test User",
Avatar: "https://example.com/avatar.jpg",
Gender: "male",
Email: "test@example.com",
Phone: "+1234567890",
Extra: map[string]interface{}{
"custom_field": "value",
},
}
if user.Provider != OAuthProviderGoogle {
t.Errorf("Provider = %s, want google", user.Provider)
}
if user.OpenID != "12345" {
t.Errorf("OpenID = %s, want 12345", user.OpenID)
}
}
func TestOAuthToken_Struct(t *testing.T) {
token := &OAuthToken{
AccessToken: "access-123",
RefreshToken: "refresh-456",
ExpiresIn: 3600,
TokenType: "Bearer",
OpenID: "openid-789",
}
if token.AccessToken != "access-123" {
t.Errorf("AccessToken = %s, want access-123", token.AccessToken)
}
if token.ExpiresIn != 3600 {
t.Errorf("ExpiresIn = %d, want 3600", token.ExpiresIn)
}
}
func TestOAuthConfig_Struct(t *testing.T) {
config := &OAuthConfig{
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURI: "https://example.com/callback",
Scope: "openid email",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
UserInfoURL: "https://example.com/userinfo",
}
if config.ClientID != "client-id" {
t.Errorf("ClientID = %s, want client-id", config.ClientID)
}
}
// Test that ValidateToken with context cancellation works properly
func TestDefaultOAuthManager_ValidateToken_ContextCancellation(t *testing.T) {
m := NewOAuthManager()
// Register a provider
m.RegisterProvider(OAuthProviderGoogle, &OAuthConfig{
ClientID: "test",
ClientSecret: "test",
RedirectURI: "http://localhost",
})
// This test just verifies the method doesn't panic
// The actual HTTP call will fail, but that's expected
ctx := context.Background()
_ = ctx // Use ctx to avoid unused variable warning
// We can't easily test context cancellation without modifying the implementation
// This is just a placeholder to indicate we've considered it
}
// TestOAuthManager_Integration tests ExchangeCode and GetUserInfo with mock servers
func TestOAuthManager_Integration(t *testing.T) {
t.Run("Google ExchangeCode and GetUserInfo", func(t *testing.T) {
// Create mock token endpoint
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"token_type": "Bearer"
}`))
}))
defer tokenServer.Close()
// Create mock userinfo endpoint
userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"id": "12345",
"name": "Test User",
"email": "test@example.com",
"picture": "https://example.com/avatar.jpg"
}`))
}))
defer userInfoServer.Close()
m := NewOAuthManager()
m.RegisterProvider(OAuthProviderGoogle, &OAuthConfig{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURI: "http://localhost/callback",
Scope: "openid email",
AuthURL: tokenServer.URL + "/auth",
TokenURL: tokenServer.URL + "/token",
UserInfoURL: userInfoServer.URL,
})
// Test ExchangeCode - Note: actual implementation uses Google's real endpoints
// We're just testing the error path when provider is configured
entry, ok := m.entries[OAuthProviderGoogle]
if !ok || entry.google == nil {
t.Fatal("Google provider not configured properly")
}
})
t.Run("GitHub GetAuthURL", func(t *testing.T) {
m := NewOAuthManager()
m.RegisterProvider(OAuthProviderGitHub, &OAuthConfig{
ClientID: "github-client-id",
ClientSecret: "github-secret",
RedirectURI: "http://localhost/callback",
Scope: "user:email",
})
url, err := m.GetAuthURL(OAuthProviderGitHub, "test-state")
if err != nil {
t.Fatalf("GetAuthURL() error = %v", err)
}
if url == "" {
t.Error("GetAuthURL() returned empty URL")
}
if !strings.Contains(url, "github.com") {
t.Errorf("GetAuthURL() URL should contain github.com, got %s", url)
}
})
t.Run("WeChat GetAuthURL", func(t *testing.T) {
m := NewOAuthManager()
m.RegisterProvider(OAuthProviderWeChat, &OAuthConfig{
ClientID: "wechat-appid",
ClientSecret: "wechat-secret",
RedirectURI: "http://localhost/callback",
Scope: "snsapi_login",
})
url, err := m.GetAuthURL(OAuthProviderWeChat, "test-state")
if err != nil {
t.Fatalf("GetAuthURL() error = %v", err)
}
if url == "" {
t.Error("GetAuthURL() returned empty URL")
}
})
t.Run("QQ GetAuthURL", func(t *testing.T) {
m := NewOAuthManager()
m.RegisterProvider(OAuthProviderQQ, &OAuthConfig{
ClientID: "qq-appid",
ClientSecret: "qq-secret",
RedirectURI: "http://localhost/callback",
Scope: "get_user_info",
})
url, err := m.GetAuthURL(OAuthProviderQQ, "test-state")
if err != nil {
t.Fatalf("GetAuthURL() error = %v", err)
}
if url == "" {
t.Error("GetAuthURL() returned empty URL")
}
})
t.Run("Alipay GetAuthURL", func(t *testing.T) {
m := NewOAuthManager()
m.RegisterProvider(OAuthProviderAlipay, &OAuthConfig{
ClientID: "alipay-appid",
ClientSecret: "alipay-private-key",
RedirectURI: "http://localhost/callback",
Scope: "auth_user",
})
url, err := m.GetAuthURL(OAuthProviderAlipay, "test-state")
if err != nil {
t.Fatalf("GetAuthURL() error = %v", err)
}
if url == "" {
t.Error("GetAuthURL() returned empty URL")
}
})
t.Run("Douyin GetAuthURL", func(t *testing.T) {
m := NewOAuthManager()
m.RegisterProvider(OAuthProviderDouyin, &OAuthConfig{
ClientID: "douyin-client-key",
ClientSecret: "douyin-secret",
RedirectURI: "http://localhost/callback",
Scope: "user_info",
})
url, err := m.GetAuthURL(OAuthProviderDouyin, "test-state")
if err != nil {
t.Fatalf("GetAuthURL() error = %v", err)
}
if url == "" {
t.Error("GetAuthURL() returned empty URL")
}
})
}
// TestOAuthManager_FallbackURL tests fallback URL generation for unsupported providers
func TestOAuthManager_FallbackURL(t *testing.T) {
m := NewOAuthManager()
// Test with provider that doesn't have specific implementation (e.g., Twitter)
m.RegisterProvider(OAuthProviderTwitter, &OAuthConfig{
ClientID: "twitter-client-id",
ClientSecret: "twitter-secret",
RedirectURI: "http://localhost/callback",
Scope: "tweet.read",
AuthURL: "https://twitter.com/i/oauth2/authorize",
})
url, err := m.GetAuthURL(OAuthProviderTwitter, "test-state")
if err != nil {
t.Fatalf("GetAuthURL() error = %v", err)
}
// Should use fallback URL generation
if !strings.Contains(url, "client_id=twitter-client-id") {
t.Errorf("Fallback URL should contain client_id, got %s", url)
}
if !strings.Contains(url, "redirect_uri=") {
t.Errorf("Fallback URL should contain redirect_uri, got %s", url)
}
if !strings.Contains(url, "state=test-state") {
t.Errorf("Fallback URL should contain state, got %s", url)
}
}
// TestOAuthManager_ExchangeCode_Errors tests error handling in ExchangeCode
func TestOAuthManager_ExchangeCode_Errors(t *testing.T) {
m := NewOAuthManager()
// Register Google provider - will fail to connect to real endpoint
m.RegisterProvider(OAuthProviderGoogle, &OAuthConfig{
ClientID: "test-id",
ClientSecret: "test-secret",
RedirectURI: "http://localhost",
})
// ExchangeCode should attempt HTTP call and fail
_, err := m.ExchangeCode(context.Background(), OAuthProviderGoogle, "test-code")
// We expect an error because there's no mock server
if err == nil {
t.Log("ExchangeCode() unexpectedly succeeded - real network may be available")
}
}
// TestOAuthManager_GetUserInfo_Errors tests error handling in GetUserInfo
func TestOAuthManager_GetUserInfo_Errors(t *testing.T) {
m := NewOAuthManager()
// Register provider - will fail to connect
m.RegisterProvider(OAuthProviderGoogle, &OAuthConfig{
ClientID: "test-id",
ClientSecret: "test-secret",
RedirectURI: "http://localhost",
})
token := &OAuthToken{AccessToken: "test-token"}
_, err := m.GetUserInfo(context.Background(), OAuthProviderGoogle, token)
// We expect an error because there's no mock server
if err == nil {
t.Log("GetUserInfo() unexpectedly succeeded - real network may be available")
}
}
// TestOAuthManager_ValidateToken_WithProviders tests ValidateToken with registered providers
func TestOAuthManager_ValidateToken_WithProviders(t *testing.T) {
m := NewOAuthManager()
// Register a provider
m.RegisterProvider(OAuthProviderGoogle, &OAuthConfig{
ClientID: "test-id",
ClientSecret: "test-secret",
RedirectURI: "http://localhost",
})
// ValidateToken will try GetUserInfo which will fail
valid, err := m.ValidateToken("some-token")
// Should return false without error (graceful failure)
if valid {
t.Error("ValidateToken() should return false for invalid token")
}
// err should be nil because the function handles errors gracefully
if err != nil {
t.Logf("ValidateToken() returned error: %v", err)
}
}
// TestOAuthManager_ValidateTokenWithProvider_WithConfig tests ValidateTokenWithProvider with configuration
func TestOAuthManager_ValidateTokenWithProvider_WithConfig(t *testing.T) {
m := NewOAuthManager()
// Register a provider
m.RegisterProvider(OAuthProviderGoogle, &OAuthConfig{
ClientID: "test-id",
ClientSecret: "test-secret",
RedirectURI: "http://localhost",
})
// ValidateTokenWithProvider will try GetUserInfo which will fail
valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token")
// Should return false
if valid {
t.Error("ValidateTokenWithProvider() should return false for invalid token")
}
if err == nil {
t.Log("ValidateTokenWithProvider() returned no error - graceful failure")
}
}