Files
lijiaoqiao/llm-gateway-competitors/sub2api-tar/backend/internal/handler/gateway_helper.go

401 lines
13 KiB
Go
Raw Normal View History

package handler
import (
"context"
"encoding/json"
"fmt"
"math/rand/v2"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// claudeCodeValidator is a singleton validator for Claude Code client detection
var claudeCodeValidator = service.NewClaudeCodeValidator()
const claudeCodeParsedRequestContextKey = "claude_code_parsed_request"
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) {
if c == nil || c.Request == nil {
return
}
if parsedReq != nil {
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
}
ua := c.GetHeader("User-Agent")
// Fast path非 Claude CLI UA 直接判定 false避免热路径二次 JSON 反序列化。
if !claudeCodeValidator.ValidateUserAgent(ua) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
c.Request = c.Request.WithContext(ctx)
return
}
isClaudeCode := false
if !strings.Contains(c.Request.URL.Path, "messages") {
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
isClaudeCode = true
} else {
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq)
if bodyMap == nil {
bodyMap = claudeCodeBodyMapFromContextCache(c)
}
if bodyMap == nil && len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
}
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
// 仅在确认为 Claude Code 客户端时提取版本号写入 context
if isClaudeCode {
if version := claudeCodeValidator.ExtractVersion(ua); version != "" {
ctx = service.SetClaudeCodeVersion(ctx, version)
}
}
c.Request = c.Request.WithContext(ctx)
}
func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any {
if parsedReq == nil {
return nil
}
bodyMap := map[string]any{
"model": parsedReq.Model,
}
if parsedReq.System != nil || parsedReq.HasSystem {
bodyMap["system"] = parsedReq.System
}
if parsedReq.MetadataUserID != "" {
bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID}
}
return bodyMap
}
func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any {
if c == nil {
return nil
}
if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok {
if bodyMap, ok := cached.(map[string]any); ok {
return bodyMap
}
}
if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok {
switch v := cached.(type) {
case *service.ParsedRequest:
return claudeCodeBodyMapFromParsedRequest(v)
case service.ParsedRequest:
return claudeCodeBodyMapFromParsedRequest(&v)
}
}
return nil
}
// 并发槽位等待相关常量
//
// 性能优化说明:
// 原实现使用固定间隔100ms轮询并发槽位存在以下问题
// 1. 高并发时频繁轮询增加 Redis 压力
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
//
// 新实现使用指数退避 + 抖动算法:
// 1. 初始退避 100ms每次乘以 1.5,最大 2s
// 2. 添加 ±20% 的随机抖动,分散重试时间点
// 3. 减少 Redis 压力,避免惊群效应
const (
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// defaultPingInterval 流式响应等待时发送 ping 的默认间隔
defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
backoffMultiplier = 1.5
// maxBackoff 最大退避时间
maxBackoff = 2 * time.Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
type SSEPingFormat string
const (
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment SSEPingFormat = ":\n\n"
)
// ConcurrencyError represents a concurrency limit error with context
type ConcurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *ConcurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
pingInterval time.Duration
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
var stop func() bool
release := func() {
once.Do(func() {
if stop != nil {
_ = stop()
}
releaseFunc()
})
}
stop = context.AfterFunc(ctx, release)
return release
}
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
}
// DecrementWaitCount decrements the wait count for a user
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, err
}
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
}
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
acquireSlot := func() (*service.AcquireResult, error) {
if slotType == "user" {
return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
}
return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if tryImmediate {
result, err := acquireSlot()
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
// Determine if ping is needed (streaming + ping format defined)
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return nil, &ConcurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingCh:
// Send ping to keep connection alive
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return nil, err
}
flusher.Flush()
case <-timer.C:
// Try to acquire slot
result, err := acquireSlot()
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
backoff = nextBackoff(backoff)
timer.Reset(backoff)
}
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// 返回值下一次退避时间100ms ~ 2s 之间)
func nextBackoff(current time.Duration) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
// 添加 ±20% 的随机抖动jitter 范围 0.8 ~ 1.2
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rand.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff
}
if jittered > maxBackoff {
return maxBackoff
}
return jittered
}