chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
# Dynamic Rate Limiter v3 - Saturation-Aware Priority-Based Rate Limiting
|
||||
|
||||
## Overview
|
||||
|
||||
The v3 dynamic rate limiter implements saturation-aware rate limiting with priority-based allocation. It balances resource efficiency (allowing unused capacity to be borrowed) with fairness guarantees (enforcing priorities during high load).
|
||||
|
||||
**Key Behavior:**
|
||||
- When system is under 80% capacity: Generous mode - allows priority borrowing
|
||||
- When system is at/above 80% capacity: Strict mode - enforces normalized priority limits
|
||||
|
||||
## How It Works
|
||||
|
||||
### Flow Diagram
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Incoming Request │
|
||||
└────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 1. Check Model Saturation │
|
||||
│ - Query v3 limiter's Redis counters │
|
||||
│ - Calculate: current_usage / capacity │
|
||||
│ - Returns: 0.0 (empty) to 1.0+ (saturated) │
|
||||
└────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌────────┴────────┐
|
||||
│ Saturation? │
|
||||
└────────┬────────┘
|
||||
│
|
||||
┌───────────────┴───────────────┐
|
||||
│ │
|
||||
▼ ▼
|
||||
< 80% (Generous) >= 80% (Strict)
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────┐ ┌─────────────────────┐
|
||||
│ Generous Mode │ │ Strict Mode │
|
||||
│ │ │ │
|
||||
│ - Enforce model- │ │ - Normalize │
|
||||
│ wide capacity │ │ priority weights │
|
||||
│ - No priority │ │ (if over 1.0) │
|
||||
│ restrictions │ │ │
|
||||
│ - Allows borrowing │ │ - Create priority- │
|
||||
│ │ │ specific │
|
||||
│ - First-come- │ │ descriptors │
|
||||
│ first-served │ │ │
|
||||
│ until capacity │ │ - Enforce strict │
|
||||
│ │ │ limits per │
|
||||
│ │ │ priority │
|
||||
└──────────┬──────────┘ └──────────┬──────────┘
|
||||
│ │
|
||||
│ ▼
|
||||
│ ┌──────────────────────┐
|
||||
│ │ Track model usage │
|
||||
│ │ for future │
|
||||
│ │ saturation checks │
|
||||
│ └──────────┬───────────┘
|
||||
│ │
|
||||
└───────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────┐
|
||||
│ v3 Limiter │
|
||||
│ Check │
|
||||
└──────┬───────┘
|
||||
│
|
||||
┌───────────────┴───────────────┐
|
||||
│ │
|
||||
▼ ▼
|
||||
OVER_LIMIT OK
|
||||
│ │
|
||||
▼ ▼
|
||||
Return 429 Error Allow Request
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Priority Reservation
|
||||
|
||||
Set priority weights in your proxy configuration:
|
||||
|
||||
```python
|
||||
litellm.priority_reservation = {
|
||||
"premium": 0.75, # 75% of capacity
|
||||
"standard": 0.25 # 25% of capacity
|
||||
}
|
||||
```
|
||||
|
||||
### Priority Reservation Settings
|
||||
|
||||
Configure saturation-aware behavior:
|
||||
|
||||
```python
|
||||
litellm.priority_reservation_settings = PriorityReservationSettings(
|
||||
default_priority=0.5, # Default weight for users without explicit priority
|
||||
saturation_threshold=0.80, # 80% - threshold for strict mode enforcement
|
||||
tracking_multiplier=10 # 10x - multiplier for non-blocking tracking in strict mode
|
||||
)
|
||||
```
|
||||
|
||||
**Settings:**
|
||||
- `default_priority` (default: 0.5) - Priority weight for users without explicit priority metadata
|
||||
- `saturation_threshold` (default: 0.80) - Saturation level (0.0-1.0) at which strict priority enforcement begins
|
||||
- `tracking_multiplier` (default: 10) - Multiplier for model-wide tracking limits in strict mode
|
||||
|
||||
### User Priority Assignment
|
||||
|
||||
Set priority in user metadata:
|
||||
|
||||
```python
|
||||
user_api_key_dict.metadata = {"priority": "premium"}
|
||||
```
|
||||
|
||||
## Priority Weight Normalization
|
||||
|
||||
If priorities sum to > 1.0, they are automatically normalized:
|
||||
|
||||
```
|
||||
Input: {key_a: 0.60, key_b: 0.80} = 1.40 total
|
||||
Output: {key_a: 0.43, key_b: 0.57} = 1.00 total
|
||||
```
|
||||
|
||||
This ensures total allocation never exceeds model capacity.
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Saturation Detection
|
||||
|
||||
- Queries v3 limiter's Redis counters for model-wide usage
|
||||
- Checks both RPM and TPM, returns higher saturation value
|
||||
- Non-blocking reads (doesn't increment counters)
|
||||
|
||||
### Mode Selection
|
||||
|
||||
**Generous Mode (< 80% saturation):**
|
||||
- Creates single model-wide descriptor
|
||||
- Enforces total capacity only
|
||||
- Allows any priority to use available capacity
|
||||
- Prevents over-subscription via model-wide limit
|
||||
|
||||
**Strict Mode (>= 80% saturation):**
|
||||
- Creates priority-specific descriptors with normalized weights
|
||||
- Each priority gets its reserved allocation
|
||||
- Tracks model-wide usage separately (non-blocking, 10x multiplier)
|
||||
- Ensures fairness under load
|
||||
|
||||
Test scenarios covered:
|
||||
1. No rate limiting when under capacity
|
||||
2. Priority queue behavior during saturation
|
||||
3. Spillover capacity for default keys
|
||||
4. Over-allocated priorities with normalization
|
||||
5. Default priority value handling
|
||||
|
||||
|
||||
### `_PROXY_DynamicRateLimitHandlerV3`
|
||||
|
||||
Main handler class inheriting from `CustomLogger`.
|
||||
|
||||
**Key Methods:**
|
||||
- `async_pre_call_hook()` - Main entry point, routes to generous/strict mode
|
||||
- `_check_model_saturation()` - Queries Redis for current usage
|
||||
- `_handle_generous_mode()` - Enforces model-wide capacity only
|
||||
- `_handle_strict_mode()` - Enforces normalized priority limits
|
||||
- `_normalize_priority_weights()` - Handles over-allocation
|
||||
- `_create_priority_based_descriptors()` - Creates rate limit descriptors
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
from typing import Literal, Union
|
||||
|
||||
from . import *
|
||||
from .cache_control_check import _PROXY_CacheControlCheck
|
||||
from .litellm_skills import SkillsInjectionHook
|
||||
from .max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
from .max_budget_per_session_limiter import _PROXY_MaxBudgetPerSessionHandler
|
||||
from .max_iterations_limiter import _PROXY_MaxIterationsHandler
|
||||
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
|
||||
from .parallel_request_limiter_v3 import _PROXY_MaxParallelRequestsHandler_v3
|
||||
from .responses_id_security import ResponsesIDSecurity
|
||||
|
||||
### CHECK IF ENTERPRISE HOOKS ARE AVAILABLE ####
|
||||
|
||||
try:
|
||||
from enterprise.enterprise_hooks import ENTERPRISE_PROXY_HOOKS
|
||||
except ImportError:
|
||||
ENTERPRISE_PROXY_HOOKS = {}
|
||||
|
||||
# List of all available hooks that can be enabled
|
||||
PROXY_HOOKS = {
|
||||
"max_budget_limiter": _PROXY_MaxBudgetLimiter,
|
||||
"parallel_request_limiter": _PROXY_MaxParallelRequestsHandler_v3,
|
||||
"cache_control_check": _PROXY_CacheControlCheck,
|
||||
"responses_id_security": ResponsesIDSecurity,
|
||||
"litellm_skills": SkillsInjectionHook,
|
||||
"max_iterations_limiter": _PROXY_MaxIterationsHandler,
|
||||
"max_budget_per_session_limiter": _PROXY_MaxBudgetPerSessionHandler,
|
||||
}
|
||||
|
||||
## FEATURE FLAG HOOKS ##
|
||||
if os.getenv("LEGACY_MULTI_INSTANCE_RATE_LIMITING", "false").lower() == "true":
|
||||
PROXY_HOOKS["parallel_request_limiter"] = _PROXY_MaxParallelRequestsHandler
|
||||
|
||||
|
||||
### update PROXY_HOOKS with ENTERPRISE_PROXY_HOOKS ###
|
||||
|
||||
PROXY_HOOKS.update(ENTERPRISE_PROXY_HOOKS)
|
||||
|
||||
|
||||
def get_proxy_hook(
|
||||
hook_name: Union[
|
||||
Literal[
|
||||
"max_budget_limiter",
|
||||
"managed_files",
|
||||
"parallel_request_limiter",
|
||||
"cache_control_check",
|
||||
],
|
||||
str,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Factory method to get a proxy hook instance by name
|
||||
"""
|
||||
if hook_name not in PROXY_HOOKS:
|
||||
raise ValueError(
|
||||
f"Unknown hook: {hook_name}. Available hooks: {list(PROXY_HOOKS.keys())}"
|
||||
)
|
||||
return PROXY_HOOKS[hook_name]
|
||||
@@ -0,0 +1,156 @@
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_AzureContentSafety(
|
||||
CustomLogger
|
||||
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self, endpoint, api_key, thresholds=None):
|
||||
try:
|
||||
from azure.ai.contentsafety.aio import ContentSafetyClient
|
||||
from azure.ai.contentsafety.models import (
|
||||
AnalyzeTextOptions,
|
||||
AnalyzeTextOutputType,
|
||||
TextCategory,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||
)
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.text_category = TextCategory
|
||||
self.analyze_text_options = AnalyzeTextOptions
|
||||
self.analyze_text_output_type = AnalyzeTextOutputType
|
||||
self.azure_http_error = HttpResponseError
|
||||
|
||||
self.thresholds = self._configure_thresholds(thresholds)
|
||||
|
||||
self.client = ContentSafetyClient(
|
||||
self.endpoint, AzureKeyCredential(self.api_key)
|
||||
)
|
||||
|
||||
def _configure_thresholds(self, thresholds=None):
|
||||
default_thresholds = {
|
||||
self.text_category.HATE: 4,
|
||||
self.text_category.SELF_HARM: 4,
|
||||
self.text_category.SEXUAL: 4,
|
||||
self.text_category.VIOLENCE: 4,
|
||||
}
|
||||
|
||||
if thresholds is None:
|
||||
return default_thresholds
|
||||
|
||||
for key, default in default_thresholds.items():
|
||||
if key not in thresholds:
|
||||
thresholds[key] = default
|
||||
|
||||
return thresholds
|
||||
|
||||
def _compute_result(self, response):
|
||||
result = {}
|
||||
|
||||
category_severity = {
|
||||
item.category: item.severity for item in response.categories_analysis
|
||||
}
|
||||
for category in self.text_category:
|
||||
severity = category_severity.get(category)
|
||||
if severity is not None:
|
||||
result[category] = {
|
||||
"filtered": severity >= self.thresholds[category],
|
||||
"severity": severity,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def test_violation(self, content: str, source: Optional[str] = None):
|
||||
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
|
||||
|
||||
# Construct a request
|
||||
request = self.analyze_text_options(
|
||||
text=content,
|
||||
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
|
||||
)
|
||||
|
||||
# Analyze text
|
||||
try:
|
||||
response = await self.client.analyze_text(request)
|
||||
except self.azure_http_error:
|
||||
verbose_proxy_logger.debug(
|
||||
"Error in Azure Content-Safety: %s", traceback.format_exc()
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise
|
||||
|
||||
result = self._compute_result(response)
|
||||
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
|
||||
|
||||
for key, value in result.items():
|
||||
if value["filtered"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated content safety policy",
|
||||
"source": source,
|
||||
"category": key,
|
||||
"severity": value["severity"],
|
||||
},
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
|
||||
try:
|
||||
if call_type == "completion" and "messages" in data:
|
||||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
await self.test_violation(content=m["content"], source="input")
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices[0], litellm.utils.Choices
|
||||
):
|
||||
await self.test_violation(
|
||||
content=response.choices[0].message.content or "", source="output"
|
||||
)
|
||||
|
||||
# async def async_post_call_streaming_hook(
|
||||
# self,
|
||||
# user_api_key_dict: UserAPIKeyAuth,
|
||||
# response: str,
|
||||
# ):
|
||||
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
|
||||
# await self.test_violation(content=response, source="output")
|
||||
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Batch Rate Limiter Hook
|
||||
|
||||
This hook implements rate limiting for batch API requests by:
|
||||
1. Reading batch input files to count requests and estimate tokens at submission
|
||||
2. Validating actual usage from output files when batches complete
|
||||
3. Integrating with the existing parallel request limiter infrastructure
|
||||
|
||||
## Integration & Calling
|
||||
This hook is automatically registered and called by the proxy system.
|
||||
See BATCH_RATE_LIMITER_INTEGRATION.md for complete integration details.
|
||||
|
||||
Quick summary:
|
||||
- Add to PROXY_HOOKS in litellm/proxy/hooks/__init__.py
|
||||
- Gets auto-instantiated on proxy startup via _add_proxy_hooks()
|
||||
- async_pre_call_hook() fires on POST /v1/batches (batch submission)
|
||||
- async_log_success_event() fires on GET /v1/batches/{id} (batch completion)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.batch_utils import (
|
||||
_get_batch_job_input_file_usage,
|
||||
_get_file_content_as_dictionary,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
RateLimitDescriptor as _RateLimitDescriptor,
|
||||
)
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
RateLimitStatus as _RateLimitStatus,
|
||||
)
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
_PROXY_MaxParallelRequestsHandler_v3 as _ParallelRequestLimiter,
|
||||
)
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
Router = _Router
|
||||
ParallelRequestLimiter = _ParallelRequestLimiter
|
||||
RateLimitStatus = _RateLimitStatus
|
||||
RateLimitDescriptor = _RateLimitDescriptor
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
Router = Any
|
||||
ParallelRequestLimiter = Any
|
||||
RateLimitStatus = Dict[str, Any]
|
||||
RateLimitDescriptor = Dict[str, Any]
|
||||
|
||||
|
||||
class BatchFileUsage(BaseModel):
|
||||
"""
|
||||
Internal model for batch file usage tracking, used for batch rate limiting
|
||||
"""
|
||||
|
||||
total_tokens: int
|
||||
request_count: int
|
||||
|
||||
|
||||
class _PROXY_BatchRateLimiter(CustomLogger):
|
||||
"""
|
||||
Rate limiter for batch API requests.
|
||||
|
||||
Handles rate limiting at two points:
|
||||
1. Batch submission - reads input file and reserves capacity
|
||||
2. Batch completion - reads output file and adjusts for actual usage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
internal_usage_cache: InternalUsageCache,
|
||||
parallel_request_limiter: ParallelRequestLimiter,
|
||||
):
|
||||
"""
|
||||
Initialize the batch rate limiter.
|
||||
|
||||
Note: These dependencies are automatically injected by ProxyLogging._add_proxy_hooks()
|
||||
when this hook is registered in PROXY_HOOKS. See BATCH_RATE_LIMITER_INTEGRATION.md.
|
||||
|
||||
Args:
|
||||
internal_usage_cache: Cache for storing rate limit data (auto-injected)
|
||||
parallel_request_limiter: Existing rate limiter to integrate with (needs custom injection)
|
||||
"""
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.parallel_request_limiter = parallel_request_limiter
|
||||
|
||||
def _raise_rate_limit_error(
|
||||
self,
|
||||
status: "RateLimitStatus",
|
||||
descriptors: List["RateLimitDescriptor"],
|
||||
batch_usage: BatchFileUsage,
|
||||
limit_type: str,
|
||||
) -> None:
|
||||
"""Raise HTTPException for rate limit exceeded."""
|
||||
from datetime import datetime
|
||||
|
||||
# Find the descriptor for this status
|
||||
descriptor_index = next(
|
||||
(
|
||||
i
|
||||
for i, d in enumerate(descriptors)
|
||||
if d.get("key") == status.get("descriptor_key")
|
||||
),
|
||||
0,
|
||||
)
|
||||
descriptor: RateLimitDescriptor = (
|
||||
descriptors[descriptor_index]
|
||||
if descriptors
|
||||
else {"key": "", "value": "", "rate_limit": None}
|
||||
)
|
||||
|
||||
now = datetime.now().timestamp()
|
||||
window_size = self.parallel_request_limiter.window_size
|
||||
reset_time = now + window_size
|
||||
reset_time_formatted = datetime.fromtimestamp(reset_time).strftime(
|
||||
"%Y-%m-%d %H:%M:%S UTC"
|
||||
)
|
||||
|
||||
remaining_display = max(0, status["limit_remaining"])
|
||||
current_limit = status["current_limit"]
|
||||
|
||||
if limit_type == "requests":
|
||||
detail = (
|
||||
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
||||
f"Batch contains {batch_usage.request_count} requests but only {remaining_display} requests remaining "
|
||||
f"out of {current_limit} RPM limit. "
|
||||
f"Limit resets at: {reset_time_formatted}"
|
||||
)
|
||||
else: # tokens
|
||||
detail = (
|
||||
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
||||
f"Batch contains {batch_usage.total_tokens} tokens but only {remaining_display} tokens remaining "
|
||||
f"out of {current_limit} TPM limit. "
|
||||
f"Limit resets at: {reset_time_formatted}"
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=detail,
|
||||
headers={
|
||||
"retry-after": str(window_size),
|
||||
"rate_limit_type": limit_type,
|
||||
"reset_at": reset_time_formatted,
|
||||
},
|
||||
)
|
||||
|
||||
async def _check_and_increment_batch_counters(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: Dict,
|
||||
batch_usage: BatchFileUsage,
|
||||
) -> None:
|
||||
"""
|
||||
Check rate limits and increment counters by the batch amounts.
|
||||
|
||||
Raises HTTPException if any limit would be exceeded.
|
||||
"""
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
# Create descriptors and check if batch would exceed limits
|
||||
descriptors = self.parallel_request_limiter._create_rate_limit_descriptors(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
rpm_limit_type=None,
|
||||
tpm_limit_type=None,
|
||||
model_has_failures=False,
|
||||
)
|
||||
|
||||
# Check current usage without incrementing
|
||||
rate_limit_response = await self.parallel_request_limiter.should_rate_limit(
|
||||
descriptors=descriptors,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# Verify batch won't exceed any limits
|
||||
for status in rate_limit_response["statuses"]:
|
||||
rate_limit_type = status["rate_limit_type"]
|
||||
limit_remaining = status["limit_remaining"]
|
||||
|
||||
required_capacity = (
|
||||
batch_usage.request_count
|
||||
if rate_limit_type == "requests"
|
||||
else batch_usage.total_tokens
|
||||
if rate_limit_type == "tokens"
|
||||
else 0
|
||||
)
|
||||
|
||||
if required_capacity > limit_remaining:
|
||||
self._raise_rate_limit_error(
|
||||
status, descriptors, batch_usage, rate_limit_type
|
||||
)
|
||||
|
||||
# Build pipeline operations for batch increments
|
||||
# Reuse the same keys that descriptors check
|
||||
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
||||
|
||||
for descriptor in descriptors:
|
||||
key = descriptor["key"]
|
||||
value = descriptor["value"]
|
||||
rate_limit = descriptor.get("rate_limit")
|
||||
|
||||
if rate_limit is None:
|
||||
continue
|
||||
|
||||
# Add RPM increment if limit is set
|
||||
if rate_limit.get("requests_per_unit") is not None:
|
||||
rpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
||||
key=key, value=value, rate_limit_type="requests"
|
||||
)
|
||||
pipeline_operations.append(
|
||||
RedisPipelineIncrementOperation(
|
||||
key=rpm_key,
|
||||
increment_value=batch_usage.request_count,
|
||||
ttl=self.parallel_request_limiter.window_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Add TPM increment if limit is set
|
||||
if rate_limit.get("tokens_per_unit") is not None:
|
||||
tpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
||||
key=key, value=value, rate_limit_type="tokens"
|
||||
)
|
||||
pipeline_operations.append(
|
||||
RedisPipelineIncrementOperation(
|
||||
key=tpm_key,
|
||||
increment_value=batch_usage.total_tokens,
|
||||
ttl=self.parallel_request_limiter.window_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute increments
|
||||
if pipeline_operations:
|
||||
await self.parallel_request_limiter.async_increment_tokens_with_ttl_preservation(
|
||||
pipeline_operations=pipeline_operations,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
|
||||
async def count_input_file_usage(
|
||||
self,
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
) -> BatchFileUsage:
|
||||
"""
|
||||
Count number of requests and tokens in a batch input file.
|
||||
|
||||
Args:
|
||||
file_id: The file ID to read
|
||||
custom_llm_provider: The custom LLM provider to use for token encoding
|
||||
user_api_key_dict: User authentication information for file access (required for managed files)
|
||||
|
||||
Returns:
|
||||
BatchFileUsage with total_tokens and request_count
|
||||
"""
|
||||
try:
|
||||
# Check if this is a managed file (base64 encoded unified file ID)
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
)
|
||||
|
||||
# Managed files require bypassing the HTTP endpoint (which runs access-check hooks)
|
||||
# and calling the managed files hook directly with the user's credentials.
|
||||
is_managed_file = _is_base64_encoded_unified_file_id(file_id)
|
||||
if is_managed_file and user_api_key_dict is not None:
|
||||
file_content = await self._fetch_managed_file_content(
|
||||
file_id=file_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
else:
|
||||
# For non-managed files, use the standard litellm.afile_content
|
||||
file_content = await litellm.afile_content(
|
||||
file_id=file_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
file_content_as_dict = _get_file_content_as_dictionary(file_content.content)
|
||||
|
||||
input_file_usage = _get_batch_job_input_file_usage(
|
||||
file_content_dictionary=file_content_as_dict,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
request_count = len(file_content_as_dict)
|
||||
return BatchFileUsage(
|
||||
total_tokens=input_file_usage.total_tokens,
|
||||
request_count=request_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error counting input file usage for {file_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _fetch_managed_file_content(
|
||||
self,
|
||||
file_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Any:
|
||||
"""
|
||||
Fetch file content from managed files hook.
|
||||
|
||||
This is needed for managed files because they require proper user context
|
||||
to verify file ownership and access permissions.
|
||||
|
||||
Args:
|
||||
file_id: The managed file ID (base64 encoded)
|
||||
user_api_key_dict: User authentication information
|
||||
|
||||
Returns:
|
||||
HttpxBinaryResponseContent with the file content
|
||||
"""
|
||||
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints
|
||||
|
||||
# Import proxy_server dependencies at runtime to avoid circular imports
|
||||
try:
|
||||
from litellm.proxy.proxy_server import llm_router, proxy_logging_obj
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"Cannot import proxy_server dependencies: {str(e)}. "
|
||||
"Managed files require proxy_server to be initialized."
|
||||
)
|
||||
|
||||
# Get the managed files hook
|
||||
if proxy_logging_obj is None:
|
||||
raise ValueError(
|
||||
"proxy_logging_obj not available. Cannot access managed files hook."
|
||||
)
|
||||
|
||||
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
if managed_files_obj is None:
|
||||
raise ValueError(
|
||||
"Managed files hook not found. Cannot access managed file."
|
||||
)
|
||||
|
||||
if not isinstance(managed_files_obj, BaseFileEndpoints):
|
||||
raise ValueError("Managed files hook is not a BaseFileEndpoints instance.")
|
||||
|
||||
if llm_router is None:
|
||||
raise ValueError("llm_router not available. Cannot access managed files.")
|
||||
|
||||
# Use the managed files hook to get file content
|
||||
# This properly handles user permissions and file ownership
|
||||
file_content = await managed_files_obj.afile_content(
|
||||
file_id=file_id,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
return file_content
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: Any,
|
||||
data: Dict,
|
||||
call_type: str,
|
||||
) -> Union[Exception, str, Dict, None]:
|
||||
"""
|
||||
Pre-call hook for batch operations.
|
||||
|
||||
Only handles batch creation (acreate_batch):
|
||||
- Reads input file
|
||||
- Counts tokens and requests
|
||||
- Reserves rate limit capacity via parallel_request_limiter
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User authentication information
|
||||
cache: Cache instance (not used directly)
|
||||
data: Request data
|
||||
call_type: Type of call being made
|
||||
|
||||
Returns:
|
||||
Modified data dict or None
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 if rate limit would be exceeded
|
||||
"""
|
||||
# Only handle batch creation
|
||||
if call_type != "acreate_batch":
|
||||
verbose_proxy_logger.debug(
|
||||
f"Batch rate limiter: Not handling batch creation rate limiting for call type: {call_type}"
|
||||
)
|
||||
return data
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Batch rate limiter: Handling batch creation rate limiting"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract input_file_id from data
|
||||
input_file_id = data.get("input_file_id")
|
||||
if not input_file_id:
|
||||
verbose_proxy_logger.debug(
|
||||
"No input_file_id in batch request, skipping rate limiting"
|
||||
)
|
||||
return data
|
||||
|
||||
# Get custom_llm_provider for token counting
|
||||
custom_llm_provider = data.get("custom_llm_provider", "openai")
|
||||
|
||||
# Count tokens and requests from input file
|
||||
verbose_proxy_logger.debug(
|
||||
f"Counting tokens from batch input file: {input_file_id}"
|
||||
)
|
||||
batch_usage = await self.count_input_file_usage(
|
||||
file_id=input_file_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Batch input file usage - Tokens: {batch_usage.total_tokens}, "
|
||||
f"Requests: {batch_usage.request_count}"
|
||||
)
|
||||
|
||||
# Store batch usage in data for later reference
|
||||
data["_batch_token_count"] = batch_usage.total_tokens
|
||||
data["_batch_request_count"] = batch_usage.request_count
|
||||
|
||||
# Directly increment counters by batch amounts (check happens atomically)
|
||||
# This will raise HTTPException if limits are exceeded
|
||||
await self._check_and_increment_batch_counters(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
batch_usage=batch_usage,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Batch rate limit check passed, counters incremented"
|
||||
)
|
||||
return data
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (rate limit exceeded)
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in batch rate limiting: {str(e)}", exc_info=True
|
||||
)
|
||||
# Don't block the request if rate limiting fails
|
||||
return data
|
||||
@@ -0,0 +1,149 @@
|
||||
# What this does?
|
||||
## Gets a key's redis cache, and store it in memory for 1 minute.
|
||||
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
|
||||
### [BETA] this is in Beta. And might change.
|
||||
|
||||
import traceback
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_BatchRedisRequests(CustomLogger):
|
||||
# Class variables or attributes
|
||||
in_memory_cache: Optional[InMemoryCache] = None
|
||||
|
||||
def __init__(self):
|
||||
if litellm.cache is not None:
|
||||
litellm.cache.async_get_cache = (
|
||||
self.async_get_cache
|
||||
) # map the litellm 'get_cache' function to our custom function
|
||||
|
||||
def print_verbose(
|
||||
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
|
||||
):
|
||||
if debug_level == "DEBUG":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
elif debug_level == "INFO":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
"""
|
||||
Get the user key
|
||||
|
||||
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
|
||||
|
||||
If no, then get relevant cache from redis
|
||||
"""
|
||||
api_key = user_api_key_dict.api_key
|
||||
|
||||
cache_key_name = f"litellm:{api_key}:{call_type}"
|
||||
self.in_memory_cache = cache.in_memory_cache
|
||||
|
||||
key_value_dict = {}
|
||||
in_memory_cache_exists = False
|
||||
for key in cache.in_memory_cache.cache_dict.keys():
|
||||
if isinstance(key, str) and key.startswith(cache_key_name):
|
||||
in_memory_cache_exists = True
|
||||
|
||||
if in_memory_cache_exists is False and litellm.cache is not None:
|
||||
"""
|
||||
- Check if `litellm.Cache` is redis
|
||||
- Get the relevant values
|
||||
"""
|
||||
if litellm.cache.type is not None and isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
# Initialize an empty list to store the keys
|
||||
keys = []
|
||||
self.print_verbose(f"cache_key_name: {cache_key_name}")
|
||||
# Use the SCAN iterator to fetch keys matching the pattern
|
||||
keys = await litellm.cache.cache.async_scan_iter(
|
||||
pattern=cache_key_name, count=100
|
||||
)
|
||||
# If you need the truly "last" based on time or another criteria,
|
||||
# ensure your key naming or storage strategy allows this determination
|
||||
# Here you would sort or filter the keys as needed based on your strategy
|
||||
self.print_verbose(f"redis keys: {keys}")
|
||||
if len(keys) > 0:
|
||||
key_value_dict = (
|
||||
await litellm.cache.cache.async_batch_get_cache(
|
||||
key_list=keys
|
||||
)
|
||||
)
|
||||
|
||||
## Add to cache
|
||||
if len(key_value_dict.items()) > 0:
|
||||
await cache.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=list(key_value_dict.items()), ttl=60
|
||||
)
|
||||
## Set cache namespace if it's a miss
|
||||
data["metadata"]["redis_namespace"] = cache_key_name
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
- Check if the cache key is in-memory
|
||||
|
||||
- Else:
|
||||
- add missing cache key from REDIS
|
||||
- update in-memory cache
|
||||
- return redis cache request
|
||||
"""
|
||||
try: # never block execution
|
||||
cache_key: Optional[str] = None
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
elif litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(
|
||||
*args, **kwargs
|
||||
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
|
||||
|
||||
if (
|
||||
cache_key is not None
|
||||
and self.in_memory_cache is not None
|
||||
and litellm.cache is not None
|
||||
):
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = self.in_memory_cache.get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
if cached_result is None:
|
||||
cached_result = await litellm.cache.cache.async_get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
if cached_result is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
cache_key, cached_result, ttl=60
|
||||
)
|
||||
return litellm.cache._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -0,0 +1,58 @@
|
||||
# What this does?
|
||||
## Checks if key is allowed to use the cache controls passed in to the completion() call
|
||||
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_CacheControlCheck(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
verbose_proxy_logger.debug("Inside Cache Control Check Pre-Call Hook")
|
||||
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
|
||||
|
||||
if data.get("cache", None) is None:
|
||||
return
|
||||
|
||||
cache_args = data.get("cache", None)
|
||||
if isinstance(cache_args, dict):
|
||||
for k, v in cache_args.items():
|
||||
if (
|
||||
(allowed_cache_controls is not None)
|
||||
and (isinstance(allowed_cache_controls, list))
|
||||
and (
|
||||
len(allowed_cache_controls) > 0
|
||||
) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
|
||||
and k not in allowed_cache_controls
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
|
||||
)
|
||||
else: # invalid cache
|
||||
return
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,303 @@
|
||||
# What is this?
|
||||
## Allocates dynamic tpm/rpm quota for a project based on current traffic
|
||||
## Tracks num active projects per minute
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, Router
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.router import ModelGroupInfo
|
||||
from litellm.types.utils import CallTypesLiteral
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
from .rate_limiter_utils import convert_priority_to_percent
|
||||
|
||||
|
||||
class DynamicRateLimiterCache:
|
||||
"""
|
||||
Thin wrapper on DualCache for this file.
|
||||
|
||||
Track number of active projects calling a model.
|
||||
"""
|
||||
|
||||
def __init__(self, cache: DualCache) -> None:
|
||||
self.cache = cache
|
||||
self.ttl = 60 # 1 min ttl
|
||||
|
||||
async def async_get_cache(self, model: str) -> Optional[int]:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
key_name = "{}:{}".format(current_minute, model)
|
||||
_response = await self.cache.async_get_cache(key=key_name)
|
||||
response: Optional[int] = None
|
||||
if _response is not None:
|
||||
response = len(_response)
|
||||
return response
|
||||
|
||||
async def async_set_cache_sadd(self, model: str, value: List):
|
||||
"""
|
||||
Add value to set.
|
||||
|
||||
Parameters:
|
||||
- model: str, the name of the model group
|
||||
- value: str, the team id
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Exception, if unable to connect to cache client (if redis caching enabled)
|
||||
"""
|
||||
try:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
||||
key_name = "{}:{}".format(current_minute, model)
|
||||
await self.cache.async_set_cache_sadd(
|
||||
key=key_name, value=value, ttl=self.ttl
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: DualCache):
|
||||
self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
|
||||
|
||||
def update_variables(self, llm_router: Router):
|
||||
self.llm_router = llm_router
|
||||
|
||||
async def check_available_usage(
|
||||
self, model: str, priority: Optional[str] = None
|
||||
) -> Tuple[
|
||||
Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
|
||||
]:
|
||||
"""
|
||||
For a given model, get its available tpm
|
||||
|
||||
Params:
|
||||
- model: str, the name of the model in the router model_list
|
||||
- priority: Optional[str], the priority for the request.
|
||||
|
||||
Returns
|
||||
- Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
|
||||
- available_tpm: int or null - always 0 or positive.
|
||||
- available_tpm: int or null - always 0 or positive.
|
||||
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
||||
- remaining_model_rpm: int or null. If available rpm is int, then this will be too.
|
||||
- active_projects: int or null
|
||||
"""
|
||||
try:
|
||||
# Get model info first for conversion
|
||||
model_group_info: Optional[
|
||||
ModelGroupInfo
|
||||
] = self.llm_router.get_model_group_info(model_group=model)
|
||||
|
||||
weight: float = 1
|
||||
if (
|
||||
litellm.priority_reservation is None
|
||||
or priority not in litellm.priority_reservation
|
||||
):
|
||||
verbose_proxy_logger.error(
|
||||
"Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format(
|
||||
priority, litellm.priority_reservation
|
||||
)
|
||||
)
|
||||
elif priority is not None and litellm.priority_reservation is not None:
|
||||
if os.getenv("LITELLM_LICENSE", None) is None:
|
||||
verbose_proxy_logger.error(
|
||||
"PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
|
||||
)
|
||||
else:
|
||||
value = litellm.priority_reservation[priority]
|
||||
weight = convert_priority_to_percent(value, model_group_info)
|
||||
|
||||
active_projects = await self.internal_usage_cache.async_get_cache(
|
||||
model=model
|
||||
)
|
||||
(
|
||||
current_model_tpm,
|
||||
current_model_rpm,
|
||||
) = await self.llm_router.get_model_group_usage(model_group=model)
|
||||
total_model_tpm: Optional[int] = None
|
||||
total_model_rpm: Optional[int] = None
|
||||
if model_group_info is not None:
|
||||
if model_group_info.tpm is not None:
|
||||
total_model_tpm = model_group_info.tpm
|
||||
if model_group_info.rpm is not None:
|
||||
total_model_rpm = model_group_info.rpm
|
||||
|
||||
remaining_model_tpm: Optional[int] = None
|
||||
if total_model_tpm is not None and current_model_tpm is not None:
|
||||
remaining_model_tpm = total_model_tpm - current_model_tpm
|
||||
elif total_model_tpm is not None:
|
||||
remaining_model_tpm = total_model_tpm
|
||||
|
||||
remaining_model_rpm: Optional[int] = None
|
||||
if total_model_rpm is not None and current_model_rpm is not None:
|
||||
remaining_model_rpm = total_model_rpm - current_model_rpm
|
||||
elif total_model_rpm is not None:
|
||||
remaining_model_rpm = total_model_rpm
|
||||
|
||||
available_tpm: Optional[int] = None
|
||||
|
||||
if remaining_model_tpm is not None:
|
||||
if active_projects is not None:
|
||||
available_tpm = int(remaining_model_tpm * weight / active_projects)
|
||||
else:
|
||||
available_tpm = int(remaining_model_tpm * weight)
|
||||
|
||||
if available_tpm is not None and available_tpm < 0:
|
||||
available_tpm = 0
|
||||
|
||||
available_rpm: Optional[int] = None
|
||||
|
||||
if remaining_model_rpm is not None:
|
||||
if active_projects is not None:
|
||||
available_rpm = int(remaining_model_rpm * weight / active_projects)
|
||||
else:
|
||||
available_rpm = int(remaining_model_rpm * weight)
|
||||
|
||||
if available_rpm is not None and available_rpm < 0:
|
||||
available_rpm = 0
|
||||
return (
|
||||
available_tpm,
|
||||
available_rpm,
|
||||
remaining_model_tpm,
|
||||
remaining_model_rpm,
|
||||
active_projects,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return None, None, None, None, None
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[
|
||||
Union[Exception, str, dict]
|
||||
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||
"""
|
||||
- For a model group
|
||||
- Check if tpm/rpm available
|
||||
- Raise RateLimitError if no tpm/rpm available
|
||||
"""
|
||||
if "model" in data:
|
||||
key_priority: Optional[str] = user_api_key_dict.metadata.get(
|
||||
"priority", None
|
||||
)
|
||||
(
|
||||
available_tpm,
|
||||
available_rpm,
|
||||
model_tpm,
|
||||
model_rpm,
|
||||
active_projects,
|
||||
) = await self.check_available_usage(
|
||||
model=data["model"], priority=key_priority
|
||||
)
|
||||
### CHECK TPM ###
|
||||
if available_tpm is not None and available_tpm == 0:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
|
||||
user_api_key_dict.api_key,
|
||||
available_tpm,
|
||||
model_tpm,
|
||||
active_projects,
|
||||
)
|
||||
},
|
||||
)
|
||||
### CHECK RPM ###
|
||||
elif available_rpm is not None and available_rpm == 0:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
|
||||
user_api_key_dict.api_key,
|
||||
available_rpm,
|
||||
model_rpm,
|
||||
active_projects,
|
||||
)
|
||||
},
|
||||
)
|
||||
elif available_rpm is not None or available_tpm is not None:
|
||||
## UPDATE CACHE WITH ACTIVE PROJECT
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_set_cache_sadd( # this is a set
|
||||
model=data["model"], # type: ignore
|
||||
value=[user_api_key_dict.token or "default_key"],
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
try:
|
||||
if isinstance(response, ModelResponse):
|
||||
model_info = self.llm_router.get_model_info(
|
||||
id=response._hidden_params["model_id"]
|
||||
)
|
||||
assert (
|
||||
model_info is not None
|
||||
), "Model info for model with id={} is None".format(
|
||||
response._hidden_params["model_id"]
|
||||
)
|
||||
key_priority: Optional[str] = user_api_key_dict.metadata.get(
|
||||
"priority", None
|
||||
)
|
||||
(
|
||||
available_tpm,
|
||||
available_rpm,
|
||||
model_tpm,
|
||||
model_rpm,
|
||||
active_projects,
|
||||
) = await self.check_available_usage(
|
||||
model=model_info["model_name"], priority=key_priority
|
||||
)
|
||||
response._hidden_params[
|
||||
"additional_headers"
|
||||
] = { # Add additional response headers - easier debugging
|
||||
"x-litellm-model_group": model_info["model_name"],
|
||||
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
|
||||
"x-ratelimit-remaining-litellm-project-requests": available_rpm,
|
||||
"x-ratelimit-remaining-model-tokens": model_tpm,
|
||||
"x-ratelimit-remaining-model-requests": model_rpm,
|
||||
"x-ratelimit-current-active-projects": active_projects,
|
||||
}
|
||||
|
||||
return response
|
||||
return await super().async_post_call_success_hook(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,809 @@
|
||||
"""
|
||||
Dynamic rate limiter v3 - Saturation-aware priority-based rate limiting
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, Router
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
RateLimitDescriptor,
|
||||
RateLimitDescriptorRateLimitObject,
|
||||
_PROXY_MaxParallelRequestsHandler_v3,
|
||||
)
|
||||
from litellm.proxy.hooks.rate_limiter_utils import convert_priority_to_percent
|
||||
from litellm.proxy.utils import InternalUsageCache
|
||||
from litellm.types.router import ModelGroupInfo
|
||||
from litellm.types.utils import CallTypesLiteral
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import PriorityReservationSettings
|
||||
|
||||
|
||||
def _get_priority_settings() -> "PriorityReservationSettings":
|
||||
"""
|
||||
Get the priority reservation settings, guaranteed to be non-None.
|
||||
|
||||
The settings are lazy-loaded in litellm.__init__ and always return an instance.
|
||||
This helper provides proper type narrowing for mypy.
|
||||
"""
|
||||
settings = litellm.priority_reservation_settings
|
||||
if settings is None:
|
||||
# This should never happen due to lazy loading, but satisfy mypy
|
||||
from litellm.types.utils import PriorityReservationSettings
|
||||
|
||||
return PriorityReservationSettings()
|
||||
return settings
|
||||
|
||||
|
||||
class _PROXY_DynamicRateLimitHandlerV3(CustomLogger):
|
||||
"""
|
||||
Saturation-aware priority-based rate limiter using v3 infrastructure.
|
||||
|
||||
Key features:
|
||||
1. Model capacity ALWAYS enforced at 100% (prevents over-allocation)
|
||||
2. Priority usage tracked from first request (accurate accounting)
|
||||
3. Priority limits only enforced when saturated >= threshold
|
||||
4. Three-phase checking prevents partial counter increments
|
||||
5. Reuses v3 limiter's Redis-based tracking (multi-instance safe)
|
||||
|
||||
How it works:
|
||||
- Phase 1: Read-only check of ALL limits (no increments)
|
||||
- Phase 2: Decide enforcement based on saturation
|
||||
- Phase 3: Increment counters only if request allowed
|
||||
- When under-saturated: priorities can borrow unused capacity (generous)
|
||||
- When saturated: strict priority-based limits enforced (fair)
|
||||
- Uses v3 limiter's atomic Lua scripts for race-free increments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
internal_usage_cache: DualCache,
|
||||
time_provider: Optional[Callable[[], datetime]] = None,
|
||||
):
|
||||
self.internal_usage_cache = InternalUsageCache(dual_cache=internal_usage_cache)
|
||||
self.v3_limiter = _PROXY_MaxParallelRequestsHandler_v3(
|
||||
self.internal_usage_cache, time_provider=time_provider
|
||||
)
|
||||
|
||||
def update_variables(self, llm_router: Router):
|
||||
self.llm_router = llm_router
|
||||
|
||||
def _get_saturation_check_cache_ttl(self) -> int:
|
||||
"""Get the configurable TTL for local cache when reading saturation values."""
|
||||
return _get_priority_settings().saturation_check_cache_ttl
|
||||
|
||||
async def _get_saturation_value_from_cache(
|
||||
self,
|
||||
counter_key: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get saturation value with configurable local cache TTL.
|
||||
|
||||
Uses DualCache with configurable TTL for local cache storage.
|
||||
TTL is configurable via litellm.priority_reservation_settings.saturation_check_cache_ttl
|
||||
|
||||
Args:
|
||||
counter_key: The cache key for the saturation counter
|
||||
|
||||
Returns:
|
||||
Counter value as string, or None if not found
|
||||
"""
|
||||
local_cache_ttl = self._get_saturation_check_cache_ttl()
|
||||
|
||||
return await self.internal_usage_cache.async_get_cache(
|
||||
key=counter_key,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=False,
|
||||
ttl=local_cache_ttl,
|
||||
)
|
||||
|
||||
def _get_priority_weight(
|
||||
self, priority: Optional[str], model_info: Optional[ModelGroupInfo] = None
|
||||
) -> float:
|
||||
"""Get the weight for a given priority from litellm.priority_reservation"""
|
||||
weight: float = _get_priority_settings().default_priority
|
||||
if (
|
||||
litellm.priority_reservation is None
|
||||
or priority not in litellm.priority_reservation
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"Priority Reservation not set for the given priority."
|
||||
)
|
||||
elif priority is not None and litellm.priority_reservation is not None:
|
||||
if os.getenv("LITELLM_LICENSE", None) is None:
|
||||
verbose_proxy_logger.error(
|
||||
"PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
|
||||
)
|
||||
else:
|
||||
value = litellm.priority_reservation[priority]
|
||||
weight = convert_priority_to_percent(value, model_info)
|
||||
return weight
|
||||
|
||||
def _get_priority_from_user_api_key_dict(
|
||||
self, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get priority from user_api_key_dict.
|
||||
|
||||
Checks team metadata first (takes precedence), then falls back to key metadata.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User authentication info
|
||||
|
||||
Returns:
|
||||
Priority string if found, None otherwise
|
||||
"""
|
||||
priority: Optional[str] = None
|
||||
|
||||
# Check team metadata first (takes precedence)
|
||||
if user_api_key_dict.team_metadata is not None:
|
||||
priority = user_api_key_dict.team_metadata.get("priority", None)
|
||||
|
||||
# Fall back to key metadata
|
||||
if priority is None:
|
||||
priority = user_api_key_dict.metadata.get("priority", None)
|
||||
|
||||
return priority
|
||||
|
||||
def _normalize_priority_weights(
|
||||
self, model_info: ModelGroupInfo
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Normalize priority weights if they sum to > 1.0
|
||||
|
||||
Handles over-allocation: {key_a: 0.60, key_b: 0.80} -> {key_a: 0.43, key_b: 0.57}
|
||||
Converts absolute rpm/tpm values to percentages based on model capacity.
|
||||
"""
|
||||
if litellm.priority_reservation is None:
|
||||
return {}
|
||||
|
||||
# Convert all values to percentages first
|
||||
weights: Dict[str, float] = {}
|
||||
for k, v in litellm.priority_reservation.items():
|
||||
weights[k] = convert_priority_to_percent(v, model_info)
|
||||
|
||||
total_weight = sum(weights.values())
|
||||
|
||||
if total_weight > 1.0:
|
||||
normalized = {k: v / total_weight for k, v in weights.items()}
|
||||
verbose_proxy_logger.debug(
|
||||
f"Normalized over-allocated priorities: {weights} -> {normalized}"
|
||||
)
|
||||
return normalized
|
||||
|
||||
return weights
|
||||
|
||||
def _get_priority_allocation(
|
||||
self,
|
||||
model: str,
|
||||
priority: Optional[str],
|
||||
normalized_weights: Dict[str, float],
|
||||
model_info: Optional[ModelGroupInfo] = None,
|
||||
) -> tuple[float, str]:
|
||||
"""
|
||||
Get priority weight and pool key for a given priority.
|
||||
|
||||
For explicit priorities: returns specific allocation and unique pool key
|
||||
For default priority: returns default allocation and shared pool key
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
priority: Priority level (None for default)
|
||||
normalized_weights: Pre-computed normalized weights
|
||||
model_info: Model configuration (optional, for fallback conversion)
|
||||
|
||||
Returns:
|
||||
tuple: (priority_weight, priority_key)
|
||||
"""
|
||||
# Check if this key has an explicit priority in litellm.priority_reservation
|
||||
has_explicit_priority = (
|
||||
priority is not None
|
||||
and litellm.priority_reservation is not None
|
||||
and priority in litellm.priority_reservation
|
||||
)
|
||||
|
||||
if has_explicit_priority and priority is not None:
|
||||
# Explicit priority: get its specific allocation
|
||||
priority_weight = normalized_weights.get(
|
||||
priority, self._get_priority_weight(priority, model_info)
|
||||
)
|
||||
# Use unique key per priority level
|
||||
priority_key = f"{model}:{priority}"
|
||||
else:
|
||||
# No explicit priority: share the default_priority pool with ALL other default keys
|
||||
priority_weight = _get_priority_settings().default_priority
|
||||
# Use shared key for all default-priority requests
|
||||
priority_key = f"{model}:default_pool"
|
||||
|
||||
return priority_weight, priority_key
|
||||
|
||||
async def _check_model_saturation(
|
||||
self,
|
||||
model: str,
|
||||
model_group_info: ModelGroupInfo,
|
||||
) -> float:
|
||||
"""
|
||||
Check current saturation by directly querying v3 limiter's cache keys.
|
||||
|
||||
Reuses v3 limiter's Redis-based tracking (works across multiple instances).
|
||||
Reads counters WITHOUT incrementing them.
|
||||
|
||||
Returns:
|
||||
float: Saturation ratio (0.0 = empty, 1.0 = at capacity, >1.0 = over)
|
||||
"""
|
||||
try:
|
||||
max_saturation = 0.0
|
||||
|
||||
# Query RPM saturation - always read from Redis for multi-node consistency
|
||||
if model_group_info.rpm is not None and model_group_info.rpm > 0:
|
||||
# Use v3 limiter's key format: {key:value}:rate_limit_type
|
||||
counter_key = self.v3_limiter.create_rate_limit_keys(
|
||||
key="model_saturation_check",
|
||||
value=model,
|
||||
rate_limit_type="requests",
|
||||
)
|
||||
|
||||
# Query Redis directly for current counter value (skip local cache for consistency)
|
||||
counter_value = await self._get_saturation_value_from_cache(
|
||||
counter_key=counter_key
|
||||
)
|
||||
|
||||
if counter_value is not None:
|
||||
current_requests = int(counter_value)
|
||||
rpm_saturation = current_requests / model_group_info.rpm
|
||||
max_saturation = max(max_saturation, rpm_saturation)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} RPM: {current_requests}/{model_group_info.rpm} "
|
||||
f"({rpm_saturation:.1%})"
|
||||
)
|
||||
|
||||
# Query TPM saturation
|
||||
if model_group_info.tpm is not None and model_group_info.tpm > 0:
|
||||
counter_key = self.v3_limiter.create_rate_limit_keys(
|
||||
key="model_saturation_check",
|
||||
value=model,
|
||||
rate_limit_type="tokens",
|
||||
)
|
||||
|
||||
counter_value = await self._get_saturation_value_from_cache(
|
||||
counter_key=counter_key
|
||||
)
|
||||
|
||||
if counter_value is not None:
|
||||
current_tokens = float(counter_value)
|
||||
tpm_saturation = current_tokens / model_group_info.tpm
|
||||
max_saturation = max(max_saturation, tpm_saturation)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} TPM: {current_tokens}/{model_group_info.tpm} "
|
||||
f"({tpm_saturation:.1%})"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} overall saturation: {max_saturation:.1%}"
|
||||
)
|
||||
|
||||
return max_saturation
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error checking saturation for {model}: {str(e)}"
|
||||
)
|
||||
# Fail open: assume not saturated on error
|
||||
return 0.0
|
||||
|
||||
def _create_priority_based_descriptors(
|
||||
self,
|
||||
model: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
priority: Optional[str],
|
||||
) -> List[RateLimitDescriptor]:
|
||||
"""
|
||||
Create rate limit descriptors with normalized priority weights.
|
||||
|
||||
Uses normalized weights to handle over-allocation scenarios.
|
||||
|
||||
For explicit priorities: each priority gets its own pool (e.g., prod gets 75%)
|
||||
For default priority: ALL keys without explicit priority share ONE pool (e.g., all share 25%)
|
||||
"""
|
||||
descriptors: List[RateLimitDescriptor] = []
|
||||
|
||||
if litellm.priority_reservation is None:
|
||||
return descriptors
|
||||
|
||||
# Get model group info
|
||||
model_group_info: Optional[
|
||||
ModelGroupInfo
|
||||
] = self.llm_router.get_model_group_info(model_group=model)
|
||||
if model_group_info is None:
|
||||
return descriptors
|
||||
|
||||
# Get normalized priority weight and pool key
|
||||
normalized_weights = self._normalize_priority_weights(model_group_info)
|
||||
priority_weight, priority_key = self._get_priority_allocation(
|
||||
model=model,
|
||||
priority=priority,
|
||||
normalized_weights=normalized_weights,
|
||||
model_info=model_group_info,
|
||||
)
|
||||
|
||||
rate_limit_config: RateLimitDescriptorRateLimitObject = {}
|
||||
|
||||
# Apply priority weight to model limits
|
||||
if model_group_info.tpm is not None:
|
||||
reserved_tpm = int(model_group_info.tpm * priority_weight)
|
||||
rate_limit_config["tokens_per_unit"] = reserved_tpm
|
||||
|
||||
if model_group_info.rpm is not None:
|
||||
reserved_rpm = int(model_group_info.rpm * priority_weight)
|
||||
rate_limit_config["requests_per_unit"] = reserved_rpm
|
||||
|
||||
if rate_limit_config:
|
||||
rate_limit_config["window_size"] = self.v3_limiter.window_size
|
||||
|
||||
descriptors.append(
|
||||
RateLimitDescriptor(
|
||||
key="priority_model",
|
||||
value=priority_key,
|
||||
rate_limit=rate_limit_config,
|
||||
)
|
||||
)
|
||||
|
||||
return descriptors
|
||||
|
||||
def _create_model_tracking_descriptor(
|
||||
self,
|
||||
model: str,
|
||||
model_group_info: ModelGroupInfo,
|
||||
high_limit_multiplier: int = 1,
|
||||
) -> RateLimitDescriptor:
|
||||
"""
|
||||
Create a descriptor for tracking model-wide usage.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
model_group_info: Model configuration with RPM/TPM limits
|
||||
high_limit_multiplier: Multiplier for limits (use >1 for tracking-only)
|
||||
|
||||
Returns:
|
||||
Rate limit descriptor for model-wide tracking
|
||||
"""
|
||||
return RateLimitDescriptor(
|
||||
key="model_saturation_check",
|
||||
value=model,
|
||||
rate_limit={
|
||||
"requests_per_unit": (
|
||||
model_group_info.rpm * high_limit_multiplier
|
||||
if model_group_info.rpm
|
||||
else None
|
||||
),
|
||||
"tokens_per_unit": (
|
||||
model_group_info.tpm * high_limit_multiplier
|
||||
if model_group_info.tpm
|
||||
else None
|
||||
),
|
||||
"window_size": self.v3_limiter.window_size,
|
||||
},
|
||||
)
|
||||
|
||||
async def _check_rate_limits(
|
||||
self,
|
||||
model: str,
|
||||
model_group_info: ModelGroupInfo,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
priority: Optional[str],
|
||||
saturation: float,
|
||||
data: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Check rate limits using THREE-PHASE approach to prevent partial increments.
|
||||
|
||||
Phase 1: Read-only check of ALL limits (no increments)
|
||||
Phase 2: Decide which limits to enforce based on saturation
|
||||
Phase 3: Increment ALL counters atomically (model + priority)
|
||||
|
||||
This prevents the bug where:
|
||||
- Model counter increments in stage 1
|
||||
- Priority check fails in stage 2
|
||||
- Request blocked but model counter already incremented
|
||||
|
||||
Key behaviors:
|
||||
- All checks performed first (read-only)
|
||||
- Only increment counters if request will be allowed
|
||||
- Model capacity: Always enforced at 100%
|
||||
- Priority limits: Only enforced when saturated >= threshold
|
||||
- Both counters tracked from first request (accurate accounting)
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
model_group_info: Model configuration
|
||||
user_api_key_dict: User authentication info
|
||||
priority: User's priority level
|
||||
saturation: Current saturation level
|
||||
data: Request data dictionary
|
||||
|
||||
Raises:
|
||||
HTTPException: If any limit is exceeded
|
||||
"""
|
||||
import json
|
||||
|
||||
saturation_threshold = _get_priority_settings().saturation_threshold
|
||||
should_enforce_priority = saturation >= saturation_threshold
|
||||
|
||||
# Build ALL descriptors upfront
|
||||
descriptors_to_check: List[RateLimitDescriptor] = []
|
||||
|
||||
# Model-wide descriptor (always enforce)
|
||||
model_wide_descriptor = self._create_model_tracking_descriptor(
|
||||
model=model,
|
||||
model_group_info=model_group_info,
|
||||
high_limit_multiplier=1,
|
||||
)
|
||||
descriptors_to_check.append(model_wide_descriptor)
|
||||
|
||||
# Priority descriptors (always track, conditionally enforce)
|
||||
priority_descriptors = self._create_priority_based_descriptors(
|
||||
model=model,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
priority=priority,
|
||||
)
|
||||
if priority_descriptors:
|
||||
descriptors_to_check.extend(priority_descriptors)
|
||||
|
||||
# PHASE 1: Read-only check of ALL limits (no increments)
|
||||
check_response = await self.v3_limiter.should_rate_limit(
|
||||
descriptors=descriptors_to_check,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
read_only=True, # CRITICAL: Don't increment counters yet
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Read-only check: {json.dumps(check_response, indent=2)}"
|
||||
)
|
||||
|
||||
# PHASE 2: Decide which limits to enforce
|
||||
if check_response["overall_code"] == "OVER_LIMIT":
|
||||
for status in check_response["statuses"]:
|
||||
if status["code"] == "OVER_LIMIT":
|
||||
descriptor_key = status["descriptor_key"]
|
||||
|
||||
# Model-wide limit exceeded (ALWAYS enforce)
|
||||
if descriptor_key == "model_saturation_check":
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": f"Model capacity reached for {model}. "
|
||||
f"Priority: {priority}, "
|
||||
f"Rate limit type: {status['rate_limit_type']}, "
|
||||
f"Remaining: {status['limit_remaining']}"
|
||||
},
|
||||
headers={
|
||||
"retry-after": str(self.v3_limiter.window_size),
|
||||
"rate_limit_type": str(status["rate_limit_type"]),
|
||||
"x-litellm-priority": priority or "default",
|
||||
},
|
||||
)
|
||||
|
||||
# Priority limit exceeded (ONLY enforce when saturated)
|
||||
elif descriptor_key == "priority_model" and should_enforce_priority:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Enforcing priority limits for {model}, saturation: {saturation:.1%}, "
|
||||
f"priority: {priority}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": f"Priority-based rate limit exceeded. "
|
||||
f"Priority: {priority}, "
|
||||
f"Rate limit type: {status['rate_limit_type']}, "
|
||||
f"Remaining: {status['limit_remaining']}, "
|
||||
f"Model saturation: {saturation:.1%}"
|
||||
},
|
||||
headers={
|
||||
"retry-after": str(self.v3_limiter.window_size),
|
||||
"rate_limit_type": str(status["rate_limit_type"]),
|
||||
"x-litellm-priority": priority or "default",
|
||||
"x-litellm-saturation": f"{saturation:.2%}",
|
||||
},
|
||||
)
|
||||
|
||||
# PHASE 3: Increment counters separately to avoid early-exit issues
|
||||
# Model counter must ALWAYS increment, but priority counter might be over limit
|
||||
# If we increment them together, v3_limiter's in-memory check will exit early
|
||||
# and skip incrementing the model counter
|
||||
|
||||
# Step 3a: Increment model-wide counter (always)
|
||||
model_increment_response = await self.v3_limiter.should_rate_limit(
|
||||
descriptors=[model_wide_descriptor],
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
# Step 3b: Increment priority counter (may be over limit, but we still track it)
|
||||
if priority_descriptors:
|
||||
priority_increment_response = await self.v3_limiter.should_rate_limit(
|
||||
descriptors=priority_descriptors,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
# Combine responses for post-call hook
|
||||
combined_response = {
|
||||
"overall_code": model_increment_response["overall_code"],
|
||||
"statuses": model_increment_response["statuses"]
|
||||
+ priority_increment_response["statuses"],
|
||||
}
|
||||
data["litellm_proxy_rate_limit_response"] = combined_response
|
||||
else:
|
||||
data["litellm_proxy_rate_limit_response"] = model_increment_response
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Saturation-aware pre-call hook for priority-based rate limiting.
|
||||
|
||||
Flow:
|
||||
1. Check current saturation level
|
||||
2. THREE-PHASE rate limit check:
|
||||
- PHASE 1: Read-only check of ALL limits (no increments)
|
||||
- PHASE 2: Decide which limits to enforce based on saturation
|
||||
- PHASE 3: Increment ALL counters atomically if request allowed
|
||||
|
||||
This three-phase approach ensures:
|
||||
- Model capacity is NEVER exceeded (always enforced at 100%)
|
||||
- Priority usage tracked from first request (accurate metrics)
|
||||
- Counters only increment when request will be allowed (prevents phantom usage)
|
||||
- When under-saturated: priorities can borrow unused capacity (generous)
|
||||
- When saturated: fair allocation based on normalized priority weights (strict)
|
||||
|
||||
Example with 100 RPM model, 60% priority allocation, 80% threshold:
|
||||
- Saturation < 80%: Priority can use up to 100 RPM (model limit enforced only)
|
||||
- Saturation >= 80%: Priority limited to 60 RPM (both limits enforced)
|
||||
|
||||
Prevents bugs where:
|
||||
- Model counter increments but priority check fails → model over-capacity
|
||||
- Priority counter increments but not enforced → inaccurate metrics
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User authentication and metadata
|
||||
cache: Dual cache instance
|
||||
data: Request data containing model name
|
||||
call_type: Type of API call being made
|
||||
|
||||
Returns:
|
||||
None if request is allowed, otherwise raises HTTPException
|
||||
"""
|
||||
if "model" not in data:
|
||||
return None
|
||||
|
||||
model = data["model"]
|
||||
priority = self._get_priority_from_user_api_key_dict(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
# Get model configuration
|
||||
model_group_info: Optional[
|
||||
ModelGroupInfo
|
||||
] = self.llm_router.get_model_group_info(model_group=model)
|
||||
if model_group_info is None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"No model group info for {model}, allowing request"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
# STEP 1: Check current saturation level
|
||||
saturation = await self._check_model_saturation(model, model_group_info)
|
||||
|
||||
saturation_threshold = _get_priority_settings().saturation_threshold
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"[Dynamic Rate Limiter] Model={model}, Saturation={saturation:.1%}, "
|
||||
f"Threshold={saturation_threshold:.1%}, Priority={priority}"
|
||||
)
|
||||
|
||||
# STEP 2: Check rate limits in THREE phases
|
||||
# Phase 1: Read-only check of ALL limits (no increments)
|
||||
# Phase 2: Decide which limits to enforce (based on saturation)
|
||||
# Phase 3: Increment ALL counters only if request will be allowed
|
||||
# This prevents partial increments and ensures accurate tracking
|
||||
await self._check_rate_limits(
|
||||
model=model,
|
||||
model_group_info=model_group_info,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
priority=priority,
|
||||
saturation=saturation,
|
||||
data=data,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in dynamic rate limiter: {str(e)}, allowing request"
|
||||
)
|
||||
# Fail open on unexpected errors
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
"""
|
||||
Post-call hook to add rate limit headers to response.
|
||||
Leverages v3 limiter's post-call hook functionality.
|
||||
"""
|
||||
try:
|
||||
# Call v3 limiter's post-call hook to add standard rate limit headers
|
||||
await self.v3_limiter.async_post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
# Add additional priority-specific headers
|
||||
if isinstance(response, ModelResponse):
|
||||
priority = self._get_priority_from_user_api_key_dict(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
# Get existing additional headers
|
||||
additional_headers = (
|
||||
getattr(response, "_hidden_params", {}).get(
|
||||
"additional_headers", {}
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
# Add priority information
|
||||
additional_headers["x-litellm-priority"] = priority or "default"
|
||||
additional_headers["x-litellm-rate-limiter-version"] = "v3"
|
||||
|
||||
# Update response
|
||||
if not hasattr(response, "_hidden_params"):
|
||||
response._hidden_params = {}
|
||||
response._hidden_params["additional_headers"] = additional_headers
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error in dynamic rate limiter v3 post-call hook: {str(e)}"
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Update token usage for priority-based rate limiting after successful API calls.
|
||||
|
||||
Increments token counters for:
|
||||
- model_saturation_check: Model-wide token tracking
|
||||
- priority_model: Priority-specific token tracking
|
||||
"""
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
)
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_model_group_from_litellm_kwargs,
|
||||
)
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"INSIDE dynamic rate limiter ASYNC SUCCESS LOGGING"
|
||||
)
|
||||
|
||||
litellm_parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
|
||||
# Get metadata from standard_logging_object
|
||||
standard_logging_object = kwargs.get("standard_logging_object") or {}
|
||||
standard_logging_metadata = standard_logging_object.get("metadata") or {}
|
||||
|
||||
# Get model and priority
|
||||
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||
if not model_group:
|
||||
return
|
||||
|
||||
# Get priority from user_api_key_auth_metadata in standard_logging_metadata
|
||||
# This is where user_api_key_dict.metadata is stored during pre-call
|
||||
user_api_key_auth_metadata = (
|
||||
standard_logging_metadata.get("user_api_key_auth_metadata") or {}
|
||||
)
|
||||
key_priority: Optional[str] = user_api_key_auth_metadata.get("priority")
|
||||
|
||||
# Get total tokens from response
|
||||
total_tokens = 0
|
||||
rate_limit_type = self.v3_limiter.get_rate_limit_type()
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage and isinstance(_usage, Usage):
|
||||
if rate_limit_type == "output":
|
||||
total_tokens = _usage.completion_tokens
|
||||
elif rate_limit_type == "input":
|
||||
total_tokens = _usage.prompt_tokens
|
||||
elif rate_limit_type == "total":
|
||||
total_tokens = _usage.total_tokens
|
||||
|
||||
if total_tokens == 0:
|
||||
return
|
||||
|
||||
# Create pipeline operations for token increments
|
||||
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
||||
|
||||
# Model-wide token tracking (model_saturation_check)
|
||||
model_token_key = self.v3_limiter.create_rate_limit_keys(
|
||||
key="model_saturation_check",
|
||||
value=model_group,
|
||||
rate_limit_type="tokens",
|
||||
)
|
||||
pipeline_operations.append(
|
||||
RedisPipelineIncrementOperation(
|
||||
key=model_token_key,
|
||||
increment_value=total_tokens,
|
||||
ttl=self.v3_limiter.window_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Priority-specific token tracking (priority_model)
|
||||
# Determine priority key (same logic as _get_priority_allocation)
|
||||
has_explicit_priority = (
|
||||
key_priority is not None
|
||||
and litellm.priority_reservation is not None
|
||||
and key_priority in litellm.priority_reservation
|
||||
)
|
||||
|
||||
if has_explicit_priority and key_priority is not None:
|
||||
priority_key = f"{model_group}:{key_priority}"
|
||||
else:
|
||||
priority_key = f"{model_group}:default_pool"
|
||||
|
||||
priority_token_key = self.v3_limiter.create_rate_limit_keys(
|
||||
key="priority_model",
|
||||
value=priority_key,
|
||||
rate_limit_type="tokens",
|
||||
)
|
||||
pipeline_operations.append(
|
||||
RedisPipelineIncrementOperation(
|
||||
key=priority_token_key,
|
||||
increment_value=total_tokens,
|
||||
ttl=self.v3_limiter.window_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute token increments with TTL preservation
|
||||
if pipeline_operations:
|
||||
await self.v3_limiter.async_increment_tokens_with_ttl_preservation(
|
||||
pipeline_operations=pipeline_operations,
|
||||
parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
# Only log 'priority' if it's known safe; otherwise, redact.
|
||||
SAFE_PRIORITIES = {"low", "medium", "high", "default"}
|
||||
logged_priority = (
|
||||
key_priority if key_priority in SAFE_PRIORITIES else "REDACTED"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"[Dynamic Rate Limiter] Incremented tokens by {total_tokens} for "
|
||||
f"model={model_group}, priority={logged_priority}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error in dynamic rate limiter success event: {str(e)}"
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
[
|
||||
{
|
||||
"name": "Zip code Recognizer",
|
||||
"supported_language": "en",
|
||||
"patterns": [
|
||||
{
|
||||
"name": "zip code (weak)",
|
||||
"regex": "(\\b\\d{5}(?:\\-\\d{4})?\\b)",
|
||||
"score": 0.01
|
||||
}
|
||||
],
|
||||
"context": ["zip", "code"],
|
||||
"supported_entity": "ZIP"
|
||||
},
|
||||
{
|
||||
"name": "Swiss AHV Number Recognizer",
|
||||
"supported_language": "en",
|
||||
"patterns": [
|
||||
{
|
||||
"name": "AHV number (strong)",
|
||||
"regex": "(756\\.\\d{4}\\.\\d{4}\\.\\d{2})|(756\\d{10})",
|
||||
"score": 0.95
|
||||
}
|
||||
],
|
||||
"context": ["AHV", "social security", "Swiss"],
|
||||
"supported_entity": "AHV_NUMBER"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,607 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import (
|
||||
GenerateKeyRequest,
|
||||
GenerateKeyResponse,
|
||||
KeyRequest,
|
||||
LiteLLM_AuditLogs,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_VerificationToken,
|
||||
LitellmTableNames,
|
||||
RegenerateKeyRequest,
|
||||
UpdateKeyRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
||||
# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
|
||||
LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
|
||||
|
||||
|
||||
class KeyManagementEventHooks:
|
||||
@staticmethod
|
||||
async def async_key_generated_hook(
|
||||
data: GenerateKeyRequest,
|
||||
response: GenerateKeyResponse,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Hook that runs after a successful /key/generate request
|
||||
|
||||
Handles the following:
|
||||
- Sending Email with Key Details
|
||||
- Storing Audit Logs for key generation
|
||||
- Storing Generated Key in DB
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Send email notification - non-blocking, independent operation
|
||||
if data.send_invite_email is True:
|
||||
try:
|
||||
await KeyManagementEventHooks._send_key_created_email(
|
||||
response.model_dump(exclude_none=True)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(f"Failed to send key created email: {e}")
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = response.model_dump_json(exclude_none=True)
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=response.token_id or "",
|
||||
action="created",
|
||||
updated_values=_updated_values,
|
||||
before_value=None,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Store the generated key in the secret manager - non-blocking, independent operation
|
||||
try:
|
||||
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
|
||||
secret_name=data.key_alias or f"virtual-key-{response.token_id}",
|
||||
secret_token=response.key,
|
||||
team_id=data.team_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to store virtual key in secret manager: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def async_key_updated_hook(
|
||||
data: UpdateKeyRequest,
|
||||
existing_key_row: Any,
|
||||
response: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Post /key/update processing hook
|
||||
|
||||
Handles the following:
|
||||
- Storing Audit Logs for key update
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
|
||||
|
||||
_before_value = existing_key_row.json(exclude_none=True)
|
||||
_before_value = json.dumps(_before_value, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=data.key,
|
||||
action="updated",
|
||||
updated_values=_updated_values,
|
||||
before_value=_before_value,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def async_key_rotated_hook(
|
||||
data: Optional[RegenerateKeyRequest],
|
||||
existing_key_row: LiteLLM_VerificationToken,
|
||||
response: GenerateKeyResponse,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Store the generated key in the secret manager - non-blocking, independent operation
|
||||
if data is not None and response.token_id is not None:
|
||||
try:
|
||||
initial_secret_name = (
|
||||
existing_key_row.key_alias
|
||||
or f"virtual-key-{existing_key_row.token}"
|
||||
)
|
||||
new_secret_name = (
|
||||
response.key_alias or data.key_alias or initial_secret_name
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
"Updating secret in secret manager: secret_name=%s",
|
||||
new_secret_name,
|
||||
)
|
||||
team_id = getattr(existing_key_row, "team_id", None)
|
||||
await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
|
||||
current_secret_name=initial_secret_name,
|
||||
new_secret_name=new_secret_name,
|
||||
new_secret_value=response.key,
|
||||
team_id=team_id,
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
"Secret updated in secret manager: secret_name=%s",
|
||||
new_secret_name,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to rotate virtual key in secret manager: {e}"
|
||||
)
|
||||
|
||||
# Send key rotated email if configured - non-blocking, independent operation
|
||||
try:
|
||||
await KeyManagementEventHooks._send_key_rotated_email(
|
||||
response=response.model_dump(exclude_none=True),
|
||||
existing_key_alias=existing_key_row.key_alias,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(f"Failed to send key rotated email: {e}")
|
||||
|
||||
# store the audit log
|
||||
if litellm.store_audit_logs is True and existing_key_row.token is not None:
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.token,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=existing_key_row.token,
|
||||
action="rotated",
|
||||
updated_values=response.model_dump_json(exclude_none=True),
|
||||
before_value=existing_key_row.model_dump_json(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def async_key_deleted_hook(
|
||||
data: KeyRequest,
|
||||
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||
response: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Post /key/delete processing hook
|
||||
|
||||
Handles the following:
|
||||
- Storing Audit Logs for key deletion
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||
if litellm.store_audit_logs is True and data.keys is not None:
|
||||
# make an audit log for each key deleted
|
||||
for key in keys_being_deleted:
|
||||
if key.token is None:
|
||||
continue
|
||||
_key_row = key.model_dump_json(exclude_none=True)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.token,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=key.token,
|
||||
action="deleted",
|
||||
updated_values="{}",
|
||||
before_value=_key_row,
|
||||
)
|
||||
)
|
||||
)
|
||||
# delete the keys from the secret manager
|
||||
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
|
||||
keys_being_deleted=keys_being_deleted
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def _store_virtual_key_in_secret_manager(
|
||||
secret_name: str, secret_token: str, team_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Store a virtual key in the secret manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the virtual key
|
||||
secret_token: Value of the virtual key (example: sk-1234)
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.base_secret_manager import (
|
||||
BaseSecretManager,
|
||||
)
|
||||
|
||||
# store the key in the secret manager
|
||||
if isinstance(litellm.secret_manager_client, BaseSecretManager):
|
||||
tags = getattr(litellm._key_management_settings, "tags", None)
|
||||
description = getattr(
|
||||
litellm._key_management_settings, "description", None
|
||||
)
|
||||
optional_params = await KeyManagementEventHooks._get_secret_manager_optional_params(
|
||||
team_id
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Creating secret with {secret_name} and tags={tags} and description={description}"
|
||||
)
|
||||
|
||||
await litellm.secret_manager_client.async_write_secret(
|
||||
secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
secret_name
|
||||
),
|
||||
description=description,
|
||||
secret_value=secret_token,
|
||||
tags=tags,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _rotate_virtual_key_in_secret_manager(
|
||||
current_secret_name: str,
|
||||
new_secret_name: str,
|
||||
new_secret_value: str,
|
||||
team_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Update a virtual key in the secret manager
|
||||
|
||||
Args:
|
||||
current_secret_name: Current name of the virtual key
|
||||
new_secret_name: New name of the virtual key
|
||||
new_secret_value: New value of the virtual key (example: sk-1234)
|
||||
team_id: Optional team ID to get team-specific secret manager settings
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.base_secret_manager import (
|
||||
BaseSecretManager,
|
||||
)
|
||||
|
||||
# store the key in the secret manager
|
||||
if isinstance(litellm.secret_manager_client, BaseSecretManager):
|
||||
optional_params = await KeyManagementEventHooks._get_secret_manager_optional_params(
|
||||
team_id
|
||||
)
|
||||
await litellm.secret_manager_client.async_rotate_secret(
|
||||
current_secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
current_secret_name
|
||||
),
|
||||
new_secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
new_secret_name
|
||||
),
|
||||
new_secret_value=new_secret_value,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_secret_name(secret_name: str) -> str:
|
||||
if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
|
||||
"/"
|
||||
):
|
||||
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
|
||||
else:
|
||||
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
|
||||
|
||||
@staticmethod
|
||||
async def _delete_virtual_keys_from_secret_manager(
|
||||
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||
):
|
||||
"""
|
||||
Deletes virtual keys from the secret manager
|
||||
|
||||
Args:
|
||||
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.base_secret_manager import (
|
||||
BaseSecretManager,
|
||||
)
|
||||
|
||||
if isinstance(litellm.secret_manager_client, BaseSecretManager):
|
||||
team_settings_cache: Dict[Optional[str], Optional[dict]] = {}
|
||||
for key in keys_being_deleted:
|
||||
if key.key_alias is not None:
|
||||
team_id = getattr(key, "team_id", None)
|
||||
if team_id not in team_settings_cache:
|
||||
team_settings_cache[
|
||||
team_id
|
||||
] = await KeyManagementEventHooks._get_secret_manager_optional_params(
|
||||
team_id
|
||||
)
|
||||
optional_params = team_settings_cache[team_id]
|
||||
await litellm.secret_manager_client.async_delete_secret(
|
||||
secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
key.key_alias
|
||||
),
|
||||
optional_params=optional_params,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_secret_manager_optional_params(
|
||||
team_id: Optional[str],
|
||||
) -> Optional[dict]:
|
||||
if team_id is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from litellm.proxy import proxy_server as proxy_server_module
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
prisma_client = getattr(proxy_server_module, "prisma_client", None)
|
||||
user_api_key_cache = getattr(proxy_server_module, "user_api_key_cache", None)
|
||||
|
||||
if prisma_client is None or user_api_key_cache is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
|
||||
team_obj = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
verbose_proxy_logger.debug(
|
||||
f"Unable to load team metadata for team_id={team_id}: {exc}"
|
||||
)
|
||||
return None
|
||||
|
||||
metadata = getattr(team_obj, "metadata", None)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
if hasattr(metadata, "model_dump"):
|
||||
metadata = metadata.model_dump()
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
|
||||
team_settings = metadata.get("secret_manager_settings")
|
||||
if isinstance(team_settings, dict) and team_settings:
|
||||
return dict(team_settings)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_email_sending_enabled() -> bool:
|
||||
"""
|
||||
Check if email sending is enabled via v2 enterprise loggers or v0 alerting config.
|
||||
|
||||
Returns True only if email is actually configured, preventing any email
|
||||
processing when the user has not opted in.
|
||||
"""
|
||||
# Check v2 enterprise email loggers
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
|
||||
BaseEmailLogger,
|
||||
)
|
||||
|
||||
initialized_email_loggers = (
|
||||
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=BaseEmailLogger
|
||||
)
|
||||
)
|
||||
if len(initialized_email_loggers) > 0:
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Check v0 alerting config
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if "email" in general_settings.get("alerting", []):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _send_key_created_email(response: dict):
|
||||
"""
|
||||
Send key created email if email sending is enabled.
|
||||
|
||||
This method is non-blocking - it will return silently if email is not
|
||||
configured, and will log warnings instead of raising exceptions on failure.
|
||||
"""
|
||||
# Early exit if email is not enabled
|
||||
if not KeyManagementEventHooks._is_email_sending_enabled():
|
||||
verbose_proxy_logger.debug(
|
||||
"Email sending not enabled, skipping key created email"
|
||||
)
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
##########################
|
||||
# v2 integration for emails (enterprise)
|
||||
##########################
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
|
||||
BaseEmailLogger,
|
||||
)
|
||||
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
|
||||
SendKeyCreatedEmailEvent,
|
||||
)
|
||||
|
||||
initialized_email_loggers = (
|
||||
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=BaseEmailLogger
|
||||
)
|
||||
)
|
||||
if len(initialized_email_loggers) > 0:
|
||||
event = SendKeyCreatedEmailEvent(
|
||||
virtual_key=response.get("key", ""),
|
||||
event="key_created",
|
||||
event_group=Litellm_EntityType.KEY,
|
||||
event_message="API Key Created",
|
||||
token=response.get("token", ""),
|
||||
spend=response.get("spend", 0.0),
|
||||
max_budget=response.get("max_budget", 0.0),
|
||||
user_id=response.get("user_id", None),
|
||||
team_id=response.get("team_id", "Default Team"),
|
||||
key_alias=response.get("key_alias", None),
|
||||
)
|
||||
for email_logger in initialized_email_loggers:
|
||||
if isinstance(email_logger, BaseEmailLogger):
|
||||
await email_logger.send_key_created_email(
|
||||
send_key_created_email_event=event,
|
||||
)
|
||||
return
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
##########################
|
||||
# v0 integration for emails
|
||||
##########################
|
||||
if "email" in general_settings.get("alerting", []):
|
||||
from litellm.proxy._types import WebhookEvent
|
||||
|
||||
event = WebhookEvent(
|
||||
event="key_created",
|
||||
event_group=Litellm_EntityType.KEY,
|
||||
event_message="API Key Created",
|
||||
token=response.get("token", ""),
|
||||
spend=response.get("spend", 0.0),
|
||||
max_budget=response.get("max_budget", 0.0),
|
||||
user_id=response.get("user_id", None),
|
||||
team_id=response.get("team_id", "Default Team"),
|
||||
key_alias=response.get("key_alias", None),
|
||||
)
|
||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||
webhook_event=event,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_key_rotated_email(
|
||||
response: dict, existing_key_alias: Optional[str]
|
||||
):
|
||||
"""
|
||||
Send key rotated email if email sending is enabled.
|
||||
|
||||
This method is non-blocking - it will return silently if email is not
|
||||
configured, and will log warnings instead of raising exceptions on failure.
|
||||
"""
|
||||
# Early exit if email is not enabled
|
||||
if not KeyManagementEventHooks._is_email_sending_enabled():
|
||||
verbose_proxy_logger.debug(
|
||||
"Email sending not enabled, skipping key rotated email"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
|
||||
BaseEmailLogger,
|
||||
)
|
||||
except ImportError:
|
||||
# Enterprise package not installed - v0 doesn't support key rotated email
|
||||
verbose_proxy_logger.debug(
|
||||
"Enterprise package not installed, skipping key rotated email"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
|
||||
SendKeyRotatedEmailEvent,
|
||||
)
|
||||
except ImportError:
|
||||
verbose_proxy_logger.debug(
|
||||
"Enterprise types not available, skipping key rotated email"
|
||||
)
|
||||
return
|
||||
|
||||
event = SendKeyRotatedEmailEvent(
|
||||
virtual_key=response.get("key", ""),
|
||||
event="key_rotated",
|
||||
event_group=Litellm_EntityType.KEY,
|
||||
event_message="API Key Rotated",
|
||||
token=response.get("token", ""),
|
||||
spend=response.get("spend", 0.0),
|
||||
max_budget=response.get("max_budget", 0.0),
|
||||
user_id=response.get("user_id", None),
|
||||
team_id=response.get("team_id", "Default Team"),
|
||||
key_alias=response.get("key_alias", existing_key_alias),
|
||||
)
|
||||
|
||||
##########################
|
||||
# v2 integration for emails
|
||||
##########################
|
||||
initialized_email_loggers = (
|
||||
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=BaseEmailLogger
|
||||
)
|
||||
)
|
||||
if len(initialized_email_loggers) > 0:
|
||||
for email_logger in initialized_email_loggers:
|
||||
if isinstance(email_logger, BaseEmailLogger):
|
||||
await email_logger.send_key_rotated_email(
|
||||
send_key_rotated_email_event=event,
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
LiteLLM Skills Hook - Proxy integration for skills
|
||||
|
||||
This module provides the CustomLogger hook for skills processing.
|
||||
The actual skill logic is in litellm/llms/litellm_proxy/skills/.
|
||||
|
||||
Usage:
|
||||
from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook
|
||||
|
||||
# Register hook in proxy
|
||||
litellm.callbacks.append(SkillsInjectionHook())
|
||||
"""
|
||||
|
||||
# Re-export from the SDK location for convenience
|
||||
from litellm.llms.litellm_proxy.skills import (
|
||||
LITELLM_CODE_EXECUTION_TOOL,
|
||||
CodeExecutionHandler,
|
||||
LiteLLMInternalTools,
|
||||
SkillPromptInjectionHandler,
|
||||
SkillsSandboxExecutor,
|
||||
code_execution_handler,
|
||||
get_litellm_code_execution_tool,
|
||||
)
|
||||
from litellm.proxy.hooks.litellm_skills.main import (
|
||||
SkillsInjectionHook,
|
||||
skills_injection_hook,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SkillsInjectionHook",
|
||||
"skills_injection_hook",
|
||||
"CodeExecutionHandler",
|
||||
"LiteLLMInternalTools",
|
||||
"LITELLM_CODE_EXECUTION_TOOL",
|
||||
"get_litellm_code_execution_tool",
|
||||
"code_execution_handler",
|
||||
"SkillPromptInjectionHandler",
|
||||
"SkillsSandboxExecutor",
|
||||
]
|
||||
@@ -0,0 +1,914 @@
|
||||
"""
|
||||
Skills Injection Hook for LiteLLM Proxy
|
||||
|
||||
Main hook that orchestrates skill processing:
|
||||
- Fetches skills from LiteLLM DB
|
||||
- Injects SKILL.md content into system prompt
|
||||
- Adds litellm_code_execution tool for automatic code execution
|
||||
- Handles agentic loop internally when litellm_code_execution is called
|
||||
|
||||
For non-Anthropic models (e.g., Bedrock, OpenAI, etc.):
|
||||
- Skills are converted to OpenAI-style tools
|
||||
- Skill file content (SKILL.md) is extracted and injected into the system prompt
|
||||
- litellm_code_execution tool is added - when model calls it, LiteLLM handles
|
||||
execution automatically and returns final response with file_ids
|
||||
|
||||
Usage:
|
||||
# Simple - LiteLLM handles everything automatically via proxy
|
||||
# The container parameter triggers the SkillsInjectionHook
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": "Create a bouncing ball GIF"}],
|
||||
container={"skills": [{"skill_id": "litellm:skill_abc123"}]},
|
||||
)
|
||||
# Response includes file_ids for generated files
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.litellm_proxy.skills.prompt_injection import (
|
||||
SkillPromptInjectionHandler,
|
||||
)
|
||||
from litellm.proxy._types import LiteLLM_SkillsTable, UserAPIKeyAuth
|
||||
from litellm.types.utils import CallTypes, CallTypesLiteral
|
||||
|
||||
|
||||
class SkillsInjectionHook(CustomLogger):
|
||||
"""
|
||||
Pre/Post-call hook that processes skills from container.skills parameter.
|
||||
|
||||
Pre-call (async_pre_call_hook):
|
||||
- Skills with 'litellm:' prefix are fetched from LiteLLM DB
|
||||
- For Anthropic models: native skills pass through, LiteLLM skills converted to tools
|
||||
- For non-Anthropic models: LiteLLM skills are converted to tools + execute_code tool
|
||||
|
||||
Post-call (async_post_call_success_deployment_hook):
|
||||
- If response has litellm_code_execution tool call, automatically execute code
|
||||
- Continue conversation loop until model gives final response
|
||||
- Return response with generated files inline
|
||||
|
||||
This hook is called automatically by litellm during completion calls.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from litellm.llms.litellm_proxy.skills.constants import (
|
||||
DEFAULT_MAX_ITERATIONS,
|
||||
DEFAULT_SANDBOX_TIMEOUT,
|
||||
)
|
||||
|
||||
self.optional_params = kwargs
|
||||
self.prompt_handler = SkillPromptInjectionHandler()
|
||||
self.max_iterations = kwargs.get("max_iterations", DEFAULT_MAX_ITERATIONS)
|
||||
self.sandbox_timeout = kwargs.get("sandbox_timeout", DEFAULT_SANDBOX_TIMEOUT)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Process skills from container.skills before the LLM call.
|
||||
|
||||
1. Check if container.skills exists in request
|
||||
2. Separate skills by prefix (litellm: vs native)
|
||||
3. Fetch LiteLLM skills from database
|
||||
4. For Anthropic: keep native skills in container
|
||||
5. For non-Anthropic: convert LiteLLM skills to tools, inject content, add execute_code
|
||||
"""
|
||||
# Only process completion-type calls
|
||||
if call_type not in ["completion", "acompletion", "anthropic_messages"]:
|
||||
return data
|
||||
|
||||
container = data.get("container")
|
||||
if not container or not isinstance(container, dict):
|
||||
return data
|
||||
|
||||
skills = container.get("skills")
|
||||
if not skills or not isinstance(skills, list):
|
||||
return data
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Processing {len(skills)} skills"
|
||||
)
|
||||
|
||||
litellm_skills: List[LiteLLM_SkillsTable] = []
|
||||
anthropic_skills: List[Dict[str, Any]] = []
|
||||
|
||||
# Separate skills by prefix
|
||||
for skill in skills:
|
||||
if not isinstance(skill, dict):
|
||||
continue
|
||||
|
||||
skill_id = skill.get("skill_id", "")
|
||||
if skill_id.startswith("litellm_"):
|
||||
# Fetch from LiteLLM DB
|
||||
db_skill = await self._fetch_skill_from_db(skill_id)
|
||||
if db_skill:
|
||||
litellm_skills.append(db_skill)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"SkillsInjectionHook: Skill '{skill_id}' not found in LiteLLM DB"
|
||||
)
|
||||
else:
|
||||
# Native Anthropic skill - pass through
|
||||
anthropic_skills.append(skill)
|
||||
|
||||
# Check if using messages API spec (anthropic_messages call type)
|
||||
# Messages API always uses Anthropic-style tool format
|
||||
use_anthropic_format = call_type == "anthropic_messages"
|
||||
|
||||
if len(litellm_skills) > 0:
|
||||
data = self._process_for_messages_api(
|
||||
data=data,
|
||||
litellm_skills=litellm_skills,
|
||||
use_anthropic_format=use_anthropic_format,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def _process_for_messages_api(
|
||||
self,
|
||||
data: dict,
|
||||
litellm_skills: List[LiteLLM_SkillsTable],
|
||||
use_anthropic_format: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Process skills for messages API (Anthropic format tools).
|
||||
|
||||
- Converts skills to Anthropic-style tools (name, description, input_schema)
|
||||
- Extracts and injects SKILL.md content into system prompt
|
||||
- Adds litellm_code_execution tool for code execution
|
||||
- Stores skill files in metadata for sandbox execution
|
||||
"""
|
||||
from litellm.llms.litellm_proxy.skills.code_execution import (
|
||||
get_litellm_code_execution_tool_anthropic,
|
||||
)
|
||||
|
||||
tools = data.get("tools", [])
|
||||
skill_contents: List[str] = []
|
||||
all_skill_files: Dict[str, Dict[str, bytes]] = {}
|
||||
all_module_paths: List[str] = []
|
||||
|
||||
for skill in litellm_skills:
|
||||
# Convert skill to Anthropic-style tool
|
||||
tools.append(self.prompt_handler.convert_skill_to_anthropic_tool(skill))
|
||||
|
||||
# Extract skill content from file if available
|
||||
content = self.prompt_handler.extract_skill_content(skill)
|
||||
if content:
|
||||
skill_contents.append(content)
|
||||
|
||||
# Extract all files for code execution
|
||||
skill_files = self.prompt_handler.extract_all_files(skill)
|
||||
if skill_files:
|
||||
all_skill_files[skill.skill_id] = skill_files
|
||||
for path in skill_files.keys():
|
||||
if path.endswith(".py"):
|
||||
all_module_paths.append(path)
|
||||
|
||||
if tools:
|
||||
data["tools"] = tools
|
||||
|
||||
# Inject skill content into system prompt
|
||||
# For Anthropic messages API, use top-level 'system' param instead of messages array
|
||||
if skill_contents:
|
||||
data = self.prompt_handler.inject_skill_content_to_messages(
|
||||
data, skill_contents, use_anthropic_format=use_anthropic_format
|
||||
)
|
||||
|
||||
# Add litellm_code_execution tool if we have skill files
|
||||
if all_skill_files:
|
||||
code_exec_tool = get_litellm_code_execution_tool_anthropic()
|
||||
data["tools"] = data.get("tools", []) + [code_exec_tool]
|
||||
|
||||
# Store skill files in litellm_metadata for automatic code execution
|
||||
data["litellm_metadata"] = data.get("litellm_metadata", {})
|
||||
data["litellm_metadata"]["_skill_files"] = all_skill_files
|
||||
data["litellm_metadata"]["_litellm_code_execution_enabled"] = True
|
||||
|
||||
# Remove container (not supported by underlying providers)
|
||||
data.pop("container", None)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Messages API - converted {len(litellm_skills)} skills to Anthropic tools, "
|
||||
f"injected {len(skill_contents)} skill contents, "
|
||||
f"added litellm_code_execution tool with {len(all_module_paths)} modules"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def _process_non_anthropic_model(
|
||||
self,
|
||||
data: dict,
|
||||
litellm_skills: List[LiteLLM_SkillsTable],
|
||||
) -> dict:
|
||||
"""
|
||||
Process skills for non-Anthropic models (OpenAI format tools).
|
||||
|
||||
- Converts skills to OpenAI-style tools
|
||||
- Extracts and injects SKILL.md content
|
||||
- Adds execute_code tool for code execution
|
||||
- Stores skill files in metadata for sandbox execution
|
||||
"""
|
||||
tools = data.get("tools", [])
|
||||
skill_contents: List[str] = []
|
||||
all_skill_files: Dict[str, Dict[str, bytes]] = {}
|
||||
all_module_paths: List[str] = []
|
||||
|
||||
for skill in litellm_skills:
|
||||
# Convert skill to OpenAI-style tool
|
||||
tools.append(self.prompt_handler.convert_skill_to_tool(skill))
|
||||
|
||||
# Extract skill content from file if available
|
||||
content = self.prompt_handler.extract_skill_content(skill)
|
||||
if content:
|
||||
skill_contents.append(content)
|
||||
|
||||
# Extract all files for code execution
|
||||
skill_files = self.prompt_handler.extract_all_files(skill)
|
||||
if skill_files:
|
||||
all_skill_files[skill.skill_id] = skill_files
|
||||
# Collect Python module paths
|
||||
for path in skill_files.keys():
|
||||
if path.endswith(".py"):
|
||||
all_module_paths.append(path)
|
||||
|
||||
if tools:
|
||||
data["tools"] = tools
|
||||
|
||||
# Inject skill content into system prompt
|
||||
if skill_contents:
|
||||
data = self.prompt_handler.inject_skill_content_to_messages(
|
||||
data, skill_contents
|
||||
)
|
||||
|
||||
# Add litellm_code_execution tool if we have skill files
|
||||
if all_skill_files:
|
||||
from litellm.llms.litellm_proxy.skills.code_execution import (
|
||||
get_litellm_code_execution_tool,
|
||||
)
|
||||
|
||||
data["tools"] = data.get("tools", []) + [get_litellm_code_execution_tool()]
|
||||
|
||||
# Store skill files in litellm_metadata for automatic code execution
|
||||
# Using litellm_metadata instead of metadata to avoid conflicts with user metadata
|
||||
data["litellm_metadata"] = data.get("litellm_metadata", {})
|
||||
data["litellm_metadata"]["_skill_files"] = all_skill_files
|
||||
data["litellm_metadata"]["_litellm_code_execution_enabled"] = True
|
||||
|
||||
# Remove container for non-Anthropic (they don't support it)
|
||||
data.pop("container", None)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Non-Anthropic model - converted {len(litellm_skills)} skills to tools, "
|
||||
f"injected {len(skill_contents)} skill contents, "
|
||||
f"added execute_code tool with {len(all_module_paths)} modules"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def _fetch_skill_from_db(
|
||||
self, skill_id: str
|
||||
) -> Optional[LiteLLM_SkillsTable]:
|
||||
"""
|
||||
Fetch a skill from the LiteLLM database.
|
||||
|
||||
Args:
|
||||
skill_id: The skill ID (without 'litellm:' prefix)
|
||||
|
||||
Returns:
|
||||
LiteLLM_SkillsTable or None if not found
|
||||
"""
|
||||
try:
|
||||
from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler
|
||||
|
||||
return await LiteLLMSkillsHandler.fetch_skill_from_db(skill_id)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"SkillsInjectionHook: Error fetching skill {skill_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _is_anthropic_model(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model is an Anthropic model using get_llm_provider.
|
||||
|
||||
Args:
|
||||
model: The model name/identifier
|
||||
|
||||
Returns:
|
||||
True if Anthropic model, False otherwise
|
||||
"""
|
||||
try:
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import (
|
||||
get_llm_provider,
|
||||
)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
return custom_llm_provider == "anthropic"
|
||||
except Exception:
|
||||
# Fallback to simple check if get_llm_provider fails
|
||||
return "claude" in model.lower() or model.lower().startswith("anthropic/")
|
||||
|
||||
async def async_post_call_success_deployment_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
response: Any,
|
||||
call_type: Optional[CallTypes],
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Post-call hook to handle automatic code execution.
|
||||
|
||||
Handles both OpenAI format (response.choices) and Anthropic/messages API
|
||||
format (response["content"]).
|
||||
|
||||
If the response contains a tool call (litellm_code_execution or skill tool):
|
||||
1. Execute the code in sandbox
|
||||
2. Add result to messages
|
||||
3. Make another LLM call
|
||||
4. Repeat until model gives final response
|
||||
5. Return modified response with generated files
|
||||
"""
|
||||
from litellm.llms.litellm_proxy.skills.code_execution import (
|
||||
LiteLLMInternalTools,
|
||||
)
|
||||
|
||||
# Check if code execution is enabled for this request
|
||||
litellm_metadata = request_data.get("litellm_metadata") or {}
|
||||
metadata = request_data.get("metadata") or {}
|
||||
|
||||
code_exec_enabled = litellm_metadata.get(
|
||||
"_litellm_code_execution_enabled"
|
||||
) or metadata.get("_litellm_code_execution_enabled")
|
||||
if not code_exec_enabled:
|
||||
return None
|
||||
|
||||
# Get skill files
|
||||
skill_files_by_id = litellm_metadata.get("_skill_files") or metadata.get(
|
||||
"_skill_files", {}
|
||||
)
|
||||
all_skill_files: Dict[str, bytes] = {}
|
||||
for files_dict in skill_files_by_id.values():
|
||||
all_skill_files.update(files_dict)
|
||||
|
||||
if not all_skill_files:
|
||||
verbose_proxy_logger.warning(
|
||||
"SkillsInjectionHook: No skill files found, cannot execute code"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check for tool calls - handle both Anthropic and OpenAI formats
|
||||
tool_calls = self._extract_tool_calls(response)
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
# Check if any tool call needs execution (litellm_code_execution or skill tool)
|
||||
has_executable_tool = False
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name", "")
|
||||
# Execute if it's litellm_code_execution OR a skill tool (skill_xxx)
|
||||
if (
|
||||
tool_name == LiteLLMInternalTools.CODE_EXECUTION.value
|
||||
or tool_name.startswith("skill_")
|
||||
):
|
||||
has_executable_tool = True
|
||||
break
|
||||
|
||||
if not has_executable_tool:
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"SkillsInjectionHook: Detected tool call, starting execution loop"
|
||||
)
|
||||
|
||||
# Start the agentic loop
|
||||
return await self._execute_code_loop_messages_api(
|
||||
data=request_data,
|
||||
response=response,
|
||||
skill_files=all_skill_files,
|
||||
)
|
||||
|
||||
def _extract_tool_calls(self, response: Any) -> List[Dict[str, Any]]:
|
||||
"""Extract tool calls from response, handling both formats."""
|
||||
tool_calls = []
|
||||
|
||||
# Get content - handle both dict and object responses
|
||||
content = None
|
||||
if isinstance(response, dict):
|
||||
content = response.get("content", [])
|
||||
elif hasattr(response, "content"):
|
||||
content = response.content
|
||||
|
||||
# Anthropic/messages API format: response has "content" list with tool_use blocks
|
||||
if content:
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"name": block.get("name"),
|
||||
"input": block.get("input", {}),
|
||||
}
|
||||
)
|
||||
elif (
|
||||
hasattr(block, "type")
|
||||
and getattr(block, "type", None) == "tool_use"
|
||||
):
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": getattr(block, "id", None),
|
||||
"name": getattr(block, "name", None),
|
||||
"input": getattr(block, "input", {}),
|
||||
}
|
||||
)
|
||||
|
||||
# OpenAI format: response has choices[0].message.tool_calls
|
||||
if not tool_calls and hasattr(response, "choices") and response.choices: # type: ignore[union-attr]
|
||||
msg = response.choices[0].message # type: ignore[union-attr]
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"input": json.loads(tc.function.arguments)
|
||||
if tc.function.arguments
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
async def _execute_code_loop_messages_api(
|
||||
self,
|
||||
data: dict,
|
||||
response: Any,
|
||||
skill_files: Dict[str, bytes],
|
||||
) -> Any:
|
||||
"""
|
||||
Execute the code execution loop for messages API (Anthropic format).
|
||||
|
||||
Returns the final response with generated files inline.
|
||||
"""
|
||||
import litellm
|
||||
from litellm.llms.litellm_proxy.skills.code_execution import (
|
||||
LiteLLMInternalTools,
|
||||
)
|
||||
from litellm.llms.litellm_proxy.skills.sandbox_executor import (
|
||||
SkillsSandboxExecutor,
|
||||
)
|
||||
|
||||
# Ensure response is not None
|
||||
if response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"SkillsInjectionHook: Response is None, cannot execute code loop"
|
||||
)
|
||||
return None
|
||||
|
||||
model = data.get("model", "")
|
||||
messages = list(data.get("messages", []))
|
||||
tools = data.get("tools", [])
|
||||
max_tokens = data.get("max_tokens", 4096)
|
||||
|
||||
executor = SkillsSandboxExecutor(timeout=self.sandbox_timeout)
|
||||
generated_files: List[Dict[str, Any]] = []
|
||||
current_response = response
|
||||
|
||||
for iteration in range(self.max_iterations):
|
||||
# Extract tool calls from current response
|
||||
tool_calls = self._extract_tool_calls(current_response)
|
||||
stop_reason = (
|
||||
current_response.get("stop_reason")
|
||||
if isinstance(current_response, dict)
|
||||
else getattr(current_response, "stop_reason", None)
|
||||
)
|
||||
|
||||
# Get content for assistant message - convert to plain dicts
|
||||
raw_content = (
|
||||
current_response.get("content", [])
|
||||
if isinstance(current_response, dict)
|
||||
else getattr(current_response, "content", [])
|
||||
)
|
||||
content_blocks = []
|
||||
for block in raw_content or []:
|
||||
if isinstance(block, dict):
|
||||
content_blocks.append(block)
|
||||
elif hasattr(block, "model_dump"):
|
||||
content_blocks.append(block.model_dump())
|
||||
elif hasattr(block, "__dict__"):
|
||||
content_blocks.append(dict(block.__dict__))
|
||||
else:
|
||||
content_blocks.append({"type": "text", "text": str(block)})
|
||||
|
||||
# Build assistant message for conversation history (Anthropic format)
|
||||
assistant_msg = {"role": "assistant", "content": content_blocks}
|
||||
messages.append(assistant_msg)
|
||||
|
||||
# Check if we're done (no tool calls)
|
||||
if stop_reason != "tool_use" or not tool_calls:
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Loop completed after {iteration + 1} iterations, "
|
||||
f"{len(generated_files)} files generated"
|
||||
)
|
||||
return self._attach_files_to_response(current_response, generated_files)
|
||||
|
||||
# Process tool calls
|
||||
tool_results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name", "")
|
||||
tool_id = tc.get("id", "")
|
||||
tool_input = tc.get("input", {})
|
||||
|
||||
# Execute if it's litellm_code_execution OR a skill tool
|
||||
if tool_name == LiteLLMInternalTools.CODE_EXECUTION.value:
|
||||
code = tool_input.get("code", "")
|
||||
result = await self._execute_code(
|
||||
code, skill_files, executor, generated_files
|
||||
)
|
||||
elif tool_name.startswith("skill_"):
|
||||
# Skill tool - execute the skill's code
|
||||
result = await self._execute_skill_tool(
|
||||
tool_name, tool_input, skill_files, executor, generated_files
|
||||
)
|
||||
else:
|
||||
result = f"Tool '{tool_name}' not handled"
|
||||
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_id,
|
||||
"content": result,
|
||||
}
|
||||
)
|
||||
|
||||
# Add tool results to messages (Anthropic format)
|
||||
messages.append({"role": "user", "content": tool_results})
|
||||
|
||||
# Make next LLM call
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Making LLM call iteration {iteration + 2}"
|
||||
)
|
||||
try:
|
||||
current_response = await litellm.anthropic.acreate(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if current_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"SkillsInjectionHook: LLM call returned None"
|
||||
)
|
||||
return self._attach_files_to_response(response, generated_files)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"SkillsInjectionHook: LLM call failed: {e}")
|
||||
return self._attach_files_to_response(response, generated_files)
|
||||
|
||||
verbose_proxy_logger.warning(
|
||||
f"SkillsInjectionHook: Max iterations ({self.max_iterations}) reached"
|
||||
)
|
||||
return self._attach_files_to_response(current_response, generated_files)
|
||||
|
||||
async def _execute_code(
|
||||
self,
|
||||
code: str,
|
||||
skill_files: Dict[str, bytes],
|
||||
executor: Any,
|
||||
generated_files: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Execute code in sandbox and return result string."""
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Executing code ({len(code)} chars)"
|
||||
)
|
||||
|
||||
exec_result = executor.execute(code=code, skill_files=skill_files)
|
||||
|
||||
result = exec_result.get("output", "") or ""
|
||||
|
||||
# Collect generated files
|
||||
if exec_result.get("files"):
|
||||
for f in exec_result["files"]:
|
||||
generated_files.append(
|
||||
{
|
||||
"name": f["name"],
|
||||
"mime_type": f["mime_type"],
|
||||
"content_base64": f["content_base64"],
|
||||
"size": len(base64.b64decode(f["content_base64"])),
|
||||
}
|
||||
)
|
||||
result += f"\n\nGenerated file: {f['name']}"
|
||||
|
||||
if exec_result.get("error"):
|
||||
result += f"\n\nError: {exec_result['error']}"
|
||||
|
||||
return result or "Code executed successfully"
|
||||
except Exception as e:
|
||||
return f"Code execution failed: {str(e)}"
|
||||
|
||||
async def _execute_skill_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
skill_files: Dict[str, bytes],
|
||||
executor: Any,
|
||||
generated_files: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Execute a skill tool by generating and running code based on skill content."""
|
||||
# Generate code based on available skill modules
|
||||
# Look for Python modules in the skill
|
||||
python_modules = [
|
||||
p
|
||||
for p in skill_files.keys()
|
||||
if p.endswith(".py") and not p.endswith("__init__.py")
|
||||
]
|
||||
|
||||
# Try to find the main builder/creator module
|
||||
main_module = None
|
||||
for mod in python_modules:
|
||||
if (
|
||||
"builder" in mod.lower()
|
||||
or "creator" in mod.lower()
|
||||
or "generator" in mod.lower()
|
||||
):
|
||||
main_module = mod
|
||||
break
|
||||
|
||||
if not main_module and python_modules:
|
||||
# Use first non-init module
|
||||
main_module = python_modules[0]
|
||||
|
||||
if main_module:
|
||||
# Convert path to import: "core/gif_builder.py" -> "core.gif_builder"
|
||||
import_path = main_module.replace("/", ".").replace(".py", "")
|
||||
|
||||
# Generate code that imports and uses the module
|
||||
code = f"""
|
||||
# Auto-generated code to execute skill
|
||||
import sys
|
||||
sys.path.insert(0, '/sandbox')
|
||||
|
||||
from {import_path} import *
|
||||
|
||||
# Try to find and use a Builder/Creator class
|
||||
import inspect
|
||||
module = __import__('{import_path}', fromlist=[''])
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and name != 'object':
|
||||
try:
|
||||
instance = obj()
|
||||
# Try common methods
|
||||
if hasattr(instance, 'create'):
|
||||
result = instance.create()
|
||||
elif hasattr(instance, 'build'):
|
||||
result = instance.build()
|
||||
elif hasattr(instance, 'generate'):
|
||||
result = instance.generate()
|
||||
elif hasattr(instance, 'save'):
|
||||
instance.save('output.gif')
|
||||
print(f'Used {{name}} class')
|
||||
break
|
||||
except Exception as e:
|
||||
print(f'Error with {{name}}: {{e}}')
|
||||
continue
|
||||
|
||||
# List generated files
|
||||
import os
|
||||
for f in os.listdir('.'):
|
||||
if f.endswith(('.gif', '.png', '.jpg')):
|
||||
print(f'Generated: {{f}}')
|
||||
"""
|
||||
else:
|
||||
# Fallback generic code
|
||||
code = """
|
||||
print('No executable skill module found')
|
||||
"""
|
||||
|
||||
return await self._execute_code(code, skill_files, executor, generated_files)
|
||||
|
||||
async def _execute_code_loop(
|
||||
self,
|
||||
data: dict,
|
||||
response: Any,
|
||||
skill_files: Dict[str, bytes],
|
||||
) -> Any:
|
||||
"""
|
||||
Execute the code execution loop until model gives final response.
|
||||
|
||||
Returns the final response with generated files inline.
|
||||
"""
|
||||
import litellm
|
||||
from litellm.llms.litellm_proxy.skills.code_execution import (
|
||||
LiteLLMInternalTools,
|
||||
)
|
||||
from litellm.llms.litellm_proxy.skills.sandbox_executor import (
|
||||
SkillsSandboxExecutor,
|
||||
)
|
||||
|
||||
model = data.get("model", "")
|
||||
messages = list(data.get("messages", []))
|
||||
tools = data.get("tools", [])
|
||||
|
||||
# Keys to exclude when passing through to acompletion
|
||||
# These are either handled explicitly or are internal LiteLLM fields
|
||||
_EXCLUDED_ACOMPLETION_KEYS = frozenset(
|
||||
{
|
||||
"messages",
|
||||
"model",
|
||||
"tools",
|
||||
"metadata",
|
||||
"litellm_metadata",
|
||||
"container",
|
||||
}
|
||||
)
|
||||
|
||||
kwargs = {k: v for k, v in data.items() if k not in _EXCLUDED_ACOMPLETION_KEYS}
|
||||
|
||||
executor = SkillsSandboxExecutor(timeout=self.sandbox_timeout)
|
||||
generated_files: List[Dict[str, Any]] = []
|
||||
current_response: Any = response
|
||||
|
||||
for iteration in range(self.max_iterations):
|
||||
# OpenAI format response has choices[0].message
|
||||
assistant_message = current_response.choices[0].message # type: ignore[union-attr]
|
||||
stop_reason = current_response.choices[0].finish_reason # type: ignore[union-attr]
|
||||
|
||||
# Build assistant message for conversation history
|
||||
assistant_msg_dict: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
}
|
||||
if assistant_message.tool_calls:
|
||||
assistant_msg_dict["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in assistant_message.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg_dict)
|
||||
|
||||
# Check if we're done (no tool calls)
|
||||
if stop_reason != "tool_calls" or not assistant_message.tool_calls:
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Code execution loop completed after "
|
||||
f"{iteration + 1} iterations, {len(generated_files)} files generated"
|
||||
)
|
||||
# Attach generated files to response
|
||||
return self._attach_files_to_response(current_response, generated_files)
|
||||
|
||||
# Process tool calls
|
||||
for tool_call in assistant_message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
|
||||
if tool_name == LiteLLMInternalTools.CODE_EXECUTION.value:
|
||||
tool_result = await self._execute_code_tool(
|
||||
tool_call=tool_call,
|
||||
skill_files=skill_files,
|
||||
executor=executor,
|
||||
generated_files=generated_files,
|
||||
)
|
||||
else:
|
||||
# Non-code-execution tool - cannot handle
|
||||
tool_result = f"Tool '{tool_name}' not handled automatically"
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
# Make next LLM call using the messages API
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Making LLM call iteration {iteration + 2}"
|
||||
)
|
||||
current_response = await litellm.anthropic.acreate(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
max_tokens=kwargs.get("max_tokens", 4096),
|
||||
)
|
||||
|
||||
# Max iterations reached
|
||||
verbose_proxy_logger.warning(
|
||||
f"SkillsInjectionHook: Max iterations ({self.max_iterations}) reached"
|
||||
)
|
||||
return self._attach_files_to_response(current_response, generated_files)
|
||||
|
||||
async def _execute_code_tool(
|
||||
self,
|
||||
tool_call: Any,
|
||||
skill_files: Dict[str, bytes],
|
||||
executor: Any,
|
||||
generated_files: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Execute a litellm_code_execution tool call and return result string."""
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
code = args.get("code", "")
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Executing code ({len(code)} chars)"
|
||||
)
|
||||
|
||||
exec_result = executor.execute(
|
||||
code=code,
|
||||
skill_files=skill_files,
|
||||
)
|
||||
|
||||
# Build tool result content
|
||||
tool_result = exec_result.get("output", "") or ""
|
||||
|
||||
# Collect generated files
|
||||
if exec_result.get("files"):
|
||||
tool_result += "\n\nGenerated files:"
|
||||
for f in exec_result["files"]:
|
||||
file_content = base64.b64decode(f["content_base64"])
|
||||
generated_files.append(
|
||||
{
|
||||
"name": f["name"],
|
||||
"mime_type": f["mime_type"],
|
||||
"content_base64": f["content_base64"],
|
||||
"size": len(file_content),
|
||||
}
|
||||
)
|
||||
tool_result += f"\n- {f['name']} ({len(file_content)} bytes)"
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Generated file {f['name']} "
|
||||
f"({len(file_content)} bytes)"
|
||||
)
|
||||
|
||||
if exec_result.get("error"):
|
||||
tool_result += f"\n\nError:\n{exec_result['error']}"
|
||||
|
||||
return tool_result
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"SkillsInjectionHook: Code execution failed: {e}"
|
||||
)
|
||||
return f"Code execution failed: {str(e)}"
|
||||
|
||||
def _attach_files_to_response(
|
||||
self,
|
||||
response: Any,
|
||||
generated_files: List[Dict[str, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
Attach generated files to the response object.
|
||||
|
||||
Files are added to response._litellm_generated_files for easy access.
|
||||
For dict responses, files are added as a key.
|
||||
"""
|
||||
if not generated_files:
|
||||
return response
|
||||
|
||||
# Handle dict response (Anthropic/messages API format)
|
||||
if isinstance(response, dict):
|
||||
response["_litellm_generated_files"] = generated_files
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Attached {len(generated_files)} files to dict response"
|
||||
)
|
||||
return response
|
||||
|
||||
# Handle object response (OpenAI format)
|
||||
try:
|
||||
response._litellm_generated_files = generated_files
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Also add to model_extra if available (for serialization)
|
||||
if hasattr(response, "model_extra"):
|
||||
if response.model_extra is None:
|
||||
response.model_extra = {}
|
||||
response.model_extra["_litellm_generated_files"] = generated_files
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"SkillsInjectionHook: Attached {len(generated_files)} files to response"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Global instance for registration
|
||||
skills_injection_hook = SkillsInjectionHook()
|
||||
|
||||
import litellm
|
||||
|
||||
litellm.logging_callback_manager.add_litellm_callback(skills_injection_hook)
|
||||
@@ -0,0 +1,49 @@
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_MaxBudgetLimiter(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
|
||||
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
|
||||
user_row = await cache.async_get_cache(
|
||||
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
|
||||
)
|
||||
if user_row is None: # value not yet cached
|
||||
return
|
||||
max_budget = user_row["max_budget"]
|
||||
curr_spend = user_row["spend"]
|
||||
|
||||
if max_budget is None:
|
||||
return
|
||||
|
||||
if curr_spend is None:
|
||||
return
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
if curr_spend >= max_budget:
|
||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Per-Session Budget Limiter for LiteLLM Proxy.
|
||||
|
||||
Enforces a dollar-amount cap per session (identified by `session_id` /
|
||||
`x-litellm-trace-id`). After each successful LLM call the response cost is
|
||||
accumulated against the session. When the accumulated spend exceeds
|
||||
`max_budget_per_session` (configured in agent litellm_params), subsequent
|
||||
requests for that session receive a 429.
|
||||
|
||||
Note: trace-id enforcement (require_trace_id_on_calls_by_agent) is handled
|
||||
separately in auth_checks.py at the agent level, not in this hook.
|
||||
|
||||
Works across multiple proxy instances via DualCache (in-memory + Redis).
|
||||
Follows the same pattern as max_iterations_limiter.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
else:
|
||||
InternalUsageCache = Any
|
||||
|
||||
|
||||
# Redis Lua script for atomic float increment with TTL.
|
||||
# INCRBYFLOAT returns the new value as a string.
|
||||
# Only sets EXPIRE on first call (when prior value was nil).
|
||||
MAX_BUDGET_SESSION_INCREMENT_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local amount = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
|
||||
local existed = redis.call('EXISTS', key)
|
||||
local new_val = redis.call('INCRBYFLOAT', key, amount)
|
||||
if existed == 0 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return new_val
|
||||
"""
|
||||
|
||||
# Default TTL for session budget counters (1 hour)
|
||||
DEFAULT_MAX_BUDGET_PER_SESSION_TTL = 3600
|
||||
|
||||
|
||||
class _PROXY_MaxBudgetPerSessionHandler(CustomLogger):
|
||||
"""
|
||||
Pre-call hook that enforces max_budget_per_session.
|
||||
|
||||
Configuration (set in agent litellm_params):
|
||||
- max_budget_per_session: dollar cap per session_id
|
||||
|
||||
Cache key pattern:
|
||||
{session_budget:<session_id>}:spend
|
||||
"""
|
||||
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.ttl = int(
|
||||
os.getenv(
|
||||
"LITELLM_MAX_BUDGET_PER_SESSION_TTL",
|
||||
DEFAULT_MAX_BUDGET_PER_SESSION_TTL,
|
||||
)
|
||||
)
|
||||
|
||||
if self.internal_usage_cache.dual_cache.redis_cache is not None:
|
||||
self.increment_script = (
|
||||
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
|
||||
MAX_BUDGET_SESSION_INCREMENT_SCRIPT
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.increment_script = None
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Before each LLM call, check if max_budget_per_session is set and
|
||||
whether accumulated spend exceeds the budget (429 if so).
|
||||
"""
|
||||
max_budget = self._get_max_budget_per_session(user_api_key_dict)
|
||||
|
||||
session_id = self._get_session_id(data)
|
||||
|
||||
if max_budget is None or session_id is None:
|
||||
return None
|
||||
|
||||
max_budget = float(max_budget)
|
||||
cache_key = self._make_cache_key(session_id)
|
||||
current_spend = await self._get_current_spend(cache_key)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"MaxBudgetPerSessionHandler: session_id=%s, spend=%.4f, max=%.2f",
|
||||
session_id,
|
||||
current_spend,
|
||||
max_budget,
|
||||
)
|
||||
|
||||
if current_spend >= max_budget:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=(
|
||||
f"Session budget exceeded for session {session_id}. "
|
||||
f"Current spend: ${current_spend:.4f}, "
|
||||
f"max_budget_per_session: ${max_budget:.2f}."
|
||||
),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
After a successful LLM call, increment the session spend by the response cost.
|
||||
"""
|
||||
try:
|
||||
litellm_params = kwargs.get("litellm_params") or {}
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
session_id = metadata.get("session_id")
|
||||
if session_id is None:
|
||||
return
|
||||
|
||||
agent_id = metadata.get("agent_id")
|
||||
if agent_id is None:
|
||||
return
|
||||
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry,
|
||||
)
|
||||
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id=str(agent_id))
|
||||
if agent is None:
|
||||
return
|
||||
|
||||
agent_litellm_params = agent.litellm_params or {}
|
||||
max_budget = agent_litellm_params.get("max_budget_per_session")
|
||||
if max_budget is None:
|
||||
return
|
||||
|
||||
response_cost = kwargs.get("response_cost") or 0.0
|
||||
if response_cost <= 0:
|
||||
return
|
||||
|
||||
cache_key = self._make_cache_key(str(session_id))
|
||||
await self._increment_spend(cache_key, float(response_cost))
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"MaxBudgetPerSessionHandler: incremented session %s spend by %.6f",
|
||||
session_id,
|
||||
response_cost,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"MaxBudgetPerSessionHandler: error in async_log_success_event: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
def _get_session_id(self, data: dict) -> Optional[str]:
|
||||
"""Extract session_id from request metadata."""
|
||||
metadata = data.get("metadata") or {}
|
||||
session_id = metadata.get("session_id")
|
||||
if session_id is not None:
|
||||
return str(session_id)
|
||||
|
||||
litellm_metadata = data.get("litellm_metadata") or {}
|
||||
session_id = litellm_metadata.get("session_id")
|
||||
if session_id is not None:
|
||||
return str(session_id)
|
||||
|
||||
return None
|
||||
|
||||
def _get_max_budget_per_session(
|
||||
self, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> Optional[float]:
|
||||
"""Extract max_budget_per_session from agent litellm_params."""
|
||||
agent_id = user_api_key_dict.agent_id
|
||||
if agent_id is None:
|
||||
return None
|
||||
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
return None
|
||||
|
||||
litellm_params = agent.litellm_params or {}
|
||||
max_budget = litellm_params.get("max_budget_per_session")
|
||||
if max_budget is not None:
|
||||
return float(max_budget)
|
||||
return None
|
||||
|
||||
def _make_cache_key(self, session_id: str) -> str:
|
||||
return f"{{session_budget:{session_id}}}:spend"
|
||||
|
||||
async def _get_current_spend(self, cache_key: str) -> float:
|
||||
"""Read current accumulated spend for a session."""
|
||||
if self.internal_usage_cache.dual_cache.redis_cache is not None:
|
||||
try:
|
||||
result = await self.internal_usage_cache.dual_cache.redis_cache.async_get_cache(
|
||||
key=cache_key
|
||||
)
|
||||
if result is not None:
|
||||
return float(result)
|
||||
return 0.0
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"MaxBudgetPerSessionHandler: Redis GET failed, "
|
||||
"falling back to in-memory: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
result = await self.internal_usage_cache.async_get_cache(
|
||||
key=cache_key,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
if result is not None:
|
||||
return float(result)
|
||||
return 0.0
|
||||
|
||||
async def _increment_spend(self, cache_key: str, amount: float) -> float:
|
||||
"""Atomically increment the session spend and return the new value."""
|
||||
if self.increment_script is not None:
|
||||
try:
|
||||
result = await self.increment_script(
|
||||
keys=[cache_key],
|
||||
args=[str(amount), self.ttl],
|
||||
)
|
||||
return float(result)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"MaxBudgetPerSessionHandler: Redis INCRBYFLOAT failed, "
|
||||
"falling back to in-memory: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
return await self._in_memory_increment_spend(cache_key, amount)
|
||||
|
||||
async def _in_memory_increment_spend(self, cache_key: str, amount: float) -> float:
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=cache_key,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
new_value = (float(current) if current is not None else 0.0) + amount
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=new_value,
|
||||
ttl=self.ttl,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
return new_value
|
||||
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Max Iterations Limiter for LiteLLM Proxy.
|
||||
|
||||
Enforces a per-session cap on the number of LLM calls an agentic loop can make.
|
||||
Callers send a `session_id` with each request (via `x-litellm-session-id` header
|
||||
or `metadata.session_id`), and this hook counts calls per session. When the count
|
||||
exceeds `max_iterations` (configured in agent litellm_params or key metadata), returns 429.
|
||||
|
||||
Works across multiple proxy instances via DualCache (in-memory + Redis).
|
||||
Follows the same pattern as parallel_request_limiter_v3.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
else:
|
||||
InternalUsageCache = Any
|
||||
|
||||
|
||||
# Redis Lua script for atomic increment with TTL.
|
||||
# Returns the new count after increment.
|
||||
# Only sets EXPIRE on first increment (when count becomes 1).
|
||||
MAX_ITERATIONS_INCREMENT_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('INCR', key)
|
||||
if current == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return current
|
||||
"""
|
||||
|
||||
# Default TTL for session iteration counters (1 hour)
|
||||
DEFAULT_MAX_ITERATIONS_TTL = 3600
|
||||
|
||||
|
||||
class _PROXY_MaxIterationsHandler(CustomLogger):
|
||||
"""
|
||||
Pre-call hook that enforces max_iterations per session.
|
||||
|
||||
Configuration:
|
||||
- max_iterations: set in agent litellm_params (preferred)
|
||||
e.g. litellm_params={"max_iterations": 25}
|
||||
Falls back to key metadata max_iterations for backwards compatibility.
|
||||
- session_id: sent by caller via x-litellm-session-id header or
|
||||
metadata.session_id in request body
|
||||
|
||||
Cache key pattern:
|
||||
{session_iterations:<session_id>}:count
|
||||
|
||||
Multi-instance support:
|
||||
Uses Redis Lua script for atomic increment (same pattern as
|
||||
parallel_request_limiter_v3). Falls back to in-memory cache
|
||||
when Redis is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.ttl = int(
|
||||
os.getenv("LITELLM_MAX_ITERATIONS_TTL", DEFAULT_MAX_ITERATIONS_TTL)
|
||||
)
|
||||
|
||||
# Register Lua script with Redis if available (same pattern as v3 limiter)
|
||||
if self.internal_usage_cache.dual_cache.redis_cache is not None:
|
||||
self.increment_script = (
|
||||
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
|
||||
MAX_ITERATIONS_INCREMENT_SCRIPT
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.increment_script = None
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Check session iteration count before making the API call.
|
||||
|
||||
Extracts session_id from request metadata and max_iterations from
|
||||
agent litellm_params. If the session has exceeded max_iterations, raises 429.
|
||||
"""
|
||||
# Extract session_id from request data
|
||||
session_id = self._get_session_id(data)
|
||||
if session_id is None:
|
||||
return None
|
||||
|
||||
max_iterations = self._get_max_iterations(user_api_key_dict)
|
||||
if max_iterations is None:
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"MaxIterationsHandler: session_id=%s, max_iterations=%s",
|
||||
session_id,
|
||||
max_iterations,
|
||||
)
|
||||
|
||||
# Increment and check
|
||||
cache_key = self._make_cache_key(session_id)
|
||||
current_count = await self._increment_and_get(cache_key)
|
||||
|
||||
if current_count > max_iterations:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=(
|
||||
f"Max iterations exceeded for session {session_id}. "
|
||||
f"Current count: {current_count}, max_iterations: {max_iterations}."
|
||||
),
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"MaxIterationsHandler: session_id=%s, count=%s/%s",
|
||||
session_id,
|
||||
current_count,
|
||||
max_iterations,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_session_id(self, data: dict) -> Optional[str]:
|
||||
"""Extract session_id from request metadata."""
|
||||
metadata = data.get("metadata") or {}
|
||||
session_id = metadata.get("session_id")
|
||||
if session_id is not None:
|
||||
return str(session_id)
|
||||
|
||||
# Also check litellm_metadata (used for /thread and /assistant endpoints)
|
||||
litellm_metadata = data.get("litellm_metadata") or {}
|
||||
session_id = litellm_metadata.get("session_id")
|
||||
if session_id is not None:
|
||||
return str(session_id)
|
||||
|
||||
return None
|
||||
|
||||
def _get_max_iterations(self, user_api_key_dict: UserAPIKeyAuth) -> Optional[int]:
|
||||
"""Extract max_iterations from agent litellm_params, with fallback to key metadata."""
|
||||
# Try agent litellm_params first
|
||||
agent_id = user_api_key_dict.agent_id
|
||||
if agent_id is not None:
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry,
|
||||
)
|
||||
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is not None:
|
||||
litellm_params = agent.litellm_params or {}
|
||||
max_iterations = litellm_params.get("max_iterations")
|
||||
if max_iterations is not None:
|
||||
return int(max_iterations)
|
||||
|
||||
# Fallback to key metadata for backwards compatibility
|
||||
metadata = user_api_key_dict.metadata or {}
|
||||
max_iterations = metadata.get("max_iterations")
|
||||
if max_iterations is not None:
|
||||
return int(max_iterations)
|
||||
return None
|
||||
|
||||
def _make_cache_key(self, session_id: str) -> str:
|
||||
"""
|
||||
Create cache key for session iteration counter.
|
||||
|
||||
Uses Redis hash-tag pattern {session_iterations:<session_id>} so all
|
||||
keys for a session land on the same Redis Cluster slot.
|
||||
"""
|
||||
return f"{{session_iterations:{session_id}}}:count"
|
||||
|
||||
async def _increment_and_get(self, cache_key: str) -> int:
|
||||
"""
|
||||
Atomically increment the session counter and return the new value.
|
||||
|
||||
Tries Redis first (via registered Lua script for atomicity across
|
||||
instances), falls back to in-memory cache.
|
||||
"""
|
||||
if self.increment_script is not None:
|
||||
try:
|
||||
result = await self.increment_script(
|
||||
keys=[cache_key],
|
||||
args=[self.ttl],
|
||||
)
|
||||
return int(result)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"MaxIterationsHandler: Redis failed, falling back to in-memory: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Fallback: in-memory cache
|
||||
return await self._in_memory_increment(cache_key)
|
||||
|
||||
async def _in_memory_increment(self, cache_key: str) -> int:
|
||||
"""Increment counter in in-memory cache with TTL."""
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=cache_key,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
new_value = (int(current) if current is not None else 0) + 1
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=new_value,
|
||||
ttl=self.ttl,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
return new_value
|
||||
@@ -0,0 +1,96 @@
|
||||
# MCP Semantic Tool Filter Architecture
|
||||
|
||||
## Why Filter MCP Tools
|
||||
|
||||
When multiple MCP servers are connected, the proxy may expose hundreds of tools. Sending all tools in every request wastes context window tokens and increases cost. The semantic filter keeps only the top-K most relevant tools based on embedding similarity.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client
|
||||
participant Hook as SemanticToolFilterHook
|
||||
participant Filter as SemanticMCPToolFilter
|
||||
participant Router as semantic-router
|
||||
participant LLM
|
||||
|
||||
Client->>Hook: POST /chat/completions
|
||||
Note over Client,Hook: tools: [100+ MCP tools]
|
||||
Note over Client,Hook: messages: [{"role": "user", "content": "Get my Jira issues"}]
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Hook: 1. Extract User Query
|
||||
Hook->>Filter: filter_tools("Get my Jira issues", tools)
|
||||
end
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Filter: 2. Convert Tools → Routes
|
||||
Note over Filter: Tool name + description → Route
|
||||
end
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Filter: 3. Semantic Matching
|
||||
Filter->>Router: router(query)
|
||||
Router->>Router: Embeddings + similarity
|
||||
Router-->>Filter: [top 10 matches]
|
||||
end
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Filter: 4. Return Filtered Tools
|
||||
Filter-->>Hook: [10 relevant tools]
|
||||
end
|
||||
|
||||
Hook->>LLM: POST /chat/completions
|
||||
Note over Hook,LLM: tools: [10 Jira-related tools] ← FILTERED
|
||||
Note over Hook,LLM: messages: [...] ← UNCHANGED
|
||||
|
||||
LLM-->>Client: Response (unchanged)
|
||||
```
|
||||
|
||||
## Filter Operations
|
||||
|
||||
The hook intercepts requests before they reach the LLM:
|
||||
|
||||
| Operation | Description |
|
||||
|-----------|-------------|
|
||||
| **Extract query** | Get user message from `messages[-1]` |
|
||||
| **Convert to Routes** | Transform MCP tools into semantic-router Routes |
|
||||
| **Semantic match** | Use `semantic-router` to find top-K similar tools |
|
||||
| **Filter tools** | Replace request `tools` with filtered subset |
|
||||
|
||||
## Trigger Conditions
|
||||
|
||||
The filter only runs when:
|
||||
- Call type is `completion` or `acompletion`
|
||||
- Request contains `tools` field
|
||||
- Request contains `messages` field
|
||||
- Filter is enabled in config
|
||||
|
||||
## What Does NOT Change
|
||||
|
||||
- Request messages
|
||||
- Response body
|
||||
- Non-tool parameters
|
||||
|
||||
## Integration with semantic-router
|
||||
|
||||
Reuses existing LiteLLM infrastructure:
|
||||
- `semantic-router` - Already an optional dependency
|
||||
- `LiteLLMRouterEncoder` - Wraps `Router.aembedding()` for embeddings
|
||||
- `SemanticRouter` - Handles similarity calculation and top-K selection
|
||||
|
||||
## Configuration
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
mcp_semantic_tool_filter:
|
||||
enabled: true
|
||||
embedding_model: "openai/text-embedding-3-small"
|
||||
top_k: 10
|
||||
similarity_threshold: 0.3
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The filter fails gracefully:
|
||||
- If filtering fails → Return all tools (no impact on functionality)
|
||||
- If query extraction fails → Skip filtering
|
||||
- If no matches found → Return all tools
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
MCP Semantic Tool Filter Hook
|
||||
|
||||
Semantic filtering for MCP tools to reduce context window size
|
||||
and improve tool selection accuracy.
|
||||
"""
|
||||
from litellm.proxy.hooks.mcp_semantic_filter.hook import SemanticToolFilterHook
|
||||
|
||||
__all__ = ["SemanticToolFilterHook"]
|
||||
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
Semantic Tool Filter Hook
|
||||
|
||||
Pre-call hook that filters MCP tools semantically before LLM inference.
|
||||
Reduces context window size and improves tool selection accuracy.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import (
|
||||
DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL,
|
||||
DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD,
|
||||
DEFAULT_MCP_SEMANTIC_FILTER_TOP_K,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._experimental.mcp_server.semantic_tool_filter import (
|
||||
SemanticMCPToolFilter,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.router import Router
|
||||
|
||||
|
||||
class SemanticToolFilterHook(CustomLogger):
|
||||
"""
|
||||
Pre-call hook that filters MCP tools semantically.
|
||||
|
||||
This hook:
|
||||
1. Extracts the user query from messages
|
||||
2. Filters tools based on semantic similarity to the query
|
||||
3. Returns only the top-k most relevant tools to the LLM
|
||||
"""
|
||||
|
||||
def __init__(self, semantic_filter: "SemanticMCPToolFilter"):
|
||||
"""
|
||||
Initialize the hook.
|
||||
|
||||
Args:
|
||||
semantic_filter: SemanticMCPToolFilter instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.filter = semantic_filter
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Initialized SemanticToolFilterHook with filter: "
|
||||
f"enabled={semantic_filter.enabled}, top_k={semantic_filter.top_k}"
|
||||
)
|
||||
|
||||
def _should_expand_mcp_tools(self, tools: List[Any]) -> bool:
|
||||
"""
|
||||
Check if tools contain MCP references with server_url="litellm_proxy".
|
||||
|
||||
Only expands MCP tools pointing to litellm proxy, not external MCP servers.
|
||||
"""
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
return LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway(tools)
|
||||
|
||||
async def _expand_mcp_tools(
|
||||
self,
|
||||
tools: List[Any],
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Expand MCP references to actual tool definitions.
|
||||
|
||||
Reuses LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format
|
||||
which internally does: parse -> fetch -> filter -> deduplicate -> transform
|
||||
"""
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
# Parse to separate MCP tools from other tools
|
||||
mcp_tools, _ = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
|
||||
|
||||
if not mcp_tools:
|
||||
return []
|
||||
|
||||
# Use single combined method instead of 3 separate calls
|
||||
# This already handles: fetch -> filter by allowed_tools -> deduplicate -> transform
|
||||
(
|
||||
openai_tools,
|
||||
_,
|
||||
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format(
|
||||
user_api_key_auth=user_api_key_dict, mcp_tools_with_litellm_proxy=mcp_tools
|
||||
)
|
||||
|
||||
# Convert Pydantic models to dicts for compatibility
|
||||
openai_tools_as_dicts = []
|
||||
for tool in openai_tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
tool_dict = tool.model_dump(exclude_none=True)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Converted Pydantic tool to dict: {type(tool).__name__} -> dict with keys: {list(tool_dict.keys())}"
|
||||
)
|
||||
openai_tools_as_dicts.append(tool_dict)
|
||||
elif hasattr(tool, "dict"):
|
||||
tool_dict = tool.dict(exclude_none=True)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Converted Pydantic tool (v1) to dict: {type(tool).__name__} -> dict"
|
||||
)
|
||||
openai_tools_as_dicts.append(tool_dict)
|
||||
elif isinstance(tool, dict):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Tool is already a dict with keys: {list(tool.keys())}"
|
||||
)
|
||||
openai_tools_as_dicts.append(tool)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Tool is unknown type: {type(tool)}, passing as-is"
|
||||
)
|
||||
openai_tools_as_dicts.append(tool)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Expanded {len(mcp_tools)} MCP reference(s) to {len(openai_tools_as_dicts)} tools (all as dicts)"
|
||||
)
|
||||
|
||||
return openai_tools_as_dicts
|
||||
|
||||
def _get_metadata_variable_name(self, data: dict) -> str:
|
||||
if "litellm_metadata" in data:
|
||||
return "litellm_metadata"
|
||||
return "metadata"
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
cache: "DualCache",
|
||||
data: dict,
|
||||
call_type: str,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Filter tools before LLM call based on user query.
|
||||
|
||||
This hook is called before the LLM request is made. It filters the
|
||||
tools list to only include semantically relevant tools.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User authentication
|
||||
cache: Cache instance
|
||||
data: Request data containing messages and tools
|
||||
call_type: Type of call (completion, acompletion, etc.)
|
||||
|
||||
Returns:
|
||||
Modified data dict with filtered tools, or None if no changes
|
||||
"""
|
||||
# Only filter endpoints that support tools
|
||||
if call_type not in ("completion", "acompletion", "aresponses"):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Skipping semantic filter for call_type={call_type}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if tools are present
|
||||
tools = data.get("tools")
|
||||
if not tools:
|
||||
verbose_proxy_logger.debug("No tools in request, skipping semantic filter")
|
||||
return None
|
||||
|
||||
original_tool_count = len(tools)
|
||||
|
||||
# Check for MCP references (server_url="litellm_proxy") and expand them
|
||||
if self._should_expand_mcp_tools(tools):
|
||||
verbose_proxy_logger.debug(
|
||||
"Detected litellm_proxy MCP references, expanding before semantic filtering"
|
||||
)
|
||||
|
||||
try:
|
||||
expanded_tools = await self._expand_mcp_tools(tools, user_api_key_dict)
|
||||
|
||||
if not expanded_tools:
|
||||
verbose_proxy_logger.warning(
|
||||
"No tools expanded from MCP references"
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Expanded {len(tools)} MCP reference(s) to {len(expanded_tools)} tools"
|
||||
)
|
||||
|
||||
# Update tools for filtering
|
||||
tools = expanded_tools
|
||||
original_tool_count = len(tools)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Failed to expand MCP references: {e}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if messages are present (try both "messages" and "input" for responses API)
|
||||
messages = data.get("messages", [])
|
||||
if not messages:
|
||||
messages = data.get("input", [])
|
||||
if not messages:
|
||||
verbose_proxy_logger.debug(
|
||||
"No messages in request, skipping semantic filter"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if filter is enabled
|
||||
if not self.filter.enabled:
|
||||
verbose_proxy_logger.debug("Semantic filter disabled, skipping")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract user query from messages
|
||||
user_query = self.filter.extract_user_query(messages)
|
||||
if not user_query:
|
||||
verbose_proxy_logger.debug(
|
||||
"No user query found, skipping semantic filter"
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Applying semantic filter to {len(tools)} tools "
|
||||
f"with query: '{user_query[:50]}...'"
|
||||
)
|
||||
|
||||
# Filter tools semantically
|
||||
filtered_tools = await self.filter.filter_tools(
|
||||
query=user_query,
|
||||
available_tools=tools, # type: ignore
|
||||
)
|
||||
|
||||
# Always update tools and emit header (even if count unchanged)
|
||||
data["tools"] = filtered_tools
|
||||
|
||||
# Store filter stats and tool names for response header
|
||||
filter_stats = f"{original_tool_count}->{len(filtered_tools)}"
|
||||
tool_names_csv = self._get_tool_names_csv(filtered_tools)
|
||||
|
||||
_metadata_variable_name = self._get_metadata_variable_name(data)
|
||||
data[_metadata_variable_name][
|
||||
"litellm_semantic_filter_stats"
|
||||
] = filter_stats
|
||||
data[_metadata_variable_name][
|
||||
"litellm_semantic_filter_tools"
|
||||
] = tool_names_csv
|
||||
|
||||
verbose_proxy_logger.info(f"Semantic tool filter: {filter_stats} tools")
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Semantic tool filter hook failed: {e}. Proceeding with all tools."
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_post_call_response_headers_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
response: Any,
|
||||
request_headers: Optional[Dict[str, str]] = None,
|
||||
litellm_call_info: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Add semantic filter stats and tool names to response headers."""
|
||||
from litellm.constants import MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH
|
||||
|
||||
_metadata_variable_name = self._get_metadata_variable_name(data)
|
||||
metadata = data[_metadata_variable_name]
|
||||
|
||||
filter_stats = metadata.get("litellm_semantic_filter_stats")
|
||||
if not filter_stats:
|
||||
return None
|
||||
|
||||
headers = {"x-litellm-semantic-filter": filter_stats}
|
||||
|
||||
# Add CSV of filtered tool names (nginx-safe length)
|
||||
tool_names_csv = metadata.get("litellm_semantic_filter_tools", "")
|
||||
if tool_names_csv:
|
||||
if len(tool_names_csv) > MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH:
|
||||
tool_names_csv = (
|
||||
tool_names_csv[: MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH - 3]
|
||||
+ "..."
|
||||
)
|
||||
|
||||
headers["x-litellm-semantic-filter-tools"] = tool_names_csv
|
||||
|
||||
return headers
|
||||
|
||||
def _get_tool_names_csv(self, tools: List[Any]) -> str:
|
||||
"""Extract tool names and return as CSV string."""
|
||||
if not tools:
|
||||
return ""
|
||||
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
name = (
|
||||
tool.get("name", "")
|
||||
if isinstance(tool, dict)
|
||||
else getattr(tool, "name", "")
|
||||
)
|
||||
if name:
|
||||
tool_names.append(name)
|
||||
|
||||
return ",".join(tool_names)
|
||||
|
||||
@staticmethod
|
||||
async def initialize_from_config(
|
||||
config: Optional[Dict[str, Any]],
|
||||
llm_router: Optional["Router"],
|
||||
) -> Optional["SemanticToolFilterHook"]:
|
||||
"""
|
||||
Initialize semantic tool filter from proxy config.
|
||||
|
||||
Args:
|
||||
config: Proxy configuration dict (litellm_settings.mcp_semantic_tool_filter)
|
||||
llm_router: LiteLLM router instance for embeddings
|
||||
|
||||
Returns:
|
||||
SemanticToolFilterHook instance if enabled, None otherwise
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.semantic_tool_filter import (
|
||||
SemanticMCPToolFilter,
|
||||
)
|
||||
|
||||
if not config or not config.get("enabled", False):
|
||||
verbose_proxy_logger.debug("Semantic tool filter not enabled in config")
|
||||
return None
|
||||
|
||||
if llm_router is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"Cannot initialize semantic filter: llm_router is None"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
embedding_model = config.get(
|
||||
"embedding_model", DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL
|
||||
)
|
||||
top_k = config.get("top_k", DEFAULT_MCP_SEMANTIC_FILTER_TOP_K)
|
||||
similarity_threshold = config.get(
|
||||
"similarity_threshold", DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD
|
||||
)
|
||||
|
||||
semantic_filter = SemanticMCPToolFilter(
|
||||
embedding_model=embedding_model,
|
||||
litellm_router_instance=llm_router,
|
||||
top_k=top_k,
|
||||
similarity_threshold=similarity_threshold,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Build router from MCP registry on startup
|
||||
await semantic_filter.build_router_from_mcp_registry()
|
||||
|
||||
hook = SemanticToolFilterHook(semantic_filter)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"✅ MCP Semantic Tool Filter enabled: "
|
||||
f"embedding_model={embedding_model}, top_k={top_k}, "
|
||||
f"similarity_threshold={similarity_threshold}"
|
||||
)
|
||||
|
||||
return hook
|
||||
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"semantic-router not installed. Install with: "
|
||||
f"pip install 'litellm[semantic-router]'. Error: {e}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Failed to initialize MCP semantic tool filter: {e}"
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,318 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import Span
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
BudgetConfig,
|
||||
GenericBudgetConfigType,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
|
||||
END_USER_SPEND_CACHE_KEY_PREFIX = "end_user_model_spend"
|
||||
|
||||
|
||||
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
|
||||
"""
|
||||
Handles budgets for model + virtual key
|
||||
|
||||
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||
"""
|
||||
|
||||
def __init__(self, dual_cache: DualCache):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
async def is_key_within_model_budget(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
model: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the user_api_key_dict is within the model budget
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If the user_api_key_dict has exceeded the model budget
|
||||
"""
|
||||
_model_max_budget = user_api_key_dict.model_max_budget
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
|
||||
for _model, _budget_info in _model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"internal_model_max_budget %s",
|
||||
json.dumps(internal_model_max_budget, indent=4, default=str),
|
||||
)
|
||||
|
||||
# check if current model is in internal_model_max_budget
|
||||
_current_model_budget_info = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if _current_model_budget_info is None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} not found in internal_model_max_budget"
|
||||
)
|
||||
return True
|
||||
|
||||
# check if current model is within budget
|
||||
if (
|
||||
_current_model_budget_info.max_budget
|
||||
and _current_model_budget_info.max_budget > 0
|
||||
):
|
||||
_current_spend = await self._get_virtual_key_spend_for_model(
|
||||
user_api_key_hash=user_api_key_dict.token,
|
||||
model=model,
|
||||
key_budget_config=_current_model_budget_info,
|
||||
)
|
||||
if (
|
||||
_current_spend is not None
|
||||
and _current_model_budget_info.max_budget is not None
|
||||
and _current_spend > _current_model_budget_info.max_budget
|
||||
):
|
||||
raise litellm.BudgetExceededError(
|
||||
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
|
||||
current_cost=_current_spend,
|
||||
max_budget=_current_model_budget_info.max_budget,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def is_end_user_within_model_budget(
|
||||
self,
|
||||
end_user_id: str,
|
||||
end_user_model_max_budget: dict,
|
||||
model: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the end_user is within the model budget
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If the end_user has exceeded the model budget
|
||||
"""
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
|
||||
for _model, _budget_info in end_user_model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"end_user internal_model_max_budget %s",
|
||||
json.dumps(internal_model_max_budget, indent=4, default=str),
|
||||
)
|
||||
|
||||
# check if current model is in internal_model_max_budget
|
||||
_current_model_budget_info = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if _current_model_budget_info is None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} not found in end_user_model_max_budget"
|
||||
)
|
||||
return True
|
||||
|
||||
# check if current model is within budget
|
||||
if (
|
||||
_current_model_budget_info.max_budget
|
||||
and _current_model_budget_info.max_budget > 0
|
||||
):
|
||||
_current_spend = await self._get_end_user_spend_for_model(
|
||||
end_user_id=end_user_id,
|
||||
model=model,
|
||||
key_budget_config=_current_model_budget_info,
|
||||
)
|
||||
if (
|
||||
_current_spend is not None
|
||||
and _current_model_budget_info.max_budget is not None
|
||||
and _current_spend > _current_model_budget_info.max_budget
|
||||
):
|
||||
raise litellm.BudgetExceededError(
|
||||
message=f"LiteLLM End User: {end_user_id}, exceeded budget for model={model}",
|
||||
current_cost=_current_spend,
|
||||
max_budget=_current_model_budget_info.max_budget,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _get_end_user_spend_for_model(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model: str,
|
||||
key_budget_config: BudgetConfig,
|
||||
) -> Optional[float]:
|
||||
# 1. model: directly look up `model`
|
||||
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=end_user_model_spend_cache_key,
|
||||
)
|
||||
|
||||
if _current_spend is None:
|
||||
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=end_user_model_spend_cache_key,
|
||||
)
|
||||
return _current_spend
|
||||
|
||||
async def _get_virtual_key_spend_for_model(
|
||||
self,
|
||||
user_api_key_hash: Optional[str],
|
||||
model: str,
|
||||
key_budget_config: BudgetConfig,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the current spend for a virtual key for a model
|
||||
|
||||
Lookup model in this order:
|
||||
1. model: directly look up `model`
|
||||
2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
"""
|
||||
|
||||
# 1. model: directly look up `model`
|
||||
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=virtual_key_model_spend_cache_key,
|
||||
)
|
||||
|
||||
if _current_spend is None:
|
||||
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
|
||||
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=virtual_key_model_spend_cache_key,
|
||||
)
|
||||
return _current_spend
|
||||
|
||||
def _get_request_model_budget_config(
|
||||
self, model: str, internal_model_max_budget: GenericBudgetConfigType
|
||||
) -> Optional[BudgetConfig]:
|
||||
"""
|
||||
Get the budget config for the request model
|
||||
|
||||
1. Check if `model` is in `internal_model_max_budget`
|
||||
2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
|
||||
"""
|
||||
return internal_model_max_budget.get(
|
||||
model, None
|
||||
) or internal_model_max_budget.get(
|
||||
self._get_model_without_custom_llm_provider(model), None
|
||||
)
|
||||
|
||||
def _get_model_without_custom_llm_provider(self, model: str) -> str:
|
||||
if "/" in model:
|
||||
return model.split("/")[-1]
|
||||
return model
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||
) -> List[dict]:
|
||||
return healthy_deployments
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Track spend for virtual key + model in DualCache
|
||||
|
||||
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||
"""
|
||||
verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event: standard_logging_payload is None"
|
||||
)
|
||||
return
|
||||
|
||||
_litellm_params: dict = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata: dict = _litellm_params.get("metadata", {}) or {}
|
||||
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
|
||||
"user_api_key_model_max_budget", None
|
||||
)
|
||||
user_api_key_end_user_model_max_budget: Optional[dict] = _metadata.get(
|
||||
"user_api_key_end_user_model_max_budget", None
|
||||
)
|
||||
if (
|
||||
user_api_key_model_max_budget is None
|
||||
or len(user_api_key_model_max_budget) == 0
|
||||
) and (
|
||||
user_api_key_end_user_model_max_budget is None
|
||||
or len(user_api_key_end_user_model_max_budget) == 0
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget and user_api_key_end_user_model_max_budget are None or empty."
|
||||
)
|
||||
return
|
||||
|
||||
response_cost: float = standard_logging_payload.get("response_cost", 0)
|
||||
model = standard_logging_payload.get("model")
|
||||
virtual_key = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_hash"
|
||||
)
|
||||
end_user_id = standard_logging_payload.get(
|
||||
"end_user"
|
||||
) or standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_end_user_id"
|
||||
)
|
||||
|
||||
if model is None:
|
||||
return
|
||||
|
||||
if (
|
||||
virtual_key is not None
|
||||
and user_api_key_model_max_budget is not None
|
||||
and len(user_api_key_model_max_budget) > 0
|
||||
):
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
for _model, _budget_info in user_api_key_model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
key_budget_config = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if key_budget_config is not None and key_budget_config.budget_duration:
|
||||
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{key_budget_config.budget_duration}"
|
||||
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=key_budget_config,
|
||||
spend_key=virtual_spend_key,
|
||||
start_time_key=virtual_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
if (
|
||||
end_user_id is not None
|
||||
and user_api_key_end_user_model_max_budget is not None
|
||||
and len(user_api_key_end_user_model_max_budget) > 0
|
||||
):
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
for _model, _budget_info in user_api_key_end_user_model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
key_budget_config = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if key_budget_config is not None and key_budget_config.budget_duration:
|
||||
end_user_spend_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
|
||||
end_user_start_time_key = f"end_user_budget_start_time:{end_user_id}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=key_budget_config,
|
||||
spend_key=end_user_spend_key,
|
||||
start_time_key=end_user_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"current state of in memory cache %s",
|
||||
json.dumps(
|
||||
self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,869 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import litellm
|
||||
from litellm import DualCache, ModelResponse
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
get_key_model_rpm_limit,
|
||||
get_key_model_tpm_limit,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
|
||||
|
||||
class CacheObject(TypedDict):
|
||||
current_global_requests: Optional[dict]
|
||||
request_count_api_key: Optional[dict]
|
||||
request_count_api_key_model: Optional[dict]
|
||||
request_count_user_id: Optional[dict]
|
||||
request_count_team_id: Optional[dict]
|
||||
request_count_end_user_id: Optional[dict]
|
||||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def check_key_in_limits(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
max_parallel_requests: int,
|
||||
tpm_limit: int,
|
||||
rpm_limit: int,
|
||||
current: Optional[dict],
|
||||
request_count_api_key: str,
|
||||
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
|
||||
values_to_update_in_cache: List[Tuple[Any, Any]],
|
||||
) -> dict:
|
||||
verbose_proxy_logger.info(
|
||||
f"Current Usage of {rate_limit_type} in this minute: {current}"
|
||||
)
|
||||
if current is None:
|
||||
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
|
||||
# base case
|
||||
raise self.raise_rate_limit_error(
|
||||
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
|
||||
)
|
||||
new_val = {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
elif (
|
||||
int(current["current_requests"]) < max_parallel_requests
|
||||
and current["current_tpm"] < tpm_limit
|
||||
and current["current_rpm"] < rpm_limit
|
||||
):
|
||||
# Increase count for this token
|
||||
new_val = {
|
||||
"current_requests": current["current_requests"] + 1,
|
||||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"] + 1,
|
||||
}
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}",
|
||||
headers={"retry-after": str(self.time_to_next_minute())},
|
||||
)
|
||||
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
local_only=True,
|
||||
)
|
||||
return new_val
|
||||
|
||||
def time_to_next_minute(self) -> float:
|
||||
# Get the current time
|
||||
now = datetime.now()
|
||||
|
||||
# Calculate the next minute
|
||||
next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
|
||||
|
||||
# Calculate the difference in seconds
|
||||
seconds_to_next_minute = (next_minute - now).total_seconds()
|
||||
|
||||
return seconds_to_next_minute
|
||||
|
||||
def raise_rate_limit_error(
|
||||
self, additional_details: Optional[str] = None
|
||||
) -> HTTPException:
|
||||
"""
|
||||
Raise an HTTPException with a 429 status code and a retry-after header
|
||||
"""
|
||||
error_message = "Max parallel request limit reached"
|
||||
if additional_details is not None:
|
||||
error_message = error_message + " " + additional_details
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Max parallel request limit reached {additional_details}",
|
||||
headers={"retry-after": str(self.time_to_next_minute())},
|
||||
)
|
||||
|
||||
async def get_all_cache_objects(
|
||||
self,
|
||||
current_global_requests: Optional[str],
|
||||
request_count_api_key: Optional[str],
|
||||
request_count_api_key_model: Optional[str],
|
||||
request_count_user_id: Optional[str],
|
||||
request_count_team_id: Optional[str],
|
||||
request_count_end_user_id: Optional[str],
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> CacheObject:
|
||||
keys = [
|
||||
current_global_requests,
|
||||
request_count_api_key,
|
||||
request_count_api_key_model,
|
||||
request_count_user_id,
|
||||
request_count_team_id,
|
||||
request_count_end_user_id,
|
||||
]
|
||||
results = await self.internal_usage_cache.async_batch_get_cache(
|
||||
keys=keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if results is None:
|
||||
return CacheObject(
|
||||
current_global_requests=None,
|
||||
request_count_api_key=None,
|
||||
request_count_api_key_model=None,
|
||||
request_count_user_id=None,
|
||||
request_count_team_id=None,
|
||||
request_count_end_user_id=None,
|
||||
)
|
||||
|
||||
return CacheObject(
|
||||
current_global_requests=results[0],
|
||||
request_count_api_key=results[1],
|
||||
request_count_api_key_model=results[2],
|
||||
request_count_user_id=results[3],
|
||||
request_count_team_id=results[4],
|
||||
request_count_end_user_id=results[5],
|
||||
)
|
||||
|
||||
async def async_pre_call_hook( # noqa: PLR0915
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
max_parallel_requests = sys.maxsize
|
||||
if data is None:
|
||||
data = {}
|
||||
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
if tpm_limit is None:
|
||||
tpm_limit = sys.maxsize
|
||||
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||
if rpm_limit is None:
|
||||
rpm_limit = sys.maxsize
|
||||
|
||||
values_to_update_in_cache: List[
|
||||
Tuple[Any, Any]
|
||||
] = (
|
||||
[]
|
||||
) # values that need to get updated in cache, will run a batch_set_cache after this function
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
new_val: Optional[dict] = None
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = await self.internal_usage_cache.async_get_cache(
|
||||
key=_key,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
# check if below limit
|
||||
if current_global_requests is None:
|
||||
current_global_requests = 1
|
||||
# if above -> raise error
|
||||
if current_global_requests >= global_max_parallel_requests:
|
||||
return self.raise_rate_limit_error(
|
||||
additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}"
|
||||
)
|
||||
# if below -> increment
|
||||
else:
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key,
|
||||
value=1,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
_model = data.get("model", None)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
cache_objects: CacheObject = await self.get_all_cache_objects(
|
||||
current_global_requests=(
|
||||
"global_max_parallel_requests"
|
||||
if global_max_parallel_requests is not None
|
||||
else None
|
||||
),
|
||||
request_count_api_key=(
|
||||
f"{api_key}::{precise_minute}::request_count"
|
||||
if api_key is not None
|
||||
else None
|
||||
),
|
||||
request_count_api_key_model=(
|
||||
f"{api_key}::{_model}::{precise_minute}::request_count"
|
||||
if api_key is not None and _model is not None
|
||||
else None
|
||||
),
|
||||
request_count_user_id=(
|
||||
f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.user_id is not None
|
||||
else None
|
||||
),
|
||||
request_count_team_id=(
|
||||
f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.team_id is not None
|
||||
else None
|
||||
),
|
||||
request_count_end_user_id=(
|
||||
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.end_user_id is not None
|
||||
else None
|
||||
),
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
if api_key is not None:
|
||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||
# CHECK IF REQUEST ALLOWED for key
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=max_parallel_requests,
|
||||
current=cache_objects["request_count_api_key"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=tpm_limit,
|
||||
rpm_limit=rpm_limit,
|
||||
rate_limit_type="key",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# Check if request under RPM/TPM per model for a given API Key
|
||||
if (
|
||||
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||
or get_key_model_rpm_limit(user_api_key_dict) is not None
|
||||
):
|
||||
_model = data.get("model", None)
|
||||
request_count_api_key = (
|
||||
f"{api_key}::{_model}::{precise_minute}::request_count"
|
||||
)
|
||||
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
|
||||
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
|
||||
tpm_limit_for_model = None
|
||||
rpm_limit_for_model = None
|
||||
|
||||
if _model is not None:
|
||||
if _tpm_limit_for_key_model:
|
||||
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||
|
||||
if _rpm_limit_for_key_model:
|
||||
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||
|
||||
new_val = await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model
|
||||
current=cache_objects["request_count_api_key_model"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=tpm_limit_for_model or sys.maxsize,
|
||||
rpm_limit=rpm_limit_for_model or sys.maxsize,
|
||||
rate_limit_type="model_per_key",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
_remaining_tokens = None
|
||||
_remaining_requests = None
|
||||
# Add remaining tokens, requests to metadata
|
||||
if new_val:
|
||||
if tpm_limit_for_model is not None:
|
||||
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
|
||||
if rpm_limit_for_model is not None:
|
||||
_remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
|
||||
|
||||
_remaining_limits_data = {
|
||||
f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
|
||||
f"litellm-key-remaining-requests-{_model}": _remaining_requests,
|
||||
}
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"].update(_remaining_limits_data)
|
||||
|
||||
# check if REQUEST ALLOWED for user_id
|
||||
user_id = user_api_key_dict.user_id
|
||||
if user_id is not None:
|
||||
user_tpm_limit = user_api_key_dict.user_tpm_limit
|
||||
user_rpm_limit = user_api_key_dict.user_rpm_limit
|
||||
if user_tpm_limit is None:
|
||||
user_tpm_limit = sys.maxsize
|
||||
if user_rpm_limit is None:
|
||||
user_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||
current=cache_objects["request_count_user_id"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=user_tpm_limit,
|
||||
rpm_limit=user_rpm_limit,
|
||||
rate_limit_type="user",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# TEAM RATE LIMITS
|
||||
## get team tpm/rpm limits
|
||||
team_id = user_api_key_dict.team_id
|
||||
if team_id is not None:
|
||||
team_tpm_limit = user_api_key_dict.team_tpm_limit
|
||||
team_rpm_limit = user_api_key_dict.team_rpm_limit
|
||||
|
||||
if team_tpm_limit is None:
|
||||
team_tpm_limit = sys.maxsize
|
||||
if team_rpm_limit is None:
|
||||
team_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
|
||||
current=cache_objects["request_count_team_id"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=team_tpm_limit,
|
||||
rpm_limit=team_rpm_limit,
|
||||
rate_limit_type="team",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# End-User Rate Limits
|
||||
# Only enforce if user passed `user` to /chat, /completions, /embeddings
|
||||
if user_api_key_dict.end_user_id:
|
||||
end_user_tpm_limit = getattr(
|
||||
user_api_key_dict, "end_user_tpm_limit", sys.maxsize
|
||||
)
|
||||
end_user_rpm_limit = getattr(
|
||||
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
|
||||
)
|
||||
|
||||
if end_user_tpm_limit is None:
|
||||
end_user_tpm_limit = sys.maxsize
|
||||
if end_user_rpm_limit is None:
|
||||
end_user_rpm_limit = sys.maxsize
|
||||
|
||||
# now do the same tpm/rpm checks
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
|
||||
request_count_api_key=request_count_api_key,
|
||||
current=cache_objects["request_count_end_user_id"],
|
||||
tpm_limit=end_user_tpm_limit,
|
||||
rpm_limit=end_user_rpm_limit,
|
||||
rate_limit_type="customer",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
) # don't block execution for cache updates
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def async_log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_model_group_from_litellm_kwargs,
|
||||
)
|
||||
|
||||
litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
|
||||
kwargs=kwargs
|
||||
)
|
||||
try:
|
||||
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
)
|
||||
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_team_id", None
|
||||
)
|
||||
user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_model_max_budget", None
|
||||
)
|
||||
user_api_key_end_user_id = kwargs.get("user")
|
||||
|
||||
user_api_key_metadata = (
|
||||
kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {})
|
||||
or {}
|
||||
)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
# decrement
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key,
|
||||
value=-1,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
# ------------
|
||||
# Update usage - API Key
|
||||
# ------------
|
||||
|
||||
values_to_update_in_cache = []
|
||||
|
||||
if user_api_key is not None:
|
||||
request_count_api_key = (
|
||||
f"{user_api_key}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - model group + API Key
|
||||
# ------------
|
||||
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||
if (
|
||||
user_api_key is not None
|
||||
and model_group is not None
|
||||
and (
|
||||
"model_rpm_limit" in user_api_key_metadata
|
||||
or "model_tpm_limit" in user_api_key_metadata
|
||||
or user_api_key_model_max_budget is not None
|
||||
)
|
||||
):
|
||||
request_count_api_key = (
|
||||
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - User
|
||||
# ------------
|
||||
if user_api_key_user_id is not None:
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - Team
|
||||
# ------------
|
||||
if user_api_key_team_id is not None:
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_team_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - End User
|
||||
# ------------
|
||||
if user_api_key_end_user_id is not None:
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
except Exception as e:
|
||||
self.print_verbose(e) # noqa
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose("Inside Max Parallel Request Failure Hook")
|
||||
litellm_parent_otel_span: Union[
|
||||
Span, None
|
||||
] = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
_metadata = kwargs["litellm_params"].get("metadata", {}) or {}
|
||||
global_max_parallel_requests = _metadata.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = _metadata.get("user_api_key", None)
|
||||
self.print_verbose(f"user_api_key: {user_api_key}")
|
||||
if user_api_key is None:
|
||||
return
|
||||
|
||||
## decrement call count if call failed
|
||||
if CommonProxyErrors.max_parallel_request_limit_reached.value in str(
|
||||
kwargs["exception"]
|
||||
):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
(
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=_key,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
)
|
||||
# decrement
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key,
|
||||
value=-1,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key,
|
||||
new_val,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) # save in cache for up to 1 min.
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Inside Parallel Request Limiter: An exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def get_internal_user_object(
|
||||
self,
|
||||
user_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Helper to get the 'Internal User Object'
|
||||
|
||||
It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks`
|
||||
|
||||
We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable.
|
||||
"""
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
_user_id_rate_limits = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=self.internal_usage_cache.dual_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
if _user_id_rate_limits is None:
|
||||
return None
|
||||
|
||||
return _user_id_rate_limits.model_dump()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Parallel Request Limiter: Error getting user object", str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
"""
|
||||
Retrieve the key's remaining rate limits.
|
||||
"""
|
||||
api_key = user_api_key_dict.api_key
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||
current: Optional[
|
||||
CurrentItemRateLimit
|
||||
] = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
|
||||
key_remaining_rpm_limit: Optional[int] = None
|
||||
key_rpm_limit: Optional[int] = None
|
||||
key_remaining_tpm_limit: Optional[int] = None
|
||||
key_tpm_limit: Optional[int] = None
|
||||
if current is not None:
|
||||
if user_api_key_dict.rpm_limit is not None:
|
||||
key_remaining_rpm_limit = (
|
||||
user_api_key_dict.rpm_limit - current["current_rpm"]
|
||||
)
|
||||
key_rpm_limit = user_api_key_dict.rpm_limit
|
||||
if user_api_key_dict.tpm_limit is not None:
|
||||
key_remaining_tpm_limit = (
|
||||
user_api_key_dict.tpm_limit - current["current_tpm"]
|
||||
)
|
||||
key_tpm_limit = user_api_key_dict.tpm_limit
|
||||
|
||||
if hasattr(response, "_hidden_params"):
|
||||
_hidden_params = getattr(response, "_hidden_params")
|
||||
else:
|
||||
_hidden_params = None
|
||||
if _hidden_params is not None and (
|
||||
isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict)
|
||||
):
|
||||
if isinstance(_hidden_params, BaseModel):
|
||||
_hidden_params = _hidden_params.model_dump()
|
||||
|
||||
_additional_headers = _hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
if key_remaining_rpm_limit is not None:
|
||||
_additional_headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
] = key_remaining_rpm_limit
|
||||
if key_rpm_limit is not None:
|
||||
_additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit
|
||||
if key_remaining_tpm_limit is not None:
|
||||
_additional_headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
] = key_remaining_tpm_limit
|
||||
if key_tpm_limit is not None:
|
||||
_additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit
|
||||
|
||||
setattr(
|
||||
response,
|
||||
"_hidden_params",
|
||||
{**_hidden_params, "additional_headers": _additional_headers},
|
||||
)
|
||||
|
||||
return await super().async_post_call_success_hook(
|
||||
data, user_api_key_dict, response
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,284 @@
|
||||
# +------------------------------------+
|
||||
#
|
||||
# Prompt Injection Detection
|
||||
#
|
||||
# +------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
## Reject a call if it contains a prompt injection attack.
|
||||
|
||||
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
prompt_injection_detection_default_pt,
|
||||
)
|
||||
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
|
||||
from litellm.router import Router
|
||||
from litellm.utils import get_formatted_prompt
|
||||
|
||||
|
||||
class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
|
||||
):
|
||||
self.prompt_injection_params = prompt_injection_params
|
||||
self.llm_router: Optional[Router] = None
|
||||
|
||||
self.verbs = [
|
||||
"Ignore",
|
||||
"Disregard",
|
||||
"Skip",
|
||||
"Forget",
|
||||
"Neglect",
|
||||
"Overlook",
|
||||
"Omit",
|
||||
"Bypass",
|
||||
"Pay no attention to",
|
||||
"Do not follow",
|
||||
"Do not obey",
|
||||
]
|
||||
self.adjectives = [
|
||||
"",
|
||||
"prior",
|
||||
"previous",
|
||||
"preceding",
|
||||
"above",
|
||||
"foregoing",
|
||||
"earlier",
|
||||
"initial",
|
||||
]
|
||||
self.prepositions = [
|
||||
"",
|
||||
"and start over",
|
||||
"and start anew",
|
||||
"and begin afresh",
|
||||
"and start from scratch",
|
||||
]
|
||||
|
||||
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
|
||||
if level == "INFO":
|
||||
verbose_proxy_logger.info(print_statement)
|
||||
elif level == "DEBUG":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
def update_environment(self, router: Optional[Router] = None):
|
||||
self.llm_router = router
|
||||
|
||||
if (
|
||||
self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.llm_api_check is True
|
||||
):
|
||||
if self.llm_router is None:
|
||||
raise Exception(
|
||||
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
|
||||
)
|
||||
|
||||
self.print_verbose(
|
||||
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
|
||||
)
|
||||
if (
|
||||
self.prompt_injection_params.llm_api_name is None
|
||||
or self.prompt_injection_params.llm_api_name
|
||||
not in self.llm_router.model_names
|
||||
):
|
||||
raise Exception(
|
||||
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
|
||||
)
|
||||
|
||||
def generate_injection_keywords(self) -> List[str]:
|
||||
combinations = []
|
||||
for verb in self.verbs:
|
||||
for adj in self.adjectives:
|
||||
for prep in self.prepositions:
|
||||
phrase = " ".join(filter(None, [verb, adj, prep])).strip()
|
||||
if (
|
||||
len(phrase.split()) > 2
|
||||
): # additional check to ensure more than 2 words
|
||||
combinations.append(phrase.lower())
|
||||
return combinations
|
||||
|
||||
def check_user_input_similarity(
|
||||
self,
|
||||
user_input: str,
|
||||
similarity_threshold: float = DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD,
|
||||
) -> bool:
|
||||
user_input_lower = user_input.lower()
|
||||
keywords = self.generate_injection_keywords()
|
||||
|
||||
for keyword in keywords:
|
||||
# Calculate the length of the keyword to extract substrings of the same length from user input
|
||||
keyword_length = len(keyword)
|
||||
|
||||
for i in range(len(user_input_lower) - keyword_length + 1):
|
||||
# Extract a substring of the same length as the keyword
|
||||
substring = user_input_lower[i : i + keyword_length]
|
||||
|
||||
# Calculate similarity
|
||||
match_ratio = SequenceMatcher(None, substring, keyword).ratio()
|
||||
if match_ratio > similarity_threshold:
|
||||
self.print_verbose(
|
||||
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
|
||||
level="INFO",
|
||||
)
|
||||
return True # Found a highly similar substring
|
||||
return False # No substring crossed the threshold
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
try:
|
||||
"""
|
||||
- check if user id part of call
|
||||
- check if user id part of blocked list
|
||||
"""
|
||||
self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
|
||||
try:
|
||||
assert call_type in [
|
||||
"acompletion",
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
]
|
||||
except Exception:
|
||||
self.print_verbose(
|
||||
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
|
||||
)
|
||||
return data
|
||||
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
|
||||
|
||||
is_prompt_attack = False
|
||||
|
||||
if self.prompt_injection_params is not None:
|
||||
# 1. check if heuristics check turned on
|
||||
if self.prompt_injection_params.heuristics_check is True:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
|
||||
if self.prompt_injection_params.vector_db_check is True:
|
||||
pass
|
||||
else:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except HTTPException as e:
|
||||
if (
|
||||
e.status_code == 400
|
||||
and isinstance(e.detail, dict)
|
||||
and "error" in e.detail # type: ignore
|
||||
and self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.reject_as_response
|
||||
):
|
||||
return e.detail.get("error")
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def async_moderation_hook( # type: ignore
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"acompletion",
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
) -> Optional[bool]:
|
||||
self.print_verbose(
|
||||
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
|
||||
)
|
||||
|
||||
if self.prompt_injection_params is None:
|
||||
return None
|
||||
|
||||
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
|
||||
is_prompt_attack = False
|
||||
|
||||
prompt_injection_system_prompt = getattr(
|
||||
self.prompt_injection_params,
|
||||
"llm_api_system_prompt",
|
||||
prompt_injection_detection_default_pt(),
|
||||
)
|
||||
|
||||
# 3. check if llm api check turned on
|
||||
if (
|
||||
self.prompt_injection_params.llm_api_check is True
|
||||
and self.prompt_injection_params.llm_api_name is not None
|
||||
and self.llm_router is not None
|
||||
):
|
||||
# make a call to the llm api
|
||||
response = await self.llm_router.acompletion(
|
||||
model=self.prompt_injection_params.llm_api_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt_injection_system_prompt,
|
||||
},
|
||||
{"role": "user", "content": formatted_prompt},
|
||||
],
|
||||
)
|
||||
|
||||
self.print_verbose(f"Received LLM Moderation response: {response}")
|
||||
self.print_verbose(
|
||||
f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
|
||||
)
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices[0], litellm.Choices
|
||||
):
|
||||
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
|
||||
is_prompt_attack = True
|
||||
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
|
||||
return is_prompt_attack
|
||||
@@ -0,0 +1,370 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
get_key_object,
|
||||
get_team_object,
|
||||
log_db_metrics,
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.utils import ProxyUpdateSpend
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
)
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
class _ProxyDBLogger(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
await self._PROXY_track_cost_callback(
|
||||
kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
request_route = user_api_key_dict.request_route
|
||||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||
return
|
||||
elif request_route is not None and not (
|
||||
RouteChecks.is_llm_api_route(route=request_route)
|
||||
or RouteChecks.is_info_route(route=request_route)
|
||||
):
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_spend=user_api_key_dict.spend,
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_budget_reset_at=(
|
||||
user_api_key_dict.budget_reset_at.isoformat()
|
||||
if user_api_key_dict.budget_reset_at
|
||||
else None
|
||||
),
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_project_id=user_api_key_dict.project_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
user_api_key_auth_metadata=user_api_key_dict.metadata,
|
||||
)
|
||||
)
|
||||
_metadata["user_api_key"] = user_api_key_dict.api_key
|
||||
_metadata["status"] = "failure"
|
||||
_metadata[
|
||||
"error_information"
|
||||
] = StandardLoggingPayloadSetup.get_error_information(
|
||||
original_exception=original_exception,
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
|
||||
_metadata = await _ProxyDBLogger._enrich_failure_metadata_with_key_info(
|
||||
metadata=_metadata,
|
||||
)
|
||||
|
||||
existing_metadata: dict = request_data.get("metadata", None) or {}
|
||||
existing_metadata.update(_metadata)
|
||||
|
||||
if "litellm_params" not in request_data:
|
||||
request_data["litellm_params"] = {}
|
||||
|
||||
existing_litellm_params = request_data.get("litellm_params", {})
|
||||
existing_litellm_metadata = existing_litellm_params.get("metadata", {}) or {}
|
||||
|
||||
# Preserve tags from existing metadata
|
||||
if existing_litellm_metadata.get("tags"):
|
||||
existing_metadata["tags"] = existing_litellm_metadata.get("tags")
|
||||
|
||||
request_data["litellm_params"]["proxy_server_request"] = (
|
||||
request_data.get("proxy_server_request")
|
||||
or existing_litellm_params.get("proxy_server_request")
|
||||
or {}
|
||||
)
|
||||
request_data["litellm_params"]["metadata"] = existing_metadata
|
||||
|
||||
# Preserve model name and custom_llm_provider
|
||||
if "model" not in request_data:
|
||||
request_data["model"] = existing_litellm_params.get(
|
||||
"model"
|
||||
) or request_data.get("model", "")
|
||||
if "custom_llm_provider" not in request_data:
|
||||
request_data["custom_llm_provider"] = existing_litellm_params.get(
|
||||
"custom_llm_provider"
|
||||
) or request_data.get("custom_llm_provider", "")
|
||||
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key_dict.api_key,
|
||||
response_cost=0.0,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
end_user_id=user_api_key_dict.end_user_id,
|
||||
team_id=user_api_key_dict.team_id,
|
||||
kwargs=request_data,
|
||||
completion_response=original_exception,
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
org_id=user_api_key_dict.org_id,
|
||||
)
|
||||
|
||||
@log_db_metrics
|
||||
async def _PROXY_track_cost_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: Optional[
|
||||
Union[litellm.ModelResponse, Any]
|
||||
], # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, update_cache
|
||||
|
||||
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
|
||||
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
|
||||
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
|
||||
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
|
||||
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
response_cost = (
|
||||
sl_object.get("response_cost", None)
|
||||
if sl_object is not None
|
||||
else kwargs.get("response_cost", None)
|
||||
)
|
||||
tags: Optional[List[str]] = (
|
||||
sl_object.get("request_tags", None) if sl_object is not None else None
|
||||
)
|
||||
|
||||
if response_cost is not None:
|
||||
user_api_key = metadata.get("user_api_key", None)
|
||||
if kwargs.get("cache_hit", False) is True:
|
||||
response_cost = 0.0
|
||||
verbose_proxy_logger.debug(
|
||||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, user_id {user_id}, team_id {team_id}, end_user_id {end_user_id}"
|
||||
)
|
||||
if _should_track_cost_callback(
|
||||
user_api_key=user_api_key,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
end_user_id=end_user_id,
|
||||
):
|
||||
## UPDATE DATABASE
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# update cache
|
||||
asyncio.create_task(
|
||||
update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
tags=tags,
|
||||
)
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
token=user_api_key,
|
||||
key_alias=key_alias,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
max_budget=end_user_max_budget,
|
||||
)
|
||||
else:
|
||||
# Non-model call types (health checks, afile_delete) have no model or standard_logging_object.
|
||||
# Use .get() for "stream" to avoid KeyError on health checks.
|
||||
if sl_object is None and not kwargs.get("model"):
|
||||
verbose_proxy_logger.warning(
|
||||
"Cost tracking - skipping, no standard_logging_object and no model for call_type=%s",
|
||||
kwargs.get("call_type", "unknown"),
|
||||
)
|
||||
return
|
||||
if kwargs.get("stream") is not True or (
|
||||
kwargs.get("stream") is True
|
||||
and "complete_streaming_response" in kwargs
|
||||
):
|
||||
if sl_object is not None:
|
||||
cost_tracking_failure_debug_info: Union[dict, str] = (
|
||||
sl_object["response_cost_failure_debug_info"] # type: ignore
|
||||
or "response_cost_failure_debug_info is None in standard_logging_object"
|
||||
)
|
||||
else:
|
||||
cost_tracking_failure_debug_info = (
|
||||
"standard_logging_object not found"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise Exception(
|
||||
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||
model = kwargs.get("model", "")
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
litellm_metadata = kwargs.get("litellm_params", {}).get(
|
||||
"litellm_metadata", {}
|
||||
)
|
||||
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
call_type = kwargs.get("call_type", "")
|
||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failed_tracking_alert(
|
||||
error_message=error_msg,
|
||||
failing_model=model,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.exception(
|
||||
"Error in tracking cost callback - %s", str(e)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _enrich_failure_metadata_with_key_info(metadata: dict) -> dict:
|
||||
"""
|
||||
Enriches failure spend log metadata by looking up the key object (and team object)
|
||||
from cache/DB when key fields are missing.
|
||||
|
||||
This handles two scenarios:
|
||||
1. Auth errors (401): UserAPIKeyAuth is created with only api_key set, all other
|
||||
fields are null. We look up the full key object to fill in alias, user_id,
|
||||
team_id, etc.
|
||||
2. Post-auth failures (provider errors, rate limits): key fields are populated
|
||||
but team_alias is missing because LiteLLM_VerificationTokenView SQL view
|
||||
doesn't include it. We look up the team object to fill in team_alias.
|
||||
"""
|
||||
api_key_hash = metadata.get("user_api_key")
|
||||
if not api_key_hash:
|
||||
return metadata
|
||||
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
# Step 1: If key fields are missing, look up the full key object
|
||||
if metadata.get("user_api_key_alias") is None:
|
||||
try:
|
||||
key_obj = await get_key_object(
|
||||
hashed_token=api_key_hash,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if metadata.get("user_api_key_alias") is None:
|
||||
metadata["user_api_key_alias"] = key_obj.key_alias
|
||||
if metadata.get("user_api_key_user_id") is None:
|
||||
metadata["user_api_key_user_id"] = key_obj.user_id
|
||||
if metadata.get("user_api_key_team_id") is None:
|
||||
metadata["user_api_key_team_id"] = key_obj.team_id
|
||||
if metadata.get("user_api_key_org_id") is None:
|
||||
metadata["user_api_key_org_id"] = key_obj.org_id
|
||||
except Exception:
|
||||
verbose_proxy_logger.debug(
|
||||
"Failed to enrich failure metadata with key info for api_key=%s",
|
||||
api_key_hash,
|
||||
)
|
||||
|
||||
# Step 2: If team_id is known but team_alias is missing, look up the team object
|
||||
team_id = metadata.get("user_api_key_team_id")
|
||||
if team_id and metadata.get("user_api_key_team_alias") is None:
|
||||
try:
|
||||
team_obj = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if team_obj.team_alias is not None:
|
||||
metadata["user_api_key_team_alias"] = team_obj.team_alias
|
||||
except Exception:
|
||||
verbose_proxy_logger.debug(
|
||||
"Failed to enrich failure metadata with team_alias for team_id=%s",
|
||||
team_id,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _should_track_errors_in_db():
|
||||
"""
|
||||
Returns True if errors should be tracked in the database
|
||||
|
||||
By default, errors are tracked in the database
|
||||
|
||||
If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if general_settings.get("disable_error_logs") is True:
|
||||
return False
|
||||
return
|
||||
|
||||
|
||||
def _should_track_cost_callback(
|
||||
user_api_key: Optional[str],
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the cost callback should be tracked based on the kwargs
|
||||
"""
|
||||
|
||||
# don't run track cost callback if user opted into disabling spend
|
||||
if ProxyUpdateSpend.disable_spend_updates() is True:
|
||||
return False
|
||||
|
||||
if (
|
||||
user_api_key is not None
|
||||
or user_id is not None
|
||||
or team_id is not None
|
||||
or end_user_id is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Shared utility functions for rate limiter hooks.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.types.router import ModelGroupInfo
|
||||
from litellm.types.utils import PriorityReservationDict
|
||||
|
||||
|
||||
def convert_priority_to_percent(
|
||||
value: Union[float, PriorityReservationDict], model_info: Optional[ModelGroupInfo]
|
||||
) -> float:
|
||||
"""
|
||||
Convert priority reservation value to percentage (0.0-1.0).
|
||||
|
||||
Supports three formats:
|
||||
1. Plain float/int: 0.9 -> 0.9 (90%)
|
||||
2. Dict with percent: {"type": "percent", "value": 0.9} -> 0.9
|
||||
3. Dict with rpm: {"type": "rpm", "value": 900} -> 900/model_rpm
|
||||
4. Dict with tpm: {"type": "tpm", "value": 900000} -> 900000/model_tpm
|
||||
|
||||
Args:
|
||||
value: Priority value as float or dict with type/value keys
|
||||
model_info: Model configuration containing rpm/tpm limits
|
||||
|
||||
Returns:
|
||||
float: Percentage value between 0.0 and 1.0
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
val_type = value.get("type", "percent")
|
||||
val_num = value.get("value", 1.0)
|
||||
|
||||
if val_type == "percent":
|
||||
return float(val_num)
|
||||
elif val_type == "rpm" and model_info and model_info.rpm and model_info.rpm > 0:
|
||||
return float(val_num) / model_info.rpm
|
||||
elif val_type == "tpm" and model_info and model_info.tpm and model_info.tpm > 0:
|
||||
return float(val_num) / model_info.tpm
|
||||
|
||||
# Fallback: treat as percent
|
||||
return float(val_num)
|
||||
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Security hook to prevent user B from seeing response from user A.
|
||||
|
||||
This hook uses the DBSpendUpdateWriter to batch-write response IDs to the database
|
||||
instead of writing immediately on each request.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple, Union, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
BaseLiteLLMOpenAIResponseObject,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.utils import CallTypesLiteral, LLMResponseTypes, SpecialEnums
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class ResponsesIDSecurity(CustomLogger):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
cache: "DualCache",
|
||||
data: dict,
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
# MAP all the responses api response ids to the encrypted response ids
|
||||
responses_api_call_types = {
|
||||
"aresponses",
|
||||
"aget_responses",
|
||||
"adelete_responses",
|
||||
"acancel_responses",
|
||||
}
|
||||
if call_type not in responses_api_call_types:
|
||||
return None
|
||||
if call_type == "aresponses":
|
||||
# check 'previous_response_id' if present in the data
|
||||
previous_response_id = data.get("previous_response_id")
|
||||
if previous_response_id and self._is_encrypted_response_id(
|
||||
previous_response_id
|
||||
):
|
||||
original_response_id, user_id, team_id = self._decrypt_response_id(
|
||||
previous_response_id
|
||||
)
|
||||
self.check_user_access_to_response_id(
|
||||
user_id, team_id, user_api_key_dict
|
||||
)
|
||||
data["previous_response_id"] = original_response_id
|
||||
elif call_type in {"aget_responses", "adelete_responses", "acancel_responses"}:
|
||||
response_id = data.get("response_id")
|
||||
|
||||
if response_id and self._is_encrypted_response_id(response_id):
|
||||
original_response_id, user_id, team_id = self._decrypt_response_id(
|
||||
response_id
|
||||
)
|
||||
|
||||
self.check_user_access_to_response_id(
|
||||
user_id, team_id, user_api_key_dict
|
||||
)
|
||||
data["response_id"] = original_response_id
|
||||
return data
|
||||
|
||||
def check_user_access_to_response_id(
|
||||
self,
|
||||
response_id_user_id: Optional[str],
|
||||
response_id_team_id: Optional[str],
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
) -> bool:
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
return True
|
||||
|
||||
if response_id_user_id and response_id_user_id != user_api_key_dict.user_id:
|
||||
if general_settings.get("disable_responses_id_security", False):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Responses ID Security is disabled. User {user_api_key_dict.user_id} is accessing response id {response_id_user_id} which is not associated with them."
|
||||
)
|
||||
return True
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Forbidden. The response id is not associated with the user, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.",
|
||||
)
|
||||
|
||||
if response_id_team_id and response_id_team_id != user_api_key_dict.team_id:
|
||||
if general_settings.get("disable_responses_id_security", False):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Responses ID Security is disabled. Response belongs to team {response_id_team_id} but user {user_api_key_dict.user_id} is accessing it with team id {user_api_key_dict.team_id}."
|
||||
)
|
||||
return True
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Forbidden. The response id is not associated with the team, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _is_encrypted_response_id(self, response_id: str) -> bool:
|
||||
split_result = response_id.split("resp_")
|
||||
if len(split_result) < 2:
|
||||
return False
|
||||
|
||||
remaining_string = split_result[1]
|
||||
decrypted_value = decrypt_value_helper(
|
||||
value=remaining_string, key="response_id", return_original_value=True
|
||||
)
|
||||
|
||||
if decrypted_value is None:
|
||||
return False
|
||||
|
||||
if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _decrypt_response_id(
|
||||
self, response_id: str
|
||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns:
|
||||
- original_response_id: the original response id
|
||||
- user_id: the user id
|
||||
- team_id: the team id
|
||||
"""
|
||||
split_result = response_id.split("resp_")
|
||||
if len(split_result) < 2:
|
||||
return response_id, None, None
|
||||
|
||||
remaining_string = split_result[1]
|
||||
decrypted_value = decrypt_value_helper(
|
||||
value=remaining_string, key="response_id", return_original_value=True
|
||||
)
|
||||
|
||||
if decrypted_value is None:
|
||||
return response_id, None, None
|
||||
|
||||
if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
# Expected format: "litellm_proxy:responses_api:response_id:{response_id};user_id:{user_id}"
|
||||
parts = decrypted_value.split(";")
|
||||
|
||||
if len(parts) >= 2:
|
||||
# Extract response_id from "litellm_proxy:responses_api:response_id:{response_id}"
|
||||
response_id_part = parts[0]
|
||||
original_response_id = response_id_part.split("response_id:")[-1]
|
||||
|
||||
# Extract user_id from "user_id:{user_id}"
|
||||
user_id_part = parts[1]
|
||||
user_id = user_id_part.split("user_id:")[-1]
|
||||
|
||||
# Extract team_id from "team_id:{team_id}"
|
||||
team_id_part = parts[2]
|
||||
team_id = team_id_part.split("team_id:")[-1]
|
||||
|
||||
return original_response_id, user_id, team_id
|
||||
else:
|
||||
# Fallback if format is unexpected
|
||||
return response_id, None, None
|
||||
return response_id, None, None
|
||||
|
||||
def _get_signing_key(self) -> Optional[str]:
|
||||
"""Get the signing key for encryption/decryption."""
|
||||
import os
|
||||
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
salt_key = os.getenv("LITELLM_SALT_KEY", None)
|
||||
if salt_key is None:
|
||||
salt_key = master_key
|
||||
return salt_key
|
||||
|
||||
def _encrypt_response_id(
|
||||
self,
|
||||
response: BaseLiteLLMOpenAIResponseObject,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
request_cache: Optional[dict[str, str]] = None,
|
||||
) -> BaseLiteLLMOpenAIResponseObject:
|
||||
# encrypt the response id using the symmetric key
|
||||
# encrypt the response id, and encode the user id and response id in base64
|
||||
|
||||
# Check if signing key is available
|
||||
signing_key = self._get_signing_key()
|
||||
if signing_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Response ID encryption is enabled but no signing key is configured. "
|
||||
"Please set LITELLM_SALT_KEY environment variable or configure a master_key. "
|
||||
"Skipping response ID encryption. "
|
||||
"See: https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
|
||||
)
|
||||
return response
|
||||
|
||||
response_id = getattr(response, "id", None)
|
||||
response_obj = getattr(response, "response", None)
|
||||
|
||||
if (
|
||||
response_id
|
||||
and isinstance(response_id, str)
|
||||
and response_id.startswith("resp_")
|
||||
):
|
||||
# Check request-scoped cache first (for streaming consistency)
|
||||
if request_cache is not None and response_id in request_cache:
|
||||
setattr(response, "id", request_cache[response_id])
|
||||
else:
|
||||
encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format(
|
||||
response_id,
|
||||
user_api_key_dict.user_id or "",
|
||||
user_api_key_dict.team_id or "",
|
||||
)
|
||||
|
||||
encoded_user_id_and_response_id = encrypt_value_helper(
|
||||
value=encrypted_response_id
|
||||
)
|
||||
encrypted_id = f"resp_{encoded_user_id_and_response_id}"
|
||||
if request_cache is not None:
|
||||
request_cache[response_id] = encrypted_id
|
||||
setattr(response, "id", encrypted_id)
|
||||
|
||||
elif response_obj and isinstance(response_obj, ResponsesAPIResponse):
|
||||
# Check request-scoped cache first (for streaming consistency)
|
||||
if request_cache is not None and response_obj.id in request_cache:
|
||||
setattr(response_obj, "id", request_cache[response_obj.id])
|
||||
else:
|
||||
encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format(
|
||||
response_obj.id,
|
||||
user_api_key_dict.user_id or "",
|
||||
user_api_key_dict.team_id or "",
|
||||
)
|
||||
encoded_user_id_and_response_id = encrypt_value_helper(
|
||||
value=encrypted_response_id
|
||||
)
|
||||
encrypted_id = f"resp_{encoded_user_id_and_response_id}"
|
||||
if request_cache is not None:
|
||||
request_cache[response_obj.id] = encrypted_id
|
||||
setattr(response_obj, "id", encrypted_id)
|
||||
setattr(response, "response", response_obj)
|
||||
return response
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
response: LLMResponseTypes,
|
||||
) -> Any:
|
||||
"""
|
||||
Queue response IDs for batch processing instead of writing directly to DB.
|
||||
|
||||
This method adds response IDs to an in-memory queue, which are then
|
||||
batch-processed by the DBSpendUpdateWriter during regular database update cycles.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if general_settings.get("disable_responses_id_security", False):
|
||||
return response
|
||||
if isinstance(response, ResponsesAPIResponse):
|
||||
response = cast(
|
||||
ResponsesAPIResponse,
|
||||
self._encrypt_response_id(
|
||||
response, user_api_key_dict, request_cache=None
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_post_call_streaming_iterator_hook( # type: ignore
|
||||
self, user_api_key_dict: "UserAPIKeyAuth", response: Any, request_data: dict
|
||||
) -> AsyncGenerator[BaseLiteLLMOpenAIResponseObject, None]:
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
# Create a request-scoped cache for consistent encryption across streaming chunks.
|
||||
request_encryption_cache: dict[str, str] = {}
|
||||
|
||||
async for chunk in response:
|
||||
if (
|
||||
isinstance(chunk, BaseLiteLLMOpenAIResponseObject)
|
||||
and user_api_key_dict.request_route
|
||||
== "/v1/responses" # only encrypt the response id for the responses api
|
||||
and not general_settings.get("disable_responses_id_security", False)
|
||||
):
|
||||
chunk = self._encrypt_response_id(
|
||||
chunk, user_api_key_dict, request_encryption_cache
|
||||
)
|
||||
yield chunk
|
||||
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Hooks that are triggered when a litellm user event occurs
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
AUDIT_ACTIONS,
|
||||
CommonProxyErrors,
|
||||
LiteLLM_AuditLogs,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_UserTable,
|
||||
LitellmTableNames,
|
||||
NewUserRequest,
|
||||
NewUserResponse,
|
||||
UserAPIKeyAuth,
|
||||
WebhookEvent,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
|
||||
|
||||
|
||||
class UserManagementEventHooks:
|
||||
@staticmethod
|
||||
async def async_user_created_hook(
|
||||
data: NewUserRequest,
|
||||
response: NewUserResponse,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
This hook is called when a new user is created on litellm
|
||||
|
||||
Handles:
|
||||
- Creating an audit log for the user creation
|
||||
- Sending a user invitation email to the user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
#########################################################
|
||||
########## Send User Invitation Email ################
|
||||
#########################################################
|
||||
await UserManagementEventHooks.async_send_user_invitation_email(
|
||||
data=data,
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
#########################################################
|
||||
########## CREATE AUDIT LOG ################
|
||||
#########################################################
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception(CommonProxyErrors.db_not_connected_error.value)
|
||||
user_row: BaseModel = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_id": response.user_id}
|
||||
)
|
||||
|
||||
user_row_litellm_typed = LiteLLM_UserTable(
|
||||
**user_row.model_dump(exclude_none=True)
|
||||
)
|
||||
asyncio.create_task(
|
||||
UserManagementEventHooks.create_internal_user_audit_log(
|
||||
user_id=user_row_litellm_typed.user_id,
|
||||
action="created",
|
||||
litellm_changed_by=user_api_key_dict.user_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
before_value=None,
|
||||
after_value=user_row_litellm_typed.model_dump_json(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"Unable to create audit log for user on `/user/new` - {}".format(str(e))
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def async_send_user_invitation_email(
|
||||
data: NewUserRequest,
|
||||
response: NewUserResponse,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
Send a user invitation email to the user
|
||||
"""
|
||||
event = WebhookEvent(
|
||||
event="internal_user_created",
|
||||
event_group=Litellm_EntityType.USER,
|
||||
event_message="Welcome to LiteLLM Proxy",
|
||||
token=response.token,
|
||||
spend=response.spend or 0.0,
|
||||
max_budget=response.max_budget,
|
||||
user_id=response.user_id,
|
||||
user_email=response.user_email,
|
||||
team_id=response.team_id,
|
||||
key_alias=response.key_alias,
|
||||
)
|
||||
|
||||
#########################################################
|
||||
########## V2 USER INVITATION EMAIL ################
|
||||
#########################################################
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
|
||||
BaseEmailLogger,
|
||||
)
|
||||
|
||||
use_enterprise_email_hooks = True
|
||||
except ImportError:
|
||||
verbose_proxy_logger.warning(
|
||||
"Defaulting to using Legacy Email Hooks."
|
||||
+ CommonProxyErrors.missing_enterprise_package.value
|
||||
)
|
||||
use_enterprise_email_hooks = False
|
||||
|
||||
if use_enterprise_email_hooks and (data.send_invite_email is True):
|
||||
initialized_email_loggers = litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=BaseEmailLogger # type: ignore
|
||||
)
|
||||
if len(initialized_email_loggers) > 0:
|
||||
for email_logger in initialized_email_loggers:
|
||||
if isinstance(email_logger, BaseEmailLogger): # type: ignore
|
||||
await email_logger.send_user_invitation_email( # type: ignore
|
||||
event=event,
|
||||
)
|
||||
|
||||
#########################################################
|
||||
########## LEGACY V1 USER INVITATION EMAIL ################
|
||||
#########################################################
|
||||
if data.send_invite_email is True:
|
||||
await UserManagementEventHooks.send_legacy_v1_user_invitation_email(
|
||||
data=data,
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
event=event,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_legacy_v1_user_invitation_email(
|
||||
data: NewUserRequest,
|
||||
response: NewUserResponse,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
event: WebhookEvent,
|
||||
):
|
||||
"""
|
||||
Send a user invitation email to the user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
# check if user has setup email alerting
|
||||
if "email" not in general_settings.get("alerting", []):
|
||||
raise ValueError(
|
||||
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
||||
)
|
||||
|
||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||
webhook_event=event,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_internal_user_audit_log(
|
||||
user_id: str,
|
||||
action: AUDIT_ACTIONS,
|
||||
litellm_changed_by: Optional[str],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_proxy_admin_name: Optional[str],
|
||||
before_value: Optional[str] = None,
|
||||
after_value: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create an audit log for an internal user.
|
||||
|
||||
Parameters:
|
||||
- user_id: str - The id of the user to create the audit log for.
|
||||
- action: AUDIT_ACTIONS - The action to create the audit log for.
|
||||
- user_row: LiteLLM_UserTable - The user row to create the audit log for.
|
||||
- litellm_changed_by: Optional[str] - The user id of the user who is changing the user.
|
||||
- user_api_key_dict: UserAPIKeyAuth - The user api key dictionary.
|
||||
- litellm_proxy_admin_name: Optional[str] - The name of the proxy admin.
|
||||
"""
|
||||
if not litellm.store_audit_logs:
|
||||
return
|
||||
|
||||
await create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.USER_TABLE_NAME,
|
||||
object_id=user_id,
|
||||
action=action,
|
||||
updated_values=after_value,
|
||||
before_value=before_value,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user