Files
lijiaoqiao/gateway/internal/router/router_test.go
Your Name 89104bd0db feat(P1/P2): 完成TDD开发及P1/P2设计文档
## 设计文档
- multi_role_permission_design: 多角色权限设计 (CONDITIONAL GO)
- audit_log_enhancement_design: 审计日志增强 (CONDITIONAL GO)
- routing_strategy_template_design: 路由策略模板 (CONDITIONAL GO)
- sso_saml_technical_research: SSO/SAML调研 (CONDITIONAL GO)
- compliance_capability_package_design: 合规能力包设计 (CONDITIONAL GO)

## TDD开发成果
- IAM模块: supply-api/internal/iam/ (111个测试)
- 审计日志模块: supply-api/internal/audit/ (40+测试)
- 路由策略模块: gateway/internal/router/ (33+测试)
- 合规能力包: gateway/internal/compliance/ + scripts/ci/compliance/

## 规范文档
- parallel_agent_output_quality_standards: 并行Agent产出质量规范
- project_experience_summary: 项目经验总结 (v2)
- 2026-04-02-p1-p2-tdd-execution-plan: TDD执行计划

## 评审报告
- 5个CONDITIONAL GO设计文档评审报告
- fix_verification_report: 修复验证报告
- full_verification_report: 全面质量验证报告
- tdd_module_quality_verification: TDD模块质量验证
- tdd_execution_summary: TDD执行总结

依据: Superpowers执行框架 + TDD规范
2026-04-02 23:35:53 +08:00

