Files
lijiaoqiao/gateway/internal/router/strategy/cost_based.go
Your Name 89104bd0db feat(P1/P2): 完成TDD开发及P1/P2设计文档
## 设计文档
- multi_role_permission_design: 多角色权限设计 (CONDITIONAL GO)
- audit_log_enhancement_design: 审计日志增强 (CONDITIONAL GO)
- routing_strategy_template_design: 路由策略模板 (CONDITIONAL GO)
- sso_saml_technical_research: SSO/SAML调研 (CONDITIONAL GO)
- compliance_capability_package_design: 合规能力包设计 (CONDITIONAL GO)

## TDD开发成果
- IAM模块: supply-api/internal/iam/ (111个测试)
- 审计日志模块: supply-api/internal/audit/ (40+测试)
- 路由策略模块: gateway/internal/router/ (33+测试)
- 合规能力包: gateway/internal/compliance/ + scripts/ci/compliance/

## 规范文档
- parallel_agent_output_quality_standards: 并行Agent产出质量规范
- project_experience_summary: 项目经验总结 (v2)
- 2026-04-02-p1-p2-tdd-execution-plan: TDD执行计划

## 评审报告
- 5个CONDITIONAL GO设计文档评审报告
- fix_verification_report: 修复验证报告
- full_verification_report: 全面质量验证报告
- tdd_module_quality_verification: TDD模块质量验证
- tdd_execution_summary: TDD执行总结

依据: Superpowers执行框架 + TDD规范
2026-04-02 23:35:53 +08:00

133 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package strategy
import (
"context"
"errors"
"sort"
"lijiaoqiao/gateway/internal/adapter"
gwerror "lijiaoqiao/gateway/pkg/error"
)
// ErrNoAffordableProvider 没有可负担的Provider
var ErrNoAffordableProvider = errors.New("no affordable provider available")
// CostBasedTemplate 成本优先策略模板
// 选择成本最低的provider
type CostBasedTemplate struct {
name string
maxCostPer1KTokens float64
providers map[string]adapter.ProviderAdapter
}
// CostParams 成本参数
type CostParams struct {
// 最大成本 ($/1K tokens)
MaxCostPer1KTokens float64
}
// NewCostBasedTemplate 创建成本优先策略模板
func NewCostBasedTemplate(name string, params CostParams) *CostBasedTemplate {
return &CostBasedTemplate{
name: name,
maxCostPer1KTokens: params.MaxCostPer1KTokens,
providers: make(map[string]adapter.ProviderAdapter),
}
}
// RegisterProvider 注册Provider
func (t *CostBasedTemplate) RegisterProvider(name string, provider adapter.ProviderAdapter) {
t.providers[name] = provider
}
// Name 获取策略名称
func (t *CostBasedTemplate) Name() string {
return t.name
}
// Type 获取策略类型
func (t *CostBasedTemplate) Type() string {
return "cost_based"
}
// SelectProvider 选择成本最低的Provider
func (t *CostBasedTemplate) SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error) {
if len(t.providers) == 0 {
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider registered")
}
// 收集所有可用provider的候选列表
type candidate struct {
name string
cost float64
}
var candidates []candidate
for name, provider := range t.providers {
// 检查provider是否支持该模型
supported := false
for _, m := range provider.SupportedModels() {
if m == req.Model || m == "*" {
supported = true
break
}
}
if !supported {
continue
}
// 检查健康状态
if !provider.HealthCheck(ctx) {
continue
}
// 获取成本信息 (实际实现需要从provider获取)
// 这里暂时设置为模拟值
cost := t.getProviderCost(provider)
candidates = append(candidates, candidate{name: name, cost: cost})
}
if len(candidates) == 0 {
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider for model: "+req.Model)
}
// 按成本排序
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].cost < candidates[j].cost
})
// 选择成本最低且在预算内的provider
maxCost := t.maxCostPer1KTokens
if req.MaxCost > 0 && req.MaxCost < maxCost {
maxCost = req.MaxCost
}
for _, c := range candidates {
if c.cost <= maxCost {
return &RoutingDecision{
Provider: c.name,
Strategy: t.Type(),
CostPer1KTokens: c.cost,
TakeoverMark: true, // M-008: 标记为接管
}, nil
}
}
return nil, ErrNoAffordableProvider
}
// CostAwareProvider 成本感知Provider接口
type CostAwareProvider interface {
GetCostPer1KTokens() float64
}
// getProviderCost 获取Provider的成本
func (t *CostBasedTemplate) getProviderCost(provider adapter.ProviderAdapter) float64 {
// 尝试类型断言获取成本
if cp, ok := provider.(CostAwareProvider); ok {
return cp.GetCostPer1KTokens()
}
// 默认返回0.5实际应从provider元数据获取
return 0.5
}