Files
lijiaoqiao/supply-api/internal/middleware/auth_test.go
Your Name 0196ee5d47 feat(supply-api): 完成核心模块实现
新增/修改内容:
- config: 添加配置管理(config.example.yaml, config.go)
- cache: 添加Redis缓存层(redis.go)
- domain: 添加invariants不变量验证及测试
- middleware: 添加auth认证和idempotency幂等性中间件及测试
- repository: 添加完整数据访问层(account, package, settlement, idempotency, db)
- sql: 添加幂等性表DDL脚本

代码覆盖:
- auth middleware实现凭证边界验证
- idempotency middleware实现请求幂等性
- invariants实现业务不变量检查
- repository层实现完整的数据访问逻辑

关联issue: Round-1 R1-ISSUE-006 凭证边界硬门禁
2026-04-01 08:53:28 +08:00

344 lines
7.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
func TestTokenVerify(t *testing.T) {
secretKey := "test-secret-key-12345678901234567890"
issuer := "test-issuer"
tests := []struct {
name string
token string
expectError bool
errorContains string
}{
{
name: "valid token",
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(time.Hour)),
expectError: false,
},
{
name: "expired token",
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(-time.Hour)),
expectError: true,
errorContains: "expired",
},
{
name: "wrong issuer",
token: createTestToken(secretKey, "wrong-issuer", "subject:1", "owner", time.Now().Add(time.Hour)),
expectError: true,
errorContains: "issuer",
},
{
name: "invalid token",
token: "invalid.token.string",
expectError: true,
errorContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{
config: AuthConfig{
SecretKey: secretKey,
Issuer: issuer,
},
}
_, err := middleware.verifyToken(tt.token)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got nil")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("error = %v, want contains %v", err, tt.errorContains)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
func TestQueryKeyRejectMiddleware(t *testing.T) {
tests := []struct {
name string
query string
expectStatus int
}{
{
name: "no query params",
query: "",
expectStatus: http.StatusOK,
},
{
name: "normal params",
query: "?page=1&size=10",
expectStatus: http.StatusOK,
},
{
name: "blocked key param",
query: "?key=abc123",
expectStatus: http.StatusUnauthorized,
},
{
name: "blocked api_key param",
query: "?api_key=secret123",
expectStatus: http.StatusUnauthorized,
},
{
name: "blocked token param",
query: "?token=bearer123",
expectStatus: http.StatusUnauthorized,
},
{
name: "suspicious long param",
query: "?apikey=verylongparamvalueexceeding20chars",
expectStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{
auditEmitter: nil,
}
nextCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
})
handler := middleware.QueryKeyRejectMiddleware(nextHandler)
req := httptest.NewRequest("POST", "/api/v1/supply/accounts"+tt.query, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if tt.expectStatus == http.StatusOK {
if !nextCalled {
t.Errorf("expected next handler to be called")
}
} else {
if w.Code != tt.expectStatus {
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
}
}
})
}
}
func TestBearerExtractMiddleware(t *testing.T) {
tests := []struct {
name string
authHeader string
expectStatus int
}{
{
name: "valid bearer",
authHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
expectStatus: http.StatusOK,
},
{
name: "missing header",
authHeader: "",
expectStatus: http.StatusUnauthorized,
},
{
name: "wrong prefix",
authHeader: "Basic abc123",
expectStatus: http.StatusUnauthorized,
},
{
name: "empty token",
authHeader: "Bearer ",
expectStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{}
nextCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
// 检查context中是否有bearer token
if r.Context().Value(bearerTokenKey) == nil && tt.authHeader != "" && strings.HasPrefix(tt.authHeader, "Bearer ") {
// 这是预期的因为token可能无效
}
})
handler := middleware.BearerExtractMiddleware(nextHandler)
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if tt.expectStatus == http.StatusOK {
if !nextCalled {
t.Errorf("expected next handler to be called")
}
} else {
if w.Code != tt.expectStatus {
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
}
}
})
}
}
func TestContainsScope(t *testing.T) {
tests := []struct {
name string
scopes []string
target string
expected bool
}{
{
name: "exact match",
scopes: []string{"read", "write", "delete"},
target: "write",
expected: true,
},
{
name: "wildcard",
scopes: []string{"*"},
target: "anything",
expected: true,
},
{
name: "no match",
scopes: []string{"read", "write"},
target: "admin",
expected: false,
},
{
name: "empty scopes",
scopes: []string{},
target: "read",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsScope(tt.scopes, tt.target)
if result != tt.expected {
t.Errorf("containsScope(%v, %s) = %v, want %v", tt.scopes, tt.target, result, tt.expected)
}
})
}
}
func TestRoleLevel(t *testing.T) {
hierarchy := map[string]int{
"admin": 3,
"owner": 2,
"viewer": 1,
}
tests := []struct {
role string
expected int
}{
{"admin", 3},
{"owner", 2},
{"viewer", 1},
{"unknown", 0},
}
for _, tt := range tests {
t.Run(tt.role, func(t *testing.T) {
result := roleLevel(tt.role, hierarchy)
if result != tt.expected {
t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected)
}
})
}
}
func TestTokenCache(t *testing.T) {
cache := NewTokenCache()
t.Run("get empty", func(t *testing.T) {
status, found := cache.Get("nonexistent")
if found {
t.Errorf("expected not found")
}
if status != "" {
t.Errorf("expected empty status")
}
})
t.Run("set and get", func(t *testing.T) {
cache.Set("token1", "active", time.Hour)
status, found := cache.Get("token1")
if !found {
t.Errorf("expected to find token1")
}
if status != "active" {
t.Errorf("expected status 'active', got '%s'", status)
}
})
t.Run("invalidate", func(t *testing.T) {
cache.Set("token2", "revoked", time.Hour)
cache.Invalidate("token2")
_, found := cache.Get("token2")
if found {
t.Errorf("expected token2 to be invalidated")
}
})
t.Run("expiration", func(t *testing.T) {
cache.Set("token3", "active", time.Nanosecond)
time.Sleep(time.Millisecond)
_, found := cache.Get("token3")
if found {
t.Errorf("expected token3 to be expired")
}
})
}
// Helper functions
func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string {
claims := TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: subject,
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
SubjectID: subject,
Role: role,
Scope: []string{"read", "write"},
TenantID: 1,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte(secretKey))
return tokenString
}