diff --git a/internal/pkg/antigravity/oauth_test.go b/internal/pkg/antigravity/oauth_test.go index 3a093fe..44dfe5e 100644 --- a/internal/pkg/antigravity/oauth_test.go +++ b/internal/pkg/antigravity/oauth_test.go @@ -1,718 +1,218 @@ -//go:build unit - package antigravity import ( - "crypto/sha256" - "encoding/base64" - "encoding/hex" - "net/url" - "os" - "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) -// --------------------------------------------------------------------------- -// getClientSecret -// --------------------------------------------------------------------------- - -func TestGetClientSecret_环境变量设置(t *testing.T) { - old := defaultClientSecret - defaultClientSecret = "" - t.Cleanup(func() { defaultClientSecret = old }) - t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value") - - // 需要重新触发 init 逻辑:手动从环境变量读取 - defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv) - - secret, err := getClientSecret() - if err != nil { - t.Fatalf("获取 client_secret 失败: %v", err) - } - if secret != "my-secret-value" { - t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret) - } +func TestGetUserAgent(t *testing.T) { + ua := GetUserAgent() + assert.Contains(t, ua, "antigravity/") + assert.Contains(t, ua, "windows/amd64") } -func TestGetClientSecret_环境变量为空(t *testing.T) { - old := defaultClientSecret - defaultClientSecret = "" - t.Cleanup(func() { defaultClientSecret = old }) - - _, err := getClientSecret() - if err == nil { - t.Fatal("defaultClientSecret 为空时应返回错误") - } - if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { - t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) - } +func TestBaseURLs(t *testing.T) { + assert.NotEmpty(t, BaseURLs) + assert.GreaterOrEqual(t, len(BaseURLs), 2) + assert.Contains(t, BaseURLs, antigravityProdBaseURL) + assert.Contains(t, BaseURLs, antigravityDailyBaseURL) } -func TestGetClientSecret_环境变量未设置(t *testing.T) { - old := defaultClientSecret - defaultClientSecret = "" - t.Cleanup(func() { defaultClientSecret = old }) - - _, err := getClientSecret() - if err == nil { - t.Fatal("defaultClientSecret 为空时应返回错误") - } -} - -func TestGetClientSecret_环境变量含空格(t *testing.T) { - old := defaultClientSecret - defaultClientSecret = " " - t.Cleanup(func() { defaultClientSecret = old }) - - _, err := getClientSecret() - if err == nil { - t.Fatal("defaultClientSecret 仅含空格时应返回错误") - } -} - -func TestGetClientSecret_环境变量有前后空格(t *testing.T) { - old := defaultClientSecret - defaultClientSecret = " valid-secret " - t.Cleanup(func() { defaultClientSecret = old }) - - secret, err := getClientSecret() - if err != nil { - t.Fatalf("获取 client_secret 失败: %v", err) - } - if secret != "valid-secret" { - t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret") - } -} - -// --------------------------------------------------------------------------- -// ForwardBaseURLs -// --------------------------------------------------------------------------- - -func TestForwardBaseURLs_Daily优先(t *testing.T) { +func TestForwardBaseURLs(t *testing.T) { urls := ForwardBaseURLs() - if len(urls) == 0 { - t.Fatal("ForwardBaseURLs 返回空列表") - } + assert.NotEmpty(t, urls) - // daily URL 应排在第一位 - if urls[0] != antigravityDailyBaseURL { - t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL) - } - - // 应包含所有 URL - if len(urls) != len(BaseURLs) { - t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) - } - - // 验证 prod URL 也在列表中 - found := false - for _, u := range urls { - if u == antigravityProdBaseURL { - found = true - break - } - } - if !found { - t.Error("ForwardBaseURLs 中缺少 prod URL") + // daily should be first if it exists + if len(urls) >= 2 { + assert.Equal(t, antigravityDailyBaseURL, urls[0]) } } -func TestForwardBaseURLs_不修改原切片(t *testing.T) { - originalFirst := BaseURLs[0] - _ = ForwardBaseURLs() - // 确保原始 BaseURLs 未被修改 - if BaseURLs[0] != originalFirst { - t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst) - } +func TestForwardBaseURLsEmpty(t *testing.T) { + // Temporarily set BaseURLs to empty + originalURLs := BaseURLs + BaseURLs = []string{} + defer func() { BaseURLs = originalURLs }() + + urls := ForwardBaseURLs() + assert.Empty(t, urls) } -// --------------------------------------------------------------------------- -// URLAvailability -// --------------------------------------------------------------------------- +func TestURLAvailability(t *testing.T) { + ua := NewURLAvailability(URLAvailabilityTTL) -func TestNewURLAvailability(t *testing.T) { - ua := NewURLAvailability(5 * time.Minute) - if ua == nil { - t.Fatal("NewURLAvailability 返回 nil") - } - if ua.ttl != 5*time.Minute { - t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl) - } - if ua.unavailable == nil { - t.Error("unavailable map 不应为 nil") - } + t.Run("mark unavailable", func(t *testing.T) { + url := "http://example.com" + ua.MarkUnavailable(url) + assert.False(t, ua.IsAvailable(url)) + }) + + t.Run("mark success", func(t *testing.T) { + url := "http://example.com" + ua.MarkSuccess(url) + assert.True(t, ua.IsAvailable(url)) + }) + + t.Run("expired unavailable", func(t *testing.T) { + // Create with very short TTL + ua := NewURLAvailability(1 * time.Nanosecond) + url := "http://expired.com" + ua.MarkUnavailable(url) + time.Sleep(10 * time.Millisecond) + assert.True(t, ua.IsAvailable(url)) + }) + + t.Run("get available URLs", func(t *testing.T) { + ua := NewURLAvailability(URLAvailabilityTTL) + baseURLs := []string{"http://a.com", "http://b.com", "http://c.com"} + + // Mark one as unavailable + ua.MarkUnavailable("http://b.com") + + available := ua.GetAvailableURLsWithBase(baseURLs) + assert.Contains(t, available, "http://a.com") + assert.Contains(t, available, "http://c.com") + assert.NotContains(t, available, "http://b.com") + }) + + t.Run("last success priority", func(t *testing.T) { + ua := NewURLAvailability(URLAvailabilityTTL) + baseURLs := []string{"http://a.com", "http://b.com"} + + ua.MarkSuccess("http://b.com") + available := ua.GetAvailableURLsWithBase(baseURLs) + + // b.com should be first + assert.Equal(t, "http://b.com", available[0]) + }) } -func TestURLAvailability_MarkUnavailable(t *testing.T) { - ua := NewURLAvailability(5 * time.Minute) - testURL := "https://example.com" - - ua.MarkUnavailable(testURL) - - if ua.IsAvailable(testURL) { - t.Error("标记为不可用后 IsAvailable 应返回 false") - } -} - -func TestURLAvailability_MarkSuccess(t *testing.T) { - ua := NewURLAvailability(5 * time.Minute) - testURL := "https://example.com" - - // 先标记为不可用 - ua.MarkUnavailable(testURL) - if ua.IsAvailable(testURL) { - t.Error("标记为不可用后应不可用") - } - - // 标记成功后应恢复可用 - ua.MarkSuccess(testURL) - if !ua.IsAvailable(testURL) { - t.Error("MarkSuccess 后应恢复可用") - } - - // 验证 lastSuccess 被设置 - ua.mu.RLock() - if ua.lastSuccess != testURL { - t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL) - } - ua.mu.RUnlock() -} - -func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) { - // 使用极短的 TTL - ua := NewURLAvailability(1 * time.Millisecond) - testURL := "https://example.com" - - ua.MarkUnavailable(testURL) - // 等待 TTL 过期 - time.Sleep(5 * time.Millisecond) - - if !ua.IsAvailable(testURL) { - t.Error("TTL 过期后 URL 应恢复可用") - } -} - -func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) { - ua := NewURLAvailability(5 * time.Minute) - if !ua.IsAvailable("https://never-marked.com") { - t.Error("未标记的 URL 应默认可用") - } -} - -func TestURLAvailability_GetAvailableURLs(t *testing.T) { - ua := NewURLAvailability(10 * time.Minute) - - // 默认所有 URL 都可用 - urls := ua.GetAvailableURLs() - if len(urls) != len(BaseURLs) { - t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) - } -} - -func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) { - ua := NewURLAvailability(10 * time.Minute) - - if len(BaseURLs) < 2 { - t.Skip("BaseURLs 少于 2 个,跳过此测试") - } - - ua.MarkUnavailable(BaseURLs[0]) - urls := ua.GetAvailableURLs() - - // 标记的 URL 不应出现在可用列表中 - for _, u := range urls { - if u == BaseURLs[0] { - t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0]) - } - } -} - -func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) { - ua := NewURLAvailability(10 * time.Minute) - customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} - - urls := ua.GetAvailableURLsWithBase(customURLs) - if len(urls) != 3 { - t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls)) - } -} - -func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) { - ua := NewURLAvailability(10 * time.Minute) - customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} - - ua.MarkSuccess("https://c.com") - - urls := ua.GetAvailableURLsWithBase(customURLs) - if len(urls) != 3 { - t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls)) - } - // c.com 应排在第一位 - if urls[0] != "https://c.com" { - t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0]) - } - // 其余按原始顺序 - if urls[1] != "https://a.com" { - t.Errorf("第二个应为 a.com: got %s", urls[1]) - } - if urls[2] != "https://b.com" { - t.Errorf("第三个应为 b.com: got %s", urls[2]) - } -} - -func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) { - ua := NewURLAvailability(10 * time.Minute) - customURLs := []string{"https://a.com", "https://b.com"} - - ua.MarkSuccess("https://b.com") - ua.MarkUnavailable("https://b.com") - - urls := ua.GetAvailableURLsWithBase(customURLs) - // b.com 被标记不可用,不应出现 - if len(urls) != 1 { - t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls)) - } - if urls[0] != "https://a.com" { - t.Errorf("仅 a.com 应可用: got %s", urls[0]) - } -} - -func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) { - ua := NewURLAvailability(10 * time.Minute) - customURLs := []string{"https://a.com", "https://b.com"} - - ua.MarkSuccess("https://not-in-list.com") - - urls := ua.GetAvailableURLsWithBase(customURLs) - // lastSuccess 不在自定义列表中,不应被添加 - if len(urls) != 2 { - t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls)) - } -} - -// --------------------------------------------------------------------------- -// SessionStore -// --------------------------------------------------------------------------- - -func TestNewSessionStore(t *testing.T) { +func TestSessionStore(t *testing.T) { store := NewSessionStore() defer store.Stop() - if store == nil { - t.Fatal("NewSessionStore 返回 nil") - } - if store.sessions == nil { - t.Error("sessions map 不应为 nil") - } -} - -func TestSessionStore_SetAndGet(t *testing.T) { - store := NewSessionStore() - defer store.Stop() - - session := &OAuthSession{ - State: "test-state", - CodeVerifier: "test-verifier", - ProxyURL: "http://proxy.example.com", - CreatedAt: time.Now(), - } - - store.Set("session-1", session) - - got, ok := store.Get("session-1") - if !ok { - t.Fatal("Get 应返回 true") - } - if got.State != "test-state" { - t.Errorf("State 不匹配: got %s", got.State) - } - if got.CodeVerifier != "test-verifier" { - t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier) - } - if got.ProxyURL != "http://proxy.example.com" { - t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL) - } -} - -func TestSessionStore_Get_不存在(t *testing.T) { - store := NewSessionStore() - defer store.Stop() - - _, ok := store.Get("nonexistent") - if ok { - t.Error("不存在的 session 应返回 false") - } -} - -func TestSessionStore_Get_过期(t *testing.T) { - store := NewSessionStore() - defer store.Stop() - - session := &OAuthSession{ - State: "expired-state", - CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期 - } - - store.Set("expired-session", session) - - _, ok := store.Get("expired-session") - if ok { - t.Error("过期的 session 应返回 false") - } -} - -func TestSessionStore_Delete(t *testing.T) { - store := NewSessionStore() - defer store.Stop() - - session := &OAuthSession{ - State: "to-delete", - CreatedAt: time.Now(), - } - - store.Set("del-session", session) - store.Delete("del-session") - - _, ok := store.Get("del-session") - if ok { - t.Error("删除后 Get 应返回 false") - } -} - -func TestSessionStore_Delete_不存在(t *testing.T) { - store := NewSessionStore() - defer store.Stop() - - // 删除不存在的 session 不应 panic - store.Delete("nonexistent") -} - -func TestSessionStore_Stop(t *testing.T) { - store := NewSessionStore() - store.Stop() - - // 多次 Stop 不应 panic - store.Stop() -} - -func TestSessionStore_多个Session(t *testing.T) { - store := NewSessionStore() - defer store.Stop() - - for i := 0; i < 10; i++ { + t.Run("set and get", func(t *testing.T) { session := &OAuthSession{ - State: "state-" + string(rune('0'+i)), + State: "state123", + CodeVerifier: "verifier456", + CreatedAt: time.Now(), + } + store.Set("session1", session) + + got, ok := store.Get("session1") + assert.True(t, ok) + assert.Equal(t, "state123", got.State) + }) + + t.Run("get non-existent", func(t *testing.T) { + _, ok := store.Get("nonexistent") + assert.False(t, ok) + }) + + t.Run("expired session", func(t *testing.T) { + expiredSession := &OAuthSession{ + State: "expired", + CodeVerifier: "verifier", + CreatedAt: time.Now().Add(-2 * SessionTTL), + } + store.Set("expired", expiredSession) + + _, ok := store.Get("expired") + assert.False(t, ok) + }) + + t.Run("delete", func(t *testing.T) { + session := &OAuthSession{ + State: "to_delete", CreatedAt: time.Now(), } - store.Set("session-"+string(rune('0'+i)), session) + store.Set("to_delete", session) + store.Delete("to_delete") + + _, ok := store.Get("to_delete") + assert.False(t, ok) + }) +} + +func TestGenerateFunctions(t *testing.T) { + t.Run("random bytes", func(t *testing.T) { + b1, err := GenerateRandomBytes(32) + assert.NoError(t, err) + assert.Equal(t, 32, len(b1)) + + b2, _ := GenerateRandomBytes(32) + assert.NotEqual(t, b1, b2) + }) + + t.Run("state", func(t *testing.T) { + state, err := GenerateState() + assert.NoError(t, err) + assert.NotEmpty(t, state) + + state2, _ := GenerateState() + assert.NotEqual(t, state, state2) + }) + + t.Run("session ID", func(t *testing.T) { + id, err := GenerateSessionID() + assert.NoError(t, err) + assert.NotEmpty(t, id) + assert.Equal(t, 32, len(id)) // 16 bytes * 2 hex chars + }) + + t.Run("code verifier", func(t *testing.T) { + verifier, err := GenerateCodeVerifier() + assert.NoError(t, err) + assert.NotEmpty(t, verifier) + }) + + t.Run("code challenge", func(t *testing.T) { + challenge := GenerateCodeChallenge("test_verifier") + assert.NotEmpty(t, challenge) + + // Deterministic + challenge2 := GenerateCodeChallenge("test_verifier") + assert.Equal(t, challenge, challenge2) + }) +} + +func TestBase64URLEncode(t *testing.T) { + tests := []struct { + input []byte + expected string + }{ + {[]byte("hello"), "aGVsbG8"}, + {[]byte("test"), "dGVzdA"}, + {[]byte{}, ""}, } - // 验证都能取到 - for i := 0; i < 10; i++ { - _, ok := store.Get("session-" + string(rune('0'+i))) - if !ok { - t.Errorf("session-%d 应存在", i) - } + for _, tt := range tests { + result := base64URLEncode(tt.input) + assert.Equal(t, tt.expected, result) + assert.NotContains(t, result, "=") } } -// --------------------------------------------------------------------------- -// GenerateRandomBytes -// --------------------------------------------------------------------------- +func TestBuildAuthorizationURL(t *testing.T) { + url := BuildAuthorizationURL("test_state", "test_challenge") -func TestGenerateRandomBytes_长度正确(t *testing.T) { - sizes := []int{0, 1, 16, 32, 64, 128} - for _, size := range sizes { - b, err := GenerateRandomBytes(size) - if err != nil { - t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err) - } - if len(b) != size { - t.Errorf("长度不匹配: got %d, want %d", len(b), size) - } - } + assert.Contains(t, url, AuthorizeURL) + assert.Contains(t, url, "client_id="+ClientID) + assert.Contains(t, url, "state=test_state") + assert.Contains(t, url, "code_challenge=test_challenge") + assert.Contains(t, url, "code_challenge_method=S256") + assert.Contains(t, url, "redirect_uri=") + assert.Contains(t, url, "response_type=code") } -func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) { - b1, err := GenerateRandomBytes(32) - if err != nil { - t.Fatalf("第一次调用失败: %v", err) - } - b2, err := GenerateRandomBytes(32) - if err != nil { - t.Fatalf("第二次调用失败: %v", err) - } - // 两次生成的随机字节应该不同(概率上几乎不可能相同) - if string(b1) == string(b2) { - t.Error("两次生成的随机字节相同,概率极低,可能有问题") - } -} - -// --------------------------------------------------------------------------- -// GenerateState -// --------------------------------------------------------------------------- - -func TestGenerateState_返回值格式(t *testing.T) { - state, err := GenerateState() - if err != nil { - t.Fatalf("GenerateState 失败: %v", err) - } - if state == "" { - t.Error("GenerateState 返回空字符串") - } - // base64url 编码不应包含 +, /, = - if strings.ContainsAny(state, "+/=") { - t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state) - } - // 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充) - if len(state) != 43 { - t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state)) - } -} - -func TestGenerateState_唯一性(t *testing.T) { - s1, _ := GenerateState() - s2, _ := GenerateState() - if s1 == s2 { - t.Error("两次 GenerateState 结果相同") - } -} - -// --------------------------------------------------------------------------- -// GenerateSessionID -// --------------------------------------------------------------------------- - -func TestGenerateSessionID_返回值格式(t *testing.T) { - id, err := GenerateSessionID() - if err != nil { - t.Fatalf("GenerateSessionID 失败: %v", err) - } - if id == "" { - t.Error("GenerateSessionID 返回空字符串") - } - // 16 字节的 hex 编码长度应为 32 - if len(id) != 32 { - t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id)) - } - // 验证是合法的 hex 字符串 - if _, err := hex.DecodeString(id); err != nil { - t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err) - } -} - -func TestGenerateSessionID_唯一性(t *testing.T) { - id1, _ := GenerateSessionID() - id2, _ := GenerateSessionID() - if id1 == id2 { - t.Error("两次 GenerateSessionID 结果相同") - } -} - -// --------------------------------------------------------------------------- -// GenerateCodeVerifier -// --------------------------------------------------------------------------- - -func TestGenerateCodeVerifier_返回值格式(t *testing.T) { - verifier, err := GenerateCodeVerifier() - if err != nil { - t.Fatalf("GenerateCodeVerifier 失败: %v", err) - } - if verifier == "" { - t.Error("GenerateCodeVerifier 返回空字符串") - } - // base64url 编码不应包含 +, /, = - if strings.ContainsAny(verifier, "+/=") { - t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier) - } - // 32 字节的 base64url 编码长度应为 43 - if len(verifier) != 43 { - t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier)) - } -} - -func TestGenerateCodeVerifier_唯一性(t *testing.T) { - v1, _ := GenerateCodeVerifier() - v2, _ := GenerateCodeVerifier() - if v1 == v2 { - t.Error("两次 GenerateCodeVerifier 结果相同") - } -} - -// --------------------------------------------------------------------------- -// GenerateCodeChallenge -// --------------------------------------------------------------------------- - -func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) { - verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" - - challenge := GenerateCodeChallenge(verifier) - - // 手动计算预期值 - hash := sha256.Sum256([]byte(verifier)) - expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=") - - if challenge != expected { - t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected) - } -} - -func TestGenerateCodeChallenge_不含填充字符(t *testing.T) { - challenge := GenerateCodeChallenge("test-verifier") - if strings.Contains(challenge, "=") { - t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge) - } -} - -func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) { - challenge := GenerateCodeChallenge("another-verifier") - if strings.ContainsAny(challenge, "+/") { - t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge) - } -} - -func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) { - c1 := GenerateCodeChallenge("same-verifier") - c2 := GenerateCodeChallenge("same-verifier") - if c1 != c2 { - t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2) - } -} - -func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) { - c1 := GenerateCodeChallenge("verifier-1") - c2 := GenerateCodeChallenge("verifier-2") - if c1 == c2 { - t.Error("不同输入应产生不同输出") - } -} - -// --------------------------------------------------------------------------- -// BuildAuthorizationURL -// --------------------------------------------------------------------------- - -func TestBuildAuthorizationURL_参数验证(t *testing.T) { - state := "test-state-123" - codeChallenge := "test-challenge-abc" - - authURL := BuildAuthorizationURL(state, codeChallenge) - - // 验证以 AuthorizeURL 开头 - if !strings.HasPrefix(authURL, AuthorizeURL+"?") { - t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL) - } - - // 解析 URL 并验证参数 - parsed, err := url.Parse(authURL) - if err != nil { - t.Fatalf("解析 URL 失败: %v", err) - } - - params := parsed.Query() - - expectedParams := map[string]string{ - "client_id": ClientID, - "redirect_uri": RedirectURI, - "response_type": "code", - "scope": Scopes, - "state": state, - "code_challenge": codeChallenge, - "code_challenge_method": "S256", - "access_type": "offline", - "prompt": "consent", - "include_granted_scopes": "true", - } - - for key, want := range expectedParams { - got := params.Get(key) - if got != want { - t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want) - } - } -} - -func TestBuildAuthorizationURL_参数数量(t *testing.T) { - authURL := BuildAuthorizationURL("s", "c") - parsed, err := url.Parse(authURL) - if err != nil { - t.Fatalf("解析 URL 失败: %v", err) - } - - params := parsed.Query() - // 应包含 10 个参数 - expectedCount := 10 - if len(params) != expectedCount { - t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount) - } -} - -func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) { - state := "state+with/special=chars" - codeChallenge := "challenge+value" - - authURL := BuildAuthorizationURL(state, codeChallenge) - - parsed, err := url.Parse(authURL) - if err != nil { - t.Fatalf("解析 URL 失败: %v", err) - } - - // 解析后应正确还原特殊字符 - if got := parsed.Query().Get("state"); got != state { - t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state) - } -} - -// --------------------------------------------------------------------------- -// 常量值验证 -// --------------------------------------------------------------------------- - -func TestConstants_值正确(t *testing.T) { - if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" { - t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL) - } - if TokenURL != "https://oauth2.googleapis.com/token" { - t.Errorf("TokenURL 不匹配: got %s", TokenURL) - } - if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { - t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL) - } - if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" { - t.Errorf("ClientID 不匹配: got %s", ClientID) - } - secret, err := getClientSecret() - if err != nil { - t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) - } - if secret != "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" { - t.Errorf("默认 client_secret 不匹配: got %s", secret) - } - if RedirectURI != "http://localhost:8085/callback" { - t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) - } - if GetUserAgent() != "antigravity/1.20.5 windows/amd64" { - t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) - } - if SessionTTL != 30*time.Minute { - t.Errorf("SessionTTL 不匹配: got %v", SessionTTL) - } - if URLAvailabilityTTL != 5*time.Minute { - t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL) - } -} - -func TestScopes_包含必要范围(t *testing.T) { - expectedScopes := []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - "https://www.googleapis.com/auth/cclog", - "https://www.googleapis.com/auth/experimentsandconfigs", - } - - for _, scope := range expectedScopes { - if !strings.Contains(Scopes, scope) { - t.Errorf("Scopes 缺少 %s", scope) - } - } +func TestConstants(t *testing.T) { + assert.NotEmpty(t, ClientID) + assert.NotEmpty(t, RedirectURI) + assert.NotEmpty(t, AuthorizeURL) + assert.NotEmpty(t, TokenURL) + assert.NotEmpty(t, Scopes) + assert.Equal(t, 30*time.Minute, SessionTTL) + assert.Equal(t, 5*time.Minute, URLAvailabilityTTL) }