package providers import ( "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/pem" "net/url" "strings" "testing" ) func generateRSAKeyForTest(t *testing.T) *rsa.PrivateKey { t.Helper() key, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { t.Fatalf("generate rsa key failed: %v", err) } return key } func marshalPKCS8PEMForTest(t *testing.T, key *rsa.PrivateKey) string { t.Helper() der, err := x509.MarshalPKCS8PrivateKey(key) if err != nil { t.Fatalf("marshal PKCS#8 failed: %v", err) } return string(pem.EncodeToMemory(&pem.Block{ Type: "PRIVATE KEY", Bytes: der, })) } func TestParseAlipayPrivateKeySupportsRawPKCS8AndPKCS1(t *testing.T) { key := generateRSAKeyForTest(t) pkcs8DER, err := x509.MarshalPKCS8PrivateKey(key) if err != nil { t.Fatalf("marshal PKCS#8 failed: %v", err) } rawPKCS8 := base64.StdEncoding.EncodeToString(pkcs8DER) parsedPKCS8, err := parseAlipayPrivateKey(rawPKCS8) if err != nil { t.Fatalf("parse raw PKCS#8 key failed: %v", err) } if parsedPKCS8.N.Cmp(key.N) != 0 || parsedPKCS8.D.Cmp(key.D) != 0 { t.Fatal("parsed raw PKCS#8 key does not match original key") } pkcs1PEM := string(pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), })) parsedPKCS1, err := parseAlipayPrivateKey(pkcs1PEM) if err != nil { t.Fatalf("parse PKCS#1 key failed: %v", err) } if parsedPKCS1.N.Cmp(key.N) != 0 || parsedPKCS1.D.Cmp(key.D) != 0 { t.Fatal("parsed PKCS#1 key does not match original key") } } func TestParseAlipayPrivateKeyRejectsInvalidPEM(t *testing.T) { if _, err := parseAlipayPrivateKey("not-a-valid-private-key"); err == nil { t.Fatal("expected invalid private key parsing to fail") } } func TestAlipayProviderSignParamsProducesVerifiableSignature(t *testing.T) { key := generateRSAKeyForTest(t) provider := NewAlipayProvider( "app-id", marshalPKCS8PEMForTest(t, key), "https://admin.example.com/login/oauth/callback", false, ) params := map[string]string{ "method": "alipay.system.oauth.token", "app_id": "app-id", "code": "auth-code", "sign": "should-be-ignored", } signature, err := provider.signParams(params) if err != nil { t.Fatalf("signParams failed: %v", err) } if signature == "" { t.Fatal("expected non-empty signature") } signatureBytes, err := base64.StdEncoding.DecodeString(signature) if err != nil { t.Fatalf("decode signature failed: %v", err) } signContent := "app_id=app-id&code=auth-code&method=alipay.system.oauth.token" hash := sha256.Sum256([]byte(signContent)) if err := rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, hash[:], signatureBytes); err != nil { t.Fatalf("signature verification failed: %v", err) } } func TestTwitterProviderPKCEHelpersAndAuthURL(t *testing.T) { provider := NewTwitterProvider("twitter-client", "https://admin.example.com/login/oauth/callback") verifierA, err := provider.GenerateCodeVerifier() if err != nil { t.Fatalf("GenerateCodeVerifier(first) failed: %v", err) } verifierB, err := provider.GenerateCodeVerifier() if err != nil { t.Fatalf("GenerateCodeVerifier(second) failed: %v", err) } if verifierA == "" || verifierB == "" { t.Fatal("expected non-empty code verifiers") } if verifierA == verifierB { t.Fatal("expected code verifiers to differ across calls") } if strings.Contains(verifierA, "=") || strings.Contains(verifierB, "=") { t.Fatal("expected code verifiers to be base64url values without padding") } if provider.GenerateCodeChallenge(verifierA) != verifierA { t.Fatal("expected current code challenge implementation to mirror the verifier") } authURL, err := provider.GetAuthURL() if err != nil { t.Fatalf("GetAuthURL failed: %v", err) } if authURL.CodeVerifier == "" || authURL.State == "" { t.Fatal("expected auth url response to include verifier and state") } if authURL.Redirect != provider.RedirectURI { t.Fatalf("expected redirect %q, got %q", provider.RedirectURI, authURL.Redirect) } parsed, err := url.Parse(authURL.URL) if err != nil { t.Fatalf("parse auth url failed: %v", err) } query := parsed.Query() if query.Get("client_id") != "twitter-client" { t.Fatalf("expected twitter client_id, got %q", query.Get("client_id")) } if query.Get("redirect_uri") != provider.RedirectURI { t.Fatalf("expected redirect_uri %q, got %q", provider.RedirectURI, query.Get("redirect_uri")) } if query.Get("code_challenge") != authURL.CodeVerifier { t.Fatalf("expected code challenge to equal verifier, got %q", query.Get("code_challenge")) } if query.Get("code_challenge_method") != "plain" { t.Fatalf("expected code_challenge_method plain, got %q", query.Get("code_challenge_method")) } if query.Get("state") != authURL.State { t.Fatalf("expected state %q, got %q", authURL.State, query.Get("state")) } }