578 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package router
import (
"context"
"math"
"testing"
"time"
"lijiaoqiao/gateway/internal/adapter"
)
// mockProvider 实现adapter.ProviderAdapter接口
type mockProvider struct {
name string
models []string
healthy bool
}
func (m *mockProvider) ChatCompletion(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (*adapter.CompletionResponse, error) {
return nil, nil
}
func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
return nil, nil
}
func (m *mockProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage {
return adapter.Usage{}
}
func (m *mockProvider) MapError(err error) adapter.ProviderError {
return adapter.ProviderError{}
}
func (m *mockProvider) HealthCheck(ctx context.Context) bool {
return m.healthy
}
func (m *mockProvider) ProviderName() string {
return m.name
}
func (m *mockProvider) SupportedModels() []string {
return m.models
}
func TestNewRouter(t *testing.T) {
r := NewRouter(StrategyLatency)
if r == nil {
t.Fatal("expected non-nil router")
}
if r.strategy != StrategyLatency {
t.Errorf("expected strategy latency, got %s", r.strategy)
}
if len(r.providers) != 0 {
t.Errorf("expected 0 providers, got %d", len(r.providers))
}
}
func TestRegisterProvider(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
if len(r.providers) != 1 {
t.Errorf("expected 1 provider, got %d", len(r.providers))
}
health := r.health["test"]
if health == nil {
t.Fatal("expected health to be registered")
}
if health.Name != "test" {
t.Errorf("expected name test, got %s", health.Name)
}
if !health.Available {
t.Error("expected provider to be available")
}
}
func TestSelectProvider_NoProviders(t *testing.T) {
r := NewRouter(StrategyLatency)
_, err := r.SelectProvider(context.Background(), "gpt-4")
if err == nil {
t.Fatal("expected error")
}
}
func TestSelectProvider_BasicSelection(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
selected, err := r.SelectProvider(context.Background(), "gpt-4")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if selected.ProviderName() != "test" {
t.Errorf("expected provider test, got %s", selected.ProviderName())
}
}
func TestSelectProvider_ModelNotSupported(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-3.5"}, healthy: true}
r.RegisterProvider("test", prov)
_, err := r.SelectProvider(context.Background(), "gpt-4")
if err == nil {
t.Fatal("expected error")
}
}
func TestSelectProvider_ProviderUnavailable(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
// 通过UpdateHealth标记为不可用
r.UpdateHealth("test", false)
_, err := r.SelectProvider(context.Background(), "gpt-4")
if err == nil {
t.Fatal("expected error")
}
}
func TestSelectProvider_WildcardModel(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"*"}, healthy: true}
r.RegisterProvider("test", prov)
selected, err := r.SelectProvider(context.Background(), "any-model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if selected.ProviderName() != "test" {
t.Errorf("expected provider test, got %s", selected.ProviderName())
}
}
func TestSelectProvider_MultipleProviders(t *testing.T) {
r := NewRouter(StrategyLatency)
prov1 := &mockProvider{name: "fast", models: []string{"gpt-4"}, healthy: true}
prov2 := &mockProvider{name: "slow", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("fast", prov1)
r.RegisterProvider("slow", prov2)
// 记录初始延迟
r.health["fast"].LatencyMs = 10
r.health["slow"].LatencyMs = 100
selected, err := r.SelectProvider(context.Background(), "gpt-4")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if selected.ProviderName() != "fast" {
t.Errorf("expected fastest provider, got %s", selected.ProviderName())
}
}
func TestRecordResult_Success(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
// 初始状态
initialLatency := r.health["test"].LatencyMs
r.RecordResult(context.Background(), "test", true, 50)
if r.health["test"].LatencyMs == initialLatency {
// 首次更新
}
if r.health["test"].FailureRate != 0 {
t.Errorf("expected failure rate 0, got %f", r.health["test"].FailureRate)
}
}
func TestRecordResult_Failure(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
r.RecordResult(context.Background(), "test", false, 100)
if r.health["test"].FailureRate == 0 {
t.Error("expected failure rate to increase")
}
}
func TestRecordResult_MultipleFailures(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
// 多次失败直到失败率超过0.5
// 公式: newRate = oldRate * 0.9 + 0.1
// 需要7次才能超过0.5 (0.469 -> 0.522)
for i := 0; i < 7; i++ {
r.RecordResult(context.Background(), "test", false, 100)
}
// 失败率超过0.5应该标记为不可用
if r.health["test"].Available {
t.Error("expected provider to be marked unavailable")
}
}
func TestUpdateHealth(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
r.UpdateHealth("test", false)
if r.health["test"].Available {
t.Error("expected provider to be unavailable")
}
}
func TestGetHealthStatus(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
status := r.GetHealthStatus()
if len(status) != 1 {
t.Errorf("expected 1 health status, got %d", len(status))
}
health := status["test"]
if health == nil {
t.Fatal("expected health for test")
}
if health.Available != true {
t.Error("expected available")
}
}
func TestGetHealthStatus_Empty(t *testing.T) {
r := NewRouter(StrategyLatency)
status := r.GetHealthStatus()
if len(status) != 0 {
t.Errorf("expected 0 health statuses, got %d", len(status))
}
}
func TestSelectByLatency_EqualLatency(t *testing.T) {
r := NewRouter(StrategyLatency)
prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true}
prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("p1", prov1)
r.RegisterProvider("p2", prov2)
// 相同的延迟
r.health["p1"].LatencyMs = 50
r.health["p2"].LatencyMs = 50
selected, err := r.selectByLatency([]string{"p1", "p2"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 应该返回其中一个
if selected.ProviderName() != "p1" && selected.ProviderName() != "p2" {
t.Errorf("unexpected provider: %s", selected.ProviderName())
}
}
func TestSelectByLatency_NoProviders(t *testing.T) {
r := NewRouter(StrategyLatency)
_, err := r.selectByLatency([]string{})
if err == nil {
t.Fatal("expected error")
}
}
func TestSelectByWeight(t *testing.T) {
r := NewRouter(StrategyLatency)
prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true}
prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("p1", prov1)
r.RegisterProvider("p2", prov2)
r.health["p1"].Weight = 3.0
r.health["p2"].Weight = 1.0
// 测试能正常返回结果
selected, err := r.selectByWeight([]string{"p1", "p2"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 应该返回其中一个
if selected.ProviderName() != "p1" && selected.ProviderName() != "p2" {
t.Errorf("unexpected provider: %s", selected.ProviderName())
}
// 注意由于实现中randVal = time.Now().UnixNano()/MaxInt64 * totalWeight
// 在大多数系统上这个值较小可能总是选中第一个provider。
// 这是实现的一个已知限制。
}
func TestSelectByWeight_SingleProvider(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("p1", prov)
r.health["p1"].Weight = 2.0
selected, err := r.selectByWeight([]string{"p1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if selected.ProviderName() != "p1" {
t.Errorf("expected p1, got %s", selected.ProviderName())
}
}
func TestSelectByAvailability(t *testing.T) {
r := NewRouter(StrategyLatency)
prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true}
prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("p1", prov1)
r.RegisterProvider("p2", prov2)
r.health["p1"].FailureRate = 0.3
r.health["p2"].FailureRate = 0.1
selected, err := r.selectByAvailability([]string{"p1", "p2"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if selected.ProviderName() != "p2" {
t.Errorf("expected provider with lower failure rate, got %s", selected.ProviderName())
}
}
func TestGetFallbackProviders(t *testing.T) {
r := NewRouter(StrategyLatency)
prov1 := &mockProvider{name: "primary", models: []string{"gpt-4"}, healthy: true}
prov2 := &mockProvider{name: "fallback", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("primary", prov1)
r.RegisterProvider("fallback", prov2)
fallbacks, err := r.GetFallbackProviders(context.Background(), "gpt-4")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(fallbacks) != 1 {
t.Errorf("expected 1 fallback, got %d", len(fallbacks))
}
if fallbacks[0].ProviderName() != "fallback" {
t.Errorf("expected fallback, got %s", fallbacks[0].ProviderName())
}
}
func TestGetFallbackProviders_AllUnavailable(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "primary", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("primary", prov)
fallbacks, err := r.GetFallbackProviders(context.Background(), "gpt-4")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(fallbacks) != 0 {
t.Errorf("expected 0 fallbacks, got %d", len(fallbacks))
}
}
func TestRecordResult_LatencyUpdate(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
// 首次记录
r.RecordResult(context.Background(), "test", true, 100)
if r.health["test"].LatencyMs != 100 {
t.Errorf("expected latency 100, got %d", r.health["test"].LatencyMs)
}
// 第二次记录,使用指数移动平均 (7/8 * 100 + 1/8 * 200 = 87.5 + 25 = 112.5)
r.RecordResult(context.Background(), "test", true, 200)
expectedLatency := int64((100*7 + 200) / 8)
if r.health["test"].LatencyMs != expectedLatency {
t.Errorf("expected latency %d, got %d", expectedLatency, r.health["test"].LatencyMs)
}
}
func TestRecordResult_UnknownProvider(t *testing.T) {
r := NewRouter(StrategyLatency)
// 不应该panic
r.RecordResult(context.Background(), "unknown", true, 100)
}
func TestUpdateHealth_UnknownProvider(t *testing.T) {
r := NewRouter(StrategyLatency)
// 不应该panic
r.UpdateHealth("unknown", false)
}
func TestIsProviderAvailable(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4", "gpt-3.5"}, healthy: true}
r.RegisterProvider("test", prov)
tests := []struct {
model string
available bool
}{
{"gpt-4", true},
{"gpt-3.5", true},
{"claude", false},
}
for _, tt := range tests {
if got := r.isProviderAvailable("test", tt.model); got != tt.available {
t.Errorf("isProviderAvailable(%s) = %v, want %v", tt.model, got, tt.available)
}
}
}
func TestIsProviderAvailable_UnknownProvider(t *testing.T) {
r := NewRouter(StrategyLatency)
if r.isProviderAvailable("unknown", "gpt-4") {
t.Error("expected false for unknown provider")
}
}
func TestIsProviderAvailable_Unhealthy(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
// 通过UpdateHealth标记为不可用
r.UpdateHealth("test", false)
if r.isProviderAvailable("test", "gpt-4") {
t.Error("expected false for unhealthy provider")
}
}
func TestProviderHealth_Struct(t *testing.T) {
health := &ProviderHealth{
Name: "test",
Available: true,
LatencyMs: 50,
FailureRate: 0.1,
Weight: 1.0,
LastCheckTime: time.Now(),
}
if health.Name != "test" {
t.Errorf("expected name test, got %s", health.Name)
}
if !health.Available {
t.Error("expected available")
}
if health.LatencyMs != 50 {
t.Errorf("expected latency 50, got %d", health.LatencyMs)
}
if health.FailureRate != 0.1 {
t.Errorf("expected failure rate 0.1, got %f", health.FailureRate)
}
if health.Weight != 1.0 {
t.Errorf("expected weight 1.0, got %f", health.Weight)
}
}
func TestLoadBalancerStrategy_Constants(t *testing.T) {
if StrategyLatency != "latency" {
t.Errorf("expected latency, got %s", StrategyLatency)
}
if StrategyRoundRobin != "round_robin" {
t.Errorf("expected round_robin, got %s", StrategyRoundRobin)
}
if StrategyWeighted != "weighted" {
t.Errorf("expected weighted, got %s", StrategyWeighted)
}
if StrategyAvailability != "availability" {
t.Errorf("expected availability, got %s", StrategyAvailability)
}
}
func TestSelectProvider_AllStrategies(t *testing.T) {
strategies := []LoadBalancerStrategy{StrategyLatency, StrategyWeighted, StrategyAvailability}
for _, strategy := range strategies {
r := NewRouter(strategy)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
selected, err := r.SelectProvider(context.Background(), "gpt-4")
if err != nil {
t.Errorf("strategy %s: unexpected error: %v", strategy, err)
}
if selected.ProviderName() != "test" {
t.Errorf("strategy %s: expected provider test, got %s", strategy, selected.ProviderName())
}
}
}
// 确保FailureRate永远不会超过1.0
func TestRecordResult_FailureRateCapped(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
// 多次失败
for i := 0; i < 20; i++ {
r.RecordResult(context.Background(), "test", false, 100)
}
if r.health["test"].FailureRate > 1.0 {
t.Errorf("failure rate should be capped at 1.0, got %f", r.health["test"].FailureRate)
}
}
// 确保LatencyMs永远不会变成负数
func TestRecordResult_LatencyNeverNegative(t *testing.T) {
r := NewRouter(StrategyLatency)
prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("test", prov)
// 提供负延迟
r.RecordResult(context.Background(), "test", true, -100)
if r.health["test"].LatencyMs < 0 {
t.Errorf("latency should never be negative, got %d", r.health["test"].LatencyMs)
}
}
// 确保math.MaxInt64不会溢出
func TestSelectByLatency_MaxInt64(t *testing.T) {
r := NewRouter(StrategyLatency)
prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true}
prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true}
r.RegisterProvider("p1", prov1)
r.RegisterProvider("p2", prov2)
// p1设置为较大值p2设置为MaxInt64
r.health["p1"].LatencyMs = math.MaxInt64 - 1
r.health["p2"].LatencyMs = math.MaxInt64
selected, err := r.selectByLatency([]string{"p1", "p2"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// p1的延迟更低应该被选中
if selected.ProviderName() != "p1" {
t.Errorf("expected provider p1 (lower latency), got %s", selected.ProviderName())
}
}