feat: sync lijiaoqiao implementation and staging validation artifacts
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/platform-token-runtime/internal/auth/service"
|
||||
)
|
||||
|
||||
var disallowedQueryKeys = []string{"key", "api_key", "token"}
|
||||
|
||||
func QueryKeyRejectMiddleware(next http.Handler, auditor service.AuditEmitter, now func() time.Time) http.Handler {
|
||||
if next == nil {
|
||||
next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
}
|
||||
if now == nil {
|
||||
now = defaultNowFunc
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, exists := externalQueryKey(r)
|
||||
if !exists {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := ensureRequestID(r, now)
|
||||
emitAuditEvent(r.Context(), auditor, service.AuditEvent{
|
||||
EventName: service.EventTokenQueryKeyRejected,
|
||||
RequestID: requestID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: service.CodeQueryKeyNotAllowed,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, service.CodeQueryKeyNotAllowed, "query key ingress is not allowed")
|
||||
})
|
||||
}
|
||||
|
||||
func externalQueryKey(r *http.Request) (string, bool) {
|
||||
values := r.URL.Query()
|
||||
for key := range values {
|
||||
lowered := strings.ToLower(key)
|
||||
for _, disallowed := range disallowedQueryKeys {
|
||||
if lowered == disallowed {
|
||||
return key, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/platform-token-runtime/internal/auth/model"
|
||||
"lijiaoqiao/platform-token-runtime/internal/auth/service"
|
||||
)
|
||||
|
||||
const requestIDHeader = "X-Request-Id"
|
||||
|
||||
var defaultNowFunc = time.Now
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
requestIDKey contextKey = "request_id"
|
||||
principalKey contextKey = "principal"
|
||||
)
|
||||
|
||||
type AuthMiddlewareConfig struct {
|
||||
Verifier service.TokenVerifier
|
||||
StatusResolver service.TokenStatusResolver
|
||||
Authorizer service.RouteAuthorizer
|
||||
Auditor service.AuditEmitter
|
||||
ProtectedPrefixes []string
|
||||
ExcludedPrefixes []string
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
func BuildTokenAuthChain(cfg AuthMiddlewareConfig, next http.Handler) http.Handler {
|
||||
handler := TokenAuthMiddleware(cfg)(next)
|
||||
handler = QueryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now)
|
||||
handler = RequestIDMiddleware(handler, cfg.Now)
|
||||
return handler
|
||||
}
|
||||
|
||||
func RequestIDMiddleware(next http.Handler, now func() time.Time) http.Handler {
|
||||
if next == nil {
|
||||
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
}
|
||||
if now == nil {
|
||||
now = defaultNowFunc
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := ensureRequestID(r, now)
|
||||
w.Header().Set(requestIDHeader, requestID)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func TokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handler {
|
||||
cfg = cfg.withDefaults()
|
||||
return func(next http.Handler) http.Handler {
|
||||
if next == nil {
|
||||
next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !cfg.shouldProtect(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := ensureRequestID(r, cfg.Now)
|
||||
if cfg.Verifier == nil || cfg.StatusResolver == nil || cfg.Authorizer == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, requestID, service.CodeAuthNotReady, "auth middleware dependencies are not ready")
|
||||
return
|
||||
}
|
||||
|
||||
rawToken, ok := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if !ok {
|
||||
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
|
||||
EventName: service.EventTokenAuthnFail,
|
||||
RequestID: requestID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: service.CodeAuthMissingBearer,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, service.CodeAuthMissingBearer, "missing bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := cfg.Verifier.Verify(r.Context(), rawToken)
|
||||
if err != nil {
|
||||
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
|
||||
EventName: service.EventTokenAuthnFail,
|
||||
RequestID: requestID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: service.CodeAuthInvalidToken,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, service.CodeAuthInvalidToken, "invalid bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
tokenStatus, err := cfg.StatusResolver.Resolve(r.Context(), claims.TokenID)
|
||||
if err != nil || tokenStatus != service.TokenStatusActive {
|
||||
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
|
||||
EventName: service.EventTokenAuthnFail,
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: service.CodeAuthTokenInactive,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, service.CodeAuthTokenInactive, "token is inactive")
|
||||
return
|
||||
}
|
||||
|
||||
if !cfg.Authorizer.Authorize(r.URL.Path, r.Method, claims.Scope, claims.Role) {
|
||||
emitAuditEvent(r.Context(), cfg.Auditor, service.AuditEvent{
|
||||
EventName: service.EventTokenAuthzDenied,
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: service.CodeAuthScopeDenied,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusForbidden, requestID, service.CodeAuthScopeDenied, "scope denied")
|
||||
return
|
||||
}
|
||||
|
||||
principal := model.Principal{
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Role: claims.Role,
|
||||
Scope: append([]string(nil), claims.Scope...),
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), principalKey, principal)
|
||||
ctx = context.WithValue(ctx, requestIDKey, requestID)
|
||||
|
||||
emitAuditEvent(ctx, cfg.Auditor, service.AuditEvent{
|
||||
EventName: service.EventTokenAuthnSuccess,
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "OK",
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RequestIDFromContext(ctx context.Context) (string, bool) {
|
||||
if ctx == nil {
|
||||
return "", false
|
||||
}
|
||||
value, ok := ctx.Value(requestIDKey).(string)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func PrincipalFromContext(ctx context.Context) (model.Principal, bool) {
|
||||
if ctx == nil {
|
||||
return model.Principal{}, false
|
||||
}
|
||||
value, ok := ctx.Value(principalKey).(model.Principal)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func (cfg AuthMiddlewareConfig) withDefaults() AuthMiddlewareConfig {
|
||||
if cfg.Now == nil {
|
||||
cfg.Now = defaultNowFunc
|
||||
}
|
||||
if len(cfg.ProtectedPrefixes) == 0 {
|
||||
cfg.ProtectedPrefixes = []string{"/api/v1/supply", "/api/v1/platform"}
|
||||
}
|
||||
if len(cfg.ExcludedPrefixes) == 0 {
|
||||
cfg.ExcludedPrefixes = []string{"/healthz", "/metrics", "/readyz"}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (cfg AuthMiddlewareConfig) shouldProtect(path string) bool {
|
||||
for _, prefix := range cfg.ExcludedPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, prefix := range cfg.ProtectedPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ensureRequestID(r *http.Request, now func() time.Time) string {
|
||||
if now == nil {
|
||||
now = defaultNowFunc
|
||||
}
|
||||
if requestID, ok := RequestIDFromContext(r.Context()); ok && requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
requestID := strings.TrimSpace(r.Header.Get(requestIDHeader))
|
||||
if requestID == "" {
|
||||
requestID = fmt.Sprintf("req-%d", now().UnixNano())
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), requestIDKey, requestID)
|
||||
*r = *r.WithContext(ctx)
|
||||
return requestID
|
||||
}
|
||||
|
||||
func extractBearerToken(authHeader string) (string, bool) {
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
return "", false
|
||||
}
|
||||
token := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
return token, token != ""
|
||||
}
|
||||
|
||||
func emitAuditEvent(ctx context.Context, auditor service.AuditEmitter, event service.AuditEvent) {
|
||||
if auditor == nil {
|
||||
return
|
||||
}
|
||||
_ = auditor.Emit(ctx, event)
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Error errorPayload `json:"error"`
|
||||
}
|
||||
|
||||
type errorPayload struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, requestID, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
payload := errorResponse{
|
||||
RequestID: requestID,
|
||||
Error: errorPayload{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func extractClientIP(r *http.Request) string {
|
||||
xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
|
||||
if xForwardedFor != "" {
|
||||
parts := strings.Split(xForwardedFor, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/platform-token-runtime/internal/auth/model"
|
||||
"lijiaoqiao/platform-token-runtime/internal/auth/service"
|
||||
)
|
||||
|
||||
var fixedNow = func() time.Time {
|
||||
return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
type fakeVerifier struct {
|
||||
token service.VerifiedToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeVerifier) Verify(context.Context, string) (service.VerifiedToken, error) {
|
||||
return f.token, f.err
|
||||
}
|
||||
|
||||
type fakeStatusResolver struct {
|
||||
status service.TokenStatus
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeStatusResolver) Resolve(context.Context, string) (service.TokenStatus, error) {
|
||||
return f.status, f.err
|
||||
}
|
||||
|
||||
type fakeAuthorizer struct {
|
||||
allowed bool
|
||||
}
|
||||
|
||||
func (f *fakeAuthorizer) Authorize(string, string, []string, string) bool {
|
||||
return f.allowed
|
||||
}
|
||||
|
||||
type fakeAuditor struct {
|
||||
events []service.AuditEvent
|
||||
}
|
||||
|
||||
func (f *fakeAuditor) Emit(_ context.Context, event service.AuditEvent) error {
|
||||
f.events = append(f.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestQueryKeyRejectMiddleware(t *testing.T) {
|
||||
auditor := &fakeAuditor{}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
nextCalled = true
|
||||
})
|
||||
handler := QueryKeyRejectMiddleware(next, auditor, fixedNow)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/supply/accounts?api_key=secret", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if nextCalled {
|
||||
t.Fatalf("next handler should not be called when query key exists")
|
||||
}
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("unexpected status code: got=%d want=%d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
if got := decodeErrorCode(t, rec); got != service.CodeQueryKeyNotAllowed {
|
||||
t.Fatalf("unexpected error code: got=%s want=%s", got, service.CodeQueryKeyNotAllowed)
|
||||
}
|
||||
if len(auditor.events) != 1 {
|
||||
t.Fatalf("unexpected audit event count: got=%d want=1", len(auditor.events))
|
||||
}
|
||||
if auditor.events[0].EventName != service.EventTokenQueryKeyRejected {
|
||||
t.Fatalf("unexpected event name: got=%s want=%s", auditor.events[0].EventName, service.EventTokenQueryKeyRejected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenAuthMiddleware(t *testing.T) {
|
||||
baseToken := service.VerifiedToken{
|
||||
TokenID: "tok-001",
|
||||
SubjectID: "subject-001",
|
||||
Role: model.RoleOwner,
|
||||
Scope: []string{"supply:*"},
|
||||
IssuedAt: fixedNow(),
|
||||
ExpiresAt: fixedNow().Add(time.Hour),
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
path string
|
||||
authHeader string
|
||||
verifierErr error
|
||||
status service.TokenStatus
|
||||
statusErr error
|
||||
allowed bool
|
||||
wantStatus int
|
||||
wantErrorCode string
|
||||
wantEvent string
|
||||
wantNext bool
|
||||
}{
|
||||
{
|
||||
name: "missing bearer",
|
||||
path: "/api/v1/supply/packages",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorCode: service.CodeAuthMissingBearer,
|
||||
wantEvent: service.EventTokenAuthnFail,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
path: "/api/v1/supply/packages",
|
||||
authHeader: "Bearer invalid-token",
|
||||
verifierErr: errors.New("invalid signature"),
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorCode: service.CodeAuthInvalidToken,
|
||||
wantEvent: service.EventTokenAuthnFail,
|
||||
},
|
||||
{
|
||||
name: "inactive token",
|
||||
path: "/api/v1/supply/packages",
|
||||
authHeader: "Bearer active-token",
|
||||
status: service.TokenStatusRevoked,
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorCode: service.CodeAuthTokenInactive,
|
||||
wantEvent: service.EventTokenAuthnFail,
|
||||
},
|
||||
{
|
||||
name: "scope denied",
|
||||
path: "/api/v1/supply/packages",
|
||||
authHeader: "Bearer active-token",
|
||||
status: service.TokenStatusActive,
|
||||
allowed: false,
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantErrorCode: service.CodeAuthScopeDenied,
|
||||
wantEvent: service.EventTokenAuthzDenied,
|
||||
},
|
||||
{
|
||||
name: "authn success",
|
||||
path: "/api/v1/supply/packages",
|
||||
authHeader: "Bearer active-token",
|
||||
status: service.TokenStatusActive,
|
||||
allowed: true,
|
||||
wantStatus: http.StatusNoContent,
|
||||
wantEvent: service.EventTokenAuthnSuccess,
|
||||
wantNext: true,
|
||||
},
|
||||
{
|
||||
name: "excluded path bypasses auth",
|
||||
path: "/healthz",
|
||||
wantStatus: http.StatusNoContent,
|
||||
wantNext: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
auditor := &fakeAuditor{}
|
||||
verifier := &fakeVerifier{
|
||||
token: baseToken,
|
||||
err: tc.verifierErr,
|
||||
}
|
||||
resolver := &fakeStatusResolver{
|
||||
status: tc.status,
|
||||
err: tc.statusErr,
|
||||
}
|
||||
authorizer := &fakeAuthorizer{allowed: tc.allowed}
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
if tc.wantNext && strings.HasPrefix(tc.path, "/api/v1/") {
|
||||
principal, ok := PrincipalFromContext(r.Context())
|
||||
if !ok {
|
||||
t.Fatalf("principal should be attached when auth succeeded")
|
||||
}
|
||||
if principal.TokenID != baseToken.TokenID {
|
||||
t.Fatalf("unexpected principal token id: got=%s want=%s", principal.TokenID, baseToken.TokenID)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
handler := TokenAuthMiddleware(AuthMiddlewareConfig{
|
||||
Verifier: verifier,
|
||||
StatusResolver: resolver,
|
||||
Authorizer: authorizer,
|
||||
Auditor: auditor,
|
||||
ProtectedPrefixes: []string{"/api/v1/supply/", "/api/v1/platform/"},
|
||||
ExcludedPrefixes: []string{"/healthz"},
|
||||
Now: fixedNow,
|
||||
})(next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
if tc.authHeader != "" {
|
||||
req.Header.Set("Authorization", tc.authHeader)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tc.wantStatus {
|
||||
t.Fatalf("unexpected status code: got=%d want=%d", rec.Code, tc.wantStatus)
|
||||
}
|
||||
if tc.wantErrorCode != "" {
|
||||
if got := decodeErrorCode(t, rec); got != tc.wantErrorCode {
|
||||
t.Fatalf("unexpected error code: got=%s want=%s", got, tc.wantErrorCode)
|
||||
}
|
||||
}
|
||||
if nextCalled != tc.wantNext {
|
||||
t.Fatalf("unexpected next call state: got=%v want=%v", nextCalled, tc.wantNext)
|
||||
}
|
||||
if tc.wantEvent == "" {
|
||||
return
|
||||
}
|
||||
if len(auditor.events) == 0 {
|
||||
t.Fatalf("audit event should be emitted")
|
||||
}
|
||||
lastEvent := auditor.events[len(auditor.events)-1]
|
||||
if lastEvent.EventName != tc.wantEvent {
|
||||
t.Fatalf("unexpected event name: got=%s want=%s", lastEvent.EventName, tc.wantEvent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errorEnvelope struct {
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func decodeErrorCode(t *testing.T, rec *httptest.ResponseRecorder) string {
|
||||
t.Helper()
|
||||
var envelope errorEnvelope
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &envelope); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
return envelope.Error.Code
|
||||
}
|
||||
Reference in New Issue
Block a user