chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,506 @@
"""
Unified deployment affinity (session stickiness) for the Router.
Features (independently enable-able):
1. Responses API continuity: when a `previous_response_id` is provided, route to the
deployment that generated the original response (highest priority).
2. API-key affinity: map an API key hash -> deployment id for a TTL and re-use that
deployment for subsequent requests to the same router deployment model name
(alias-safe, aligns to `model_map_information.model_map_key`).
This is designed to support "implicit prompt caching" scenarios (no explicit cache_control),
where routing to a consistent deployment is still beneficial.
"""
import hashlib
from typing import Any, Dict, List, Optional, cast
from typing_extensions import TypedDict
from litellm._logging import verbose_router_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CallTypes
class DeploymentAffinityCacheValue(TypedDict):
model_id: str
class DeploymentAffinityCheck(CustomLogger):
"""
Router deployment affinity callback.
NOTE: This is a Router-only callback intended to be wired through
`Router(optional_pre_call_checks=[...])`.
"""
CACHE_KEY_PREFIX = "deployment_affinity:v1"
def __init__(
self,
cache: DualCache,
ttl_seconds: int,
enable_user_key_affinity: bool,
enable_responses_api_affinity: bool,
enable_session_id_affinity: bool = False,
):
super().__init__()
self.cache = cache
self.ttl_seconds = ttl_seconds
self.enable_user_key_affinity = enable_user_key_affinity
self.enable_responses_api_affinity = enable_responses_api_affinity
self.enable_session_id_affinity = enable_session_id_affinity
@staticmethod
def _looks_like_sha256_hex(value: str) -> bool:
if len(value) != 64:
return False
try:
int(value, 16)
except ValueError:
return False
return True
@staticmethod
def _hash_user_key(user_key: str) -> str:
"""
Hash user identifiers before storing them in cache keys.
This avoids putting raw API keys / user identifiers into Redis keys (and therefore
into logs/metrics), while keeping the cache key stable and a fixed length.
"""
# If the proxy already provides a stable SHA-256 (e.g. `metadata.user_api_key_hash`),
# keep it as-is to avoid double-hashing and to make correlation/debugging possible.
if DeploymentAffinityCheck._looks_like_sha256_hex(user_key):
return user_key.lower()
return hashlib.sha256(user_key.encode("utf-8")).hexdigest()
@staticmethod
def _get_model_map_key_from_litellm_model_name(
litellm_model_name: str,
) -> Optional[str]:
"""
Best-effort derivation of a stable "model map key" for affinity scoping.
The intent is to align with `standard_logging_payload.model_map_information.model_map_key`,
which is typically the base model identifier (stable across deployments/endpoints).
Notes:
- When the model name is in "provider/model" format, the provider prefix is stripped.
- For Azure, the string after "azure/" is commonly an *Azure deployment name*, which may
differ across instances. If `base_model` is not explicitly set, we skip deriving a
model-map key from the model string to avoid generating unstable keys.
"""
if not litellm_model_name:
return None
if "/" not in litellm_model_name:
return litellm_model_name
provider_prefix, remainder = litellm_model_name.split("/", 1)
if provider_prefix == "azure":
return None
return remainder
@staticmethod
def _get_model_map_key_from_deployment(deployment: dict) -> Optional[str]:
"""
Derive a stable model-map key from a router deployment dict.
Primary source: `deployment.model_name` (Router's canonical group name after
alias resolution). This is stable across provider-specific deployments (e.g.,
Azure/Vertex/Bedrock for the same logical model) and aligns with
`model_map_information.model_map_key` in standard logging.
Prefer `base_model` when available (important for Azure), otherwise fall back to
parsing `litellm_params.model`.
"""
model_name = deployment.get("model_name")
if isinstance(model_name, str) and model_name:
return model_name
model_info = deployment.get("model_info")
if isinstance(model_info, dict):
base_model = model_info.get("base_model")
if isinstance(base_model, str) and base_model:
return base_model
litellm_params = deployment.get("litellm_params")
if isinstance(litellm_params, dict):
base_model = litellm_params.get("base_model")
if isinstance(base_model, str) and base_model:
return base_model
litellm_model_name = litellm_params.get("model")
if isinstance(litellm_model_name, str) and litellm_model_name:
return (
DeploymentAffinityCheck._get_model_map_key_from_litellm_model_name(
litellm_model_name
)
)
return None
@staticmethod
def _get_stable_model_map_key_from_deployments(
healthy_deployments: List[dict],
) -> Optional[str]:
"""
Only use model-map key scoping when it is stable across the deployment set.
This prevents accidentally keying on per-deployment identifiers like Azure deployment
names (when `base_model` is not configured).
"""
if not healthy_deployments:
return None
keys: List[str] = []
for deployment in healthy_deployments:
key = DeploymentAffinityCheck._get_model_map_key_from_deployment(deployment)
if key is None:
return None
keys.append(key)
unique_keys = set(keys)
if len(unique_keys) != 1:
return None
return keys[0]
@staticmethod
def _shorten_for_logs(value: str, keep: int = 8) -> str:
if len(value) <= keep:
return value
return f"{value[:keep]}..."
@classmethod
def get_affinity_cache_key(cls, model_group: str, user_key: str) -> str:
hashed_user_key = cls._hash_user_key(user_key=user_key)
return f"{cls.CACHE_KEY_PREFIX}:{model_group}:{hashed_user_key}"
@classmethod
def get_session_affinity_cache_key(cls, model_group: str, session_id: str) -> str:
return f"{cls.CACHE_KEY_PREFIX}:session:{model_group}:{session_id}"
@staticmethod
def _get_user_key_from_metadata_dict(metadata: dict) -> Optional[str]:
# NOTE: affinity is keyed on the *API key hash* provided by the proxy (not the
# OpenAI `user` parameter, which is an end-user identifier).
user_key = metadata.get("user_api_key_hash")
if user_key is None:
return None
return str(user_key)
@staticmethod
def _get_session_id_from_metadata_dict(metadata: dict) -> Optional[str]:
session_id = metadata.get("session_id")
if session_id is None:
return None
return str(session_id)
@staticmethod
def _iter_metadata_dicts(request_kwargs: dict) -> List[dict]:
"""
Return all metadata dicts available on the request.
Depending on the endpoint, Router may populate `metadata` or `litellm_metadata`.
Users may also send one or both, so we check both (rather than using `or`).
"""
metadata_dicts: List[dict] = []
for key in ("litellm_metadata", "metadata"):
md = request_kwargs.get(key)
if isinstance(md, dict):
metadata_dicts.append(md)
return metadata_dicts
@staticmethod
def _get_user_key_from_request_kwargs(request_kwargs: dict) -> Optional[str]:
"""
Extract a stable affinity key from request kwargs.
Source (proxy): `metadata.user_api_key_hash`
Note: the OpenAI `user` parameter is an end-user identifier and is intentionally
not used for deployment affinity.
"""
# Check metadata dicts (Proxy usage)
for metadata in DeploymentAffinityCheck._iter_metadata_dicts(request_kwargs):
user_key = DeploymentAffinityCheck._get_user_key_from_metadata_dict(
metadata=metadata
)
if user_key is not None:
return user_key
return None
@staticmethod
def _get_session_id_from_request_kwargs(request_kwargs: dict) -> Optional[str]:
for metadata in DeploymentAffinityCheck._iter_metadata_dicts(request_kwargs):
session_id = DeploymentAffinityCheck._get_session_id_from_metadata_dict(
metadata=metadata
)
if session_id is not None:
return session_id
return None
@staticmethod
def _find_deployment_by_model_id(
healthy_deployments: List[dict], model_id: str
) -> Optional[dict]:
for deployment in healthy_deployments:
model_info = deployment.get("model_info")
if not isinstance(model_info, dict):
continue
deployment_model_id = model_info.get("id")
if deployment_model_id is not None and str(deployment_model_id) == str(
model_id
):
return deployment
return None
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
"""
Optionally filter healthy deployments based on:
1. `previous_response_id` (Responses API continuity) [highest priority]
2. cached API-key deployment affinity
"""
request_kwargs = request_kwargs or {}
typed_healthy_deployments = cast(List[dict], healthy_deployments)
# 1) Responses API continuity (high priority)
if self.enable_responses_api_affinity:
previous_response_id = request_kwargs.get("previous_response_id")
if previous_response_id is not None:
responses_model_id = (
ResponsesAPIRequestUtils.get_model_id_from_response_id(
str(previous_response_id)
)
)
if responses_model_id is not None:
deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=responses_model_id,
)
if deployment is not None:
verbose_router_logger.debug(
"DeploymentAffinityCheck: previous_response_id pinning -> deployment=%s",
responses_model_id,
)
return [deployment]
stable_model_map_key = self._get_stable_model_map_key_from_deployments(
healthy_deployments=typed_healthy_deployments
)
if stable_model_map_key is None:
return typed_healthy_deployments
# 2) Session-id -> deployment affinity
if self.enable_session_id_affinity:
session_id = self._get_session_id_from_request_kwargs(
request_kwargs=request_kwargs
)
if session_id is not None:
session_cache_key = self.get_session_affinity_cache_key(
model_group=stable_model_map_key, session_id=session_id
)
session_cache_result = await self.cache.async_get_cache(
key=session_cache_key
)
session_model_id: Optional[str] = None
if isinstance(session_cache_result, dict):
session_model_id = cast(
Optional[str], session_cache_result.get("model_id")
)
elif isinstance(session_cache_result, str):
session_model_id = session_cache_result
if session_model_id:
session_deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=session_model_id,
)
if session_deployment is not None:
verbose_router_logger.debug(
"DeploymentAffinityCheck: session-id affinity hit -> deployment=%s session_id=%s",
session_model_id,
session_id,
)
return [session_deployment]
else:
verbose_router_logger.debug(
"DeploymentAffinityCheck: session-id pinned deployment=%s not found in healthy_deployments",
session_model_id,
)
# 3) User key -> deployment affinity
if not self.enable_user_key_affinity:
return typed_healthy_deployments
user_key = self._get_user_key_from_request_kwargs(request_kwargs=request_kwargs)
if user_key is None:
return typed_healthy_deployments
cache_key = self.get_affinity_cache_key(
model_group=stable_model_map_key, user_key=user_key
)
cache_result = await self.cache.async_get_cache(key=cache_key)
model_id: Optional[str] = None
if isinstance(cache_result, dict):
model_id = cast(Optional[str], cache_result.get("model_id"))
elif isinstance(cache_result, str):
# Backwards / safety: allow raw string values.
model_id = cache_result
if not model_id:
return typed_healthy_deployments
deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=model_id,
)
if deployment is None:
verbose_router_logger.debug(
"DeploymentAffinityCheck: pinned deployment=%s not found in healthy_deployments",
model_id,
)
return typed_healthy_deployments
verbose_router_logger.debug(
"DeploymentAffinityCheck: api-key affinity hit -> deployment=%s user_key=%s",
model_id,
self._shorten_for_logs(user_key),
)
return [deployment]
async def async_pre_call_deployment_hook(
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
) -> Optional[dict]:
"""
Persist/update the API-key -> deployment mapping for this request.
Why pre-call?
- LiteLLM runs async success callbacks via a background logging worker for performance.
- We want affinity to be immediately available for subsequent requests.
"""
if not self.enable_user_key_affinity and not self.enable_session_id_affinity:
return None
user_key = None
if self.enable_user_key_affinity:
user_key = self._get_user_key_from_request_kwargs(request_kwargs=kwargs)
session_id = None
if self.enable_session_id_affinity:
session_id = self._get_session_id_from_request_kwargs(request_kwargs=kwargs)
if user_key is None and session_id is None:
return None
metadata_dicts = self._iter_metadata_dicts(kwargs)
model_info = kwargs.get("model_info")
if not isinstance(model_info, dict):
model_info = None
if model_info is None:
for metadata in metadata_dicts:
maybe_model_info = metadata.get("model_info")
if isinstance(maybe_model_info, dict):
model_info = maybe_model_info
break
if model_info is None:
# Router sets `model_info` after selecting a deployment. If it's missing, this is
# likely a non-router call or a call path that doesn't support affinity.
return None
model_id = model_info.get("id")
if not model_id:
verbose_router_logger.warning(
"DeploymentAffinityCheck: model_id missing; skipping affinity cache update."
)
return None
# Scope affinity by the Router deployment model name (alias-safe, consistent across
# heterogeneous providers, and matches standard logging's `model_map_key`).
deployment_model_name: Optional[str] = None
for metadata in metadata_dicts:
maybe_deployment_model_name = metadata.get("deployment_model_name")
if (
isinstance(maybe_deployment_model_name, str)
and maybe_deployment_model_name
):
deployment_model_name = maybe_deployment_model_name
break
if not deployment_model_name:
verbose_router_logger.warning(
"DeploymentAffinityCheck: deployment_model_name missing; skipping affinity cache update. model_id=%s",
model_id,
)
return None
if user_key is not None:
try:
cache_key = self.get_affinity_cache_key(
model_group=deployment_model_name, user_key=user_key
)
await self.cache.async_set_cache(
cache_key,
DeploymentAffinityCacheValue(model_id=str(model_id)),
ttl=self.ttl_seconds,
)
verbose_router_logger.debug(
"DeploymentAffinityCheck: set affinity mapping model_map_key=%s deployment=%s ttl=%s user_key=%s",
deployment_model_name,
model_id,
self.ttl_seconds,
self._shorten_for_logs(user_key),
)
except Exception as e:
# Non-blocking: affinity is a best-effort optimization.
verbose_router_logger.debug(
"DeploymentAffinityCheck: failed to set user key affinity cache. model_map_key=%s error=%s",
deployment_model_name,
e,
)
# Also persist Session-ID affinity if enabled and session-id is provided
if session_id is not None:
try:
session_cache_key = self.get_session_affinity_cache_key(
model_group=deployment_model_name, session_id=session_id
)
await self.cache.async_set_cache(
session_cache_key,
DeploymentAffinityCacheValue(model_id=str(model_id)),
ttl=self.ttl_seconds,
)
verbose_router_logger.debug(
"DeploymentAffinityCheck: set session affinity mapping model_map_key=%s deployment=%s ttl=%s session_id=%s",
deployment_model_name,
model_id,
self.ttl_seconds,
session_id,
)
except Exception as e:
verbose_router_logger.debug(
"DeploymentAffinityCheck: failed to set session affinity cache. model_map_key=%s error=%s",
deployment_model_name,
e,
)
return None

