Files
lijiaoqiao/gateway/internal/handler/handler.go
Your Name d44e9966e0 fix(security): 修复多个MED安全问题
MED-03: 数据库密码明文配置
- 在 gateway/internal/config/config.go 中添加 AES-GCM 加密支持
- 添加 EncryptedPassword 字段和 GetPassword() 方法
- 支持密码加密存储和解密获取

MED-04: 审计日志Route字段未验证
- 在 supply-api/internal/middleware/auth.go 中添加 sanitizeRoute() 函数
- 防止路径遍历攻击(.., ./, \ 等)
- 防止 null 字节和换行符注入

MED-05: 请求体大小无限制
- 在 gateway/internal/handler/handler.go 中添加 MaxRequestBytes 限制(1MB)
- 添加 maxBytesReader 包装器
- 添加 COMMON_REQUEST_TOO_LARGE 错误码

MED-08: 缺少CORS配置
- 创建 gateway/internal/middleware/cors.go CORS 中间件
- 支持来源域名白名单、通配符子域名
- 支持预检请求处理和凭证配置

MED-09: 错误信息泄露内部细节
- 添加测试验证 JWT 错误消息不包含敏感信息
- 当前实现已正确返回安全错误消息

MED-10: 数据库凭证日志泄露风险
- 在 gateway/cmd/gateway/main.go 中使用 GetPassword() 代替 Password
- 避免 DSN 中明文密码被记录

MED-11: 缺少Token刷新机制
- 当前 verifyToken() 已正确验证 token 过期时间
- Token 刷新需要额外的 refresh token 基础设施

MED-12: 缺少暴力破解保护
- 添加 BruteForceProtection 结构体
- 支持最大尝试次数和锁定时长配置
- 在 TokenVerifyMiddleware 中集成暴力破解保护
2026-04-03 09:51:39 +08:00

