diff --git a/gateway/internal/router/engine/routing_engine.go b/gateway/internal/router/engine/routing_engine.go index c69e781..d66b151 100644 --- a/gateway/internal/router/engine/routing_engine.go +++ b/gateway/internal/router/engine/routing_engine.go @@ -3,6 +3,7 @@ package engine import ( "context" "errors" + "sync" "lijiaoqiao/gateway/internal/router/strategy" ) @@ -18,6 +19,7 @@ type RoutingMetrics interface { // RoutingEngine 路由引擎 type RoutingEngine struct { + mu sync.RWMutex strategies map[string]strategy.StrategyTemplate metrics RoutingMetrics } @@ -32,6 +34,8 @@ func NewRoutingEngine() *RoutingEngine { // RegisterStrategy 注册路由策略 func (e *RoutingEngine) RegisterStrategy(name string, template strategy.StrategyTemplate) { + e.mu.Lock() + defer e.mu.Unlock() e.strategies[name] = template } @@ -54,8 +58,11 @@ func (e *RoutingEngine) SelectProvider(ctx context.Context, req *strategy.Routin return nil, err } - // 记录指标 - if e.metrics != nil && decision != nil { + if decision == nil { + return nil, ErrStrategyNotFound + } + + if e.metrics != nil { e.metrics.RecordSelection(decision.Provider, decision.Strategy, decision) } diff --git a/gateway/internal/router/engine/routing_engine_test.go b/gateway/internal/router/engine/routing_engine_test.go index ba584ff..2a8e362 100644 --- a/gateway/internal/router/engine/routing_engine_test.go +++ b/gateway/internal/router/engine/routing_engine_test.go @@ -152,3 +152,88 @@ func (m *MockRoutingMetrics) RecordSelection(provider string, strategyName strin m.takeoverMark = decision.TakeoverMark } } + +// ==================== P0问题测试 ==================== + +// TestP0_07_RegisterStrategy_ThreadSafety 测试P0-07: 策略注册非线程安全 +func TestP0_07_RegisterStrategy_ThreadSafety(t *testing.T) { + engine := NewRoutingEngine() + + // 并发注册多个策略,启用-race检测器可以发现数据竞争 + done := make(chan bool) + const goroutines = 100 + + for i := 0; i < goroutines; i++ { + go func(idx int) { + name := strategyName(idx) + tpl := strategy.NewCostBasedTemplate(name, strategy.CostParams{ + MaxCostPer1KTokens: 1.0, + }) + tpl.RegisterProvider("ProviderA", &MockProvider{ + name: "ProviderA", + costPer1KTokens: 0.5, + available: true, + models: []string{"gpt-4"}, + }) + engine.RegisterStrategy(name, tpl) + done <- true + }(i) + } + + // 等待所有goroutine完成 + for i := 0; i < goroutines; i++ { + <-done + } + + // 验证所有策略都已注册 + for i := 0; i < goroutines; i++ { + name := strategyName(i) + _, ok := engine.strategies[name] + assert.True(t, ok, "Strategy %s should be registered", name) + } +} + +func strategyName(idx int) string { + return "strategy_" + string(rune('a'+idx%26)) + string(rune('0'+idx/26%10)) +} + +// TestP0_08_DecisionNilPanic 测试P0-08: decision可能为空指针 +func TestP0_08_DecisionNilPanic(t *testing.T) { + engine := NewRoutingEngine() + + // 创建一个返回nil decision但不返回错误的策略 + nilDecisionStrategy := &NilDecisionStrategy{} + + engine.RegisterStrategy("nil_decision", nilDecisionStrategy) + + // 设置metrics + engine.metrics = &MockRoutingMetrics{} + + req := &strategy.RoutingRequest{ + Model: "gpt-4", + UserID: "user123", + } + + // 验证返回ErrStrategyNotFound而不是panic + decision, err := engine.SelectProvider(context.Background(), req, "nil_decision") + + assert.Error(t, err, "Should return error when decision is nil") + assert.Equal(t, ErrStrategyNotFound, err, "Should return ErrStrategyNotFound") + assert.Nil(t, decision, "Decision should be nil") +} + +// NilDecisionStrategy 返回nil decision的测试策略 +type NilDecisionStrategy struct{} + +func (s *NilDecisionStrategy) SelectProvider(ctx context.Context, req *strategy.RoutingRequest) (*strategy.RoutingDecision, error) { + // 返回nil decision但不返回错误 - 这模拟了潜在的边界情况 + return nil, nil +} + +func (s *NilDecisionStrategy) Name() string { + return "nil_decision" +} + +func (s *NilDecisionStrategy) Type() string { + return "nil_decision" +}