View File

@@ -0,0 +1,172 @@
"""
Encrypted-content-aware deployment affinity for the Router.
When Codex or other models use `store: false` with `include: ["reasoning.encrypted_content"]`,
the response output items contain encrypted reasoning tokens tied to the originating
organization's API key. If a follow-up request containing those items is routed to a
different deployment (different org), OpenAI rejects it with an `invalid_encrypted_content`
error because the organization_id doesn't match.
This callback solves the problem by encoding the originating deployment's ``model_id``
into the response output items that carry ``encrypted_content``. Two encoding strategies:
1. **Items with IDs**: Encode model_id into the item ID itself (e.g., ``encitem_...``)
2. **Items without IDs** (Codex): Wrap the encrypted_content with model_id metadata
(e.g., ``litellm_enc:{base64_metadata};{original_encrypted_content}``)
The encoded model_id is decoded on the next request so the router can pin to the correct
deployment without any cache lookup.
Response post-processing (encoding) is handled by
``ResponsesAPIRequestUtils._update_encrypted_content_item_ids_in_response`` which is
called inside ``_update_responses_api_response_id_with_model_id`` in ``responses/utils.py``.
Request pre-processing (ID/content restoration before forwarding to upstream) is handled by
``ResponsesAPIRequestUtils._restore_encrypted_content_item_ids_in_input`` which is called
in ``get_optional_params_responses_api``.
This pre-call check is responsible only for the routing decision: it reads the encoded
``model_id`` from either item IDs or wrapped encrypted_content and pins the request to
the matching deployment.
Safe to enable globally:
- Only activates when encoded markers appear in the request ``input``.
- No effect on embedding models, chat completions, or first-time requests.
- No quota reduction -- first requests are fully load balanced.
- No cache required.
"""
from typing import Any, List, Optional, cast
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import AllMessageValues
class EncryptedContentAffinityCheck(CustomLogger):
"""
Routes follow-up Responses API requests to the deployment that produced
the encrypted output items they reference.
The ``model_id`` is decoded directly from the litellm-encoded item IDs
no caching or TTL management needed.
Wired via ``Router(optional_pre_call_checks=["encrypted_content_affinity"])``.
"""
def __init__(self) -> None:
super().__init__()
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _extract_model_id_from_input(request_input: Any) -> Optional[str]:
"""
Scan ``input`` items for litellm-encoded encrypted-content markers and
return the ``model_id`` embedded in the first one found.
Checks both:
1. Encoded item IDs (encitem_...) - for clients that send IDs
2. Wrapped encrypted_content (litellm_enc:...) - for clients like Codex that don't send IDs
``input`` can be:
- a plain string -> no encoded markers
- a list of items -> check each item's ``id`` and ``encrypted_content`` fields
"""
if not isinstance(request_input, list):
return None
for item in request_input:
if not isinstance(item, dict):
continue
# First, try to decode from item ID (if present)
item_id = item.get("id")
if item_id and isinstance(item_id, str):
decoded = ResponsesAPIRequestUtils._decode_encrypted_item_id(item_id)
if decoded:
return decoded.get("model_id")
# If no encoded ID, check if encrypted_content itself is wrapped
encrypted_content = item.get("encrypted_content")
if encrypted_content and isinstance(encrypted_content, str):
(
model_id,
_,
) = ResponsesAPIRequestUtils._unwrap_encrypted_content_with_model_id(
encrypted_content
)
if model_id:
return model_id
return None
@staticmethod
def _find_deployment_by_model_id(
healthy_deployments: List[dict], model_id: str
) -> Optional[dict]:
for deployment in healthy_deployments:
model_info = deployment.get("model_info")
if not isinstance(model_info, dict):
continue
deployment_model_id = model_info.get("id")
if deployment_model_id is not None and str(deployment_model_id) == str(
model_id
):
return deployment
return None
# ------------------------------------------------------------------
# Request routing (pre-call filter)
# ------------------------------------------------------------------
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
"""
If the request ``input`` contains litellm-encoded item IDs, decode the
embedded ``model_id`` and pin the request to that deployment.
"""
request_kwargs = request_kwargs or {}
typed_healthy_deployments = cast(List[dict], healthy_deployments)
# Signal to the response post-processor that encrypted item IDs should be
# encoded in the output of this request.
litellm_metadata = request_kwargs.setdefault("litellm_metadata", {})
litellm_metadata["encrypted_content_affinity_enabled"] = True
request_input = request_kwargs.get("input")
model_id = self._extract_model_id_from_input(request_input)
if not model_id:
return typed_healthy_deployments
verbose_router_logger.debug(
"EncryptedContentAffinityCheck: decoded model_id=%s from input item IDs",
model_id,
)
deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=model_id,
)
if deployment is not None:
verbose_router_logger.debug(
"EncryptedContentAffinityCheck: pinning -> deployment=%s",
model_id,
)
request_kwargs["_encrypted_content_affinity_pinned"] = True
return [deployment]
verbose_router_logger.error(
"EncryptedContentAffinityCheck: decoded deployment=%s not found in healthy_deployments",
model_id,
)
return typed_healthy_deployments

