feat(gateway): 实现网关核心模块
实现内容: - internal/adapter: Provider Adapter抽象层和OpenAI实现 - internal/router: 多Provider路由(支持latency/weighted/availability策略) - internal/handler: OpenAI兼容API端点(/v1/chat/completions, /v1/completions) - internal/ratelimit: Token Bucket和Sliding Window限流器 - internal/alert: 告警系统(支持邮件/钉钉/飞书) - internal/config: 配置管理 - pkg/error: 完整错误码体系 - pkg/model: API请求/响应模型 PRD对齐: - P0-1: 统一API接入 ✅ (OpenAI兼容) - P0-2: 基础路由与稳定性 ✅ (多Provider路由+Fallback) - P0-4: 预算与限流 ✅ (Token Bucket限流) 注意:需要供应链模块支持后再完善成本归因和账单导出
This commit is contained in:
326
gateway/internal/adapter/openai_adapter.go
Normal file
326
gateway/internal/adapter/openai_adapter.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// OpenAIAdapter OpenAI适配器
|
||||
type OpenAIAdapter struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
models []string
|
||||
}
|
||||
|
||||
// NewOpenAIAdapter 创建OpenAI适配器
|
||||
func NewOpenAIAdapter(baseURL, apiKey string, models []string) *OpenAIAdapter {
|
||||
return &OpenAIAdapter{
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
models: models,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletion 实现ChatCompletion接口
|
||||
func (a *OpenAIAdapter) ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error) {
|
||||
// 构建请求
|
||||
reqBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if options.Temperature > 0 {
|
||||
reqBody["temperature"] = options.Temperature
|
||||
}
|
||||
if options.MaxTokens > 0 {
|
||||
reqBody["max_tokens"] = options.MaxTokens
|
||||
}
|
||||
if options.TopP > 0 {
|
||||
reqBody["top_p"] = options.TopP
|
||||
}
|
||||
if len(options.Stop) > 0 {
|
||||
reqBody["stop"] = options.Stop
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp map[string]interface{}
|
||||
if json.Unmarshal(body, &errResp) == nil {
|
||||
if errDetail, ok := errResp["error"].(map[string]interface{}); ok {
|
||||
return nil, a.MapError(fmt.Errorf("%v", errDetail))
|
||||
}
|
||||
}
|
||||
return nil, a.MapError(fmt.Errorf("unexpected status: %d", resp.StatusCode))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// 转换响应
|
||||
response := &CompletionResponse{
|
||||
ID: result.ID,
|
||||
Object: result.Object,
|
||||
Created: result.Created,
|
||||
Model: result.Model,
|
||||
Choices: make([]Choice, len(result.Choices)),
|
||||
}
|
||||
|
||||
for i, c := range result.Choices {
|
||||
response.Choices[i] = Choice{
|
||||
Message: &Message{
|
||||
Role: c.Message.Role,
|
||||
Content: c.Message.Content,
|
||||
},
|
||||
FinishReason: c.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
response.Usage = Usage{
|
||||
PromptTokens: result.Usage.PromptTokens,
|
||||
CompletionTokens: result.Usage.CompletionTokens,
|
||||
TotalTokens: result.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ChatCompletionStream 实现流式ChatCompletion
|
||||
func (a *OpenAIAdapter) ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error) {
|
||||
// 构建请求
|
||||
reqBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
if options.Temperature > 0 {
|
||||
reqBody["temperature"] = options.Temperature
|
||||
}
|
||||
if options.MaxTokens > 0 {
|
||||
reqBody["max_tokens"] = options.MaxTokens
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, a.MapError(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
ch := make(chan *StreamChunk, 100)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer resp.Body.Close()
|
||||
|
||||
reader := io.Reader(resp.Body)
|
||||
for {
|
||||
line, err := io.ReadLine(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(line) < 6 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE格式: data: {...}
|
||||
if string(line[:5]) != "data:" {
|
||||
continue
|
||||
}
|
||||
|
||||
data := line[6:]
|
||||
if string(data) == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
||||
var chunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if json.Unmarshal(data, &chunk) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
streamChunk := &StreamChunk{
|
||||
ID: chunk.ID,
|
||||
Object: chunk.Object,
|
||||
Created: chunk.Created,
|
||||
Model: chunk.Model,
|
||||
Choices: make([]StreamChoice, len(chunk.Choices)),
|
||||
}
|
||||
|
||||
for i, c := range chunk.Choices {
|
||||
streamChunk.Choices[i] = StreamChoice{
|
||||
Delta: &Delta{
|
||||
Role: c.Delta.Role,
|
||||
Content: c.Delta.Content,
|
||||
},
|
||||
FinishReason: c.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- streamChunk:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// GetUsage 获取使用量
|
||||
func (a *OpenAIAdapter) GetUsage(response *CompletionResponse) Usage {
|
||||
return response.Usage
|
||||
}
|
||||
|
||||
// MapError 错误码映射
|
||||
func (a *OpenAIAdapter) MapError(err error) error {
|
||||
// 简化实现,实际应根据OpenAI错误响应映射
|
||||
errStr := err.Error()
|
||||
|
||||
if contains(errStr, "invalid_api_key") {
|
||||
return error.NewGatewayError(error.PROVIDER_INVALID_KEY, "Invalid API key").WithInternal(err)
|
||||
}
|
||||
if contains(errStr, "rate_limit") {
|
||||
return error.NewGatewayError(error.PROVIDER_RATE_LIMIT, "Rate limit exceeded").WithInternal(err)
|
||||
}
|
||||
if contains(errStr, "quota") {
|
||||
return error.NewGatewayError(error.PROVIDER_QUOTA_EXCEEDED, "Quota exceeded").WithInternal(err)
|
||||
}
|
||||
if contains(errStr, "model_not_found") {
|
||||
return error.NewGatewayError(error.PROVIDER_MODEL_NOT_FOUND, "Model not found").WithInternal(err)
|
||||
}
|
||||
|
||||
return error.NewGatewayError(error.PROVIDER_ERROR, "Provider error").WithInternal(err)
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (a *OpenAIAdapter) HealthCheck(ctx context.Context) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/models", a.baseURL), nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
// ProviderName 供应商名称
|
||||
func (a *OpenAIAdapter) ProviderName() string {
|
||||
return "openai"
|
||||
}
|
||||
|
||||
// SupportedModels 支持的模型列表
|
||||
func (a *OpenAIAdapter) SupportedModels() []string {
|
||||
return a.models
|
||||
}
|
||||
Reference in New Issue
Block a user