Files
lijiaoqiao/gateway/internal/handler/handler.go

350 lines
9.6 KiB
Go
Raw Permalink Normal View History

package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/internal/router"
gwerror "lijiaoqiao/gateway/pkg/error"
"lijiaoqiao/gateway/pkg/model"
)
// MaxRequestBytes 最大请求体大小 (1MB)
const MaxRequestBytes = 1 * 1024 * 1024
// maxBytesReader 限制读取字节数的reader
type maxBytesReader struct {
reader io.ReadCloser
remaining int64
}
// Read 实现io.Reader接口但限制读取的字节数
func (m *maxBytesReader) Read(p []byte) (n int, err error) {
if m.remaining <= 0 {
return 0, io.EOF
}
if int64(len(p)) > m.remaining {
p = p[:m.remaining]
}
n, err = m.reader.Read(p)
m.remaining -= int64(n)
return n, err
}
// Close 实现io.Closer接口
func (m *maxBytesReader) Close() error {
return m.reader.Close()
}
// Handler API处理器
type Handler struct {
router *router.Router
version string
}
// NewHandler 创建处理器
func NewHandler(r *router.Router) *Handler {
return &Handler{
router: r,
version: "v1",
}
}
// ChatCompletionsHandle /v1/chat/completions端点
func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
ctx := context.WithValue(r.Context(), "request_id", requestID)
ctx = context.WithValue(ctx, "start_time", startTime)
// 解析请求 - 使用限制reader防止过大的请求体
var req model.ChatCompletionRequest
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
// 检查是否是请求体过大的错误
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
return
}
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
return
}
// 验证请求
if len(req.Messages) == 0 {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
return
}
// 选择Provider
provider, err := h.router.SelectProvider(ctx, req.Model)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 转换消息格式
messages := make([]adapter.Message, len(req.Messages))
for i, m := range req.Messages {
messages[i] = adapter.Message{
Role: m.Role,
Content: m.Content,
Name: m.Name,
}
}
// 构建选项
options := adapter.CompletionOptions{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stream: req.Stream,
Stop: req.Stop,
}
// 处理流式请求
if req.Stream {
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
return
}
// 处理非流式请求
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
if err != nil {
// 记录失败
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 记录成功
h.router.RecordResult(ctx, provider.ProviderName(), true, time.Since(startTime).Milliseconds())
// 转换响应
chatResp := model.ChatCompletionResponse{
ID: response.ID,
Object: "chat.completion",
Created: response.Created,
Model: response.Model,
Choices: make([]model.Choice, len(response.Choices)),
}
for i, c := range response.Choices {
chatResp.Choices[i] = model.Choice{
Index: c.Index,
Message: model.ChatMessage{
Role: c.Message.Role,
Content: c.Message.Content,
},
FinishReason: c.FinishReason,
}
}
chatResp.Usage = model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}
h.writeJSON(w, http.StatusOK, chatResp, requestID)
}
// handleStream 处理流式请求
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 设置SSE头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Request-ID", requestID)
flusher, ok := w.(http.Flusher)
if !ok {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
return
}
// 流式发送响应
for chunk := range ch {
data := fmt.Sprintf("data: %s\n\n", marshalJSON(chunk))
w.Write([]byte(data))
flusher.Flush()
}
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}
// CompletionsHandle /v1/completions端点
func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// 解析请求 - 使用限制reader防止过大的请求体
var req model.CompletionRequest
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
// 检查是否是请求体过大的错误
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
return
}
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
return
}
// 构造消息
ctx := r.Context()
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
provider, err := h.router.SelectProvider(ctx, req.Model)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
options := adapter.CompletionOptions{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stream: req.Stream,
Stop: req.Stop,
}
if req.Stream {
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
return
}
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 转换响应为Completion格式
compResp := model.CompletionResponse{
ID: response.ID,
Object: "text_completion",
Created: response.Created,
Model: response.Model,
Choices: make([]model.Choice1, len(response.Choices)),
}
for i, c := range response.Choices {
compResp.Choices[i] = model.Choice1{
Text: c.Message.Content,
Index: i,
FinishReason: c.FinishReason,
}
}
compResp.Usage = model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}
h.writeJSON(w, http.StatusOK, compResp, requestID)
}
// ModelsHandle /v1/models端点
func (h *Handler) ModelsHandle(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// 返回支持的模型列表
models := []map[string]interface{}{
{"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"},
{"id": "gpt-3.5-turbo", "object": "model", "created": 1677610602, "owned_by": "openai"},
{"id": "claude-3-opus", "object": "model", "created": 1709598254, "owned_by": "anthropic"},
{"id": "claude-3-sonnet", "object": "model", "created": 1709598255, "owned_by": "anthropic"},
}
h.writeJSON(w, http.StatusOK, map[string]interface{}{
"object": "list",
"data": models,
}, requestID)
}
// HealthHandle /health端点
func (h *Handler) HealthHandle(w http.ResponseWriter, r *http.Request) {
healthStatus := h.router.GetHealthStatus()
allHealthy := true
services := make(map[string]bool)
for name, health := range healthStatus {
services[name] = health.Available
if !health.Available {
allHealthy = false
}
}
status := "healthy"
statusCode := http.StatusOK
if !allHealthy {
status = "degraded"
statusCode = http.StatusServiceUnavailable
}
h.writeJSON(w, statusCode, model.HealthStatus{
Status: status,
Timestamp: time.Now(),
Services: services,
}, "")
}
func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{}, requestID string) {
w.Header().Set("Content-Type", "application/json")
if requestID != "" {
w.Header().Set("X-Request-ID", requestID)
}
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *gwerror.GatewayError) {
info := err.GetErrorInfo()
w.Header().Set("Content-Type", "application/json")
if err.RequestID != "" {
w.Header().Set("X-Request-ID", err.RequestID)
}
w.WriteHeader(info.HTTPStatus)
resp := model.ErrorResponse{
Error: model.ErrorDetail{
Message: err.Message,
Type: "gateway_error",
Code: string(err.Code),
},
}
json.NewEncoder(w).Encode(resp)
}
func generateRequestID() string {
return fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
}
func marshalJSON(v interface{}) string {
data, _ := json.Marshal(v)
return string(data)
}