View File

@@ -0,0 +1,332 @@
"""
Enforce TPM/RPM rate limits set on model deployments.
This pre-call check ensures that model-level TPM/RPM limits are enforced
across all requests, regardless of routing strategy.
When enabled via `enforce_model_rate_limits: true` in litellm_settings,
requests that exceed the configured TPM/RPM limits will receive a 429 error.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_router_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.router import RouterErrors
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_utc_datetime
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class RoutingArgs:
ttl: int = 60 # 1min (RPM/TPM expire key)
class ModelRateLimitingCheck(CustomLogger):
"""
Pre-call check that enforces TPM/RPM limits on model deployments.
This check runs before each request and raises a RateLimitError
if the deployment has exceeded its configured TPM or RPM limits.
Unlike the usage-based-routing strategy which uses limits for routing decisions,
this check actively enforces those limits across ALL routing strategies.
"""
def __init__(self, dual_cache: DualCache):
self.dual_cache = dual_cache
def _get_deployment_limits(
self, deployment: Dict
) -> tuple[Optional[int], Optional[int]]:
"""
Extract TPM and RPM limits from a deployment configuration.
Checks in order:
1. Top-level 'tpm'/'rpm' fields
2. litellm_params.tpm/rpm
3. model_info.tpm/rpm
Returns:
Tuple of (tpm_limit, rpm_limit)
"""
# Check top-level
tpm = deployment.get("tpm")
rpm = deployment.get("rpm")
# Check litellm_params
if tpm is None:
tpm = deployment.get("litellm_params", {}).get("tpm")
if rpm is None:
rpm = deployment.get("litellm_params", {}).get("rpm")
# Check model_info
if tpm is None:
tpm = deployment.get("model_info", {}).get("tpm")
if rpm is None:
rpm = deployment.get("model_info", {}).get("rpm")
return tpm, rpm
def _get_cache_keys(self, deployment: Dict, current_minute: str) -> tuple[str, str]:
"""Get the cache keys for TPM and RPM tracking."""
model_id = deployment.get("model_info", {}).get("id")
deployment_name = deployment.get("litellm_params", {}).get("model")
tpm_key = f"{model_id}:{deployment_name}:tpm:{current_minute}"
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
return tpm_key, rpm_key
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
Synchronous pre-call check for model rate limits.
Raises RateLimitError if deployment exceeds TPM/RPM limits.
"""
try:
tpm_limit, rpm_limit = self._get_deployment_limits(deployment)
# If no limits are set, allow the request
if tpm_limit is None and rpm_limit is None:
return deployment
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key, rpm_key = self._get_cache_keys(deployment, current_minute)
model_id = deployment.get("model_info", {}).get("id")
model_name = deployment.get("litellm_params", {}).get("model")
model_group = deployment.get("model_name", "")
# Check TPM limit
if tpm_limit is not None:
# First check local cache
current_tpm = self.dual_cache.get_cache(key=tpm_key, local_only=True)
if current_tpm is not None and current_tpm >= tpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. TPM limit={tpm_limit}, current usage={current_tpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} tpm limit={tpm_limit}. current usage={current_tpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
)
# Check RPM limit (atomic increment-first to avoid race conditions)
if rpm_limit is not None:
current_rpm = self.dual_cache.increment_cache(
key=rpm_key, value=1, ttl=RoutingArgs.ttl
)
if current_rpm is not None and current_rpm > rpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. RPM limit={rpm_limit}, current usage={current_rpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} rpm limit={rpm_limit}. current usage={current_rpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
)
return deployment
except litellm.RateLimitError:
raise
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.pre_call_check: {str(e)}"
)
# Don't fail the request if rate limit check fails
return deployment
async def async_pre_call_check(
self, deployment: Dict, parent_otel_span: Optional[Span] = None
) -> Optional[Dict]:
"""
Async pre-call check for model rate limits.
Raises RateLimitError if deployment exceeds TPM/RPM limits.
"""
try:
tpm_limit, rpm_limit = self._get_deployment_limits(deployment)
# If no limits are set, allow the request
if tpm_limit is None and rpm_limit is None:
return deployment
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key, rpm_key = self._get_cache_keys(deployment, current_minute)
model_id = deployment.get("model_info", {}).get("id")
model_name = deployment.get("litellm_params", {}).get("model")
model_group = deployment.get("model_name", "")
# Check TPM limit
if tpm_limit is not None:
# First check local cache
current_tpm = await self.dual_cache.async_get_cache(
key=tpm_key, local_only=True
)
if current_tpm is not None and current_tpm >= tpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. TPM limit={tpm_limit}, current usage={current_tpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} tpm limit={tpm_limit}. current usage={current_tpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
num_retries=0, # Don't retry - return 429 immediately
)
# Check RPM limit (atomic increment-first to avoid race conditions)
if rpm_limit is not None:
current_rpm = await self.dual_cache.async_increment_cache(
key=rpm_key,
value=1,
ttl=RoutingArgs.ttl,
parent_otel_span=parent_otel_span,
)
if current_rpm is not None and current_rpm > rpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. RPM limit={rpm_limit}, current usage={current_rpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} rpm limit={rpm_limit}. current usage={current_rpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
num_retries=0, # Don't retry - return 429 immediately
)
return deployment
except litellm.RateLimitError:
raise
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.async_pre_call_check: {str(e)}"
)
# Don't fail the request if rate limit check fails
return deployment
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Track TPM usage after successful request.
This updates the TPM counter with the actual tokens used.
Always tracks tokens - the pre-call check handles enforcement.
"""
try:
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
return
model_id = standard_logging_object.get("model_id")
if model_id is None:
return
total_tokens = standard_logging_object.get("total_tokens", 0)
model = standard_logging_object.get("hidden_params", {}).get(
"litellm_model_name"
)
verbose_router_logger.debug(
f"[TPM TRACKING] model_id={model_id}, total_tokens={total_tokens}, model={model}"
)
if not model or not total_tokens:
return
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key = f"{model_id}:{model}:tpm:{current_minute}"
verbose_router_logger.debug(
f"[TPM TRACKING] Incrementing {tpm_key} by {total_tokens}"
)
await self.dual_cache.async_increment_cache(
key=tpm_key,
value=total_tokens,
ttl=RoutingArgs.ttl,
)
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.async_log_success_event: {str(e)}"
)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Sync version of tracking TPM usage after successful request.
Always tracks tokens - the pre-call check handles enforcement.
"""
try:
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
return
model_id = standard_logging_object.get("model_id")
if model_id is None:
return
total_tokens = standard_logging_object.get("total_tokens", 0)
model = standard_logging_object.get("hidden_params", {}).get(
"litellm_model_name"
)
if not model or not total_tokens:
return
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key = f"{model_id}:{model}:tpm:{current_minute}"
self.dual_cache.increment_cache(
key=tpm_key,
value=total_tokens,
ttl=RoutingArgs.ttl,
)
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.log_success_event: {str(e)}"
)

