chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Response Polling Module for Background Responses with Cache
|
||||
"""
|
||||
from litellm.proxy.response_polling.background_streaming import (
|
||||
background_streaming_task,
|
||||
)
|
||||
from litellm.proxy.response_polling.polling_handler import (
|
||||
ResponsePollingHandler,
|
||||
should_use_polling_for_request,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ResponsePollingHandler",
|
||||
"background_streaming_task",
|
||||
"should_use_polling_for_request",
|
||||
]
|
||||
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Background Streaming Task for Polling Via Cache Feature
|
||||
|
||||
Handles streaming responses from LLM providers and updates Redis cache
|
||||
with partial results for polling.
|
||||
|
||||
Follows OpenAI Response Streaming format:
|
||||
https://platform.openai.com/docs/api-reference/responses-streaming
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler
|
||||
|
||||
|
||||
async def background_streaming_task( # noqa: PLR0915
|
||||
polling_id: str,
|
||||
data: dict,
|
||||
polling_handler: ResponsePollingHandler,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
general_settings: dict,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_model,
|
||||
user_temperature,
|
||||
user_request_timeout,
|
||||
user_max_tokens,
|
||||
user_api_base,
|
||||
version,
|
||||
):
|
||||
"""
|
||||
Background task to stream response and update cache
|
||||
|
||||
Follows OpenAI Response Streaming format:
|
||||
https://platform.openai.com/docs/api-reference/responses-streaming
|
||||
|
||||
Processes streaming events and builds Response object:
|
||||
https://platform.openai.com/docs/api-reference/responses/object
|
||||
"""
|
||||
|
||||
try:
|
||||
verbose_proxy_logger.info(f"Starting background streaming for {polling_id}")
|
||||
|
||||
# Update status to in_progress (OpenAI format)
|
||||
await polling_handler.update_state(
|
||||
polling_id=polling_id,
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
# Force streaming mode and remove background flag
|
||||
data["stream"] = True
|
||||
data.pop("background", None)
|
||||
|
||||
# Create processor
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
|
||||
# Make streaming request
|
||||
response = await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="aresponses",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=None,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
|
||||
# Process streaming response following OpenAI events format
|
||||
# https://platform.openai.com/docs/api-reference/responses-streaming
|
||||
output_items: dict[str, dict[str, Any]] = {} # Track output items by ID
|
||||
accumulated_text = (
|
||||
{}
|
||||
) # Track accumulated text deltas by (item_id, content_index)
|
||||
|
||||
# ResponsesAPIResponse fields to extract from response.completed
|
||||
usage_data = None
|
||||
reasoning_data = None
|
||||
tool_choice_data = None
|
||||
tools_data = None
|
||||
model_data = None
|
||||
instructions_data = None
|
||||
temperature_data = None
|
||||
top_p_data = None
|
||||
max_output_tokens_data = None
|
||||
previous_response_id_data = None
|
||||
text_data = None
|
||||
truncation_data = None
|
||||
parallel_tool_calls_data = None
|
||||
user_data = None
|
||||
store_data = None
|
||||
incomplete_details_data = None
|
||||
|
||||
state_dirty = False # Track if state needs to be synced
|
||||
last_update_time = asyncio.get_event_loop().time()
|
||||
UPDATE_INTERVAL = 0.150 # 150ms batching interval
|
||||
|
||||
async def flush_state_if_needed(force: bool = False) -> None:
|
||||
"""Flush accumulated state to Redis if interval elapsed or forced"""
|
||||
nonlocal state_dirty, last_update_time
|
||||
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if state_dirty and (
|
||||
force or (current_time - last_update_time) >= UPDATE_INTERVAL
|
||||
):
|
||||
# Convert output_items dict to list for update
|
||||
output_list = list(output_items.values())
|
||||
await polling_handler.update_state(
|
||||
polling_id=polling_id,
|
||||
output=output_list,
|
||||
)
|
||||
state_dirty = False
|
||||
last_update_time = current_time
|
||||
|
||||
# Handle StreamingResponse
|
||||
if hasattr(response, "body_iterator"):
|
||||
async for chunk in response.body_iterator:
|
||||
# Parse chunk
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8")
|
||||
|
||||
if isinstance(chunk, str) and chunk.startswith("data: "):
|
||||
chunk_data = chunk[6:].strip()
|
||||
if chunk_data == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
event = json.loads(chunk_data)
|
||||
event_type = event.get("type", "")
|
||||
|
||||
# Process different event types based on OpenAI streaming spec
|
||||
if event_type == "response.output_item.added":
|
||||
# New output item added
|
||||
item = event.get("item", {})
|
||||
item_id = item.get("id")
|
||||
if item_id:
|
||||
output_items[item_id] = item
|
||||
state_dirty = True
|
||||
|
||||
elif event_type == "response.content_part.added":
|
||||
# Content part added to an output item
|
||||
item_id = event.get("item_id")
|
||||
content_part = event.get("part", {})
|
||||
|
||||
if item_id and item_id in output_items:
|
||||
# Update the output item with new content
|
||||
if "content" not in output_items[item_id]:
|
||||
output_items[item_id]["content"] = []
|
||||
output_items[item_id]["content"].append(content_part)
|
||||
state_dirty = True
|
||||
|
||||
elif event_type == "response.output_text.delta":
|
||||
# Text delta - accumulate text content
|
||||
# https://platform.openai.com/docs/api-reference/responses-streaming/response-text-delta
|
||||
item_id = event.get("item_id")
|
||||
content_index = event.get("content_index", 0)
|
||||
delta = event.get("delta", "")
|
||||
|
||||
if item_id and item_id in output_items:
|
||||
# Accumulate text delta
|
||||
key = (item_id, content_index)
|
||||
if key not in accumulated_text:
|
||||
accumulated_text[key] = ""
|
||||
accumulated_text[key] += delta
|
||||
|
||||
# Update the content in output_items
|
||||
if "content" in output_items[item_id]:
|
||||
content_list = output_items[item_id]["content"]
|
||||
if content_index < len(content_list):
|
||||
# Update existing content part with accumulated text
|
||||
if isinstance(
|
||||
content_list[content_index], dict
|
||||
):
|
||||
content_list[content_index][
|
||||
"text"
|
||||
] = accumulated_text[key]
|
||||
state_dirty = True
|
||||
|
||||
elif event_type == "response.content_part.done":
|
||||
# Content part completed
|
||||
item_id = event.get("item_id")
|
||||
content_part = event.get("part", {})
|
||||
content_index = event.get("content_index", 0)
|
||||
|
||||
if item_id and item_id in output_items:
|
||||
# Update with final content from event
|
||||
if "content" in output_items[item_id]:
|
||||
content_list = output_items[item_id]["content"]
|
||||
if content_index < len(content_list):
|
||||
content_list[content_index] = content_part
|
||||
state_dirty = True
|
||||
|
||||
elif event_type == "response.output_item.done":
|
||||
# Output item completed - use final item data
|
||||
item = event.get("item", {})
|
||||
item_id = item.get("id")
|
||||
if item_id:
|
||||
output_items[item_id] = item
|
||||
state_dirty = True
|
||||
|
||||
elif event_type == "response.in_progress":
|
||||
# Response is now in progress
|
||||
# https://platform.openai.com/docs/api-reference/responses-streaming/response-in-progress
|
||||
await polling_handler.update_state(
|
||||
polling_id=polling_id,
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Response completed - extract all ResponsesAPIResponse fields
|
||||
# https://platform.openai.com/docs/api-reference/responses-streaming/response-completed
|
||||
response_data = event.get("response", {})
|
||||
|
||||
# Core response fields
|
||||
usage_data = response_data.get("usage")
|
||||
reasoning_data = response_data.get("reasoning")
|
||||
tool_choice_data = response_data.get("tool_choice")
|
||||
tools_data = response_data.get("tools")
|
||||
|
||||
# Additional ResponsesAPIResponse fields
|
||||
model_data = response_data.get("model")
|
||||
instructions_data = response_data.get("instructions")
|
||||
temperature_data = response_data.get("temperature")
|
||||
top_p_data = response_data.get("top_p")
|
||||
max_output_tokens_data = response_data.get(
|
||||
"max_output_tokens"
|
||||
)
|
||||
previous_response_id_data = response_data.get(
|
||||
"previous_response_id"
|
||||
)
|
||||
text_data = response_data.get("text")
|
||||
truncation_data = response_data.get("truncation")
|
||||
parallel_tool_calls_data = response_data.get(
|
||||
"parallel_tool_calls"
|
||||
)
|
||||
user_data = response_data.get("user")
|
||||
store_data = response_data.get("store")
|
||||
incomplete_details_data = response_data.get(
|
||||
"incomplete_details"
|
||||
)
|
||||
|
||||
# Also update output from final response if available
|
||||
if "output" in response_data:
|
||||
final_output = response_data.get("output", [])
|
||||
for item in final_output:
|
||||
item_id = item.get("id")
|
||||
if item_id:
|
||||
output_items[item_id] = item
|
||||
state_dirty = True
|
||||
|
||||
# Flush state to Redis if interval elapsed
|
||||
await flush_state_if_needed()
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to parse streaming chunk: {e}"
|
||||
)
|
||||
pass
|
||||
|
||||
# Final flush to ensure all accumulated state is saved
|
||||
await flush_state_if_needed(force=True)
|
||||
|
||||
# Mark as completed with all ResponsesAPIResponse fields
|
||||
await polling_handler.update_state(
|
||||
polling_id=polling_id,
|
||||
status="completed",
|
||||
usage=usage_data,
|
||||
reasoning=reasoning_data,
|
||||
tool_choice=tool_choice_data,
|
||||
tools=tools_data,
|
||||
model=model_data,
|
||||
instructions=instructions_data,
|
||||
temperature=temperature_data,
|
||||
top_p=top_p_data,
|
||||
max_output_tokens=max_output_tokens_data,
|
||||
previous_response_id=previous_response_id_data,
|
||||
text=text_data,
|
||||
truncation=truncation_data,
|
||||
parallel_tool_calls=parallel_tool_calls_data,
|
||||
user=user_data,
|
||||
store=store_data,
|
||||
incomplete_details=incomplete_details_data,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Completed background streaming for {polling_id}, output_items={len(output_items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in background streaming task for {polling_id}: {str(e)}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
verbose_proxy_logger.error(traceback.format_exc())
|
||||
|
||||
await polling_handler.update_state(
|
||||
polling_id=polling_id,
|
||||
status="failed",
|
||||
error={
|
||||
"type": "internal_error",
|
||||
"message": str(e),
|
||||
"code": "background_streaming_error",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Response Polling Handler for Background Responses with Cache
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid4
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
from litellm.types.llms.openai import ResponsesAPIResponse, ResponsesAPIStatus
|
||||
|
||||
|
||||
class ResponsePollingHandler:
|
||||
"""Handles polling-based responses with Redis cache"""
|
||||
|
||||
CACHE_KEY_PREFIX = "litellm:polling:response:"
|
||||
POLLING_ID_PREFIX = "litellm_poll_" # Clear prefix to identify polling IDs
|
||||
|
||||
def __init__(self, redis_cache: Optional[RedisCache] = None, ttl: int = 3600):
|
||||
self.redis_cache = redis_cache
|
||||
self.ttl = ttl # Time-to-live for cache entries (default: 1 hour)
|
||||
|
||||
@classmethod
|
||||
def generate_polling_id(cls) -> str:
|
||||
"""Generate a unique UUID for polling with clear prefix"""
|
||||
return f"{cls.POLLING_ID_PREFIX}{uuid4()}"
|
||||
|
||||
@classmethod
|
||||
def is_polling_id(cls, response_id: str) -> bool:
|
||||
"""Check if a response_id is a polling ID"""
|
||||
return response_id.startswith(cls.POLLING_ID_PREFIX)
|
||||
|
||||
@classmethod
|
||||
def get_cache_key(cls, polling_id: str) -> str:
|
||||
"""Get Redis cache key for a polling ID"""
|
||||
return f"{cls.CACHE_KEY_PREFIX}{polling_id}"
|
||||
|
||||
async def create_initial_state(
|
||||
self,
|
||||
polling_id: str,
|
||||
request_data: Dict[str, Any],
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Create initial state in Redis for a polling request
|
||||
|
||||
Uses OpenAI ResponsesAPIResponse object:
|
||||
https://platform.openai.com/docs/api-reference/responses/object
|
||||
|
||||
Args:
|
||||
polling_id: Unique identifier for this polling request
|
||||
request_data: Original request data
|
||||
|
||||
Returns:
|
||||
ResponsesAPIResponse object following OpenAI spec
|
||||
"""
|
||||
created_timestamp = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Create OpenAI-compliant response object
|
||||
response = ResponsesAPIResponse(
|
||||
id=polling_id,
|
||||
object="response",
|
||||
status="queued", # OpenAI native status
|
||||
created_at=created_timestamp,
|
||||
output=[],
|
||||
metadata=request_data.get("metadata", {}),
|
||||
usage=None,
|
||||
)
|
||||
|
||||
cache_key = self.get_cache_key(polling_id)
|
||||
|
||||
if self.redis_cache:
|
||||
# Store ResponsesAPIResponse directly in Redis
|
||||
await self.redis_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=response.model_dump_json(), # Pydantic v2 method
|
||||
ttl=self.ttl,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Created initial polling state for {polling_id} with TTL={self.ttl}s"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def update_state(
|
||||
self,
|
||||
polling_id: str,
|
||||
status: Optional[ResponsesAPIStatus] = None,
|
||||
usage: Optional[Dict] = None,
|
||||
error: Optional[Dict] = None,
|
||||
incomplete_details: Optional[Dict] = None,
|
||||
reasoning: Optional[Dict] = None,
|
||||
tool_choice: Optional[Any] = None,
|
||||
tools: Optional[list] = None,
|
||||
output: Optional[list] = None,
|
||||
# Additional ResponsesAPIResponse fields
|
||||
model: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
previous_response_id: Optional[str] = None,
|
||||
text: Optional[Dict] = None,
|
||||
truncation: Optional[str] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
user: Optional[str] = None,
|
||||
store: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the polling state in Redis
|
||||
|
||||
Uses OpenAI Response object format with native status types:
|
||||
https://platform.openai.com/docs/api-reference/responses/object
|
||||
|
||||
Args:
|
||||
polling_id: Unique identifier for this polling request
|
||||
status: OpenAI ResponsesAPIStatus value
|
||||
usage: Usage information
|
||||
error: Error dict (automatically sets status to "failed")
|
||||
incomplete_details: Details for incomplete responses
|
||||
reasoning: Reasoning configuration from response.completed
|
||||
tool_choice: Tool choice configuration from response.completed
|
||||
tools: Tools list from response.completed
|
||||
output: Full output list to replace current output
|
||||
model: Model identifier
|
||||
instructions: System instructions
|
||||
temperature: Sampling temperature
|
||||
top_p: Nucleus sampling parameter
|
||||
max_output_tokens: Maximum output tokens
|
||||
previous_response_id: ID of previous response in conversation
|
||||
text: Text configuration
|
||||
truncation: Truncation setting
|
||||
parallel_tool_calls: Whether parallel tool calls are enabled
|
||||
user: User identifier
|
||||
store: Whether to store the response
|
||||
"""
|
||||
if not self.redis_cache:
|
||||
return
|
||||
|
||||
cache_key = self.get_cache_key(polling_id)
|
||||
|
||||
# Get current state
|
||||
cached_state = await self.redis_cache.async_get_cache(cache_key)
|
||||
if not cached_state:
|
||||
verbose_proxy_logger.warning(
|
||||
f"No cached state found for polling_id: {polling_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Parse existing ResponsesAPIResponse from cache
|
||||
state = json.loads(cached_state)
|
||||
|
||||
# Update status (using OpenAI native status values)
|
||||
if status:
|
||||
state["status"] = status
|
||||
|
||||
# Replace full output list if provided
|
||||
if output is not None:
|
||||
state["output"] = output
|
||||
|
||||
# Update usage
|
||||
if usage:
|
||||
state["usage"] = usage
|
||||
|
||||
# Handle error (sets status to OpenAI's "failed")
|
||||
if error:
|
||||
state["status"] = "failed"
|
||||
state["error"] = error # Use OpenAI's 'error' field
|
||||
|
||||
# Handle incomplete details
|
||||
if incomplete_details:
|
||||
state["incomplete_details"] = incomplete_details
|
||||
|
||||
# Update reasoning, tool_choice, tools from response.completed
|
||||
if reasoning is not None:
|
||||
state["reasoning"] = reasoning
|
||||
if tool_choice is not None:
|
||||
state["tool_choice"] = tool_choice
|
||||
if tools is not None:
|
||||
state["tools"] = tools
|
||||
|
||||
# Update additional ResponsesAPIResponse fields
|
||||
if model is not None:
|
||||
state["model"] = model
|
||||
if instructions is not None:
|
||||
state["instructions"] = instructions
|
||||
if temperature is not None:
|
||||
state["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
state["top_p"] = top_p
|
||||
if max_output_tokens is not None:
|
||||
state["max_output_tokens"] = max_output_tokens
|
||||
if previous_response_id is not None:
|
||||
state["previous_response_id"] = previous_response_id
|
||||
if text is not None:
|
||||
state["text"] = text
|
||||
if truncation is not None:
|
||||
state["truncation"] = truncation
|
||||
if parallel_tool_calls is not None:
|
||||
state["parallel_tool_calls"] = parallel_tool_calls
|
||||
if user is not None:
|
||||
state["user"] = user
|
||||
if store is not None:
|
||||
state["store"] = store
|
||||
|
||||
# Update cache with configured TTL
|
||||
await self.redis_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=json.dumps(state),
|
||||
ttl=self.ttl,
|
||||
)
|
||||
|
||||
output_count = len(state.get("output", []))
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updated polling state for {polling_id}: status={state['status']}, output_items={output_count}"
|
||||
)
|
||||
|
||||
async def get_state(self, polling_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get current polling state from Redis"""
|
||||
if not self.redis_cache:
|
||||
return None
|
||||
|
||||
cache_key = self.get_cache_key(polling_id)
|
||||
cached_state = await self.redis_cache.async_get_cache(cache_key)
|
||||
|
||||
if cached_state:
|
||||
return json.loads(cached_state)
|
||||
|
||||
return None
|
||||
|
||||
async def cancel_polling(self, polling_id: str) -> bool:
|
||||
"""
|
||||
Cancel a polling request
|
||||
|
||||
Following OpenAI Response object format for cancelled status
|
||||
"""
|
||||
await self.update_state(
|
||||
polling_id=polling_id,
|
||||
status="cancelled",
|
||||
)
|
||||
return True
|
||||
|
||||
async def delete_polling(self, polling_id: str) -> bool:
|
||||
"""Delete a polling request from cache"""
|
||||
if not self.redis_cache:
|
||||
return False
|
||||
|
||||
cache_key = self.get_cache_key(polling_id)
|
||||
# Use RedisCache's async_delete_cache method which handles Redis/RedisCluster
|
||||
await self.redis_cache.async_delete_cache(cache_key)
|
||||
return True
|
||||
|
||||
|
||||
def should_use_polling_for_request(
|
||||
background_mode: bool,
|
||||
polling_via_cache_enabled, # Can be False, "all", or List[str]
|
||||
redis_cache, # RedisCache or None
|
||||
model: str,
|
||||
llm_router, # Router instance or None
|
||||
native_background_mode: Optional[
|
||||
List[str]
|
||||
] = None, # List of models that should use native background mode
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if polling via cache should be used for a request.
|
||||
|
||||
Args:
|
||||
background_mode: Whether background=true was set in the request
|
||||
polling_via_cache_enabled: Config value - False, "all", or list of providers
|
||||
redis_cache: Redis cache instance (required for polling)
|
||||
model: Model name from the request (e.g., "gpt-5" or "openai/gpt-4o")
|
||||
llm_router: LiteLLM router instance for looking up model deployments
|
||||
native_background_mode: List of model names that should use native provider
|
||||
background mode instead of polling via cache
|
||||
|
||||
Returns:
|
||||
True if polling should be used, False otherwise
|
||||
"""
|
||||
# All conditions must be met
|
||||
if not (background_mode and polling_via_cache_enabled and redis_cache):
|
||||
return False
|
||||
|
||||
# Check if model is in native_background_mode list - these use native provider background mode
|
||||
if native_background_mode and model in native_background_mode:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} is in native_background_mode list, skipping polling via cache"
|
||||
)
|
||||
return False
|
||||
|
||||
# "all" enables polling for all providers
|
||||
if polling_via_cache_enabled == "all":
|
||||
return True
|
||||
|
||||
# Check if provider is in the enabled list
|
||||
if isinstance(polling_via_cache_enabled, list):
|
||||
# First, try to get provider from model string format "provider/model"
|
||||
if "/" in model:
|
||||
provider = model.split("/")[0]
|
||||
if provider in polling_via_cache_enabled:
|
||||
return True
|
||||
# Otherwise, check ALL deployments for this model_name in router
|
||||
elif llm_router is not None:
|
||||
try:
|
||||
# Get all deployment indices for this model name
|
||||
indices = llm_router.model_name_to_deployment_indices.get(model, [])
|
||||
for idx in indices:
|
||||
deployment_dict = llm_router.model_list[idx]
|
||||
litellm_params = deployment_dict.get("litellm_params", {})
|
||||
|
||||
# Check custom_llm_provider first
|
||||
dep_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Then try to extract from model (e.g., "openai/gpt-5")
|
||||
if not dep_provider:
|
||||
dep_model = litellm_params.get("model", "")
|
||||
if "/" in dep_model:
|
||||
dep_provider = dep_model.split("/")[0]
|
||||
|
||||
# If ANY deployment's provider matches, enable polling
|
||||
if dep_provider and dep_provider in polling_via_cache_enabled:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Polling enabled for model={model}, provider={dep_provider}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Could not resolve provider for model {model}: {e}"
|
||||
)
|
||||
|
||||
return False
|
||||
Reference in New Issue
Block a user