350 lines
9.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/internal/router"
gwerror "lijiaoqiao/gateway/pkg/error"
"lijiaoqiao/gateway/pkg/model"
)
// MaxRequestBytes 最大请求体大小 (1MB)
const MaxRequestBytes = 1 * 1024 * 1024
// maxBytesReader 限制读取字节数的reader
type maxBytesReader struct {
reader io.ReadCloser
remaining int64
}
// Read 实现io.Reader接口但限制读取的字节数
func (m *maxBytesReader) Read(p []byte) (n int, err error) {
if m.remaining <= 0 {
return 0, io.EOF
}
if int64(len(p)) > m.remaining {
p = p[:m.remaining]
}
n, err = m.reader.Read(p)
m.remaining -= int64(n)
return n, err
}
// Close 实现io.Closer接口
func (m *maxBytesReader) Close() error {
return m.reader.Close()
}
// Handler API处理器
type Handler struct {
router *router.Router
version string
}
// NewHandler 创建处理器
func NewHandler(r *router.Router) *Handler {
return &Handler{
router: r,
version: "v1",
}
}
// ChatCompletionsHandle /v1/chat/completions端点
func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
ctx := context.WithValue(r.Context(), "request_id", requestID)
ctx = context.WithValue(ctx, "start_time", startTime)
// 解析请求 - 使用限制reader防止过大的请求体
var req model.ChatCompletionRequest
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
// 检查是否是请求体过大的错误
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
return
}
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
return
}
// 验证请求
if len(req.Messages) == 0 {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
return
}
// 选择Provider
provider, err := h.router.SelectProvider(ctx, req.Model)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 转换消息格式
messages := make([]adapter.Message, len(req.Messages))
for i, m := range req.Messages {
messages[i] = adapter.Message{
Role: m.Role,
Content: m.Content,
Name: m.Name,
}
}
// 构建选项
options := adapter.CompletionOptions{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stream: req.Stream,
Stop: req.Stop,
}
// 处理流式请求
if req.Stream {
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
return
}
// 处理非流式请求
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
if err != nil {
// 记录失败
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 记录成功
h.router.RecordResult(ctx, provider.ProviderName(), true, time.Since(startTime).Milliseconds())
// 转换响应
chatResp := model.ChatCompletionResponse{
ID: response.ID,
Object: "chat.completion",
Created: response.Created,
Model: response.Model,
Choices: make([]model.Choice, len(response.Choices)),
}
for i, c := range response.Choices {
chatResp.Choices[i] = model.Choice{
Index: c.Index,
Message: model.ChatMessage{
Role: c.Message.Role,
Content: c.Message.Content,
},
FinishReason: c.FinishReason,
}
}
chatResp.Usage = model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}
h.writeJSON(w, http.StatusOK, chatResp, requestID)
}
// handleStream 处理流式请求
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 设置SSE头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Request-ID", requestID)
flusher, ok := w.(http.Flusher)
if !ok {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
return
}
// 流式发送响应
for chunk := range ch {
data := fmt.Sprintf("data: %s\n\n", marshalJSON(chunk))
w.Write([]byte(data))
flusher.Flush()
}
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}
// CompletionsHandle /v1/completions端点
func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// 解析请求 - 使用限制reader防止过大的请求体
var req model.CompletionRequest
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
// 检查是否是请求体过大的错误
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
return
}
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
return
}
// 构造消息
ctx := r.Context()
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
provider, err := h.router.SelectProvider(ctx, req.Model)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
options := adapter.CompletionOptions{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stream: req.Stream,
Stop: req.Stop,
}
if req.Stream {
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
return
}
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
if err != nil {
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
return
}
// 转换响应为Completion格式
compResp := model.CompletionResponse{
ID: response.ID,
Object: "text_completion",
Created: response.Created,
Model: response.Model,
Choices: make([]model.Choice1, len(response.Choices)),
}
for i, c := range response.Choices {
compResp.Choices[i] = model.Choice1{
Text: c.Message.Content,
Index: i,
FinishReason: c.FinishReason,
}
}
compResp.Usage = model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}
h.writeJSON(w, http.StatusOK, compResp, requestID)
}
// ModelsHandle /v1/models端点
func (h *Handler) ModelsHandle(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// 返回支持的模型列表
models := []map[string]interface{}{
{"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"},
{"id": "gpt-3.5-turbo", "object": "model", "created": 1677610602, "owned_by": "openai"},
{"id": "claude-3-opus", "object": "model", "created": 1709598254, "owned_by": "anthropic"},
{"id": "claude-3-sonnet", "object": "model", "created": 1709598255, "owned_by": "anthropic"},
}
h.writeJSON(w, http.StatusOK, map[string]interface{}{
"object": "list",
"data": models,
}, requestID)
}
// HealthHandle /health端点
func (h *Handler) HealthHandle(w http.ResponseWriter, r *http.Request) {
healthStatus := h.router.GetHealthStatus()
allHealthy := true
services := make(map[string]bool)
for name, health := range healthStatus {
services[name] = health.Available
if !health.Available {
allHealthy = false
}
}
status := "healthy"
statusCode := http.StatusOK
if !allHealthy {
status = "degraded"
statusCode = http.StatusServiceUnavailable
}
h.writeJSON(w, statusCode, model.HealthStatus{
Status: status,
Timestamp: time.Now(),
Services: services,
}, "")
}
func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{}, requestID string) {
w.Header().Set("Content-Type", "application/json")
if requestID != "" {
w.Header().Set("X-Request-ID", requestID)
}
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *gwerror.GatewayError) {
info := err.GetErrorInfo()
w.Header().Set("Content-Type", "application/json")
if err.RequestID != "" {
w.Header().Set("X-Request-ID", err.RequestID)
}
w.WriteHeader(info.HTTPStatus)
resp := model.ErrorResponse{
Error: model.ErrorDetail{
Message: err.Message,
Type: "gateway_error",
Code: string(err.Code),
},
}
json.NewEncoder(w).Encode(resp)
}
func generateRequestID() string {
return fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
}
func marshalJSON(v interface{}) string {
data, _ := json.Marshal(v)
return string(data)
}