chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
This file contains the calling OpenAI's `/v1/realtime` endpoint.
|
||||
|
||||
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES
|
||||
from litellm.types.realtime import RealtimeQueryParams
|
||||
|
||||
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||
from ....llms.custom_httpx.http_handler import get_shared_realtime_ssl_context
|
||||
from ..openai import OpenAIChatCompletion
|
||||
|
||||
|
||||
class OpenAIRealtime(OpenAIChatCompletion):
|
||||
"""
|
||||
Base handler for OpenAI-compatible realtime WebSocket connections.
|
||||
|
||||
Subclasses can override template methods to customize:
|
||||
- _get_default_api_base(): Default API base URL
|
||||
- _get_additional_headers(): Extra headers beyond Authorization
|
||||
- _get_ssl_config(): SSL configuration for WebSocket connection
|
||||
"""
|
||||
|
||||
def _get_default_api_base(self) -> str:
|
||||
"""
|
||||
Get the default API base URL for this provider.
|
||||
Override this in subclasses to set provider-specific defaults.
|
||||
"""
|
||||
return "https://api.openai.com/"
|
||||
|
||||
def _get_additional_headers(self, api_key: str) -> dict:
|
||||
"""
|
||||
Get additional headers beyond Authorization.
|
||||
Override this in subclasses to customize headers (e.g., remove OpenAI-Beta).
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary of additional headers
|
||||
"""
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
def _get_ssl_config(self, url: str) -> Any:
|
||||
"""
|
||||
Get SSL configuration for WebSocket connection.
|
||||
Override this in subclasses to customize SSL behavior.
|
||||
|
||||
Args:
|
||||
url: WebSocket URL (ws:// or wss://)
|
||||
|
||||
Returns:
|
||||
SSL configuration (None, True, or SSLContext)
|
||||
"""
|
||||
if url.startswith("ws://"):
|
||||
return None
|
||||
|
||||
# Use the shared SSL context which respects custom CA certs and SSL settings
|
||||
ssl_config = get_shared_realtime_ssl_context()
|
||||
|
||||
# If ssl_config is False (ssl_verify=False), websockets library needs True instead
|
||||
# to establish connection without verification (False would fail)
|
||||
if ssl_config is False:
|
||||
return True
|
||||
|
||||
return ssl_config
|
||||
|
||||
def _construct_url(self, api_base: str, query_params: RealtimeQueryParams) -> str:
|
||||
"""
|
||||
Construct the backend websocket URL with all query parameters (including 'model').
|
||||
"""
|
||||
from httpx import URL
|
||||
|
||||
api_base = api_base.replace("https://", "wss://")
|
||||
api_base = api_base.replace("http://", "ws://")
|
||||
url = URL(api_base)
|
||||
# Set the correct path
|
||||
url = url.copy_with(path="/v1/realtime")
|
||||
# Include all query parameters including 'model'
|
||||
if query_params:
|
||||
url = url.copy_with(params=query_params)
|
||||
return str(url)
|
||||
|
||||
async def async_realtime(
|
||||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
logging_obj: LiteLLMLogging,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
query_params: Optional[RealtimeQueryParams] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
litellm_metadata: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
if api_base is None:
|
||||
api_base = self._get_default_api_base()
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for OpenAI realtime calls")
|
||||
|
||||
# Use all query params if provided, else fallback to just model
|
||||
if query_params is None:
|
||||
query_params = {"model": model}
|
||||
url = self._construct_url(api_base, query_params)
|
||||
|
||||
try:
|
||||
# Get provider-specific SSL configuration
|
||||
ssl_config = self._get_ssl_config(url)
|
||||
|
||||
# Get provider-specific headers
|
||||
headers = self._get_additional_headers(api_key)
|
||||
|
||||
# Log a masked request preview consistent with other endpoints.
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
"complete_input_dict": {"query_params": query_params},
|
||||
},
|
||||
)
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
additional_headers=headers, # type: ignore
|
||||
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
|
||||
ssl=ssl_config,
|
||||
) as backend_ws:
|
||||
realtime_streaming = RealTimeStreaming(
|
||||
websocket,
|
||||
cast(ClientConnection, backend_ws),
|
||||
logging_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data={"litellm_metadata": litellm_metadata or {}},
|
||||
)
|
||||
await realtime_streaming.bidirectional_forward()
|
||||
|
||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||
await websocket.close(code=e.status_code, reason=str(e))
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.close(
|
||||
code=1011, reason=f"Internal server error: {str(e)}"
|
||||
)
|
||||
except RuntimeError as close_error:
|
||||
if "already completed" in str(close_error) or "websocket.close" in str(
|
||||
close_error
|
||||
):
|
||||
# The WebSocket is already closed or the response is completed, so we can ignore this error
|
||||
pass
|
||||
else:
|
||||
# If it's a different RuntimeError, we might want to log it or handle it differently
|
||||
raise Exception(
|
||||
f"Unexpected error while closing WebSocket: {close_error}"
|
||||
)
|
||||
Reference in New Issue
Block a user