diff --git a/frontend/admin/src/app/providers/AuthProvider.tsx b/frontend/admin/src/app/providers/AuthProvider.tsx index 64a57f1..072aa06 100644 --- a/frontend/admin/src/app/providers/AuthProvider.tsx +++ b/frontend/admin/src/app/providers/AuthProvider.tsx @@ -186,7 +186,7 @@ export function AuthProvider({ children }: AuthProviderProps) { user: effectiveUser, roles: effectiveRoles, isAdmin, - isAuthenticated: effectiveUser !== null && isAuthenticated(), + isAuthenticated: effectiveUser !== null, isLoading, onLoginSuccess, logout, diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index b7af02d..cd053e3 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -63,10 +63,10 @@ type OAuthManager interface { GetAuthURL(provider OAuthProvider, state string) (string, error) // ExchangeCode 换取访问令牌 - ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) + ExchangeCode(ctx context.Context, provider OAuthProvider, code string) (*OAuthToken, error) // GetUserInfo 获取用户信息 - GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) + GetUserInfo(ctx context.Context, provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) // ValidateToken 验证令牌 ValidateToken(token string) (bool, error) @@ -203,14 +203,12 @@ func (m *DefaultOAuthManager) GetAuthURL(provider OAuthProvider, state string) ( } // ExchangeCode 换取访问令牌(使用真实 provider 实现) -func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) { +func (m *DefaultOAuthManager) ExchangeCode(ctx context.Context, provider OAuthProvider, code string) (*OAuthToken, error) { entry, ok := m.entries[provider] if !ok { return nil, ErrOAuthProviderNotSupported } - ctx := context.Background() - switch provider { case OAuthProviderGoogle: if entry.google != nil { @@ -302,14 +300,12 @@ func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) } // GetUserInfo 获取用户信息(使用真实 provider 实现) -func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) { +func (m *DefaultOAuthManager) GetUserInfo(ctx context.Context, provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) { entry, ok := m.entries[provider] if !ok { return nil, ErrOAuthProviderNotSupported } - ctx := context.Background() - switch provider { case OAuthProviderGoogle: if entry.google != nil { @@ -448,8 +444,9 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) { } // 尝试任一 provider 的 userinfo 端点验证 tokenObj := &OAuthToken{AccessToken: token} + ctx := context.Background() for _, p := range providers { - if _, err := m.GetUserInfo(p.Provider, tokenObj); err == nil { + if _, err := m.GetUserInfo(ctx, p.Provider, tokenObj); err == nil { return true, nil } } @@ -469,7 +466,8 @@ func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, // 通过 provider 的 userinfo 端点验证 token tokenObj := &OAuthToken{AccessToken: token} - _, err := m.GetUserInfo(provider, tokenObj) + ctx := context.Background() + _, err := m.GetUserInfo(ctx, provider, tokenObj) if err != nil { return false, err } diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go index 2b230ef..e3c51d4 100644 --- a/internal/auth/oauth_test.go +++ b/internal/auth/oauth_test.go @@ -137,7 +137,7 @@ func TestDefaultOAuthManager_ExchangeCode(t *testing.T) { m := NewOAuthManager() // Test non-existent provider - _, err := m.ExchangeCode(OAuthProviderGoogle, "test-code") + _, err := m.ExchangeCode(context.Background(), OAuthProviderGoogle, "test-code") if err != ErrOAuthProviderNotSupported { t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err) } @@ -148,7 +148,7 @@ func TestDefaultOAuthManager_GetUserInfo(t *testing.T) { // Test non-existent provider token := &OAuthToken{AccessToken: "test-token"} - _, err := m.GetUserInfo(OAuthProviderGoogle, token) + _, err := m.GetUserInfo(context.Background(), OAuthProviderGoogle, token) if err != ErrOAuthProviderNotSupported { t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err) } @@ -546,7 +546,7 @@ func TestOAuthManager_ExchangeCode_Errors(t *testing.T) { }) // ExchangeCode should attempt HTTP call and fail - _, err := m.ExchangeCode(OAuthProviderGoogle, "test-code") + _, err := m.ExchangeCode(context.Background(), OAuthProviderGoogle, "test-code") // We expect an error because there's no mock server if err == nil { t.Log("ExchangeCode() unexpectedly succeeded - real network may be available") @@ -565,7 +565,7 @@ func TestOAuthManager_GetUserInfo_Errors(t *testing.T) { }) token := &OAuthToken{AccessToken: "test-token"} - _, err := m.GetUserInfo(OAuthProviderGoogle, token) + _, err := m.GetUserInfo(context.Background(), OAuthProviderGoogle, token) // We expect an error because there's no mock server if err == nil { t.Log("GetUserInfo() unexpectedly succeeded - real network may be available") diff --git a/internal/service/auth.go b/internal/service/auth.go index 788e09a..eec0a2e 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -949,12 +949,12 @@ func (s *AuthService) OAuthCallback(ctx context.Context, provider, code string) } oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider))) - token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code)) + token, err := s.oauthManager.ExchangeCode(ctx, oauthProvider, strings.TrimSpace(code)) if err != nil { return nil, err } - oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token) + oauthUser, err := s.oauthManager.GetUserInfo(ctx, oauthProvider, token) if err != nil { return nil, err } @@ -1127,12 +1127,12 @@ func (s *AuthService) OAuthBindCallback(ctx context.Context, userID int64, provi } oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider))) - token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code)) + token, err := s.oauthManager.ExchangeCode(ctx, oauthProvider, strings.TrimSpace(code)) if err != nil { return nil, err } - oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token) + oauthUser, err := s.oauthManager.GetUserInfo(ctx, oauthProvider, token) if err != nil { return nil, err } diff --git a/internal/service/auth_oauth_internal_test.go b/internal/service/auth_oauth_internal_test.go index 2de7604..3d288de 100644 --- a/internal/service/auth_oauth_internal_test.go +++ b/internal/service/auth_oauth_internal_test.go @@ -32,14 +32,14 @@ func (m *mockOAuthManager) GetAuthURL(provider auth.OAuthProvider, state string) return m.authURL, nil } -func (m *mockOAuthManager) ExchangeCode(provider auth.OAuthProvider, code string) (*auth.OAuthToken, error) { +func (m *mockOAuthManager) ExchangeCode(ctx context.Context, provider auth.OAuthProvider, code string) (*auth.OAuthToken, error) { if m.exchangeErr != nil { return nil, m.exchangeErr } return &auth.OAuthToken{AccessToken: "mock-token"}, nil } -func (m *mockOAuthManager) GetUserInfo(provider auth.OAuthProvider, token *auth.OAuthToken) (*auth.OAuthUser, error) { +func (m *mockOAuthManager) GetUserInfo(ctx context.Context, provider auth.OAuthProvider, token *auth.OAuthToken) (*auth.OAuthUser, error) { if m.userInfoErr != nil { return nil, m.userInfoErr }