190 lines
4.8 KiB
Go
190 lines
4.8 KiB
Go
|
|
package strategy
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
|
||
|
|
"lijiaoqiao/gateway/internal/adapter"
|
||
|
|
"lijiaoqiao/gateway/internal/router/scoring"
|
||
|
|
gwerror "lijiaoqiao/gateway/pkg/error"
|
||
|
|
)
|
||
|
|
|
||
|
|
// ErrNoQualifiedProvider 没有符合条件的Provider
|
||
|
|
var ErrNoQualifiedProvider = errors.New("no qualified provider available")
|
||
|
|
|
||
|
|
// CostAwareTemplate 成本感知策略模板
|
||
|
|
// 综合考虑成本、质量、延迟进行权衡
|
||
|
|
type CostAwareTemplate struct {
|
||
|
|
name string
|
||
|
|
maxCostPer1KTokens float64
|
||
|
|
maxLatencyMs int64
|
||
|
|
minQualityScore float64
|
||
|
|
providers map[string]adapter.ProviderAdapter
|
||
|
|
scoringModel *scoring.ScoringModel
|
||
|
|
}
|
||
|
|
|
||
|
|
// CostAwareParams 成本感知参数
|
||
|
|
type CostAwareParams struct {
|
||
|
|
MaxCostPer1KTokens float64
|
||
|
|
MaxLatencyMs int64
|
||
|
|
MinQualityScore float64
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewCostAwareTemplate 创建成本感知策略模板
|
||
|
|
func NewCostAwareTemplate(name string, params CostAwareParams) *CostAwareTemplate {
|
||
|
|
return &CostAwareTemplate{
|
||
|
|
name: name,
|
||
|
|
maxCostPer1KTokens: params.MaxCostPer1KTokens,
|
||
|
|
maxLatencyMs: params.MaxLatencyMs,
|
||
|
|
minQualityScore: params.MinQualityScore,
|
||
|
|
providers: make(map[string]adapter.ProviderAdapter),
|
||
|
|
scoringModel: scoring.NewScoringModel(scoring.DefaultWeights),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// RegisterProvider 注册Provider
|
||
|
|
func (t *CostAwareTemplate) RegisterProvider(name string, provider adapter.ProviderAdapter) {
|
||
|
|
t.providers[name] = provider
|
||
|
|
}
|
||
|
|
|
||
|
|
// Name 获取策略名称
|
||
|
|
func (t *CostAwareTemplate) Name() string {
|
||
|
|
return t.name
|
||
|
|
}
|
||
|
|
|
||
|
|
// Type 获取策略类型
|
||
|
|
func (t *CostAwareTemplate) Type() string {
|
||
|
|
return "cost_aware"
|
||
|
|
}
|
||
|
|
|
||
|
|
// SelectProvider 选择最佳平衡的Provider
|
||
|
|
func (t *CostAwareTemplate) SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error) {
|
||
|
|
if len(t.providers) == 0 {
|
||
|
|
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider registered")
|
||
|
|
}
|
||
|
|
|
||
|
|
type candidate struct {
|
||
|
|
name string
|
||
|
|
cost float64
|
||
|
|
quality float64
|
||
|
|
latency int64
|
||
|
|
score float64
|
||
|
|
}
|
||
|
|
|
||
|
|
var candidates []candidate
|
||
|
|
maxCost := t.maxCostPer1KTokens
|
||
|
|
if req.MaxCost > 0 && req.MaxCost < maxCost {
|
||
|
|
maxCost = req.MaxCost
|
||
|
|
}
|
||
|
|
maxLatency := t.maxLatencyMs
|
||
|
|
if req.MaxLatency > 0 && req.MaxLatency < maxLatency {
|
||
|
|
maxLatency = req.MaxLatency
|
||
|
|
}
|
||
|
|
minQuality := t.minQualityScore
|
||
|
|
if req.MinQuality > 0 && req.MinQuality > minQuality {
|
||
|
|
minQuality = req.MinQuality
|
||
|
|
}
|
||
|
|
|
||
|
|
for name, provider := range t.providers {
|
||
|
|
// 检查provider是否支持该模型
|
||
|
|
supported := false
|
||
|
|
for _, m := range provider.SupportedModels() {
|
||
|
|
if m == req.Model || m == "*" {
|
||
|
|
supported = true
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if !supported {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查健康状态
|
||
|
|
if !provider.HealthCheck(ctx) {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
// 获取provider指标
|
||
|
|
cost := t.getProviderCost(provider)
|
||
|
|
quality := t.getProviderQuality(provider)
|
||
|
|
latency := t.getProviderLatency(provider)
|
||
|
|
|
||
|
|
// 过滤不满足基本条件的provider
|
||
|
|
if cost > maxCost || latency > maxLatency || quality < minQuality {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
// 计算综合评分
|
||
|
|
metrics := scoring.ProviderMetrics{
|
||
|
|
Name: name,
|
||
|
|
LatencyMs: latency,
|
||
|
|
Availability: 1.0, // 假设可用
|
||
|
|
CostPer1KTokens: cost,
|
||
|
|
QualityScore: quality,
|
||
|
|
}
|
||
|
|
score := t.scoringModel.CalculateScore(metrics)
|
||
|
|
|
||
|
|
candidates = append(candidates, candidate{
|
||
|
|
name: name,
|
||
|
|
cost: cost,
|
||
|
|
quality: quality,
|
||
|
|
latency: latency,
|
||
|
|
score: score,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(candidates) == 0 {
|
||
|
|
return nil, ErrNoQualifiedProvider
|
||
|
|
}
|
||
|
|
|
||
|
|
// 选择评分最高的provider
|
||
|
|
best := &candidates[0]
|
||
|
|
for i := 1; i < len(candidates); i++ {
|
||
|
|
if candidates[i].score > best.score {
|
||
|
|
best = &candidates[i]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return &RoutingDecision{
|
||
|
|
Provider: best.name,
|
||
|
|
Strategy: t.Type(),
|
||
|
|
CostPer1KTokens: best.cost,
|
||
|
|
EstimatedLatency: best.latency,
|
||
|
|
QualityScore: best.quality,
|
||
|
|
TakeoverMark: true, // M-008: 标记为接管
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// getProviderCost 获取Provider的成本
|
||
|
|
func (t *CostAwareTemplate) getProviderCost(provider adapter.ProviderAdapter) float64 {
|
||
|
|
if cp, ok := provider.(CostAwareProvider); ok {
|
||
|
|
return cp.GetCostPer1KTokens()
|
||
|
|
}
|
||
|
|
return 0.5
|
||
|
|
}
|
||
|
|
|
||
|
|
// getProviderQuality 获取Provider的质量分数
|
||
|
|
func (t *CostAwareTemplate) getProviderQuality(provider adapter.ProviderAdapter) float64 {
|
||
|
|
if qp, ok := provider.(QualityProvider); ok {
|
||
|
|
return qp.GetQualityScore()
|
||
|
|
}
|
||
|
|
return 0.8 // 默认质量分数
|
||
|
|
}
|
||
|
|
|
||
|
|
// getProviderLatency 获取Provider的延迟
|
||
|
|
func (t *CostAwareTemplate) getProviderLatency(provider adapter.ProviderAdapter) int64 {
|
||
|
|
if lp, ok := provider.(LatencyProvider); ok {
|
||
|
|
return lp.GetLatencyMs()
|
||
|
|
}
|
||
|
|
return 100 // 默认延迟100ms
|
||
|
|
}
|
||
|
|
|
||
|
|
// QualityProvider 质量感知Provider接口
|
||
|
|
type QualityProvider interface {
|
||
|
|
GetQualityScore() float64
|
||
|
|
}
|
||
|
|
|
||
|
|
// LatencyProvider 延迟感知Provider接口
|
||
|
|
type LatencyProvider interface {
|
||
|
|
GetLatencyMs() int64
|
||
|
|
}
|