Files
lijiaoqiao/gateway/internal/adapter/openai_adapter.go
Your Name 0484c7be74 feat(gateway): 实现网关核心模块
实现内容:
- internal/adapter: Provider Adapter抽象层和OpenAI实现
- internal/router: 多Provider路由(支持latency/weighted/availability策略)
- internal/handler: OpenAI兼容API端点(/v1/chat/completions, /v1/completions)
- internal/ratelimit: Token Bucket和Sliding Window限流器
- internal/alert: 告警系统(支持邮件/钉钉/飞书)
- internal/config: 配置管理
- pkg/error: 完整错误码体系
- pkg/model: API请求/响应模型

PRD对齐:
- P0-1: 统一API接入  (OpenAI兼容)
- P0-2: 基础路由与稳定性  (多Provider路由+Fallback)
- P0-4: 预算与限流  (Token Bucket限流)

注意:需要供应链模块支持后再完善成本归因和账单导出
2026-04-01 10:04:52 +08:00

327 lines
7.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package adapter
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"lijiaoqiao/gateway/pkg/error"
)
// OpenAIAdapter OpenAI适配器
type OpenAIAdapter struct {
baseURL string
apiKey string
httpClient *http.Client
models []string
}
// NewOpenAIAdapter 创建OpenAI适配器
func NewOpenAIAdapter(baseURL, apiKey string, models []string) *OpenAIAdapter {
return &OpenAIAdapter{
baseURL: baseURL,
apiKey: apiKey,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
models: models,
}
}
// ChatCompletion 实现ChatCompletion接口
func (a *OpenAIAdapter) ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error) {
// 构建请求
reqBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if options.Temperature > 0 {
reqBody["temperature"] = options.Temperature
}
if options.MaxTokens > 0 {
reqBody["max_tokens"] = options.MaxTokens
}
if options.TopP > 0 {
reqBody["top_p"] = options.TopP
}
if len(options.Stop) > 0 {
reqBody["stop"] = options.Stop
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// 发送请求
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp map[string]interface{}
if json.Unmarshal(body, &errResp) == nil {
if errDetail, ok := errResp["error"].(map[string]interface{}); ok {
return nil, a.MapError(fmt.Errorf("%v", errDetail))
}
}
return nil, a.MapError(fmt.Errorf("unexpected status: %d", resp.StatusCode))
}
// 解析响应
var result struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
// 转换响应
response := &CompletionResponse{
ID: result.ID,
Object: result.Object,
Created: result.Created,
Model: result.Model,
Choices: make([]Choice, len(result.Choices)),
}
for i, c := range result.Choices {
response.Choices[i] = Choice{
Message: &Message{
Role: c.Message.Role,
Content: c.Message.Content,
},
FinishReason: c.FinishReason,
}
}
response.Usage = Usage{
PromptTokens: result.Usage.PromptTokens,
CompletionTokens: result.Usage.CompletionTokens,
TotalTokens: result.Usage.TotalTokens,
}
return response, nil
}
// ChatCompletionStream 实现流式ChatCompletion
func (a *OpenAIAdapter) ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error) {
// 构建请求
reqBody := map[string]interface{}{
"model": model,
"messages": messages,
"stream": true,
}
if options.Temperature > 0 {
reqBody["temperature"] = options.Temperature
}
if options.MaxTokens > 0 {
reqBody["max_tokens"] = options.MaxTokens
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, a.MapError(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body)))
}
ch := make(chan *StreamChunk, 100)
go func() {
defer close(ch)
defer resp.Body.Close()
reader := io.Reader(resp.Body)
for {
line, err := io.ReadLine(reader)
if err != nil {
return
}
if len(line) < 6 {
continue
}
// SSE格式: data: {...}
if string(line[:5]) != "data:" {
continue
}
data := line[6:]
if string(data) == "[DONE]" {
return
}
var chunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Delta struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
if json.Unmarshal(data, &chunk) != nil {
continue
}
streamChunk := &StreamChunk{
ID: chunk.ID,
Object: chunk.Object,
Created: chunk.Created,
Model: chunk.Model,
Choices: make([]StreamChoice, len(chunk.Choices)),
}
for i, c := range chunk.Choices {
streamChunk.Choices[i] = StreamChoice{
Delta: &Delta{
Role: c.Delta.Role,
Content: c.Delta.Content,
},
FinishReason: c.FinishReason,
}
}
select {
case ch <- streamChunk:
case <-ctx.Done():
return
}
}
}()
return ch, nil
}
// GetUsage 获取使用量
func (a *OpenAIAdapter) GetUsage(response *CompletionResponse) Usage {
return response.Usage
}
// MapError 错误码映射
func (a *OpenAIAdapter) MapError(err error) error {
// 简化实现实际应根据OpenAI错误响应映射
errStr := err.Error()
if contains(errStr, "invalid_api_key") {
return error.NewGatewayError(error.PROVIDER_INVALID_KEY, "Invalid API key").WithInternal(err)
}
if contains(errStr, "rate_limit") {
return error.NewGatewayError(error.PROVIDER_RATE_LIMIT, "Rate limit exceeded").WithInternal(err)
}
if contains(errStr, "quota") {
return error.NewGatewayError(error.PROVIDER_QUOTA_EXCEEDED, "Quota exceeded").WithInternal(err)
}
if contains(errStr, "model_not_found") {
return error.NewGatewayError(error.PROVIDER_MODEL_NOT_FOUND, "Model not found").WithInternal(err)
}
return error.NewGatewayError(error.PROVIDER_ERROR, "Provider error").WithInternal(err)
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// HealthCheck 健康检查
func (a *OpenAIAdapter) HealthCheck(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/models", a.baseURL), nil)
if err != nil {
return false
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
resp, err := a.httpClient.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
// ProviderName 供应商名称
func (a *OpenAIAdapter) ProviderName() string {
return "openai"
}
// SupportedModels 支持的模型列表
func (a *OpenAIAdapter) SupportedModels() []string {
return a.models
}