Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/bedrock/realtime/handler.py
2026-03-26 20:06:14 +08:00

308 lines
12 KiB
Python

"""
This file contains the handler for AWS Bedrock Nova Sonic realtime API.
This uses aws_sdk_bedrock_runtime for bidirectional streaming with Nova Sonic.
"""
import asyncio
import json
from typing import Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..base_aws_llm import BaseAWSLLM
from .transformation import BedrockRealtimeConfig
class BedrockRealtime(BaseAWSLLM):
"""Handler for Bedrock Nova Sonic realtime speech-to-speech API."""
def __init__(self):
super().__init__()
async def async_realtime(
self,
model: str,
websocket: Any,
logging_obj: LiteLLMLogging,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
timeout: Optional[float] = None,
aws_region_name: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str] = None,
aws_external_id: Optional[str] = None,
**kwargs,
):
"""
Establish bidirectional streaming connection with Bedrock Nova Sonic.
Args:
model: Model ID (e.g., 'amazon.nova-sonic-v1:0')
websocket: Client WebSocket connection
logging_obj: LiteLLM logging object
aws_region_name: AWS region
Various AWS authentication parameters
"""
try:
from aws_sdk_bedrock_runtime.client import (
BedrockRuntimeClient,
InvokeModelWithBidirectionalStreamOperationInput,
)
from aws_sdk_bedrock_runtime.config import Config
from smithy_aws_core.identity.environment import (
EnvironmentCredentialsResolver,
)
except ImportError:
raise ImportError(
"Missing aws_sdk_bedrock_runtime. Install with: pip install aws-sdk-bedrock-runtime"
)
# Get AWS region
if aws_region_name is None:
optional_params = {
"aws_region_name": aws_region_name,
}
aws_region_name = self._get_aws_region_name(optional_params, model)
# Get endpoint URL
if api_base is not None:
endpoint_uri = api_base
elif aws_bedrock_runtime_endpoint is not None:
endpoint_uri = aws_bedrock_runtime_endpoint
else:
endpoint_uri = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
verbose_proxy_logger.debug(
f"Bedrock Realtime: Connecting to {endpoint_uri} with model {model}"
)
# Initialize Bedrock client with aws_sdk_bedrock_runtime
config = Config(
endpoint_uri=endpoint_uri,
region=aws_region_name,
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
)
bedrock_client = BedrockRuntimeClient(config=config)
transformation_config = BedrockRealtimeConfig()
try:
# Initialize the bidirectional stream
bedrock_stream = (
await bedrock_client.invoke_model_with_bidirectional_stream(
InvokeModelWithBidirectionalStreamOperationInput(model_id=model)
)
)
verbose_proxy_logger.debug(
"Bedrock Realtime: Bidirectional stream established"
)
# Track state for transformation
session_state = {
"current_output_item_id": None,
"current_response_id": None,
"current_conversation_id": None,
"current_delta_chunks": None,
"current_item_chunks": None,
"current_delta_type": None,
"session_configuration_request": None,
}
# Create tasks for bidirectional forwarding
client_to_bedrock_task = asyncio.create_task(
self._forward_client_to_bedrock(
websocket,
bedrock_stream,
transformation_config,
model,
session_state,
)
)
bedrock_to_client_task = asyncio.create_task(
self._forward_bedrock_to_client(
bedrock_stream,
websocket,
transformation_config,
model,
logging_obj,
session_state,
)
)
# Wait for both tasks to complete
await asyncio.gather(
client_to_bedrock_task,
bedrock_to_client_task,
return_exceptions=True,
)
except Exception as e:
verbose_proxy_logger.exception(
f"Error in BedrockRealtime.async_realtime: {e}"
)
try:
await websocket.close(code=1011, reason=f"Internal error: {str(e)}")
except Exception:
pass
raise
async def _forward_client_to_bedrock(
self,
client_ws: Any,
bedrock_stream: Any,
transformation_config: BedrockRealtimeConfig,
model: str,
session_state: dict,
):
"""Forward messages from client WebSocket to Bedrock stream."""
try:
from aws_sdk_bedrock_runtime.models import (
BidirectionalInputPayloadPart,
InvokeModelWithBidirectionalStreamInputChunk,
)
while True:
# Receive message from client
message = await client_ws.receive_text()
verbose_proxy_logger.debug(
f"Bedrock Realtime: Received from client: {message[:200]}"
)
# Transform OpenAI format to Bedrock format
transformed_messages = transformation_config.transform_realtime_request(
message=message,
model=model,
session_configuration_request=session_state.get(
"session_configuration_request"
),
)
# Send transformed messages to Bedrock
for bedrock_message in transformed_messages:
event = InvokeModelWithBidirectionalStreamInputChunk(
value=BidirectionalInputPayloadPart(
bytes_=bedrock_message.encode("utf-8")
)
)
await bedrock_stream.input_stream.send(event)
verbose_proxy_logger.debug(
f"Bedrock Realtime: Sent to Bedrock: {bedrock_message[:200]}"
)
except Exception as e:
verbose_proxy_logger.debug(
f"Client to Bedrock forwarding ended: {e}", exc_info=True
)
# Close the Bedrock stream input
try:
await bedrock_stream.input_stream.close()
except Exception:
pass
async def _forward_bedrock_to_client(
self,
bedrock_stream: Any,
client_ws: Any,
transformation_config: BedrockRealtimeConfig,
model: str,
logging_obj: LiteLLMLogging,
session_state: dict,
):
"""Forward messages from Bedrock stream to client WebSocket."""
try:
while True:
# Receive from Bedrock
output = await bedrock_stream.await_output()
result = await output[1].receive()
if result.value and result.value.bytes_:
bedrock_response = result.value.bytes_.decode("utf-8")
verbose_proxy_logger.debug(
f"Bedrock Realtime: Received from Bedrock: {bedrock_response[:200]}"
)
# Transform Bedrock format to OpenAI format
from litellm.types.realtime import RealtimeResponseTransformInput
realtime_response_transform_input: RealtimeResponseTransformInput = {
"current_output_item_id": session_state.get(
"current_output_item_id"
),
"current_response_id": session_state.get("current_response_id"),
"current_conversation_id": session_state.get(
"current_conversation_id"
),
"current_delta_chunks": session_state.get(
"current_delta_chunks"
),
"current_item_chunks": session_state.get("current_item_chunks"),
"current_delta_type": session_state.get("current_delta_type"),
"session_configuration_request": session_state.get(
"session_configuration_request"
),
}
transformed_response = transformation_config.transform_realtime_response(
message=bedrock_response,
model=model,
logging_obj=logging_obj,
realtime_response_transform_input=realtime_response_transform_input,
)
# Update session state
session_state.update(
{
"current_output_item_id": transformed_response.get(
"current_output_item_id"
),
"current_response_id": transformed_response.get(
"current_response_id"
),
"current_conversation_id": transformed_response.get(
"current_conversation_id"
),
"current_delta_chunks": transformed_response.get(
"current_delta_chunks"
),
"current_item_chunks": transformed_response.get(
"current_item_chunks"
),
"current_delta_type": transformed_response.get(
"current_delta_type"
),
"session_configuration_request": transformed_response.get(
"session_configuration_request"
),
}
)
# Send transformed messages to client
openai_messages = transformed_response.get("response", [])
for openai_message in openai_messages:
message_json = json.dumps(openai_message)
await client_ws.send_text(message_json)
verbose_proxy_logger.debug(
f"Bedrock Realtime: Sent to client: {message_json[:200]}"
)
except Exception as e:
verbose_proxy_logger.debug(
f"Bedrock to client forwarding ended: {e}", exc_info=True
)
# Close the client WebSocket
try:
await client_ws.close()
except Exception:
pass