Files

537 lines
19 KiB
Python
Raw Permalink Normal View History

"""Abstraction function for OpenAI's realtime API"""
import os
from typing import Any, Dict, Optional, cast
import litellm
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES, request_timeout
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.secret_managers.main import get_secret_str
from litellm.types.realtime import (
RealtimeClientSecretRequest,
RealtimeExpiresAfter,
RealtimeQueryParams,
RealtimeSessionConfig,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
from ..litellm_core_utils.get_litellm_params import get_litellm_params
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
from ..llms.bedrock.realtime.handler import BedrockRealtime
from ..llms.custom_httpx.http_handler import get_shared_realtime_ssl_context
from ..llms.openai.realtime.handler import OpenAIRealtime
from ..llms.vertex_ai.realtime.transformation import VertexAIRealtimeConfig
from ..llms.vertex_ai.vertex_llm_base import VertexBase
from ..llms.xai.realtime.handler import XAIRealtime
from ..utils import client as wrapper_client
azure_realtime = AzureOpenAIRealtime()
openai_realtime = OpenAIRealtime()
bedrock_realtime = BedrockRealtime()
xai_realtime = XAIRealtime()
vertex_llm_base = VertexBase()
base_llm_http_handler = BaseLLMHTTPHandler()
def _build_litellm_metadata(kwargs: dict) -> dict:
"""Build the litellm_metadata dict for guardrail checking (internal only, not forwarded to provider)."""
metadata: dict = {**(kwargs.get("litellm_metadata") or {})}
guardrails = (
(kwargs.get("metadata") or {}).get("guardrails")
or kwargs.get("guardrails")
or []
)
if guardrails:
metadata["guardrails"] = guardrails
return metadata
def _get_realtime_http_provider_config(
custom_llm_provider: str,
dynamic_api_base: Optional[str],
dynamic_api_key: Optional[str],
litellm_params: GenericLiteLLMParams,
) -> tuple[Any, str, str]:
"""
Return (provider_config, resolved_api_base, resolved_api_key) for the
realtime HTTP endpoints (client_secrets / realtime_calls).
Uses ProviderConfigManager so each provider keeps its credential-resolution
and URL-construction logic in its own transformation class.
"""
from litellm.llms.base_llm.realtime.http_transformation import (
BaseRealtimeHTTPConfig,
)
provider_config: Optional[BaseRealtimeHTTPConfig] = None
if custom_llm_provider in LlmProviders._member_map_.values():
provider_config = ProviderConfigManager.get_provider_realtime_http_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
raw_api_base = dynamic_api_base or litellm_params.api_base
raw_api_key = dynamic_api_key or litellm_params.api_key
if provider_config is not None:
resolved_api_base = provider_config.get_api_base(api_base=raw_api_base)
resolved_api_key = provider_config.get_api_key(api_key=raw_api_key)
else:
# Fallback for providers without a dedicated HTTP config (treated as OpenAI-compatible).
resolved_api_base = raw_api_base or litellm.api_base or "https://api.openai.com"
resolved_api_key = (
raw_api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
or ""
)
return provider_config, resolved_api_base.rstrip("/"), resolved_api_key
@wrapper_client
async def acreate_realtime_client_secret(
model: Optional[str] = None,
session: Optional[Dict[str, Any]] = None,
expires_after: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
**kwargs,
):
req = RealtimeClientSecretRequest(
model=model,
session=RealtimeSessionConfig(**session) if session else None,
expires_after=RealtimeExpiresAfter(**expires_after) if expires_after else None,
)
model_name = (
(req.session.model if req.session is not None else None)
or req.model
or "gpt-4o-realtime-preview"
)
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
litellm_params = GenericLiteLLMParams(**kwargs)
(
model_name,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = get_llm_provider(
model=model_name,
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
)
(
provider_config,
resolved_api_base,
resolved_api_key,
) = _get_realtime_http_provider_config(
custom_llm_provider=custom_llm_provider,
dynamic_api_base=dynamic_api_base,
dynamic_api_key=dynamic_api_key,
litellm_params=litellm_params,
)
litellm_logging_obj.update_environment_variables(
model=model_name,
optional_params={"expires_after": expires_after, "session": session},
litellm_params={"api_base": resolved_api_base},
custom_llm_provider=custom_llm_provider,
)
request_data = req.model_dump(exclude_none=True, exclude={"model"})
return await base_llm_http_handler.async_realtime_client_secret_handler(
api_base=resolved_api_base,
api_key=resolved_api_key,
request_data=request_data,
logging_obj=litellm_logging_obj,
timeout=timeout or request_timeout,
provider_config=provider_config,
model=model_name,
extra_headers=kwargs.get("extra_headers"),
client=kwargs.get("client"),
api_version=litellm_params.api_version,
)
@wrapper_client
async def arealtime_calls(
openai_ephemeral_key: str,
sdp_body: bytes,
model: Optional[str] = None,
session: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
**kwargs,
):
model_name = model or "gpt-4o-realtime-preview"
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
litellm_params = GenericLiteLLMParams(**kwargs)
(
model_name,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = get_llm_provider(
model=model_name,
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
)
provider_config, resolved_api_base, _ = _get_realtime_http_provider_config(
custom_llm_provider=custom_llm_provider,
dynamic_api_base=dynamic_api_base,
dynamic_api_key=dynamic_api_key,
litellm_params=litellm_params,
)
litellm_logging_obj.update_environment_variables(
model=model_name,
optional_params={"realtime_calls": True, "session": session},
litellm_params={"api_base": resolved_api_base},
custom_llm_provider=custom_llm_provider,
)
return await base_llm_http_handler.async_realtime_calls_handler(
api_base=resolved_api_base,
openai_ephemeral_key=openai_ephemeral_key,
sdp_body=sdp_body,
logging_obj=litellm_logging_obj,
timeout=timeout or request_timeout,
provider_config=provider_config,
model=model_name,
session_config=session,
extra_headers=kwargs.get("extra_headers"),
client=kwargs.get("client"),
api_version=litellm_params.api_version,
)
@wrapper_client
async def _arealtime( # noqa: PLR0915
model: str,
websocket: Any, # fastapi websocket
api_base: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
azure_ad_token: Optional[str] = None,
client: Optional[Any] = None,
timeout: Optional[float] = None,
query_params: Optional[RealtimeQueryParams] = None,
**kwargs,
):
"""
Private function to handle the realtime API call.
For PROXY use only.
"""
headers = cast(Optional[dict], kwargs.get("headers"))
extra_headers = cast(Optional[dict], kwargs.get("extra_headers"))
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
user = kwargs.get("user", None)
litellm_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
model=model,
api_base=api_base,
api_key=api_key,
)
# Ensure query params use the normalized provider model (no proxy aliases).
if query_params is not None:
query_params = {**query_params, "model": model}
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params={},
litellm_params=litellm_params_dict,
custom_llm_provider=_custom_llm_provider,
)
provider_config: Optional[BaseRealtimeConfig] = None
if _custom_llm_provider in LlmProviders._member_map_.values():
provider_config = ProviderConfigManager.get_provider_realtime_config(
model=model,
provider=LlmProviders(_custom_llm_provider),
)
if provider_config is not None:
await base_llm_http_handler.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
provider_config=provider_config,
api_base=api_base,
api_key=api_key,
client=client,
timeout=timeout,
headers=headers,
user_api_key_dict=kwargs.get("user_api_key_dict"),
litellm_metadata=_build_litellm_metadata(kwargs),
)
elif _custom_llm_provider == "azure":
api_base = (
dynamic_api_base
or litellm_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
# set API KEY
api_key = (
dynamic_api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("AZURE_API_KEY")
)
api_version = api_version or litellm_params.api_version or "2024-10-01-preview"
realtime_protocol = (
kwargs.get("realtime_protocol")
or litellm_params.get("realtime_protocol")
or os.environ.get("LITELLM_AZURE_REALTIME_PROTOCOL")
or "beta"
)
await azure_realtime.async_realtime(
model=model,
websocket=websocket,
api_base=api_base,
api_key=api_key,
api_version=api_version,
azure_ad_token=None,
client=None,
timeout=timeout,
logging_obj=litellm_logging_obj,
realtime_protocol=realtime_protocol,
user_api_key_dict=kwargs.get("user_api_key_dict"),
litellm_metadata=_build_litellm_metadata(kwargs),
)
elif _custom_llm_provider == "openai":
api_base = (
dynamic_api_base
or litellm_params.api_base
or litellm.api_base
or "https://api.openai.com/"
)
# set API KEY
api_key = (
dynamic_api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
await openai_realtime.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
client=None,
timeout=timeout,
query_params=query_params,
user_api_key_dict=kwargs.get("user_api_key_dict"),
litellm_metadata=_build_litellm_metadata(kwargs),
)
elif _custom_llm_provider == "bedrock":
# Extract AWS parameters from kwargs
aws_region_name = kwargs.get("aws_region_name")
aws_access_key_id = kwargs.get("aws_access_key_id")
aws_secret_access_key = kwargs.get("aws_secret_access_key")
aws_session_token = kwargs.get("aws_session_token")
aws_role_name = kwargs.get("aws_role_name")
aws_session_name = kwargs.get("aws_session_name")
aws_profile_name = kwargs.get("aws_profile_name")
aws_web_identity_token = kwargs.get("aws_web_identity_token")
aws_sts_endpoint = kwargs.get("aws_sts_endpoint")
aws_bedrock_runtime_endpoint = kwargs.get("aws_bedrock_runtime_endpoint")
aws_external_id = kwargs.get("aws_external_id")
await bedrock_realtime.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
api_base=dynamic_api_base or api_base,
api_key=dynamic_api_key or api_key,
timeout=timeout,
aws_region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_external_id=aws_external_id,
)
elif _custom_llm_provider == "xai":
api_base = (
dynamic_api_base
or litellm_params.api_base
or get_secret_str("XAI_API_BASE")
or "https://api.x.ai/v1"
)
# set API KEY
api_key = dynamic_api_key or litellm.api_key or get_secret_str("XAI_API_KEY")
await xai_realtime.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
client=None,
timeout=timeout,
query_params=query_params,
user_api_key_dict=kwargs.get("user_api_key_dict"),
litellm_metadata=_build_litellm_metadata(kwargs),
)
elif _custom_llm_provider == "vertex_ai":
vertex_credentials = (
kwargs.get("vertex_credentials")
or kwargs.get("vertex_ai_credentials")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
vertex_project = (
kwargs.get("vertex_project")
or kwargs.get("vertex_ai_project")
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_location = (
kwargs.get("vertex_location")
or kwargs.get("vertex_ai_location")
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
resolved_location = vertex_llm_base.get_vertex_region(
vertex_region=vertex_location, model=model
)
(
access_token,
resolved_project,
) = await vertex_llm_base._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
vertex_realtime_config = VertexAIRealtimeConfig(
access_token=access_token,
project=resolved_project,
location=resolved_location,
)
await base_llm_http_handler.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
provider_config=vertex_realtime_config,
api_base=dynamic_api_base or litellm_params.api_base,
api_key=None,
client=client,
timeout=timeout,
headers=headers,
user_api_key_dict=kwargs.get("user_api_key_dict"),
litellm_metadata=_build_litellm_metadata(kwargs),
)
else:
raise ValueError(f"Unsupported model: {model}")
async def _realtime_health_check(
model: str,
custom_llm_provider: str,
api_key: Optional[str],
api_base: Optional[str] = None,
api_version: Optional[str] = None,
realtime_protocol: Optional[str] = None,
):
"""
Health check for realtime API - tries connection to the realtime API websocket
Args:
model: str - model name
api_base: str - api base
api_version: Optional[str] - api version
api_key: str - api key
custom_llm_provider: str - custom llm provider
realtime_protocol: Optional[str] - protocol version ("GA"/"v1" for GA path, "beta"/None for beta path)
Returns:
bool - True if connection is successful, False otherwise
Raises:
Exception - if the connection is not successful
"""
import websockets
url: Optional[str] = None
if custom_llm_provider == "azure":
url = azure_realtime._construct_url(
api_base=api_base or "",
model=model,
api_version=api_version or "2024-10-01-preview",
realtime_protocol=realtime_protocol,
)
elif custom_llm_provider == "openai":
url = openai_realtime._construct_url(
api_base=api_base or "https://api.openai.com/",
query_params={"model": model},
)
elif custom_llm_provider == "xai":
url = xai_realtime._construct_url(
api_base=api_base or "https://api.x.ai/v1", query_params={"model": model}
)
elif custom_llm_provider == "vertex_ai":
vertex_location = litellm.vertex_location or get_secret_str("VERTEXAI_LOCATION")
resolved_location = vertex_llm_base.get_vertex_region(
vertex_region=vertex_location, model=model
)
(
access_token,
resolved_project,
) = await vertex_llm_base._ensure_access_token_async(
credentials=None,
project_id=litellm.vertex_project or get_secret_str("VERTEXAI_PROJECT"),
custom_llm_provider="vertex_ai",
)
vertex_realtime_config = VertexAIRealtimeConfig(
access_token=access_token,
project=resolved_project,
location=resolved_location,
)
url = vertex_realtime_config.get_complete_url(api_base=api_base, model=model)
ssl_context = get_shared_realtime_ssl_context()
headers = vertex_realtime_config.validate_environment(
headers={}, model=model, api_key=None
)
async with websockets.connect( # type: ignore
url,
additional_headers=headers,
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
ssl=ssl_context,
):
return True
else:
raise ValueError(f"Unsupported model: {model}")
ssl_context = get_shared_realtime_ssl_context()
async with websockets.connect( # type: ignore
url,
additional_headers={
"api-key": api_key, # type: ignore
},
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
ssl=ssl_context,
):
return True