Compare commits
2 Commits
abcbc4e58d
...
dfca5e2272
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfca5e2272 | ||
|
|
65309b95e7 |
@@ -113,3 +113,121 @@ func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
|
||||
}
|
||||
|
||||
func TestBuildClientKey(t *testing.T) {
|
||||
opts1 := Options{
|
||||
ProxyURL: "http://proxy:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
InsecureSkipVerify: false,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: false,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 0,
|
||||
}
|
||||
|
||||
key1 := buildClientKey(opts1)
|
||||
require.NotEmpty(t, key1)
|
||||
|
||||
// Same options should produce same key
|
||||
key2 := buildClientKey(opts1)
|
||||
require.Equal(t, key1, key2)
|
||||
|
||||
// Different options should produce different key
|
||||
opts2 := opts1
|
||||
opts2.Timeout = 60 * time.Second
|
||||
key3 := buildClientKey(opts2)
|
||||
require.NotEqual(t, key1, key3)
|
||||
}
|
||||
|
||||
func TestBuildClientKeyTrimsSpaces(t *testing.T) {
|
||||
opts1 := Options{ProxyURL: "http://proxy:8080"}
|
||||
opts2 := Options{ProxyURL: " http://proxy:8080 "}
|
||||
|
||||
key1 := buildClientKey(opts1)
|
||||
key2 := buildClientKey(opts2)
|
||||
|
||||
require.Equal(t, key1, key2)
|
||||
}
|
||||
|
||||
func TestIsValidatedHost(t *testing.T) {
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
})
|
||||
|
||||
transport := newValidatedTransport(base)
|
||||
now := time.Unix(1730000000, 0)
|
||||
transport.now = func() time.Time { return now }
|
||||
|
||||
host := "example.com"
|
||||
transport.validatedHosts.Store(host, now.Add(validatedHostTTL))
|
||||
|
||||
require.True(t, transport.isValidatedHost(host, now))
|
||||
require.False(t, transport.isValidatedHost(host, now.Add(validatedHostTTL+1)))
|
||||
require.False(t, transport.isValidatedHost("other.com", now))
|
||||
}
|
||||
|
||||
func TestIsValidatedHostNilTransport(t *testing.T) {
|
||||
var transport *validatedTransport
|
||||
now := time.Now()
|
||||
require.False(t, transport.isValidatedHost("example.com", now))
|
||||
}
|
||||
|
||||
func TestNewValidatedTransport(t *testing.T) {
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
})
|
||||
|
||||
transport := newValidatedTransport(base)
|
||||
require.NotNil(t, transport)
|
||||
require.NotNil(t, transport.base)
|
||||
require.NotNil(t, transport.now)
|
||||
}
|
||||
|
||||
func TestBuildClient(t *testing.T) {
|
||||
t.Run("valid options", func(t *testing.T) {
|
||||
opts := Options{
|
||||
Timeout: 30 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
}
|
||||
|
||||
client, err := buildClient(opts)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
})
|
||||
|
||||
t.Run("insecure skip verify not allowed", func(t *testing.T) {
|
||||
opts := Options{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
|
||||
_, err := buildClient(opts)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "insecure_skip_verify is not allowed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildTransport(t *testing.T) {
|
||||
t.Run("default values", func(t *testing.T) {
|
||||
opts := Options{}
|
||||
transport, err := buildTransport(opts)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, transport)
|
||||
require.Equal(t, defaultMaxIdleConns, transport.MaxIdleConns)
|
||||
require.Equal(t, defaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
|
||||
})
|
||||
|
||||
t.Run("custom values", func(t *testing.T) {
|
||||
opts := Options{
|
||||
MaxIdleConns: 50,
|
||||
MaxIdleConnsPerHost: 5,
|
||||
}
|
||||
transport, err := buildTransport(opts)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, transport)
|
||||
require.Equal(t, 50, transport.MaxIdleConns)
|
||||
require.Equal(t, 5, transport.MaxIdleConnsPerHost)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,43 +1,166 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
func TestGenerateRandomBytes(t *testing.T) {
|
||||
t.Run("generates requested length", func(t *testing.T) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 32, len(bytes))
|
||||
})
|
||||
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
t.Run("generates different bytes each time", func(t *testing.T) {
|
||||
bytes1, _ := GenerateRandomBytes(16)
|
||||
bytes2, _ := GenerateRandomBytes(16)
|
||||
assert.NotEqual(t, bytes1, bytes2)
|
||||
})
|
||||
}
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
func TestGenerateState(t *testing.T) {
|
||||
t.Run("generates non-empty state", func(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, state)
|
||||
})
|
||||
|
||||
t.Run("generates unique states", func(t *testing.T) {
|
||||
state1, _ := GenerateState()
|
||||
state2, _ := GenerateState()
|
||||
assert.NotEqual(t, state1, state2)
|
||||
})
|
||||
|
||||
t.Run("generates URL-safe base64", func(t *testing.T) {
|
||||
state, _ := GenerateState()
|
||||
// Should not contain padding
|
||||
assert.NotContains(t, state, "=")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateSessionID(t *testing.T) {
|
||||
t.Run("generates hex string", func(t *testing.T) {
|
||||
sessionID, err := GenerateSessionID()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, sessionID)
|
||||
// Should be 32 hex chars (16 bytes * 2)
|
||||
assert.Equal(t, 32, len(sessionID))
|
||||
})
|
||||
|
||||
t.Run("generates unique IDs", func(t *testing.T) {
|
||||
id1, _ := GenerateSessionID()
|
||||
id2, _ := GenerateSessionID()
|
||||
assert.NotEqual(t, id1, id2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateCodeVerifier(t *testing.T) {
|
||||
t.Run("generates verifier", func(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, verifier)
|
||||
})
|
||||
|
||||
t.Run("generates unique verifiers", func(t *testing.T) {
|
||||
v1, _ := GenerateCodeVerifier()
|
||||
v2, _ := GenerateCodeVerifier()
|
||||
assert.NotEqual(t, v1, v2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier string
|
||||
}{
|
||||
{"simple verifier", "test_verifier_123"},
|
||||
{"empty string", ""},
|
||||
{"long verifier", "a_very_long_verifier_string_for_testing_purposes_only"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge(tt.verifier)
|
||||
assert.NotEmpty(t, challenge)
|
||||
assert.NotContains(t, challenge, "=") // No padding
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("deterministic for same input", func(t *testing.T) {
|
||||
verifier := "test_verifier"
|
||||
c1 := GenerateCodeChallenge(verifier)
|
||||
c2 := GenerateCodeChallenge(verifier)
|
||||
assert.Equal(t, c1, c2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBase64URLEncode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []byte
|
||||
expected string
|
||||
}{
|
||||
{[]byte("hello"), "aGVsbG8"},
|
||||
{[]byte("test+123"), "dGVzdCsxMjM"},
|
||||
{[]byte(""), ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.input), func(t *testing.T) {
|
||||
result := base64URLEncode(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
assert.NotContains(t, result, "=")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
func TestBuildAuthorizationURL(t *testing.T) {
|
||||
url := BuildAuthorizationURL("test_state", "test_challenge", ScopeOAuth)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 50 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
store.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
assert.Contains(t, url, AuthorizeURL)
|
||||
assert.Contains(t, url, "client_id="+ClientID)
|
||||
assert.Contains(t, url, "state=test_state")
|
||||
assert.Contains(t, url, "code_challenge=test_challenge")
|
||||
assert.Contains(t, url, "code_challenge_method=S256")
|
||||
assert.Contains(t, url, "response_type=code")
|
||||
}
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
assert.NotEmpty(t, ClientID)
|
||||
assert.NotEmpty(t, AuthorizeURL)
|
||||
assert.NotEmpty(t, TokenURL)
|
||||
assert.NotEmpty(t, RedirectURI)
|
||||
assert.NotEmpty(t, ScopeOAuth)
|
||||
assert.NotEmpty(t, ScopeAPI)
|
||||
assert.NotEmpty(t, ScopeInference)
|
||||
}
|
||||
|
||||
func TestTokenResponse(t *testing.T) {
|
||||
resp := TokenResponse{
|
||||
AccessToken: "token123",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "refresh456",
|
||||
Scope: "user:profile",
|
||||
}
|
||||
|
||||
assert.Equal(t, "token123", resp.AccessToken)
|
||||
assert.Equal(t, "Bearer", resp.TokenType)
|
||||
assert.Equal(t, int64(3600), resp.ExpiresIn)
|
||||
}
|
||||
|
||||
func TestOrgInfo(t *testing.T) {
|
||||
org := OrgInfo{UUID: "org-123"}
|
||||
assert.Equal(t, "org-123", org.UUID)
|
||||
}
|
||||
|
||||
func TestAccountInfo(t *testing.T) {
|
||||
account := AccountInfo{
|
||||
UUID: "acc-456",
|
||||
EmailAddress: "test@example.com",
|
||||
}
|
||||
assert.Equal(t, "acc-456", account.UUID)
|
||||
assert.Equal(t, "test@example.com", account.EmailAddress)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user