""" Max Iterations Limiter for LiteLLM Proxy. Enforces a per-session cap on the number of LLM calls an agentic loop can make. Callers send a `session_id` with each request (via `x-litellm-session-id` header or `metadata.session_id`), and this hook counts calls per session. When the count exceeds `max_iterations` (configured in agent litellm_params or key metadata), returns 429. Works across multiple proxy instances via DualCache (in-memory + Redis). Follows the same pattern as parallel_request_limiter_v3.py. """ import os from typing import TYPE_CHECKING, Any, Optional, Union from fastapi import HTTPException from litellm import DualCache from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth if TYPE_CHECKING: from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache InternalUsageCache = _InternalUsageCache else: InternalUsageCache = Any # Redis Lua script for atomic increment with TTL. # Returns the new count after increment. # Only sets EXPIRE on first increment (when count becomes 1). MAX_ITERATIONS_INCREMENT_SCRIPT = """ local key = KEYS[1] local ttl = tonumber(ARGV[1]) local current = redis.call('INCR', key) if current == 1 then redis.call('EXPIRE', key, ttl) end return current """ # Default TTL for session iteration counters (1 hour) DEFAULT_MAX_ITERATIONS_TTL = 3600 class _PROXY_MaxIterationsHandler(CustomLogger): """ Pre-call hook that enforces max_iterations per session. Configuration: - max_iterations: set in agent litellm_params (preferred) e.g. litellm_params={"max_iterations": 25} Falls back to key metadata max_iterations for backwards compatibility. - session_id: sent by caller via x-litellm-session-id header or metadata.session_id in request body Cache key pattern: {session_iterations:}:count Multi-instance support: Uses Redis Lua script for atomic increment (same pattern as parallel_request_limiter_v3). Falls back to in-memory cache when Redis is unavailable. """ def __init__(self, internal_usage_cache: InternalUsageCache): self.internal_usage_cache = internal_usage_cache self.ttl = int( os.getenv("LITELLM_MAX_ITERATIONS_TTL", DEFAULT_MAX_ITERATIONS_TTL) ) # Register Lua script with Redis if available (same pattern as v3 limiter) if self.internal_usage_cache.dual_cache.redis_cache is not None: self.increment_script = ( self.internal_usage_cache.dual_cache.redis_cache.async_register_script( MAX_ITERATIONS_INCREMENT_SCRIPT ) ) else: self.increment_script = None async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ) -> Optional[Union[Exception, str, dict]]: """ Check session iteration count before making the API call. Extracts session_id from request metadata and max_iterations from agent litellm_params. If the session has exceeded max_iterations, raises 429. """ # Extract session_id from request data session_id = self._get_session_id(data) if session_id is None: return None max_iterations = self._get_max_iterations(user_api_key_dict) if max_iterations is None: return None verbose_proxy_logger.debug( "MaxIterationsHandler: session_id=%s, max_iterations=%s", session_id, max_iterations, ) # Increment and check cache_key = self._make_cache_key(session_id) current_count = await self._increment_and_get(cache_key) if current_count > max_iterations: raise HTTPException( status_code=429, detail=( f"Max iterations exceeded for session {session_id}. " f"Current count: {current_count}, max_iterations: {max_iterations}." ), ) verbose_proxy_logger.debug( "MaxIterationsHandler: session_id=%s, count=%s/%s", session_id, current_count, max_iterations, ) return None def _get_session_id(self, data: dict) -> Optional[str]: """Extract session_id from request metadata.""" metadata = data.get("metadata") or {} session_id = metadata.get("session_id") if session_id is not None: return str(session_id) # Also check litellm_metadata (used for /thread and /assistant endpoints) litellm_metadata = data.get("litellm_metadata") or {} session_id = litellm_metadata.get("session_id") if session_id is not None: return str(session_id) return None def _get_max_iterations(self, user_api_key_dict: UserAPIKeyAuth) -> Optional[int]: """Extract max_iterations from agent litellm_params, with fallback to key metadata.""" # Try agent litellm_params first agent_id = user_api_key_dict.agent_id if agent_id is not None: from litellm.proxy.agent_endpoints.agent_registry import ( global_agent_registry, ) agent = global_agent_registry.get_agent_by_id(agent_id=agent_id) if agent is not None: litellm_params = agent.litellm_params or {} max_iterations = litellm_params.get("max_iterations") if max_iterations is not None: return int(max_iterations) # Fallback to key metadata for backwards compatibility metadata = user_api_key_dict.metadata or {} max_iterations = metadata.get("max_iterations") if max_iterations is not None: return int(max_iterations) return None def _make_cache_key(self, session_id: str) -> str: """ Create cache key for session iteration counter. Uses Redis hash-tag pattern {session_iterations:} so all keys for a session land on the same Redis Cluster slot. """ return f"{{session_iterations:{session_id}}}:count" async def _increment_and_get(self, cache_key: str) -> int: """ Atomically increment the session counter and return the new value. Tries Redis first (via registered Lua script for atomicity across instances), falls back to in-memory cache. """ if self.increment_script is not None: try: result = await self.increment_script( keys=[cache_key], args=[self.ttl], ) return int(result) except Exception as e: verbose_proxy_logger.warning( "MaxIterationsHandler: Redis failed, falling back to in-memory: %s", str(e), ) # Fallback: in-memory cache return await self._in_memory_increment(cache_key) async def _in_memory_increment(self, cache_key: str) -> int: """Increment counter in in-memory cache with TTL.""" current = await self.internal_usage_cache.async_get_cache( key=cache_key, litellm_parent_otel_span=None, local_only=True, ) new_value = (int(current) if current is not None else 0) + 1 await self.internal_usage_cache.async_set_cache( key=cache_key, value=new_value, ttl=self.ttl, litellm_parent_otel_span=None, local_only=True, ) return new_value