View File

@@ -0,0 +1,100 @@
"""
Check if prompt caching is valid for a given deployment
Route to previously cached model id, if valid
"""
from typing import List, Optional, cast
from litellm import verbose_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CallTypes, StandardLoggingPayload
from litellm.utils import is_prompt_caching_valid_prompt
from ..prompt_caching_cache import PromptCachingCache
class PromptCachingDeploymentCheck(CustomLogger):
def __init__(self, cache: DualCache):
self.cache = cache
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
if messages is not None and is_prompt_caching_valid_prompt(
messages=messages,
model=model,
): # prompt > 1024 tokens
prompt_cache = PromptCachingCache(
cache=self.cache,
)
model_id_dict = await prompt_cache.async_get_model_id(
messages=cast(List[AllMessageValues], messages),
tools=None,
)
if model_id_dict is not None:
model_id = model_id_dict["model_id"]
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
call_type = standard_logging_object["call_type"]
if (
call_type != CallTypes.completion.value
and call_type != CallTypes.acompletion.value
and call_type != CallTypes.anthropic_messages.value
): # only use prompt caching for completion calls
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION or ANTHROPIC MESSAGE"
)
return
model = standard_logging_object["model"]
messages = standard_logging_object["messages"]
model_id = standard_logging_object["model_id"]
if messages is None or not isinstance(messages, list):
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
)
return
if model_id is None:
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
)
return
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
if is_prompt_caching_valid_prompt(
model=model,
messages=cast(List[AllMessageValues], messages),
):
cache = PromptCachingCache(
cache=self.cache,
)
await cache.async_add_model_id(
model_id=model_id,
messages=messages,
tools=None, # [TODO]: add tools once standard_logging_object supports it
)
return

View File

@@ -0,0 +1,57 @@
"""
For Responses API, we need routing affinity when a user sends a previous_response_id.
eg. If proxy admins are load balancing between N gpt-4.1-turbo deployments, and a user sends a previous_response_id,
we want to route to the same gpt-4.1-turbo deployment.
This is different from the normal behavior of the router, which does not have routing affinity for previous_response_id.
If previous_response_id is provided, route to the deployment that returned the previous response
"""
import warnings
from typing import List, Optional
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import AllMessageValues
class ResponsesApiDeploymentCheck(CustomLogger):
def __init__(self) -> None:
super().__init__()
warnings.warn(
(
"ResponsesApiDeploymentCheck is deprecated. "
"Use DeploymentAffinityCheck(enable_responses_api_affinity=True) instead."
),
DeprecationWarning,
stacklevel=2,
)
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
request_kwargs = request_kwargs or {}
previous_response_id = request_kwargs.get("previous_response_id", None)
if previous_response_id is None:
return healthy_deployments
decoded_response = ResponsesAPIRequestUtils._decode_responses_api_response_id(
response_id=previous_response_id,
)
model_id = decoded_response.get("model_id")
if model_id is None:
return healthy_deployments
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments