857 lines
23 KiB
Go
857 lines
23 KiB
Go
|
|
package middleware
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"net/http"
|
|||
|
|
"net/http/httptest"
|
|||
|
|
"strings"
|
|||
|
|
"testing"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
func TestExtractBearerToken(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
authHeader string
|
|||
|
|
wantToken string
|
|||
|
|
wantOK bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "valid bearer token",
|
|||
|
|
authHeader: "Bearer test-token-123",
|
|||
|
|
wantToken: "test-token-123",
|
|||
|
|
wantOK: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "valid bearer token with extra spaces",
|
|||
|
|
authHeader: "Bearer test-token-456 ",
|
|||
|
|
wantToken: "test-token-456",
|
|||
|
|
wantOK: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "missing bearer prefix",
|
|||
|
|
authHeader: "test-token-123",
|
|||
|
|
wantToken: "",
|
|||
|
|
wantOK: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "empty bearer token",
|
|||
|
|
authHeader: "Bearer ",
|
|||
|
|
wantToken: "",
|
|||
|
|
wantOK: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "empty header",
|
|||
|
|
authHeader: "",
|
|||
|
|
wantToken: "",
|
|||
|
|
wantOK: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "case sensitive bearer",
|
|||
|
|
authHeader: "bearer test-token",
|
|||
|
|
wantToken: "",
|
|||
|
|
wantOK: false,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
token, ok := extractBearerToken(tt.authHeader)
|
|||
|
|
if token != tt.wantToken {
|
|||
|
|
t.Errorf("extractBearerToken() token = %v, want %v", token, tt.wantToken)
|
|||
|
|
}
|
|||
|
|
if ok != tt.wantOK {
|
|||
|
|
t.Errorf("extractBearerToken() ok = %v, want %v", ok, tt.wantOK)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestHasExternalQueryKey(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
query string
|
|||
|
|
want bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "has key param",
|
|||
|
|
query: "?key=abc123",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "has api_key param",
|
|||
|
|
query: "?api_key=abc123",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "has token param",
|
|||
|
|
query: "?token=abc123",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "has access_token param",
|
|||
|
|
query: "?access_token=abc123",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "has other param",
|
|||
|
|
query: "?name=test&value=123",
|
|||
|
|
want: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "no params",
|
|||
|
|
query: "",
|
|||
|
|
want: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "case insensitive key",
|
|||
|
|
query: "?KEY=abc123",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
req := httptest.NewRequest("GET", "/test"+tt.query, nil)
|
|||
|
|
if got := hasExternalQueryKey(req); got != tt.want {
|
|||
|
|
t.Errorf("hasExternalQueryKey() = %v, want %v", got, tt.want)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestRequestIDMiddleware(t *testing.T) {
|
|||
|
|
t.Run("generates request ID when not present", func(t *testing.T) {
|
|||
|
|
var capturedReqID string
|
|||
|
|
handler := requestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
capturedReqID, _ = RequestIDFromContext(r.Context())
|
|||
|
|
}), time.Now)
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if capturedReqID == "" {
|
|||
|
|
t.Error("expected request ID to be set in context")
|
|||
|
|
}
|
|||
|
|
if rr.Header().Get("X-Request-Id") == "" {
|
|||
|
|
t.Error("expected X-Request-Id header to be set in response")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("uses existing request ID from header", func(t *testing.T) {
|
|||
|
|
existingID := "existing-req-id-123"
|
|||
|
|
var capturedID string
|
|||
|
|
handler := requestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
capturedID = r.Header.Get("X-Request-Id")
|
|||
|
|
}), time.Now)
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|||
|
|
req.Header.Set("X-Request-Id", existingID)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if capturedID != existingID {
|
|||
|
|
t.Errorf("expected request ID %q, got %q", existingID, capturedID)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("nil next handler does not panic", func(t *testing.T) {
|
|||
|
|
defer func() {
|
|||
|
|
if r := recover(); r != nil {
|
|||
|
|
t.Errorf("panic with nil next handler: %v", r)
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
handler := requestIDMiddleware(nil, time.Now)
|
|||
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestQueryKeyRejectMiddleware(t *testing.T) {
|
|||
|
|
t.Run("rejects request with query key", func(t *testing.T) {
|
|||
|
|
auditCalled := false
|
|||
|
|
auditor := &mockAuditEmitter{
|
|||
|
|
onEmit: func(ctx context.Context, event AuditEvent) error {
|
|||
|
|
auditCalled = true
|
|||
|
|
if event.EventName != EventTokenQueryKeyRejected {
|
|||
|
|
t.Errorf("expected event %s, got %s", EventTokenQueryKeyRejected, event.EventName)
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
handler := queryKeyRejectMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
t.Error("next handler should not be called")
|
|||
|
|
}), auditor, time.Now)
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply?key=abc123", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if !auditCalled {
|
|||
|
|
t.Error("expected audit to be called")
|
|||
|
|
}
|
|||
|
|
if rr.Code != http.StatusUnauthorized {
|
|||
|
|
t.Errorf("expected status 401, got %d", rr.Code)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("allows request without query key", func(t *testing.T) {
|
|||
|
|
nextCalled := false
|
|||
|
|
handler := queryKeyRejectMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
nextCalled = true
|
|||
|
|
}), nil, time.Now)
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply?name=test", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if !nextCalled {
|
|||
|
|
t.Error("expected next handler to be called")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("rejects api_key parameter", func(t *testing.T) {
|
|||
|
|
handler := queryKeyRejectMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
t.Error("next handler should not be called")
|
|||
|
|
}), nil, time.Now)
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply?api_key=secret", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if rr.Code != http.StatusUnauthorized {
|
|||
|
|
t.Errorf("expected status 401, got %d", rr.Code)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestTokenAuthMiddleware(t *testing.T) {
|
|||
|
|
t.Run("allows request when all checks pass", func(t *testing.T) {
|
|||
|
|
now := time.Now()
|
|||
|
|
tokenRuntime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
|||
|
|
|
|||
|
|
// Issue a valid token
|
|||
|
|
token, err := tokenRuntime.Issue(context.Background(), "user1", "admin", []string{"supply:read", "supply:write"}, time.Hour)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to issue token: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
cfg := AuthMiddlewareConfig{
|
|||
|
|
Verifier: tokenRuntime,
|
|||
|
|
StatusResolver: tokenRuntime,
|
|||
|
|
Authorizer: NewScopeRoleAuthorizer(),
|
|||
|
|
ProtectedPrefixes: []string{"/api/v1/supply"},
|
|||
|
|
ExcludedPrefixes: []string{"/health"},
|
|||
|
|
Now: func() time.Time { return now },
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
nextCalled := false
|
|||
|
|
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
nextCalled = true
|
|||
|
|
// Verify principal is set in context
|
|||
|
|
principal, ok := PrincipalFromContext(r.Context())
|
|||
|
|
if !ok {
|
|||
|
|
t.Error("expected principal in context")
|
|||
|
|
}
|
|||
|
|
if principal.SubjectID != "user1" {
|
|||
|
|
t.Errorf("expected subject user1, got %s", principal.SubjectID)
|
|||
|
|
}
|
|||
|
|
}))
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
|||
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if !nextCalled {
|
|||
|
|
t.Error("expected next handler to be called")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("rejects request without bearer token", func(t *testing.T) {
|
|||
|
|
cfg := AuthMiddlewareConfig{
|
|||
|
|
Verifier: &mockVerifier{},
|
|||
|
|
StatusResolver: &mockStatusResolver{},
|
|||
|
|
Authorizer: NewScopeRoleAuthorizer(),
|
|||
|
|
ProtectedPrefixes: []string{"/api/v1/supply"},
|
|||
|
|
Now: time.Now,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
t.Error("next handler should not be called")
|
|||
|
|
}))
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if rr.Code != http.StatusUnauthorized {
|
|||
|
|
t.Errorf("expected status 401, got %d", rr.Code)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("rejects request to excluded path", func(t *testing.T) {
|
|||
|
|
cfg := AuthMiddlewareConfig{
|
|||
|
|
Verifier: &mockVerifier{},
|
|||
|
|
StatusResolver: &mockStatusResolver{},
|
|||
|
|
Authorizer: NewScopeRoleAuthorizer(),
|
|||
|
|
ProtectedPrefixes: []string{"/api/v1/supply"},
|
|||
|
|
ExcludedPrefixes: []string{"/health"},
|
|||
|
|
Now: time.Now,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
nextCalled := false
|
|||
|
|
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
nextCalled = true
|
|||
|
|
}))
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/health", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if !nextCalled {
|
|||
|
|
t.Error("expected next handler to be called for excluded path")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("returns 503 when dependencies not ready", func(t *testing.T) {
|
|||
|
|
cfg := AuthMiddlewareConfig{
|
|||
|
|
Verifier: nil,
|
|||
|
|
StatusResolver: nil,
|
|||
|
|
Authorizer: nil,
|
|||
|
|
ProtectedPrefixes: []string{"/api/v1/supply"},
|
|||
|
|
Now: time.Now,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
t.Error("next handler should not be called")
|
|||
|
|
}))
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
|||
|
|
req.Header.Set("Authorization", "Bearer test-token")
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if rr.Code != http.StatusServiceUnavailable {
|
|||
|
|
t.Errorf("expected status 503, got %d", rr.Code)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestScopeRoleAuthorizer(t *testing.T) {
|
|||
|
|
authorizer := NewScopeRoleAuthorizer()
|
|||
|
|
|
|||
|
|
t.Run("admin role has access to all", func(t *testing.T) {
|
|||
|
|
if !authorizer.Authorize("/api/v1/supply", "POST", []string{}, "admin") {
|
|||
|
|
t.Error("expected admin to have access")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("supply read scope for GET", func(t *testing.T) {
|
|||
|
|
if !authorizer.Authorize("/api/v1/supply", "GET", []string{"supply:read"}, "user") {
|
|||
|
|
t.Error("expected supply:read to have access to GET")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("supply write scope for POST", func(t *testing.T) {
|
|||
|
|
if !authorizer.Authorize("/api/v1/supply", "POST", []string{"supply:write"}, "user") {
|
|||
|
|
t.Error("expected supply:write to have access to POST")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("supply:read scope is denied for POST", func(t *testing.T) {
|
|||
|
|
// supply:read only allows GET, POST should be denied
|
|||
|
|
if authorizer.Authorize("/api/v1/supply", "POST", []string{"supply:read"}, "user") {
|
|||
|
|
t.Error("expected supply:read to be denied for POST")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("wildcard scope works", func(t *testing.T) {
|
|||
|
|
if !authorizer.Authorize("/api/v1/supply", "POST", []string{"supply:*"}, "user") {
|
|||
|
|
t.Error("expected supply:* to have access")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("platform admin scope", func(t *testing.T) {
|
|||
|
|
if !authorizer.Authorize("/api/v1/platform/users", "GET", []string{"platform:admin"}, "user") {
|
|||
|
|
t.Error("expected platform:admin to have access")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestInMemoryTokenRuntime(t *testing.T) {
|
|||
|
|
now := time.Now()
|
|||
|
|
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
|||
|
|
|
|||
|
|
t.Run("issue and verify token", func(t *testing.T) {
|
|||
|
|
token, err := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to issue token: %v", err)
|
|||
|
|
}
|
|||
|
|
if token == "" {
|
|||
|
|
t.Error("expected non-empty token")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
claims, err := runtime.Verify(context.Background(), token)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to verify token: %v", err)
|
|||
|
|
}
|
|||
|
|
if claims.SubjectID != "user1" {
|
|||
|
|
t.Errorf("expected subject user1, got %s", claims.SubjectID)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("resolve token status", func(t *testing.T) {
|
|||
|
|
token, err := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to issue token: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Get token ID first
|
|||
|
|
claims, _ := runtime.Verify(context.Background(), token)
|
|||
|
|
|
|||
|
|
status, err := runtime.Resolve(context.Background(), claims.TokenID)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to resolve status: %v", err)
|
|||
|
|
}
|
|||
|
|
if status != TokenStatusActive {
|
|||
|
|
t.Errorf("expected status active, got %s", status)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("revoke token", func(t *testing.T) {
|
|||
|
|
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
|||
|
|
claims, _ := runtime.Verify(context.Background(), token)
|
|||
|
|
|
|||
|
|
err := runtime.Revoke(context.Background(), claims.TokenID)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to revoke token: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
status, _ := runtime.Resolve(context.Background(), claims.TokenID)
|
|||
|
|
if status != TokenStatusRevoked {
|
|||
|
|
t.Errorf("expected status revoked, got %s", status)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("verify invalid token", func(t *testing.T) {
|
|||
|
|
_, err := runtime.Verify(context.Background(), "invalid-token")
|
|||
|
|
if err == nil {
|
|||
|
|
t.Error("expected error for invalid token")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestBuildTokenAuthChain(t *testing.T) {
|
|||
|
|
now := time.Now()
|
|||
|
|
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
|||
|
|
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read", "supply:write"}, time.Hour)
|
|||
|
|
|
|||
|
|
cfg := AuthMiddlewareConfig{
|
|||
|
|
Verifier: runtime,
|
|||
|
|
StatusResolver: runtime,
|
|||
|
|
Authorizer: NewScopeRoleAuthorizer(),
|
|||
|
|
ProtectedPrefixes: []string{"/api/v1/supply", "/api/v1/platform"},
|
|||
|
|
ExcludedPrefixes: []string{"/health", "/healthz"},
|
|||
|
|
Now: func() time.Time { return now },
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
t.Run("full chain with valid token", func(t *testing.T) {
|
|||
|
|
nextCalled := false
|
|||
|
|
handler := BuildTokenAuthChain(cfg, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
nextCalled = true
|
|||
|
|
}))
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
|||
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|||
|
|
recorder := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(recorder, req)
|
|||
|
|
|
|||
|
|
if !nextCalled {
|
|||
|
|
t.Error("expected chain to complete successfully")
|
|||
|
|
}
|
|||
|
|
if recorder.Header().Get("X-Request-Id") == "" {
|
|||
|
|
t.Error("expected X-Request-Id header to be set by chain")
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
t.Run("full chain rejects query key", func(t *testing.T) {
|
|||
|
|
handler := BuildTokenAuthChain(cfg, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
t.Error("next handler should not be called")
|
|||
|
|
}))
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/api/v1/supply?key=blocked", nil)
|
|||
|
|
rr := httptest.NewRecorder()
|
|||
|
|
handler.ServeHTTP(rr, req)
|
|||
|
|
|
|||
|
|
if rr.Code != http.StatusUnauthorized {
|
|||
|
|
t.Errorf("expected status 401, got %d", rr.Code)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Mock implementations
|
|||
|
|
type mockVerifier struct{}
|
|||
|
|
|
|||
|
|
func (m *mockVerifier) Verify(ctx context.Context, rawToken string) (VerifiedToken, error) {
|
|||
|
|
return VerifiedToken{}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type mockStatusResolver struct{}
|
|||
|
|
|
|||
|
|
func (m *mockStatusResolver) Resolve(ctx context.Context, tokenID string) (TokenStatus, error) {
|
|||
|
|
return TokenStatusActive, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type mockAuditEmitter struct {
|
|||
|
|
onEmit func(ctx context.Context, event AuditEvent) error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (m *mockAuditEmitter) Emit(ctx context.Context, event AuditEvent) error {
|
|||
|
|
if m.onEmit != nil {
|
|||
|
|
return m.onEmit(ctx, event)
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestHasScope(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
scopes []string
|
|||
|
|
required string
|
|||
|
|
want bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "exact match",
|
|||
|
|
scopes: []string{"supply:read", "supply:write"},
|
|||
|
|
required: "supply:read",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "no match",
|
|||
|
|
scopes: []string{"supply:read"},
|
|||
|
|
required: "supply:write",
|
|||
|
|
want: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "wildcard match",
|
|||
|
|
scopes: []string{"supply:*"},
|
|||
|
|
required: "supply:read",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "wildcard match write",
|
|||
|
|
scopes: []string{"supply:*"},
|
|||
|
|
required: "supply:write",
|
|||
|
|
want: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "empty scopes",
|
|||
|
|
scopes: []string{},
|
|||
|
|
required: "supply:read",
|
|||
|
|
want: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "partial wildcard no match",
|
|||
|
|
scopes: []string{"supply:read"},
|
|||
|
|
required: "platform:admin",
|
|||
|
|
want: false,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
got := hasScope(tt.scopes, tt.required)
|
|||
|
|
if got != tt.want {
|
|||
|
|
t.Errorf("hasScope(%v, %s) = %v, want %v", tt.scopes, tt.required, got, tt.want)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestRequiredScopeForRoute(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
path string
|
|||
|
|
method string
|
|||
|
|
want string
|
|||
|
|
}{
|
|||
|
|
{"/api/v1/supply", "GET", "supply:read"},
|
|||
|
|
{"/api/v1/supply", "HEAD", "supply:read"},
|
|||
|
|
{"/api/v1/supply", "OPTIONS", "supply:read"},
|
|||
|
|
{"/api/v1/supply", "POST", "supply:write"},
|
|||
|
|
{"/api/v1/supply", "PUT", "supply:write"},
|
|||
|
|
{"/api/v1/supply", "DELETE", "supply:write"},
|
|||
|
|
{"/api/v1/supply/", "GET", "supply:read"},
|
|||
|
|
{"/api/v1/supply/123", "GET", "supply:read"},
|
|||
|
|
{"/api/v1/platform", "GET", "platform:admin"},
|
|||
|
|
{"/api/v1/platform", "POST", "platform:admin"},
|
|||
|
|
{"/api/v1/platform/", "DELETE", "platform:admin"},
|
|||
|
|
{"/api/v1/platform/users", "GET", "platform:admin"},
|
|||
|
|
{"/unknown", "GET", ""},
|
|||
|
|
{"/api/v1/other", "GET", ""},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.path+"_"+tt.method, func(t *testing.T) {
|
|||
|
|
got := requiredScopeForRoute(tt.path, tt.method)
|
|||
|
|
if got != tt.want {
|
|||
|
|
t.Errorf("requiredScopeForRoute(%s, %s) = %s, want %s", tt.path, tt.method, got, tt.want)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestGenerateAccessToken(t *testing.T) {
|
|||
|
|
token, err := generateAccessToken()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("generateAccessToken() error = %v", err)
|
|||
|
|
}
|
|||
|
|
if !strings.HasPrefix(token, "ptk_") {
|
|||
|
|
t.Errorf("expected token to start with ptk_, got %s", token)
|
|||
|
|
}
|
|||
|
|
if len(token) < 10 {
|
|||
|
|
t.Errorf("expected token length >= 10, got %d", len(token))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 生成多个token应该不同
|
|||
|
|
token2, _ := generateAccessToken()
|
|||
|
|
if token == token2 {
|
|||
|
|
t.Error("expected different tokens")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestGenerateTokenID(t *testing.T) {
|
|||
|
|
tokenID, err := generateTokenID()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("generateTokenID() error = %v", err)
|
|||
|
|
}
|
|||
|
|
if !strings.HasPrefix(tokenID, "tok_") {
|
|||
|
|
t.Errorf("expected token ID to start with tok_, got %s", tokenID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tokenID2, _ := generateTokenID()
|
|||
|
|
if tokenID == tokenID2 {
|
|||
|
|
t.Error("expected different token IDs")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestGenerateEventID(t *testing.T) {
|
|||
|
|
eventID, err := generateEventID()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("generateEventID() error = %v", err)
|
|||
|
|
}
|
|||
|
|
if !strings.HasPrefix(eventID, "evt_") {
|
|||
|
|
t.Errorf("expected event ID to start with evt_, got %s", eventID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
eventID2, _ := generateEventID()
|
|||
|
|
if eventID == eventID2 {
|
|||
|
|
t.Error("expected different event IDs")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestNullString(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
input string
|
|||
|
|
wantStr string
|
|||
|
|
wantValid bool
|
|||
|
|
}{
|
|||
|
|
{"hello", "hello", true},
|
|||
|
|
{"", "", false},
|
|||
|
|
{"world", "world", true},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
got := nullString(tt.input)
|
|||
|
|
if got.String != tt.wantStr {
|
|||
|
|
t.Errorf("nullString(%q).String = %q, want %q", tt.input, got.String, tt.wantStr)
|
|||
|
|
}
|
|||
|
|
if got.Valid != tt.wantValid {
|
|||
|
|
t.Errorf("nullString(%q).Valid = %v, want %v", tt.input, got.Valid, tt.wantValid)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestInMemoryTokenRuntime_Issue_Errors(t *testing.T) {
|
|||
|
|
now := time.Now()
|
|||
|
|
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
|||
|
|
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
subjectID string
|
|||
|
|
role string
|
|||
|
|
scopes []string
|
|||
|
|
ttl time.Duration
|
|||
|
|
wantErr string
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "empty subject_id",
|
|||
|
|
subjectID: "",
|
|||
|
|
role: "admin",
|
|||
|
|
scopes: []string{"supply:read"},
|
|||
|
|
ttl: time.Hour,
|
|||
|
|
wantErr: "subject_id is required",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "whitespace subject_id",
|
|||
|
|
subjectID: " ",
|
|||
|
|
role: "admin",
|
|||
|
|
scopes: []string{"supply:read"},
|
|||
|
|
ttl: time.Hour,
|
|||
|
|
wantErr: "subject_id is required",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "empty role",
|
|||
|
|
subjectID: "user1",
|
|||
|
|
role: "",
|
|||
|
|
scopes: []string{"supply:read"},
|
|||
|
|
ttl: time.Hour,
|
|||
|
|
wantErr: "role is required",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "empty scopes",
|
|||
|
|
subjectID: "user1",
|
|||
|
|
role: "admin",
|
|||
|
|
scopes: []string{},
|
|||
|
|
ttl: time.Hour,
|
|||
|
|
wantErr: "scope must not be empty",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "zero ttl",
|
|||
|
|
subjectID: "user1",
|
|||
|
|
role: "admin",
|
|||
|
|
scopes: []string{"supply:read"},
|
|||
|
|
ttl: 0,
|
|||
|
|
wantErr: "ttl must be positive",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "negative ttl",
|
|||
|
|
subjectID: "user1",
|
|||
|
|
role: "admin",
|
|||
|
|
scopes: []string{"supply:read"},
|
|||
|
|
ttl: -time.Second,
|
|||
|
|
wantErr: "ttl must be positive",
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
_, err := runtime.Issue(context.Background(), tt.subjectID, tt.role, tt.scopes, tt.ttl)
|
|||
|
|
if err == nil {
|
|||
|
|
t.Fatal("expected error")
|
|||
|
|
}
|
|||
|
|
if err.Error() != tt.wantErr {
|
|||
|
|
t.Errorf("error = %q, want %q", err.Error(), tt.wantErr)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestInMemoryTokenRuntime_Verify_Expired(t *testing.T) {
|
|||
|
|
now := time.Now()
|
|||
|
|
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
|||
|
|
|
|||
|
|
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
|||
|
|
|
|||
|
|
// 验证token仍然有效
|
|||
|
|
claims, err := runtime.Verify(context.Background(), token)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("Verify failed: %v", err)
|
|||
|
|
}
|
|||
|
|
if claims.SubjectID != "user1" {
|
|||
|
|
t.Errorf("SubjectID = %s, want user1", claims.SubjectID)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestInMemoryTokenRuntime_ApplyExpiry(t *testing.T) {
|
|||
|
|
now := time.Now()
|
|||
|
|
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
|||
|
|
|
|||
|
|
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
|||
|
|
claims, _ := runtime.Verify(context.Background(), token)
|
|||
|
|
|
|||
|
|
// 手动设置过期
|
|||
|
|
runtime.mu.Lock()
|
|||
|
|
record := runtime.records[claims.TokenID]
|
|||
|
|
record.ExpiresAt = now.Add(-time.Hour) // 1小时前过期
|
|||
|
|
runtime.mu.Unlock()
|
|||
|
|
|
|||
|
|
// Resolve应该检测到过期
|
|||
|
|
status, _ := runtime.Resolve(context.Background(), claims.TokenID)
|
|||
|
|
if status != TokenStatusExpired {
|
|||
|
|
t.Errorf("status = %s, want Expired", status)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestScopeRoleAuthorizer_Authorize(t *testing.T) {
|
|||
|
|
authorizer := NewScopeRoleAuthorizer()
|
|||
|
|
|
|||
|
|
tests := []struct {
|
|||
|
|
path string
|
|||
|
|
method string
|
|||
|
|
scopes []string
|
|||
|
|
role string
|
|||
|
|
want bool
|
|||
|
|
}{
|
|||
|
|
{"/api/v1/supply", "GET", []string{"supply:read"}, "user", true},
|
|||
|
|
{"/api/v1/supply", "POST", []string{"supply:write"}, "user", true},
|
|||
|
|
{"/api/v1/supply", "DELETE", []string{"supply:read"}, "user", false},
|
|||
|
|
{"/api/v1/supply", "GET", []string{}, "admin", true},
|
|||
|
|
{"/api/v1/supply", "POST", []string{}, "admin", true},
|
|||
|
|
{"/api/v1/other", "GET", []string{}, "user", true}, // 无需权限
|
|||
|
|
{"/api/v1/platform/users", "GET", []string{"platform:admin"}, "user", true},
|
|||
|
|
{"/api/v1/platform/users", "POST", []string{"platform:admin"}, "user", true},
|
|||
|
|
{"/api/v1/platform/users", "DELETE", []string{"supply:read"}, "user", false},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.path+"_"+tt.method, func(t *testing.T) {
|
|||
|
|
got := authorizer.Authorize(tt.path, tt.method, tt.scopes, tt.role)
|
|||
|
|
if got != tt.want {
|
|||
|
|
t.Errorf("Authorize(%s, %s, %v, %s) = %v, want %v", tt.path, tt.method, tt.scopes, tt.role, got, tt.want)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestMemoryAuditEmitter(t *testing.T) {
|
|||
|
|
emitter := NewMemoryAuditEmitter()
|
|||
|
|
|
|||
|
|
event := AuditEvent{
|
|||
|
|
EventName: EventTokenQueryKeyRejected,
|
|||
|
|
RequestID: "req-123",
|
|||
|
|
Route: "/api/v1/supply",
|
|||
|
|
ResultCode: "401",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
err := emitter.Emit(context.Background(), event)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("Emit failed: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(emitter.events) != 1 {
|
|||
|
|
t.Errorf("expected 1 event, got %d", len(emitter.events))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if emitter.events[0].EventID == "" {
|
|||
|
|
t.Error("expected EventID to be set")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestNewInMemoryTokenRuntime_NilNow(t *testing.T) {
|
|||
|
|
// 不传入now函数,应该使用默认的time.Now
|
|||
|
|
runtime := NewInMemoryTokenRuntime(nil)
|
|||
|
|
if runtime == nil {
|
|||
|
|
t.Fatal("expected non-nil runtime")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证基本功能
|
|||
|
|
_, err := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("Issue failed: %v", err)
|
|||
|
|
}
|
|||
|
|
}
|