chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
from openai.types.batch import BatchRequestCounts
from openai.types.batch import Metadata as OpenAIBatchMetadata
from litellm.types.utils import LiteLLMBatch
class BedrockBatchesHandler:
"""
Handler for Bedrock Batches.
Specific providers/models needed some special handling.
E.g. Twelve Labs Embedding Async Invoke
"""
@staticmethod
def _handle_async_invoke_status(
batch_id: str, aws_region_name: str, logging_obj=None, **kwargs
) -> "LiteLLMBatch":
"""
Handle async invoke status check for AWS Bedrock.
This is for Twelve Labs Embedding Async Invoke.
Args:
batch_id: The async invoke ARN
aws_region_name: AWS region name
**kwargs: Additional parameters
Returns:
dict: Status information including status, output_file_id (S3 URL), etc.
"""
import asyncio
from litellm.llms.bedrock.embed.embedding import BedrockEmbedding
async def _async_get_status():
# Create embedding handler instance
embedding_handler = BedrockEmbedding()
# Get the status of the async invoke job
status_response = await embedding_handler._get_async_invoke_status(
invocation_arn=batch_id,
aws_region_name=aws_region_name,
logging_obj=logging_obj,
**kwargs,
)
# Transform response to a LiteLLMBatch object
from litellm.types.utils import LiteLLMBatch
openai_batch_metadata: OpenAIBatchMetadata = {
"output_file_id": status_response["outputDataConfig"][
"s3OutputDataConfig"
]["s3Uri"],
"failure_message": status_response.get("failureMessage") or "",
"model_arn": status_response["modelArn"],
}
result = LiteLLMBatch(
id=status_response["invocationArn"],
object="batch",
status=status_response["status"],
created_at=status_response["submitTime"],
in_progress_at=status_response["lastModifiedTime"],
completed_at=status_response.get("endTime"),
failed_at=status_response.get("endTime")
if status_response["status"] == "failed"
else None,
request_counts=BatchRequestCounts(
total=1,
completed=1 if status_response["status"] == "completed" else 0,
failed=1 if status_response["status"] == "failed" else 0,
),
metadata=openai_batch_metadata,
completion_window="24h",
endpoint="/v1/embeddings",
input_file_id="",
)
return result
# Since this function is called from within an async context via run_in_executor,
# we need to create a new event loop in a thread to avoid conflicts
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_async_get_status())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()

View File

@@ -0,0 +1,549 @@
import os
import time
from typing import Any, Dict, List, Literal, Optional, Union, cast
from httpx import Headers, Response
from litellm.llms.base_llm.batches.transformation import BaseBatchesConfig
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.bedrock import (
BedrockCreateBatchRequest,
BedrockCreateBatchResponse,
BedrockInputDataConfig,
BedrockOutputDataConfig,
BedrockS3InputDataConfig,
BedrockS3OutputDataConfig,
)
from litellm.types.llms.openai import (
AllMessageValues,
CreateBatchRequest,
)
from litellm.types.utils import LiteLLMBatch, LlmProviders
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import CommonBatchFilesUtils
class BedrockBatchesConfig(BaseAWSLLM, BaseBatchesConfig):
"""
Config for Bedrock Batches - handles batch job creation and management for Bedrock
"""
def __init__(self):
super().__init__()
self.common_utils = CommonBatchFilesUtils()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.BEDROCK
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate and prepare environment for Bedrock batch requests.
AWS credentials are handled by BaseAWSLLM.
"""
# Add any Bedrock-specific headers if needed
return headers
def get_complete_batch_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateBatchRequest,
) -> str:
"""
Get the complete URL for Bedrock batch creation.
Bedrock batch jobs are created via the model invocation job API.
"""
aws_region_name = self._get_aws_region_name(optional_params, model)
# Bedrock model invocation job endpoint
# Format: https://bedrock.{region}.amazonaws.com/model-invocation-job
bedrock_endpoint = (
f"https://bedrock.{aws_region_name}.amazonaws.com/model-invocation-job"
)
return bedrock_endpoint
def transform_create_batch_request(
self,
model: str,
create_batch_data: CreateBatchRequest,
optional_params: dict,
litellm_params: dict,
) -> Dict[str, Any]:
"""
Transform the batch creation request to Bedrock format.
Bedrock batch inference requires:
- modelId: The Bedrock model ID
- jobName: Unique name for the batch job
- inputDataConfig: Configuration for input data (S3 location)
- outputDataConfig: Configuration for output data (S3 location)
- roleArn: IAM role ARN for the batch job
"""
# Get required parameters
input_file_id = create_batch_data.get("input_file_id")
if not input_file_id:
raise ValueError("input_file_id is required for Bedrock batch creation")
# Extract S3 information from file ID using common utility
input_bucket, input_key = self.common_utils.parse_s3_uri(input_file_id)
# Get output S3 configuration
output_bucket = litellm_params.get("s3_output_bucket_name") or os.getenv(
"AWS_S3_OUTPUT_BUCKET_NAME"
)
if not output_bucket:
# Use same bucket as input if no output bucket specified
output_bucket = input_bucket
# Get IAM role ARN
role_arn = (
litellm_params.get("aws_batch_role_arn")
or optional_params.get("aws_batch_role_arn")
or os.getenv("AWS_BATCH_ROLE_ARN")
)
if not role_arn:
raise ValueError(
"AWS IAM role ARN is required for Bedrock batch jobs. "
"Set 'aws_batch_role_arn' in litellm_params or AWS_BATCH_ROLE_ARN env var"
)
if not model:
raise ValueError(
"Could not determine Bedrock model ID. Please pass `model` in your request body."
)
# Generate job name with the correct model ID using common utility
job_name = self.common_utils.generate_unique_job_name(model, prefix="litellm")
output_key = f"litellm-batch-outputs/{job_name}/"
# Build input data config
input_data_config: BedrockInputDataConfig = {
"s3InputDataConfig": BedrockS3InputDataConfig(
s3Uri=f"s3://{input_bucket}/{input_key}"
)
}
# Build output data config
s3_output_config: BedrockS3OutputDataConfig = BedrockS3OutputDataConfig(
s3Uri=f"s3://{output_bucket}/{output_key}"
)
# Add optional KMS encryption key ID if provided
s3_encryption_key_id = litellm_params.get(
"s3_encryption_key_id"
) or get_secret_str("AWS_S3_ENCRYPTION_KEY_ID")
if s3_encryption_key_id:
s3_output_config["s3EncryptionKeyId"] = s3_encryption_key_id
output_data_config: BedrockOutputDataConfig = {
"s3OutputDataConfig": s3_output_config
}
# Create Bedrock batch request with proper typing
bedrock_request: BedrockCreateBatchRequest = {
"modelId": model,
"jobName": job_name,
"inputDataConfig": input_data_config,
"outputDataConfig": output_data_config,
"roleArn": role_arn,
}
# Add optional parameters if provided
completion_window = create_batch_data.get("completion_window")
if completion_window:
# Map OpenAI completion window to Bedrock timeout
# OpenAI uses "24h", Bedrock expects timeout in hours
if completion_window == "24h":
bedrock_request["timeoutDurationInHours"] = 24
# For Bedrock, we need to return a pre-signed request with AWS auth headers
# Use common utility for AWS signing
endpoint_url = f"https://bedrock.{self._get_aws_region_name(optional_params, model)}.amazonaws.com/model-invocation-job"
signed_headers, signed_data = self.common_utils.sign_aws_request(
service_name="bedrock",
data=bedrock_request,
endpoint_url=endpoint_url,
optional_params=optional_params,
method="POST",
)
# Return a pre-signed request format that the HTTP handler can use
return {
"method": "POST",
"url": endpoint_url,
"headers": signed_headers,
"data": signed_data.decode("utf-8"),
}
def transform_create_batch_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: Any,
litellm_params: dict,
) -> LiteLLMBatch:
"""
Transform Bedrock batch creation response to LiteLLM format.
"""
try:
response_data: BedrockCreateBatchResponse = raw_response.json()
except Exception as e:
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
# Extract information from typed Bedrock response
job_arn = response_data.get("jobArn", "")
status_str: str = str(response_data.get("status", "Submitted"))
# Map Bedrock status to OpenAI-compatible status
status_mapping: Dict[str, str] = {
"Submitted": "validating",
"Validating": "validating",
"Scheduled": "in_progress",
"InProgress": "in_progress",
"PartiallyCompleted": "completed",
"Completed": "completed",
"Failed": "failed",
"Stopping": "cancelling",
"Stopped": "cancelled",
"Expired": "expired",
}
openai_status = cast(
Literal[
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
],
status_mapping.get(status_str, "validating"),
)
# Get original request data from litellm_params if available
original_request = litellm_params.get("original_batch_request", {})
# Create LiteLLM batch object
return LiteLLMBatch(
id=job_arn, # Use ARN as the batch ID
object="batch",
endpoint=original_request.get("endpoint", "/v1/chat/completions"),
errors=None,
input_file_id=original_request.get("input_file_id", ""),
completion_window=original_request.get("completion_window", "24h"),
status=openai_status,
output_file_id=None, # Will be populated when job completes
error_file_id=None,
created_at=int(time.time()),
in_progress_at=int(time.time()) if status_str == "InProgress" else None,
expires_at=None,
finalizing_at=None,
completed_at=None,
failed_at=None,
expired_at=None,
cancelling_at=None,
cancelled_at=None,
request_counts=None,
metadata=original_request.get("metadata", {}),
)
def transform_retrieve_batch_request(
self,
batch_id: str,
optional_params: dict,
litellm_params: dict,
) -> Dict[str, Any]:
"""
Transform batch retrieval request for Bedrock.
Args:
batch_id: Bedrock job ARN
optional_params: Optional parameters
litellm_params: LiteLLM parameters
Returns:
Transformed request data for Bedrock GetModelInvocationJob API
"""
# For Bedrock, batch_id should be the full job ARN
# The GetModelInvocationJob API expects the full ARN as the identifier
if not batch_id.startswith("arn:aws:bedrock:"):
raise ValueError(f"Invalid batch_id format. Expected ARN, got: {batch_id}")
# Extract the job identifier from the ARN - use the full ARN path part
# ARN format: arn:aws:bedrock:region:account:model-invocation-job/job-name
arn_parts = batch_id.split(":")
if len(arn_parts) < 6:
raise ValueError(f"Invalid ARN format: {batch_id}")
region = arn_parts[3]
# arn_parts[5] contains "model-invocation-job/{jobId}"
# Build the endpoint URL for GetModelInvocationJob
# AWS API format: GET /model-invocation-job/{jobIdentifier}
# Use the FULL ARN as jobIdentifier and URL-encode it (includes ':' and '/')
import urllib.parse as _ul
encoded_arn = _ul.quote(batch_id, safe="")
endpoint_url = (
f"https://bedrock.{region}.amazonaws.com/model-invocation-job/{encoded_arn}"
)
# Use common utility for AWS signing
signed_headers, _ = self.common_utils.sign_aws_request(
service_name="bedrock",
data={}, # GET request has no body
endpoint_url=endpoint_url,
optional_params=optional_params,
method="GET",
)
# Return pre-signed request format
return {
"method": "GET",
"url": endpoint_url,
"headers": signed_headers,
"data": None,
}
def _parse_timestamps_and_status(self, response_data, status_str: str):
"""Helper to parse timestamps based on status."""
import datetime
def parse_timestamp(ts_str: Optional[str]) -> Optional[int]:
if not ts_str:
return None
try:
dt = datetime.datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
return int(dt.timestamp())
except Exception:
return None
created_at = parse_timestamp(
str(response_data.get("submitTime"))
if response_data.get("submitTime") is not None
else None
)
in_progress_states = {"InProgress", "Validating", "Scheduled"}
in_progress_at = (
parse_timestamp(
str(response_data.get("lastModifiedTime"))
if response_data.get("lastModifiedTime") is not None
else None
)
if status_str in in_progress_states
else None
)
completed_at = (
parse_timestamp(
str(response_data.get("endTime"))
if response_data.get("endTime") is not None
else None
)
if status_str in {"Completed", "PartiallyCompleted"}
else None
)
failed_at = (
parse_timestamp(
str(response_data.get("endTime"))
if response_data.get("endTime") is not None
else None
)
if status_str == "Failed"
else None
)
cancelled_at = (
parse_timestamp(
str(response_data.get("endTime"))
if response_data.get("endTime") is not None
else None
)
if status_str == "Stopped"
else None
)
expires_at = parse_timestamp(
str(response_data.get("jobExpirationTime"))
if response_data.get("jobExpirationTime") is not None
else None
)
return (
created_at,
in_progress_at,
completed_at,
failed_at,
cancelled_at,
expires_at,
)
def _extract_file_configs(self, response_data):
"""Helper to extract input and output file configurations."""
# Extract input file ID
input_file_id = ""
input_data_config = response_data.get("inputDataConfig", {})
if isinstance(input_data_config, dict):
s3_input_config = input_data_config.get("s3InputDataConfig", {})
if isinstance(s3_input_config, dict):
input_file_id = s3_input_config.get("s3Uri", "")
# Extract output file ID
output_file_id = None
output_data_config = response_data.get("outputDataConfig", {})
if isinstance(output_data_config, dict):
s3_output_config = output_data_config.get("s3OutputDataConfig", {})
if isinstance(s3_output_config, dict):
output_file_id = s3_output_config.get("s3Uri", "")
return input_file_id, output_file_id
def _extract_errors_and_metadata(self, response_data, raw_response):
"""Helper to extract errors and enriched metadata."""
# Extract errors
message = response_data.get("message")
errors = None
if message:
from openai.types.batch import Errors
from openai.types.batch_error import BatchError
errors = Errors(
data=[BatchError(message=message, code=str(raw_response.status_code))],
object="list",
)
# Enrich metadata with useful Bedrock fields
enriched_metadata_raw: Dict[str, Any] = {
"jobName": response_data.get("jobName"),
"clientRequestToken": response_data.get("clientRequestToken"),
"modelId": response_data.get("modelId"),
"roleArn": response_data.get("roleArn"),
"timeoutDurationInHours": response_data.get("timeoutDurationInHours"),
"vpcConfig": response_data.get("vpcConfig"),
}
import json as _json
enriched_metadata: Dict[str, str] = {}
for _k, _v in enriched_metadata_raw.items():
if _v is None:
continue
if isinstance(_v, (dict, list)):
try:
enriched_metadata[_k] = _json.dumps(_v)
except Exception:
enriched_metadata[_k] = str(_v)
else:
enriched_metadata[_k] = str(_v)
return errors, enriched_metadata
def transform_retrieve_batch_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: Any,
litellm_params: dict,
) -> LiteLLMBatch:
"""
Transform Bedrock batch retrieval response to LiteLLM format.
"""
from litellm.types.llms.bedrock import BedrockGetBatchResponse
try:
response_data: BedrockGetBatchResponse = raw_response.json()
except Exception as e:
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
job_arn = response_data.get("jobArn", "")
status_str: str = str(response_data.get("status", "Submitted"))
# Map Bedrock status to OpenAI-compatible status
status_mapping: Dict[str, str] = {
"Submitted": "validating",
"Validating": "validating",
"Scheduled": "in_progress",
"InProgress": "in_progress",
"PartiallyCompleted": "completed",
"Completed": "completed",
"Failed": "failed",
"Stopping": "cancelling",
"Stopped": "cancelled",
"Expired": "expired",
}
openai_status = cast(
Literal[
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
],
status_mapping.get(status_str, "validating"),
)
# Parse timestamps
(
created_at,
in_progress_at,
completed_at,
failed_at,
cancelled_at,
expires_at,
) = self._parse_timestamps_and_status(response_data, status_str)
# Extract file configurations
input_file_id, output_file_id = self._extract_file_configs(response_data)
# Extract errors and metadata
errors, enriched_metadata = self._extract_errors_and_metadata(
response_data, raw_response
)
return LiteLLMBatch(
id=job_arn,
object="batch",
endpoint="/v1/chat/completions",
errors=errors,
input_file_id=input_file_id,
completion_window="24h",
status=openai_status,
output_file_id=output_file_id,
error_file_id=None,
created_at=created_at or int(time.time()),
in_progress_at=in_progress_at,
expires_at=expires_at,
finalizing_at=None,
completed_at=completed_at,
failed_at=failed_at,
expired_at=None,
cancelling_at=None,
cancelled_at=cancelled_at,
request_counts=None,
metadata=enriched_metadata,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
"""
Get Bedrock-specific error class using common utility.
"""
return self.common_utils.get_error_class(error_message, status_code, headers)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from .converse_handler import BedrockConverseLLM
from .invoke_handler import (
AmazonAnthropicClaudeStreamDecoder,
AmazonDeepSeekR1StreamDecoder,
AWSEventStreamDecoder,
BedrockLLM,
)
def get_bedrock_event_stream_decoder(
invoke_provider: Optional[str], model: str, sync_stream: bool, json_mode: bool
):
if invoke_provider and invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=sync_stream,
json_mode=json_mode,
)
return decoder
elif invoke_provider and invoke_provider == "deepseek_r1":
decoder = AmazonDeepSeekR1StreamDecoder(
model=model,
sync_stream=sync_stream,
)
return decoder
else:
decoder = AWSEventStreamDecoder(model=model)
return decoder

View File

@@ -0,0 +1,3 @@
from .transformation import AmazonAgentCoreConfig
__all__ = ["AmazonAgentCoreConfig"]

View File

@@ -0,0 +1,512 @@
import json
from typing import Any, Optional, Union
import httpx
import litellm
from litellm.anthropic_beta_headers_manager import (
update_headers_with_filtered_beta,
)
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from ..base_aws_llm import BaseAWSLLM, Credentials
from ..common_utils import BedrockError, _get_all_bedrock_regions
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj: LiteLLMLoggingObject,
json_mode: Optional[bool] = False,
fake_stream: bool = False,
stream_chunk_size: int = 1024,
):
if client is None:
client = _get_httpx_client() # Create a new client if none provided
response = client.post(
api_base,
headers=headers,
data=data,
stream=not fake_stream,
logging_obj=logging_obj,
)
if response.status_code != 200:
raise BedrockError(
status_code=response.status_code, message=str(response.read())
)
if fake_stream:
model_response: (
ModelResponse
) = litellm.AmazonConverseConfig()._transform_response(
model=model,
response=response,
model_response=litellm.ModelResponse(),
stream=True,
logging_obj=logging_obj,
optional_params={},
api_key="",
data=data,
messages=messages,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
else:
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(
response.iter_bytes(chunk_size=stream_chunk_size)
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None:
super().__init__()
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params: dict,
credentials: Credentials,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
api_key: Optional[str] = None,
stream_chunk_size: int = 1024,
) -> CustomStreamWrapper:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": dict(prepped.headers),
},
)
completion_stream = await make_call(
client=client,
api_base=api_base,
headers=dict(prepped.headers),
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=fake_stream,
json_mode=json_mode,
stream_chunk_size=stream_chunk_size,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj: LiteLLMLoggingObject,
stream,
optional_params: dict,
litellm_params: dict,
credentials: Credentials,
logger_fn=None,
headers: dict = {},
client: Optional[AsyncHTTPHandler] = None,
api_key: Optional[str] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": prepped.headers,
},
)
headers = dict(prepped.headers)
if client is None or not isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = get_async_httpx_client(
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
)
else:
client = client # type: ignore
try:
response = await client.post(
url=api_base,
headers=headers,
data=data,
logging_obj=logging_obj,
) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return litellm.AmazonConverseConfig()._transform_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
def completion( # noqa: PLR0915
self,
model: str,
messages: list,
api_base: Optional[str],
custom_prompt_dict: dict,
model_response: ModelResponse,
encoding,
logging_obj: LiteLLMLoggingObject,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params: dict,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
api_key: Optional[str] = None,
):
## SETUP ##
stream = optional_params.pop("stream", None)
stream_chunk_size = optional_params.pop("stream_chunk_size", 1024)
unencoded_model_id = optional_params.pop("model_id", None)
fake_stream = optional_params.pop("fake_stream", False)
json_mode = optional_params.get("json_mode", False)
if unencoded_model_id is not None:
modelId = self.encode_model_id(model_id=unencoded_model_id)
else:
# Strip nova spec prefixes before encoding model ID for API URL
_model_for_id = model
_stripped = _model_for_id
for rp in ["bedrock/converse/", "bedrock/", "converse/"]:
if _stripped.startswith(rp):
_stripped = _stripped[len(rp) :]
break
# Strip embedded region prefix (e.g. "bedrock/us-east-1/model" -> "model")
# and capture it so it can be used as aws_region_name below.
_region_from_model: Optional[str] = None
_potential_region = _stripped.split("/", 1)[0]
if _potential_region in _get_all_bedrock_regions() and "/" in _stripped:
_region_from_model = _potential_region
_stripped = _stripped.split("/", 1)[1]
_model_for_id = _stripped
for _nova_prefix in ["nova-2/", "nova/"]:
if _stripped.startswith(_nova_prefix):
_model_for_id = _model_for_id.replace(_nova_prefix, "", 1)
break
modelId = self.encode_model_id(model_id=_model_for_id)
# Inject region extracted from model path so _get_aws_region_name picks it up
if (
_region_from_model is not None
and "aws_region_name" not in optional_params
):
optional_params["aws_region_name"] = _region_from_model
fake_stream = litellm.AmazonConverseConfig().should_fake_stream(
fake_stream=fake_stream,
model=model,
stream=stream,
custom_llm_provider="bedrock",
)
### SET REGION NAME ###
aws_region_name = self._get_aws_region_name(
optional_params=optional_params,
model=model,
model_id=unencoded_model_id,
)
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
aws_external_id = optional_params.pop("aws_external_id", None)
optional_params.pop("aws_region_name", None)
litellm_params[
"aws_region_name"
] = aws_region_name # [DO NOT DELETE] important for async calls
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
aws_external_id=aws_external_id,
)
### SET RUNTIME ENDPOINT ###
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
if (stream is not None and stream is True) and not fake_stream:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
# Filter beta headers in HTTP headers before making the request
headers = update_headers_with_filtered_beta(
headers=headers, provider="bedrock_converse"
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
json_mode=json_mode,
fake_stream=fake_stream,
credentials=credentials,
api_key=api_key,
stream_chunk_size=stream_chunk_size,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
credentials=credentials,
api_key=api_key,
) # type: ignore
## TRANSFORMATION ##
_data = litellm.AmazonConverseConfig()._transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=extra_headers,
)
data = json.dumps(_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=aws_region_name,
extra_headers=extra_headers,
endpoint_url=proxy_endpoint_url,
data=data,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_httpx_client(_params) # type: ignore
else:
client = client
if stream is not None and stream is True:
completion_stream = make_sync_call(
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
api_base=proxy_endpoint_url,
headers=prepped.headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
json_mode=json_mode,
fake_stream=fake_stream,
stream_chunk_size=stream_chunk_size,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
### COMPLETION
try:
response = client.post(
url=proxy_endpoint_url,
headers=prepped.headers,
data=data,
logging_obj=logging_obj,
) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return litellm.AmazonConverseConfig()._transform_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)

View File

@@ -0,0 +1,5 @@
"""
Uses base_llm_http_handler to call the 'converse like' endpoint.
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
"""

View File

@@ -0,0 +1,3 @@
"""
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
"""

View File

@@ -0,0 +1,547 @@
"""
Transformation for Bedrock Invoke Agent
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_InvokeAgent.html
"""
import base64
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.types.llms.bedrock_invoke_agents import (
InvokeAgentChunkPayload,
InvokeAgentEvent,
InvokeAgentEventHeaders,
InvokeAgentEventList,
InvokeAgentMetadata,
InvokeAgentModelInvocationInput,
InvokeAgentModelInvocationOutput,
InvokeAgentOrchestrationTrace,
InvokeAgentPreProcessingTrace,
InvokeAgentTrace,
InvokeAgentTracePayload,
InvokeAgentUsage,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonInvokeAgentConfig(BaseConfig, BaseAWSLLM):
def __init__(self, **kwargs):
BaseConfig.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
def get_supported_openai_params(self, model: str) -> List[str]:
"""
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
Bedrock Invoke Agents has 0 OpenAI compatible params
As of May 29th, 2025 - they don't support streaming.
"""
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
"""
return optional_params
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete url for the request
"""
### SET RUNTIME ENDPOINT ###
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
endpoint_url, _ = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=self._get_aws_region_name(
optional_params=optional_params, model=model
),
endpoint_type="agent",
)
agent_id, agent_alias_id = self._get_agent_id_and_alias_id(model)
session_id = self._get_session_id(optional_params)
endpoint_url = f"{endpoint_url}/agents/{agent_id}/agentAliases/{agent_alias_id}/sessions/{session_id}/text"
return endpoint_url
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
model=model,
stream=stream,
fake_stream=fake_stream,
api_key=api_key,
)
def _get_agent_id_and_alias_id(self, model: str) -> tuple[str, str]:
"""
model = "agent/L1RT58GYRW/MFPSBCXYTW"
agent_id = "L1RT58GYRW"
agent_alias_id = "MFPSBCXYTW"
"""
# Split the model string by '/' and extract components
parts = model.split("/")
if len(parts) != 3 or parts[0] != "agent":
raise ValueError(
"Invalid model format. Expected format: 'model=agent/AGENT_ID/ALIAS_ID'"
)
return parts[1], parts[2] # Return (agent_id, agent_alias_id)
def _get_session_id(self, optional_params: dict) -> str:
""" """
return optional_params.get("sessionID", None) or str(uuid.uuid4())
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
# use the last message content as the query
query: str = convert_content_list_to_str(messages[-1])
return {
"inputText": query,
"enableTrace": True,
**optional_params,
}
def _parse_aws_event_stream(self, raw_content: bytes) -> InvokeAgentEventList:
"""
Parse AWS event stream format using boto3/botocore's built-in parser.
This is the same approach used in the existing AWSEventStreamDecoder.
"""
try:
from botocore.eventstream import EventStreamBuffer
from botocore.parsers import EventStreamJSONParser
except ImportError:
raise ImportError("boto3/botocore is required for AWS event stream parsing")
events: InvokeAgentEventList = []
parser = EventStreamJSONParser()
event_stream_buffer = EventStreamBuffer()
# Add the entire response to the buffer
event_stream_buffer.add_data(raw_content)
# Process all events in the buffer
for event in event_stream_buffer:
try:
headers = self._extract_headers_from_event(event)
event_type = headers.get("event_type", "")
if event_type == "chunk":
# Handle chunk events specially - they contain decoded content, not JSON
message = self._parse_message_from_event(event, parser)
parsed_event: InvokeAgentEvent = InvokeAgentEvent()
if message:
# For chunk events, create a payload with the decoded content
parsed_event = {
"headers": headers,
"payload": {
"bytes": base64.b64encode(
message.encode("utf-8")
).decode("utf-8")
}, # Re-encode for consistency
}
events.append(parsed_event)
elif event_type == "trace":
# Handle trace events normally - they contain JSON
message = self._parse_message_from_event(event, parser)
if message:
try:
event_data = json.loads(message)
parsed_event = {
"headers": headers,
"payload": event_data,
}
events.append(parsed_event)
except json.JSONDecodeError as e:
verbose_logger.warning(
f"Failed to parse trace event JSON: {e}"
)
else:
verbose_logger.debug(f"Unknown event type: {event_type}")
except Exception as e:
verbose_logger.error(f"Error processing event: {e}")
continue
return events
def _parse_message_from_event(self, event, parser) -> Optional[str]:
"""Extract message content from an AWS event, adapted from AWSEventStreamDecoder."""
try:
response_dict = event.to_response_dict()
verbose_logger.debug(f"Response dict: {response_dict}")
# Use the same response shape parsing as the existing decoder
parsed_response = parser.parse(
response_dict, self._get_response_stream_shape()
)
verbose_logger.debug(f"Parsed response: {parsed_response}")
if response_dict["status_code"] != 200:
decoded_body = response_dict["body"].decode()
if isinstance(decoded_body, dict):
error_message = decoded_body.get("message")
elif isinstance(decoded_body, str):
error_message = decoded_body
else:
error_message = ""
exception_status = response_dict["headers"].get(":exception-type")
error_message = exception_status + " " + error_message
raise BedrockError(
status_code=response_dict["status_code"],
message=(
json.dumps(error_message)
if isinstance(error_message, dict)
else error_message
),
)
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode()
else:
chunk = response_dict.get("body")
if not chunk:
return None
return chunk.decode()
except Exception as e:
verbose_logger.debug(f"Error parsing message from event: {e}")
return None
def _extract_headers_from_event(self, event) -> InvokeAgentEventHeaders:
"""Extract headers from an AWS event for categorization."""
try:
response_dict = event.to_response_dict()
headers = response_dict.get("headers", {})
# Extract the event-type and content-type headers that we care about
return InvokeAgentEventHeaders(
event_type=headers.get(":event-type", ""),
content_type=headers.get(":content-type", ""),
message_type=headers.get(":message-type", ""),
)
except Exception as e:
verbose_logger.debug(f"Error extracting headers: {e}")
return InvokeAgentEventHeaders(
event_type="", content_type="", message_type=""
)
def _get_response_stream_shape(self):
"""Get the response stream shape for parsing, reusing existing logic."""
try:
# Try to reuse the cached shape from the existing decoder
from litellm.llms.bedrock.chat.invoke_handler import (
get_response_stream_shape,
)
return get_response_stream_shape()
except ImportError:
# Fallback: create our own shape
try:
from botocore.loaders import Loader
from botocore.model import ServiceModel
loader = Loader()
bedrock_service_dict = loader.load_service_model(
"bedrock-runtime", "service-2"
)
bedrock_service_model = ServiceModel(bedrock_service_dict)
return bedrock_service_model.shape_for("ResponseStream")
except Exception as e:
verbose_logger.warning(f"Could not load response stream shape: {e}")
return None
def _extract_response_content(self, events: InvokeAgentEventList) -> str:
"""Extract the final response content from parsed events."""
response_parts = []
for event in events:
headers = event.get("headers", {})
payload = event.get("payload")
event_type = headers.get(
"event_type"
) # Note: using event_type not event-type
if event_type == "chunk" and payload:
# Extract base64 encoded content from chunk events
chunk_payload: InvokeAgentChunkPayload = payload # type: ignore
encoded_bytes = chunk_payload.get("bytes", "")
if encoded_bytes:
try:
decoded_content = base64.b64decode(encoded_bytes).decode(
"utf-8"
)
response_parts.append(decoded_content)
except Exception as e:
verbose_logger.warning(f"Failed to decode chunk content: {e}")
return "".join(response_parts)
def _extract_usage_info(self, events: InvokeAgentEventList) -> InvokeAgentUsage:
"""Extract token usage information from trace events."""
usage_info = InvokeAgentUsage(
inputTokens=0,
outputTokens=0,
model=None,
)
response_model: Optional[str] = None
for event in events:
if not self._is_trace_event(event):
continue
trace_data = self._get_trace_data(event)
if not trace_data:
continue
verbose_logger.debug(f"Trace event: {trace_data}")
# Extract usage from pre-processing trace
self._extract_and_update_preprocessing_usage(
trace_data=trace_data,
usage_info=usage_info,
)
# Extract model from orchestration trace
if response_model is None:
response_model = self._extract_orchestration_model(trace_data)
usage_info["model"] = response_model
return usage_info
def _is_trace_event(self, event: InvokeAgentEvent) -> bool:
"""Check if the event is a trace event."""
headers = event.get("headers", {})
event_type = headers.get("event_type")
payload = event.get("payload")
return event_type == "trace" and payload is not None
def _get_trace_data(self, event: InvokeAgentEvent) -> Optional[InvokeAgentTrace]:
"""Extract trace data from a trace event."""
payload = event.get("payload")
if not payload:
return None
trace_payload: InvokeAgentTracePayload = payload # type: ignore
return trace_payload.get("trace", {})
def _extract_and_update_preprocessing_usage(
self, trace_data: InvokeAgentTrace, usage_info: InvokeAgentUsage
) -> None:
"""Extract usage information from preprocessing trace."""
pre_processing: Optional[InvokeAgentPreProcessingTrace] = trace_data.get(
"preProcessingTrace"
)
if not pre_processing:
return
model_output: Optional[InvokeAgentModelInvocationOutput] = (
pre_processing.get("modelInvocationOutput")
or InvokeAgentModelInvocationOutput()
)
if not model_output:
return
metadata: Optional[InvokeAgentMetadata] = (
model_output.get("metadata") or InvokeAgentMetadata()
)
if not metadata:
return
usage: Optional[Union[InvokeAgentUsage, Dict]] = metadata.get("usage", {})
if not usage:
return
usage_info["inputTokens"] += usage.get("inputTokens", 0)
usage_info["outputTokens"] += usage.get("outputTokens", 0)
def _extract_orchestration_model(
self, trace_data: InvokeAgentTrace
) -> Optional[str]:
"""Extract model information from orchestration trace."""
orchestration_trace: Optional[InvokeAgentOrchestrationTrace] = trace_data.get(
"orchestrationTrace"
)
if not orchestration_trace:
return None
model_invocation: Optional[InvokeAgentModelInvocationInput] = (
orchestration_trace.get("modelInvocationInput")
or InvokeAgentModelInvocationInput()
)
if not model_invocation:
return None
return model_invocation.get("foundationModel")
def _build_model_response(
self,
content: str,
model: str,
usage_info: InvokeAgentUsage,
model_response: ModelResponse,
) -> ModelResponse:
"""Build the final ModelResponse object."""
# Create the message content
message = Message(content=content, role="assistant")
# Create choices
choice = Choices(finish_reason="stop", index=0, message=message)
# Update model response
model_response.choices = [choice]
model_response.model = usage_info.get("model", model)
# Add usage information if available
if usage_info:
from litellm.types.utils import Usage
usage = Usage(
prompt_tokens=usage_info.get("inputTokens", 0),
completion_tokens=usage_info.get("outputTokens", 0),
total_tokens=usage_info.get("inputTokens", 0)
+ usage_info.get("outputTokens", 0),
)
setattr(model_response, "usage", usage)
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
# Get the raw binary content
raw_content = raw_response.content
verbose_logger.debug(
f"Processing {len(raw_content)} bytes of AWS event stream data"
)
# Parse the AWS event stream format
events = self._parse_aws_event_stream(raw_content)
verbose_logger.debug(f"Parsed {len(events)} events from stream")
# Extract response content from chunk events
content = self._extract_response_content(events)
# Extract usage information from trace events
usage_info = self._extract_usage_info(events)
# Build and return the model response
return self._build_model_response(
content=content,
model=model,
usage_info=usage_info,
model_response=model_response,
)
except Exception as e:
verbose_logger.error(
f"Error processing Bedrock Invoke Agent response: {str(e)}"
)
raise BedrockError(
message=f"Error processing response: {str(e)}",
status_code=raw_response.status_code,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return BedrockError(status_code=status_code, message=error_message)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
return True

View File

@@ -0,0 +1,99 @@
import types
from typing import List, Optional
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
Supported Params for the Amazon / AI21 models:
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object.
"""
maxTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
maxTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List:
return [
"max_tokens",
"temperature",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["maxTokens"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["topP"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params

View File

@@ -0,0 +1,75 @@
import types
from typing import List, Optional
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.cohere.chat.transformation import CohereChatConfig
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
Supported Params for the Amazon / Cohere models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) model temperature,
- `return_likelihood` (string) n/a
"""
max_tokens: Optional[int] = None
return_likelihood: Optional[str] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
supported_params = CohereChatConfig.get_supported_openai_params(
self, model=model
)
return supported_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return CohereChatConfig.map_openai_params(
self,
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)

View File

@@ -0,0 +1,135 @@
from typing import Any, List, Optional, cast
from httpx import Response
from litellm import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
_parse_content_for_reasoning,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
LiteLLMLoggingObj,
)
from litellm.types.llms.bedrock import AmazonDeepSeekR1StreamingResponse
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionUsageBlock,
Choices,
Delta,
Message,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
from .amazon_llama_transformation import AmazonLlamaConfig
class AmazonDeepSeekR1Config(AmazonLlamaConfig):
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Extract the reasoning content, and return it as a separate field in the response.
"""
response = super().transform_response(
model,
raw_response,
model_response,
logging_obj,
request_data,
messages,
optional_params,
litellm_params,
encoding,
api_key,
json_mode,
)
prompt = cast(Optional[str], request_data.get("prompt"))
message_content = cast(
Optional[str], cast(Choices, response.choices[0]).message.get("content")
)
if prompt and prompt.strip().endswith("<think>") and message_content:
message_content_with_reasoning_token = "<think>" + message_content
reasoning, content = _parse_content_for_reasoning(
message_content_with_reasoning_token
)
provider_specific_fields = (
cast(Choices, response.choices[0]).message.provider_specific_fields
or {}
)
if reasoning:
provider_specific_fields["reasoning_content"] = reasoning
message = Message(
**{
**cast(Choices, response.choices[0]).message.model_dump(),
"content": content,
"provider_specific_fields": provider_specific_fields,
}
)
cast(Choices, response.choices[0]).message = message
return response
class AmazonDeepseekR1ResponseIterator(BaseModelResponseIterator):
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
self.has_finished_thinking = False
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
"""
Deepseek r1 starts by thinking, then it generates the response.
"""
try:
typed_chunk = AmazonDeepSeekR1StreamingResponse(**chunk) # type: ignore
generated_content = typed_chunk["generation"]
if generated_content == "</think>" and not self.has_finished_thinking:
verbose_logger.debug(
"Deepseek r1: </think> received, setting has_finished_thinking to True"
)
generated_content = ""
self.has_finished_thinking = True
prompt_token_count = typed_chunk.get("prompt_token_count") or 0
generation_token_count = typed_chunk.get("generation_token_count") or 0
usage = ChatCompletionUsageBlock(
prompt_tokens=prompt_token_count,
completion_tokens=generation_token_count,
total_tokens=prompt_token_count + generation_token_count,
)
return ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=typed_chunk["stop_reason"],
delta=Delta(
content=(
generated_content
if self.has_finished_thinking
else None
),
reasoning_content=(
generated_content
if not self.has_finished_thinking
else None
),
),
)
],
usage=usage,
)
except Exception as e:
raise e

View File

@@ -0,0 +1,80 @@
import types
from typing import List, Optional
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
Supported Params for the Amazon / Meta Llama models:
- `max_gen_len` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
"""
max_gen_len: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List:
return [
"max_tokens",
"temperature",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_gen_len"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["top_p"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params

View File

@@ -0,0 +1,119 @@
import types
from typing import List, Optional, TYPE_CHECKING
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import BedrockError
if TYPE_CHECKING:
from litellm.types.utils import ModelResponse
class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
Supported Params for the Amazon / Mistral models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
- `top_k` (float) top k for model
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[float] = None
stop: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[int] = None,
top_k: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_tokens"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["top_p"] = v
if k == "stop":
optional_params["stop"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params
@staticmethod
def get_outputText(
completion_response: dict, model_response: "ModelResponse"
) -> str:
"""This function extracts the output text from a bedrock mistral completion.
As a side effect, it updates the finish reason for a model response.
Args:
completion_response: JSON from the completion.
model_response: ModelResponse
Returns:
A string with the response of the LLM
"""
if "choices" in completion_response:
outputText = completion_response["choices"][0]["message"]["content"]
model_response.choices[0].finish_reason = completion_response["choices"][0][
"finish_reason"
]
elif "outputs" in completion_response:
outputText = completion_response["outputs"][0]["text"]
model_response.choices[0].finish_reason = completion_response["outputs"][0][
"stop_reason"
]
else:
raise BedrockError(
message="Unexpected mistral completion response", status_code=400
)
return outputText

View File

@@ -0,0 +1,266 @@
"""
Transformation for Bedrock Moonshot AI (Kimi K2) models.
Supports the Kimi K2 Thinking model available on Amazon Bedrock.
Model format: bedrock/moonshot.kimi-k2-thinking-v1:0
Reference: https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union
import re
import httpx
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.llms.moonshot.chat.transformation import MoonshotChatConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.types.utils import ModelResponse
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonMoonshotConfig(AmazonInvokeConfig, MoonshotChatConfig):
"""
Configuration for Bedrock Moonshot AI (Kimi K2) models.
Reference:
https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
https://platform.moonshot.ai/docs/api/chat
Supported Params for the Amazon / Moonshot models:
- `max_tokens` (integer) max tokens
- `temperature` (float) temperature for model (0-1 for Moonshot)
- `top_p` (float) top p for model
- `stream` (bool) whether to stream responses
- `tools` (list) tool definitions (supported on kimi-k2-thinking)
- `tool_choice` (str|dict) tool choice specification (supported on kimi-k2-thinking)
NOT Supported on Bedrock:
- `stop` sequences (Bedrock doesn't support stopSequences field for this model)
Note: The kimi-k2-thinking model DOES support tool calls, unlike kimi-thinking-preview.
"""
def __init__(self, **kwargs):
AmazonInvokeConfig.__init__(self, **kwargs)
MoonshotChatConfig.__init__(self, **kwargs)
@property
def custom_llm_provider(self) -> Optional[str]:
return "bedrock"
def _get_model_id(self, model: str) -> str:
"""
Extract the actual model ID from the LiteLLM model name.
Removes routing prefixes like:
- bedrock/invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
- invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
- moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
"""
# Remove bedrock/ prefix if present
if model.startswith("bedrock/"):
model = model[8:]
# Remove invoke/ prefix if present
if model.startswith("invoke/"):
model = model[7:]
# Remove any provider prefix (e.g., moonshot/)
if "/" in model and not model.startswith("arn:"):
parts = model.split("/", 1)
if len(parts) == 2:
model = parts[1]
return model
def get_supported_openai_params(self, model: str) -> List[str]:
"""
Get the supported OpenAI params for Moonshot AI models on Bedrock.
Bedrock-specific limitations:
- stopSequences field is not supported on Bedrock (unlike native Moonshot API)
- functions parameter is not supported (use tools instead)
- tool_choice doesn't support "required" value
Note: kimi-k2-thinking DOES support tool calls (unlike kimi-thinking-preview)
The parent MoonshotChatConfig class handles the kimi-thinking-preview exclusion.
"""
excluded_params: List[str] = [
"functions",
"stop",
] # Bedrock doesn't support stopSequences
base_openai_params = super(
MoonshotChatConfig, self
).get_supported_openai_params(model=model)
final_params: List[str] = []
for param in base_openai_params:
if param not in excluded_params:
final_params.append(param)
return final_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to Moonshot AI parameters for Bedrock.
Handles Moonshot AI specific limitations:
- tool_choice doesn't support "required" value
- Temperature <0.3 limitation for n>1
- Temperature range is [0, 1] (not [0, 2] like OpenAI)
"""
return MoonshotChatConfig.map_openai_params(
self,
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request for Bedrock Moonshot AI models.
Uses the Moonshot transformation logic which handles:
- Converting content lists to strings (Moonshot doesn't support list format)
- Adding tool_choice="required" message if needed
- Temperature and parameter validation
"""
# Filter out AWS credentials using the existing method from BaseAWSLLM
self._get_boto_credentials_from_optional_params(optional_params, model)
# Strip routing prefixes to get the actual model ID
clean_model_id = self._get_model_id(model)
# Use Moonshot's transform_request which handles message transformation
# and tool_choice="required" workaround
return MoonshotChatConfig.transform_request(
self,
model=clean_model_id,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
def _extract_reasoning_from_content(
self, content: str
) -> tuple[Optional[str], str]:
"""
Extract reasoning content from <reasoning> tags in the response.
Moonshot AI's Kimi K2 Thinking model returns reasoning in <reasoning> tags.
This method extracts that content and returns it separately.
Args:
content: The full content string from the API response
Returns:
tuple: (reasoning_content, main_content)
"""
if not content:
return None, content
# Match <reasoning>...</reasoning> tags
reasoning_match = re.match(
r"<reasoning>(.*?)</reasoning>\s*(.*)", content, re.DOTALL
)
if reasoning_match:
reasoning_content = reasoning_match.group(1).strip()
main_content = reasoning_match.group(2).strip()
return reasoning_content, main_content
return None, content
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: "ModelResponse",
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> "ModelResponse":
"""
Transform the response from Bedrock Moonshot AI models.
Moonshot AI uses OpenAI-compatible response format, but returns reasoning
content in <reasoning> tags. This method:
1. Calls parent class transformation
2. Extracts reasoning content from <reasoning> tags
3. Sets reasoning_content on the message object
"""
# First, get the standard transformation
model_response = MoonshotChatConfig.transform_response(
self,
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
# Extract reasoning content from <reasoning> tags
if model_response.choices and len(model_response.choices) > 0:
for choice in model_response.choices:
# Only process Choices (not StreamingChoices) which have message attribute
if (
isinstance(choice, Choices)
and choice.message
and choice.message.content
):
(
reasoning_content,
main_content,
) = self._extract_reasoning_from_content(choice.message.content)
if reasoning_content:
# Set the reasoning_content field
choice.message.reasoning_content = reasoning_content
# Update the main content without reasoning tags
choice.message.content = main_content
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BedrockError:
"""Return the appropriate error class for Bedrock."""
return BedrockError(status_code=status_code, message=error_message)

View File

@@ -0,0 +1,120 @@
"""
Handles transforming requests for `bedrock/invoke/{nova} models`
Inherits from `AmazonConverseConfig`
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
"""
from typing import Any, List, Optional
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..converse_transformation import AmazonConverseConfig
from .base_invoke_transformation import AmazonInvokeConfig
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
"""
Config for sending `nova` requests to `/bedrock/invoke/`
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_supported_openai_params(self, model: str) -> list:
return AmazonConverseConfig.get_supported_openai_params(self, model)
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return AmazonConverseConfig.map_openai_params(
self, non_default_params, optional_params, model, drop_params
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_transformed_nova_request = AmazonConverseConfig.transform_request(
self,
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
**_transformed_nova_request
)
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
bedrock_invoke_nova_request = self._filter_allowed_fields(
_bedrock_invoke_nova_request
)
return bedrock_invoke_nova_request
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: Logging,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return AmazonConverseConfig.transform_response(
self,
model,
raw_response,
model_response,
logging_obj,
request_data,
messages,
optional_params,
litellm_params,
encoding,
api_key,
json_mode,
)
def _filter_allowed_fields(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> dict:
"""
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
"""
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
return {
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
}
def _remove_empty_system_messages(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> None:
"""
In-place remove empty `system` messages from the request.
/bedrock/invoke/ does not allow empty `system` messages.
"""
_system_message = bedrock_invoke_nova_request.get("system", None)
if isinstance(_system_message, list) and len(_system_message) == 0:
bedrock_invoke_nova_request.pop("system", None)
return

View File

@@ -0,0 +1,192 @@
"""
Transformation for Bedrock imported models that use OpenAI Chat Completions format.
Use this for models imported into Bedrock that accept the OpenAI API format.
Model format: bedrock/openai/<model-id>
Example: bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123
"""
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import httpx
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.passthrough.utils import CommonUtils
from litellm.types.llms.openai import AllMessageValues
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonBedrockOpenAIConfig(OpenAIGPTConfig, BaseAWSLLM):
"""
Configuration for Bedrock imported models that use OpenAI Chat Completions format.
This class handles the transformation of requests and responses for Bedrock
imported models that accept the OpenAI API format directly.
Inherits from OpenAIGPTConfig to leverage standard OpenAI parameter handling
and response transformation, while adding Bedrock-specific URL generation
and AWS request signing.
Usage:
model = "bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123"
"""
def __init__(self, **kwargs):
OpenAIGPTConfig.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
@property
def custom_llm_provider(self) -> Optional[str]:
return "bedrock"
def _get_openai_model_id(self, model: str) -> str:
"""
Extract the actual model ID from the LiteLLM model name.
Input format: bedrock/openai/<model-id>
Returns: <model-id>
"""
# Remove bedrock/ prefix if present
if model.startswith("bedrock/"):
model = model[8:]
# Remove openai/ prefix
if model.startswith("openai/"):
model = model[7:]
return model
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the Bedrock invoke endpoint.
Uses the standard Bedrock invoke endpoint format.
"""
model_id = self._get_openai_model_id(model)
# Get AWS region
aws_region_name = self._get_aws_region_name(
optional_params=optional_params, model=model
)
# Get runtime endpoint
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None
)
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
# Encode model ID for ARNs (e.g., :imported-model/ -> :imported-model%2F)
model_id = CommonUtils.encode_bedrock_runtime_modelid_arn(model_id)
# Build the invoke URL
if stream:
endpoint_url = (
f"{endpoint_url}/model/{model_id}/invoke-with-response-stream"
)
else:
endpoint_url = f"{endpoint_url}/model/{model_id}/invoke"
return endpoint_url
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
"""
Sign the request using AWS Signature Version 4.
"""
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
api_key=api_key,
model=model,
stream=stream,
fake_stream=fake_stream,
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request to OpenAI Chat Completions format for Bedrock imported models.
Removes AWS-specific params and stream param (handled separately in URL),
then delegates to parent class for standard OpenAI request transformation.
"""
# Remove stream from optional_params as it's handled separately in URL
optional_params.pop("stream", None)
# Remove AWS-specific params that shouldn't be in the request body
inference_params = {
k: v
for k, v in optional_params.items()
if k not in self.aws_authentication_params
}
# Use parent class transform_request for OpenAI format
return super().transform_request(
model=self._get_openai_model_id(model),
messages=messages,
optional_params=inference_params,
litellm_params=litellm_params,
headers=headers,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate the environment and return headers.
For Bedrock, we don't need Bearer token auth since we use AWS SigV4.
"""
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BedrockError:
"""Return the appropriate error class for Bedrock."""
return BedrockError(status_code=status_code, message=error_message)

View File

@@ -0,0 +1,99 @@
"""
Handles transforming requests for `bedrock/invoke/{qwen2} models`
Inherits from `AmazonQwen3Config` since Qwen2 and Qwen3 architectures are mostly similar.
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
Qwen2 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
from typing import Any, List, Optional
import httpx
from litellm.llms.bedrock.chat.invoke_transformations.amazon_qwen3_transformation import (
AmazonQwen3Config,
)
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
class AmazonQwen2Config(AmazonQwen3Config):
"""
Config for sending `qwen2` requests to `/bedrock/invoke/`
Inherits from AmazonQwen3Config since Qwen2 and Qwen3 architectures are mostly similar.
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform Qwen2 Bedrock response to OpenAI format
Qwen2 uses "text" field, but we also support "generation" field for compatibility.
"""
try:
if hasattr(raw_response, "json"):
response_data = raw_response.json()
else:
response_data = raw_response
# Extract the generated text - Qwen2 uses "text" field, but also support "generation" for compatibility
generated_text = response_data.get("generation", "") or response_data.get(
"text", ""
)
# Clean up the response (remove assistant start token if present)
if generated_text.startswith("<|im_start|>assistant\n"):
generated_text = generated_text[len("<|im_start|>assistant\n") :]
if generated_text.endswith("<|im_end|>"):
generated_text = generated_text[: -len("<|im_end|>")]
# Set the content in the existing model_response structure
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
choice = model_response.choices[0]
choice.message.content = generated_text
choice.finish_reason = "stop"
# Set usage information if available in response
if "usage" in response_data:
usage_data = response_data["usage"]
setattr(
model_response,
"usage",
Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
),
)
return model_response
except Exception as e:
if logging_obj:
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response,
additional_args={"error": str(e)},
)
raise e

View File

@@ -0,0 +1,225 @@
"""
Handles transforming requests for `bedrock/invoke/{qwen3} models`
Inherits from `AmazonInvokeConfig`
Qwen3 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
from typing import Any, List, Optional
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
class AmazonQwen3Config(AmazonInvokeConfig, BaseConfig):
"""
Config for sending `qwen3` requests to `/bedrock/invoke/`
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"max_tokens",
"temperature",
"top_p",
"top_k",
"stop",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_tokens"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["top_p"] = v
if k == "top_k":
optional_params["top_k"] = v
if k == "stop":
optional_params["stop"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform OpenAI format to Qwen3 Bedrock invoke format
"""
# Convert messages to prompt format
prompt = self._convert_messages_to_prompt(messages)
# Build the request body
request_body = {
"prompt": prompt,
}
# Add optional parameters
if "max_tokens" in optional_params:
request_body["max_gen_len"] = optional_params["max_tokens"]
if "temperature" in optional_params:
request_body["temperature"] = optional_params["temperature"]
if "top_p" in optional_params:
request_body["top_p"] = optional_params["top_p"]
if "top_k" in optional_params:
request_body["top_k"] = optional_params["top_k"]
if "stop" in optional_params:
request_body["stop"] = optional_params["stop"]
return request_body
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
"""
Convert OpenAI messages format to Qwen3 prompt format
Supports tool calls, multimodal content, and various message types
"""
prompt_parts = []
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
tool_calls = message.get("tool_calls", [])
if role == "system":
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
elif role == "user":
# Handle multimodal content
if isinstance(content, list):
text_content = []
for item in content:
if item.get("type") == "text":
text_content.append(item.get("text", ""))
elif item.get("type") == "image_url":
# For Qwen3, we can include image placeholders
text_content.append(
"<|vision_start|><|image_pad|><|vision_end|>"
)
content = "".join(text_content)
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
if tool_calls and isinstance(tool_calls, list):
# Handle tool calls
for tool_call in tool_calls:
function_name = tool_call.get("function", {}).get("name", "")
function_args = tool_call.get("function", {}).get(
"arguments", ""
)
prompt_parts.append(
f'<|im_start|>assistant\n<tool_call>\n{{"name": "{function_name}", "arguments": "{function_args}"}}\n</tool_call><|im_end|>'
)
else:
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
elif role == "tool":
# Handle tool responses
prompt_parts.append(f"<|im_start|>tool\n{content}<|im_end|>")
# Add assistant start token for response generation
prompt_parts.append("<|im_start|>assistant\n")
return "\n".join(prompt_parts)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform Qwen3 Bedrock response to OpenAI format
"""
try:
if hasattr(raw_response, "json"):
response_data = raw_response.json()
else:
response_data = raw_response
# Extract the generated text - Qwen3 uses "generation" field
generated_text = response_data.get("generation", "")
# Clean up the response (remove assistant start token if present)
if generated_text.startswith("<|im_start|>assistant\n"):
generated_text = generated_text[len("<|im_start|>assistant\n") :]
if generated_text.endswith("<|im_end|>"):
generated_text = generated_text[: -len("<|im_end|>")]
# Set the content in the existing model_response structure
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
choice = model_response.choices[0]
choice.message.content = generated_text
choice.finish_reason = "stop"
# Set usage information if available in response
if "usage" in response_data:
usage_data = response_data["usage"]
setattr(
model_response,
"usage",
Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
),
)
return model_response
except Exception as e:
if logging_obj:
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response,
additional_args={"error": str(e)},
)
raise e

View File

@@ -0,0 +1,116 @@
import re
import types
from typing import List, Optional, Union
import litellm
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
Supported Params for the Amazon Titan models:
- `maxTokenCount` (integer) max tokens,
- `stopSequences` (string[]) list of stop sequence strings
- `temperature` (float) temperature for model,
- `topP` (int) top p for model
"""
maxTokenCount: Optional[int] = None
stopSequences: Optional[list] = None
temperature: Optional[float] = None
topP: Optional[int] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
stopSequences: Optional[list] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _map_and_modify_arg(
self,
supported_params: dict,
provider: str,
model: str,
stop: Union[List[str], str],
):
"""
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
"""
filtered_stop = None
if "stop" in supported_params and litellm.drop_params:
if provider == "bedrock" and "amazon" in model:
filtered_stop = []
if isinstance(stop, list):
for s in stop:
if re.match(r"^(\|+|User:)$", s):
filtered_stop.append(s)
if filtered_stop is not None:
supported_params["stop"] = filtered_stop
return supported_params
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"max_tokens",
"max_completion_tokens",
"stop",
"temperature",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens" or k == "max_completion_tokens":
optional_params["maxTokenCount"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "stop":
filtered_stop = self._map_and_modify_arg(
{"stop": v}, provider="bedrock", model=model, stop=v
)
optional_params["stopSequences"] = filtered_stop["stop"]
if k == "top_p":
optional_params["topP"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params

View File

@@ -0,0 +1,280 @@
"""
Transforms OpenAI-style requests into TwelveLabs Pegasus 1.2 requests for Bedrock.
Reference:
https://docs.twelvelabs.io/docs/models/pegasus
"""
import json
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.base_utils import type_to_response_format_param
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import get_base64_str
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonTwelveLabsPegasusConfig(AmazonInvokeConfig, BaseConfig):
"""
Handles transforming OpenAI-style requests into Bedrock InvokeModel requests for
`twelvelabs.pegasus-1-2-v1:0`.
Pegasus 1.2 requires an `inputPrompt` and a `mediaSource` that either references
an S3 object or a base64-encoded clip. Optional OpenAI params (temperature,
response_format, max_tokens) are translated to the TwelveLabs schema.
"""
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"max_tokens",
"max_completion_tokens",
"temperature",
"response_format",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param in {"max_tokens", "max_completion_tokens"}:
optional_params["maxOutputTokens"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "response_format":
optional_params["responseFormat"] = self._normalize_response_format(
value
)
return optional_params
def _normalize_response_format(self, value: Any) -> Any:
"""Normalize response_format to TwelveLabs format.
TwelveLabs expects:
{
"jsonSchema": {...}
}
But OpenAI format is:
{
"type": "json_schema",
"json_schema": {
"name": "...",
"schema": {...}
}
}
"""
if isinstance(value, dict):
# If it has json_schema field, extract and transform it
if "json_schema" in value:
json_schema = value["json_schema"]
# Extract the schema if nested
if isinstance(json_schema, dict) and "schema" in json_schema:
return {"jsonSchema": json_schema["schema"]}
# Otherwise use json_schema directly
return {"jsonSchema": json_schema}
# If it already has jsonSchema, return as is
if "jsonSchema" in value:
return value
# Otherwise return the dict as is
return value
return type_to_response_format_param(response_format=value) or value
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
input_prompt = self._convert_messages_to_prompt(messages=messages)
request_data: Dict[str, Any] = {"inputPrompt": input_prompt}
media_source = self._build_media_source(optional_params)
if media_source is not None:
request_data["mediaSource"] = media_source
# Handle temperature and maxOutputTokens
for key in ("temperature", "maxOutputTokens"):
if key in optional_params:
request_data[key] = optional_params.get(key)
# Handle responseFormat - transform to TwelveLabs format
if "responseFormat" in optional_params:
response_format = optional_params["responseFormat"]
transformed_format = self._normalize_response_format(response_format)
if transformed_format:
request_data["responseFormat"] = transformed_format
return request_data
def _build_media_source(self, optional_params: dict) -> Optional[dict]:
direct_source = optional_params.get("mediaSource") or optional_params.get(
"media_source"
)
if isinstance(direct_source, dict):
return direct_source
base64_input = optional_params.get("video_base64") or optional_params.get(
"base64_string"
)
if base64_input:
return {"base64String": get_base64_str(base64_input)}
s3_uri = (
optional_params.get("video_s3_uri")
or optional_params.get("s3_uri")
or optional_params.get("media_source_s3_uri")
)
if s3_uri:
s3_location = {"uri": s3_uri}
bucket_owner = (
optional_params.get("video_s3_bucket_owner")
or optional_params.get("s3_bucket_owner")
or optional_params.get("media_source_bucket_owner")
)
if bucket_owner:
s3_location["bucketOwner"] = bucket_owner
return {"s3Location": s3_location}
return None
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
prompt_parts: List[str] = []
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
if isinstance(content, list):
text_fragments = []
for item in content:
if isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
text_fragments.append(item.get("text", ""))
elif item_type == "image_url":
text_fragments.append("<image>")
elif item_type == "video_url":
text_fragments.append("<video>")
elif item_type == "audio_url":
text_fragments.append("<audio>")
elif isinstance(item, str):
text_fragments.append(item)
content = " ".join(text_fragments)
prompt_parts.append(f"{role}: {content}")
return "\n".join(part for part in prompt_parts if part).strip()
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform TwelveLabs Pegasus response to LiteLLM format.
TwelveLabs response format:
{
"message": "...",
"finishReason": "stop" | "length"
}
LiteLLM format:
ModelResponse with choices[0].message.content and finish_reason
"""
try:
completion_response = raw_response.json()
except Exception as e:
raise BedrockError(
message=f"Error parsing response: {raw_response.text}, error: {str(e)}",
status_code=raw_response.status_code,
)
verbose_logger.debug(
"twelvelabs pegasus response: %s",
json.dumps(completion_response, indent=4, default=str),
)
# Extract message content
message_content = completion_response.get("message", "")
# Extract finish reason and map to LiteLLM format
finish_reason_raw = completion_response.get("finishReason", "stop")
finish_reason = map_finish_reason(finish_reason_raw)
# Set the response content
try:
if (
message_content
and hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is None
):
model_response.choices[0].message.content = message_content # type: ignore
model_response.choices[0].finish_reason = finish_reason
else:
raise Exception("Unable to set message content")
except Exception as e:
raise BedrockError(
message=f"Error setting response content: {str(e)}. Response: {completion_response}",
status_code=raw_response.status_code,
)
# Calculate usage from headers
bedrock_input_tokens = raw_response.headers.get(
"x-amzn-bedrock-input-token-count", None
)
bedrock_output_tokens = raw_response.headers.get(
"x-amzn-bedrock-output-token-count", None
)
prompt_tokens = int(
bedrock_input_tokens or litellm.token_counter(messages=messages)
)
completion_tokens = int(
bedrock_output_tokens
or litellm.token_counter(
text=model_response.choices[0].message.content, # type: ignore
count_response_tokens=True,
)
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response

View File

@@ -0,0 +1,98 @@
import types
from typing import Optional
import litellm
from .base_invoke_transformation import AmazonInvokeConfig
class AmazonAnthropicConfig(AmazonInvokeConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
Supported Params for the Amazon / Anthropic models:
- `max_tokens_to_sample` (integer) max tokens,
- `temperature` (float) model temperature,
- `top_k` (integer) top k,
- `top_p` (integer) top p,
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens_to_sample: Optional[int] = litellm.max_tokens
stop_sequences: Optional[list] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
anthropic_version: Optional[str] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = None,
stop_sequences: Optional[list] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@staticmethod
def get_legacy_anthropic_model_names():
return [
"anthropic.claude-v2",
"anthropic.claude-instant-v1",
"anthropic.claude-v2:1",
]
def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
"temperature",
"stop",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
return optional_params

View File

@@ -0,0 +1,206 @@
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import (
get_anthropic_beta_from_headers,
remove_custom_field_from_tools,
)
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaudeConfig(AmazonInvokeConfig, AnthropicConfig):
"""
Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html
Supported Params for the Amazon / Anthropic Claude models (Claude 3, Claude 4, etc.):
Supports anthropic_beta parameter for beta features like:
- computer-use-2025-01-24 (Claude 3.7 Sonnet)
- computer-use-2024-10-22 (Claude 3.5 Sonnet v2)
- token-efficient-tools-2025-02-19 (Claude 3.7 Sonnet)
- interleaved-thinking-2025-05-14 (Claude 4 models)
- output-128k-2025-02-19 (Claude 3.7 Sonnet)
- dev-full-thinking-2025-05-14 (Claude 4 models)
- context-1m-2025-08-07 (Claude Sonnet 4)
"""
anthropic_version: str = "bedrock-2023-05-31"
@property
def custom_llm_provider(self) -> Optional[str]:
return "bedrock"
def get_supported_openai_params(self, model: str) -> List[str]:
return AnthropicConfig.get_supported_openai_params(self, model)
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
# Force tool-based structured outputs for Bedrock Invoke
# (similar to VertexAI fix in #19201)
# Bedrock Invoke doesn't support output_format parameter
original_model = model
if "response_format" in non_default_params:
# Use a model name that forces tool-based approach
model = "claude-3-sonnet-20240229"
optional_params = AnthropicConfig.map_openai_params(
self,
non_default_params,
optional_params,
model,
drop_params,
)
# Restore original model name
model = original_model
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
# Filter out AWS authentication parameters before passing to Anthropic transformation
# AWS params should only be used for signing requests, not included in request body
filtered_params = {
k: v
for k, v in optional_params.items()
if k not in self.aws_authentication_params
}
filtered_params = self._normalize_bedrock_tool_search_tools(filtered_params)
_anthropic_request = AnthropicConfig.transform_request(
self,
model=model,
messages=messages,
optional_params=filtered_params,
litellm_params=litellm_params,
headers=headers,
)
_anthropic_request.pop("model", None)
_anthropic_request.pop("stream", None)
# Bedrock Invoke doesn't support output_format parameter
_anthropic_request.pop("output_format", None)
# Bedrock Invoke doesn't support output_config parameter
# Fixes: https://github.com/BerriAI/litellm/issues/22797
_anthropic_request.pop("output_config", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version
# Remove `custom` field from tools (Bedrock doesn't support it)
# Claude Code sends `custom: {defer_loading: true}` on tool definitions,
# which causes Bedrock to reject the request with "Extra inputs are not permitted"
# Ref: https://github.com/BerriAI/litellm/issues/22847
remove_custom_field_from_tools(_anthropic_request)
tools = optional_params.get("tools")
tool_search_used = self.is_tool_search_used(tools)
programmatic_tool_calling_used = self.is_programmatic_tool_calling_used(tools)
input_examples_used = self.is_input_examples_used(tools)
beta_set = set(get_anthropic_beta_from_headers(headers))
auto_betas = self.get_anthropic_beta_list(
model=model,
optional_params=optional_params,
computer_tool_used=self.is_computer_tool_used(tools),
prompt_caching_set=False,
file_id_used=self.is_file_id_used(messages),
mcp_server_used=self.is_mcp_server_used(optional_params.get("mcp_servers")),
)
beta_set.update(auto_betas)
if tool_search_used and not (
programmatic_tool_calling_used or input_examples_used
):
beta_set.discard(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
if "opus-4" in model.lower() or "opus_4" in model.lower():
beta_set.add("tool-search-tool-2025-10-19")
# Filter out beta headers that Bedrock Invoke doesn't support
# Uses centralized configuration from anthropic_beta_headers_config.json
beta_list = list(beta_set)
_anthropic_request["anthropic_beta"] = beta_list
return _anthropic_request
def _normalize_bedrock_tool_search_tools(self, optional_params: dict) -> dict:
"""
Convert tool search entries to the format supported by the Bedrock Invoke API.
"""
tools = optional_params.get("tools")
if not tools or not isinstance(tools, list):
return optional_params
normalized_tools = []
for tool in tools:
tool_type = tool.get("type")
if tool_type == "tool_search_tool_bm25_20251119":
# Bedrock Invoke does not support the BM25 variant, so skip it.
continue
if tool_type == "tool_search_tool_regex_20251119":
normalized_tool = tool.copy()
normalized_tool["type"] = "tool_search_tool_regex"
normalized_tool["name"] = normalized_tool.get(
"name", "tool_search_tool_regex"
)
normalized_tools.append(normalized_tool)
continue
normalized_tools.append(tool)
optional_params["tools"] = normalized_tools
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return AnthropicConfig.transform_response(
self,
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)

View File

@@ -0,0 +1,613 @@
import copy
import json
import time
from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import (
cohere_message_pt,
custom_prompt,
deepseek_r1_pt,
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import make_call, make_sync_call
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import CustomStreamWrapper
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
def __init__(self, **kwargs):
BaseConfig.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
def get_supported_openai_params(self, model: str) -> List[str]:
"""
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
"""
return [
"max_tokens",
"max_completion_tokens",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
"""
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "stream":
optional_params["stream"] = value
return optional_params
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete url for the request
"""
provider = self.get_bedrock_invoke_provider(model)
modelId = self.get_bedrock_model_id(
model=model,
provider=provider,
optional_params=optional_params,
)
### SET RUNTIME ENDPOINT ###
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=self._get_aws_region_name(
optional_params=optional_params, model=model
),
)
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
proxy_endpoint_url = (
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
)
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
return endpoint_url
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
api_key=api_key,
model=model,
stream=stream,
fake_stream=fake_stream,
)
def _apply_config_to_params(self, config: dict, inference_params: dict) -> None:
"""Apply config values to inference_params if not already set."""
for k, v in config.items():
if k not in inference_params:
inference_params[k] = v
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
## SETUP ##
stream = optional_params.pop("stream", None)
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
hf_model_name = litellm_params.get("hf_model_name", None)
provider = self.get_bedrock_invoke_provider(model)
prompt, chat_history = self.convert_messages_to_prompt(
model=hf_model_name or model,
messages=messages,
provider=provider,
custom_prompt_dict=custom_prompt_dict,
)
inference_params = copy.deepcopy(optional_params)
inference_params = {
k: v
for k, v in inference_params.items()
if k not in self.aws_authentication_params
}
request_data: dict = {}
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
config = litellm.AmazonCohereChatConfig().get_config()
self._apply_config_to_params(config, inference_params)
_data = {"message": prompt, **inference_params}
if chat_history is not None:
_data["chat_history"] = chat_history
request_data = _data
else:
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
self._apply_config_to_params(config, inference_params)
if stream is True:
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic":
transformed_request = (
litellm.AmazonAnthropicClaudeConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
)
return transformed_request
elif provider == "nova":
return litellm.AmazonInvokeNovaConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {"prompt": prompt, **inference_params}
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonMistralConfig.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {"prompt": prompt, **inference_params}
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {
"inputText": prompt,
"textGenerationConfig": inference_params,
}
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {"prompt": prompt, **inference_params}
elif provider == "twelvelabs":
return litellm.AmazonTwelveLabsPegasusConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
elif provider == "openai":
# OpenAI imported models use OpenAI Chat Completions format
return litellm.AmazonBedrockOpenAIConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
else:
raise BedrockError(
status_code=404,
message="Bedrock Invoke HTTPX: Unknown provider={}, model={}. Try calling via converse route - `bedrock/converse/<model>`.".format(
provider, model
),
)
return request_data
def transform_response( # noqa: PLR0915
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
completion_response = raw_response.json()
except Exception:
raise BedrockError(
message=raw_response.text, status_code=raw_response.status_code
)
verbose_logger.debug(
"bedrock invoke response % s",
json.dumps(completion_response, indent=4, default=str),
)
provider = self.get_bedrock_invoke_provider(model)
outputText: Optional[str] = None
try:
if provider == "cohere":
if "text" in completion_response:
outputText = completion_response["text"] # type: ignore
elif "generations" in completion_response:
outputText = completion_response["generations"][0]["text"]
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["generations"][0]["finish_reason"]
)
elif provider == "anthropic":
return litellm.AmazonAnthropicClaudeConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
elif provider == "nova":
return litellm.AmazonInvokeNovaConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
elif provider == "twelvelabs":
return litellm.AmazonTwelveLabsPegasusConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
elif provider == "ai21":
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
outputText = completion_response["generation"]
elif provider == "mistral":
outputText = litellm.AmazonMistralConfig.get_outputText(
completion_response, model_response
)
else: # amazon titan
outputText = completion_response.get("results")[0].get("outputText")
except Exception as e:
raise BedrockError(
message="Error processing={}, Received error={}".format(
raw_response.text, str(e)
),
status_code=422,
)
try:
if (
outputText is not None
and len(outputText) > 0
and hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
is None
):
model_response.choices[0].message.content = outputText # type: ignore
elif (
hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
is not None
):
pass
else:
raise Exception()
except Exception as e:
raise BedrockError(
message="Error parsing received text={}.\nError-{}".format(
outputText, str(e)
),
status_code=raw_response.status_code,
)
## CALCULATING USAGE - bedrock returns usage in the headers
bedrock_input_tokens = raw_response.headers.get(
"x-amzn-bedrock-input-token-count", None
)
bedrock_output_tokens = raw_response.headers.get(
"x-amzn-bedrock-output-token-count", None
)
prompt_tokens = int(
bedrock_input_tokens or litellm.token_counter(messages=messages)
)
completion_tokens = int(
bedrock_output_tokens
or litellm.token_counter(
text=model_response.choices[0].message.content, # type: ignore
count_response_tokens=True,
)
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return BedrockError(status_code=status_code, message=error_message)
@track_llm_api_timing()
async def get_async_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
json_mode=json_mode,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
@track_llm_api_timing()
def get_sync_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> CustomStreamWrapper:
if client is None or isinstance(client, AsyncHTTPHandler):
client = _get_httpx_client(params={})
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
signed_json_body=signed_json_body,
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
json_mode=json_mode,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
@property
def has_custom_stream_wrapper(self) -> bool:
return True
@property
def supports_stream_param_in_request_body(self) -> bool:
"""
Bedrock invoke does not allow passing `stream` in the request body.
"""
return False
@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 4 scenarios:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
# Special case: Check for "nova" in model name first (before "amazon")
# This handles amazon.nova-* models which would otherwise match "amazon" (Titan)
if "nova" in model.lower():
if "nova" in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, "nova")
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
# If not a known provider, check for pattern with two slashes
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
if provider is not None:
return provider
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
if provider in model:
return provider
return None
@staticmethod
def _get_provider_from_model_path(
model_path: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the provider from a model path with format: provider/model-name
Args:
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
Returns:
Optional[str]: The provider name, or None if no valid provider found
"""
parts = model_path.split("/")
if len(parts) >= 1:
provider = parts[0]
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
return None
def convert_messages_to_prompt(
self, model, messages, provider, custom_prompt_dict
) -> Tuple[str, Optional[list]]:
# handle anthropic prompts and amazon titan prompts
prompt = ""
chat_history: Optional[list] = None
## CUSTOM PROMPT
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
return prompt, None
## ELSE
if provider == "anthropic" or provider == "amazon":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta" or provider == "llama":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "cohere":
prompt, chat_history = cohere_message_pt(messages=messages)
elif provider == "deepseek_r1":
prompt = deepseek_r1_pt(messages=messages)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt, chat_history # type: ignore

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
"""
Helper util for handling bedrock-specific cost calculation
- e.g.: prompt caching
"""
from typing import TYPE_CHECKING, Optional, Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
if TYPE_CHECKING:
from litellm.types.utils import Usage
def cost_per_token(
model: str, usage: "Usage", service_tier: Optional[str] = None
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Follows the same logic as Anthropic's cost per token calculation.
"""
return generic_cost_per_token(
model=model,
usage=usage,
custom_llm_provider="bedrock",
service_tier=service_tier,
)

View File

@@ -0,0 +1,115 @@
"""
Bedrock Token Counter implementation using the CountTokens API.
"""
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.llms.bedrock.common_utils import BedrockError, get_bedrock_base_model
from litellm.llms.bedrock.count_tokens.handler import BedrockCountTokensHandler
from litellm.types.utils import LlmProviders, TokenCountResponse
class BedrockTokenCounter(BaseTokenCounter):
"""Token counter implementation for AWS Bedrock provider using the CountTokens API."""
def should_use_token_counting_api(
self,
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Returns True if we should use the Bedrock CountTokens API for token counting.
"""
return custom_llm_provider == LlmProviders.BEDROCK.value
async def count_tokens(
self,
model_to_use: str,
messages: Optional[List[Dict[str, Any]]],
contents: Optional[List[Dict[str, Any]]],
deployment: Optional[Dict[str, Any]] = None,
request_model: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Optional[TokenCountResponse]:
"""
Count tokens using AWS Bedrock's CountTokens API.
This method calls the existing BedrockCountTokensHandler to make an API call
to Bedrock's token counting endpoint, bypassing the local tiktoken-based counting.
Args:
model_to_use: The model identifier
messages: The messages to count tokens for
contents: Alternative content format (not used for Bedrock)
deployment: Deployment configuration containing litellm_params
request_model: The original request model name
Returns:
TokenCountResponse with token count, or None if counting fails
"""
if not messages:
return None
deployment = deployment or {}
litellm_params = deployment.get("litellm_params", {})
# Build request data in the format expected by BedrockCountTokensHandler
request_data: Dict[str, Any] = {
"model": model_to_use,
"messages": messages,
}
if tools:
request_data["tools"] = tools
if system:
request_data["system"] = system
# Get the resolved model (strip prefixes like bedrock/, converse/, etc.)
resolved_model = get_bedrock_base_model(model_to_use)
try:
handler = BedrockCountTokensHandler()
result = await handler.handle_count_tokens_request(
request_data=request_data,
litellm_params=litellm_params,
resolved_model=resolved_model,
)
# Transform response to TokenCountResponse
if result is not None:
return TokenCountResponse(
total_tokens=result.get("input_tokens", 0),
request_model=request_model,
model_used=model_to_use,
tokenizer_type="bedrock_api",
original_response=result,
)
except BedrockError as e:
verbose_logger.warning(
f"Bedrock CountTokens API error: status={e.status_code}, message={e.message}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="bedrock_api",
error=True,
error_message=e.message,
status_code=e.status_code,
)
except Exception as e:
verbose_logger.warning(f"Error calling Bedrock CountTokens API: {e}")
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="bedrock_api",
error=True,
error_message=str(e),
status_code=500,
)
return None

View File

@@ -0,0 +1,136 @@
"""
AWS Bedrock CountTokens API handler.
Simplified handler leveraging existing LiteLLM Bedrock infrastructure.
"""
from typing import Any, Dict
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
class BedrockCountTokensHandler(BedrockCountTokensConfig):
"""
Simplified handler for AWS Bedrock CountTokens API requests.
Uses existing LiteLLM infrastructure for authentication and request handling.
"""
async def handle_count_tokens_request(
self,
request_data: Dict[str, Any],
litellm_params: Dict[str, Any],
resolved_model: str,
) -> Dict[str, Any]:
"""
Handle a CountTokens request using existing LiteLLM patterns.
Args:
request_data: The incoming request payload
litellm_params: LiteLLM configuration parameters
resolved_model: The actual model ID resolved from router
Returns:
Dictionary containing token count response
"""
try:
# Validate the request
self.validate_count_tokens_request(request_data)
verbose_logger.debug(
f"Processing CountTokens request for resolved model: {resolved_model}"
)
# Get AWS region using existing LiteLLM function
aws_region_name = self._get_aws_region_name(
optional_params=litellm_params,
model=resolved_model,
model_id=None,
)
verbose_logger.debug(f"Retrieved AWS region: {aws_region_name}")
# Transform request to Bedrock format (supports both Converse and InvokeModel)
bedrock_request = self.transform_anthropic_to_bedrock_count_tokens(
request_data=request_data
)
verbose_logger.debug(f"Transformed request: {bedrock_request}")
# Get endpoint URL using simplified function
endpoint_url = self.get_bedrock_count_tokens_endpoint(
resolved_model, aws_region_name
)
verbose_logger.debug(f"Making request to: {endpoint_url}")
# Use existing _sign_request method from BaseAWSLLM
# Extract api_key for bearer token auth if provided
api_key = litellm_params.get("api_key", None)
headers = {"Content-Type": "application/json"}
signed_headers, signed_body = self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=litellm_params,
request_data=bedrock_request,
api_base=endpoint_url,
model=resolved_model,
api_key=api_key,
)
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK
)
response = await async_client.post(
endpoint_url,
headers=signed_headers,
data=signed_body,
timeout=30.0,
)
verbose_logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
error_text = response.text
verbose_logger.error(f"AWS Bedrock error: {error_text}")
raise BedrockError(
status_code=response.status_code,
message=error_text,
)
bedrock_response = response.json()
verbose_logger.debug(f"Bedrock response: {bedrock_response}")
# Transform response back to expected format
final_response = self.transform_bedrock_response_to_anthropic(
bedrock_response
)
verbose_logger.debug(f"Final response: {final_response}")
return final_response
except BedrockError:
# Re-raise Bedrock exceptions as-is
raise
except httpx.HTTPStatusError as e:
# HTTP errors - preserve the actual status code
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
raise BedrockError(
status_code=e.response.status_code,
message=e.response.text,
)
except Exception as e:
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
raise BedrockError(
status_code=500,
message=f"CountTokens processing error: {str(e)}",
)

View File

@@ -0,0 +1,262 @@
"""
AWS Bedrock CountTokens API transformation logic.
This module handles the transformation of requests from Anthropic Messages API format
to AWS Bedrock's CountTokens API format and vice versa.
"""
import re
from typing import Any, Dict, List, Optional
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.bedrock.common_utils import get_bedrock_base_model
class BedrockCountTokensConfig(BaseAWSLLM):
"""
Configuration and transformation logic for AWS Bedrock CountTokens API.
AWS Bedrock CountTokens API Specification:
- Endpoint: POST /model/{modelId}/count-tokens
- Input formats: 'invokeModel' or 'converse'
- Response: {"inputTokens": <number>}
"""
def _detect_input_type(self, request_data: Dict[str, Any]) -> str:
"""
Detect whether to use 'converse' or 'invokeModel' input format.
Args:
request_data: The original request data
Returns:
'converse' or 'invokeModel'
"""
# If the request has messages in the expected Anthropic format, use converse
if "messages" in request_data and isinstance(request_data["messages"], list):
return "converse"
# For raw text or other formats, use invokeModel
# This handles cases where the input is prompt-based or already in raw Bedrock format
return "invokeModel"
def transform_anthropic_to_bedrock_count_tokens(
self,
request_data: Dict[str, Any],
) -> Dict[str, Any]:
"""
Transform request to Bedrock CountTokens format.
Supports both Converse and InvokeModel input types.
Input (Anthropic format):
{
"model": "claude-3-5-sonnet",
"messages": [{"role": "user", "content": "Hello!"}]
}
Output (Bedrock CountTokens format for Converse):
{
"input": {
"converse": {
"messages": [...],
"system": [...] (if present)
}
}
}
Output (Bedrock CountTokens format for InvokeModel):
{
"input": {
"invokeModel": {
"body": "{...raw model input...}"
}
}
}
"""
input_type = self._detect_input_type(request_data)
if input_type == "converse":
return self._transform_to_converse_format(request_data)
else:
return self._transform_to_invoke_model_format(request_data)
def _transform_to_converse_format(
self, request_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Transform to Converse input format, including system and tools."""
messages = request_data.get("messages", [])
system = request_data.get("system")
tools = request_data.get("tools")
# Transform messages
user_messages = []
for message in messages:
transformed_message: Dict[str, Any] = {
"role": message.get("role"),
"content": [],
}
content = message.get("content", "")
if isinstance(content, str):
transformed_message["content"].append({"text": content})
elif isinstance(content, list):
transformed_message["content"] = content
user_messages.append(transformed_message)
converse_input: Dict[str, Any] = {"messages": user_messages}
# Transform system prompt (string or list of blocks → Bedrock format)
system_blocks = self._transform_system(system)
if system_blocks:
converse_input["system"] = system_blocks
# Transform tools (Anthropic format → Bedrock toolConfig)
tool_config = self._transform_tools(tools)
if tool_config:
converse_input["toolConfig"] = tool_config
return {"input": {"converse": converse_input}}
def _transform_system(self, system: Optional[Any]) -> List[Dict[str, Any]]:
"""Transform Anthropic system prompt to Bedrock system blocks."""
if system is None:
return []
if isinstance(system, str):
return [{"text": system}]
if isinstance(system, list):
# Already in blocks format (e.g. [{"type": "text", "text": "..."}])
return [
{"text": block.get("text", "")}
for block in system
if isinstance(block, dict)
]
return []
def _transform_tools(
self, tools: Optional[List[Dict[str, Any]]]
) -> Optional[Dict[str, Any]]:
"""Transform Anthropic tools to Bedrock toolConfig format."""
if not tools:
return None
bedrock_tools = []
for tool in tools:
name = tool.get("name", "")
# Bedrock tool names must match [a-zA-Z][a-zA-Z0-9_]* and max 64 chars
name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
if name and not name[0].isalpha():
name = "t_" + name
name = name[:64]
description = tool.get("description") or name
input_schema = tool.get(
"input_schema", {"type": "object", "properties": {}}
)
bedrock_tools.append(
{
"toolSpec": {
"name": name,
"description": description,
"inputSchema": {"json": input_schema},
}
}
)
return {"tools": bedrock_tools}
def _transform_to_invoke_model_format(
self, request_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Transform to InvokeModel input format."""
import json
# For InvokeModel, we need to provide the raw body that would be sent to the model
# Remove the 'model' field from the body as it's not part of the model input
body_data = {k: v for k, v in request_data.items() if k != "model"}
return {"input": {"invokeModel": {"body": json.dumps(body_data)}}}
def get_bedrock_count_tokens_endpoint(
self, model: str, aws_region_name: str
) -> str:
"""
Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.
Args:
model: The resolved model ID from router lookup
aws_region_name: AWS region (e.g., "eu-west-1")
Returns:
Complete endpoint URL for CountTokens API
"""
# Use existing LiteLLM function to get the base model ID (removes region prefix)
model_id = get_bedrock_base_model(model)
# Remove bedrock/ prefix if present
if model_id.startswith("bedrock/"):
model_id = model_id[8:] # Remove "bedrock/" prefix
base_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
endpoint = f"{base_url}/model/{model_id}/count-tokens"
return endpoint
def transform_bedrock_response_to_anthropic(
self, bedrock_response: Dict[str, Any]
) -> Dict[str, Any]:
"""
Transform Bedrock CountTokens response to Anthropic format.
Input (Bedrock response):
{
"inputTokens": 123
}
Output (Anthropic format):
{
"input_tokens": 123
}
"""
input_tokens = bedrock_response.get("inputTokens", 0)
return {"input_tokens": input_tokens}
def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
"""
Validate the incoming count tokens request.
Supports both Converse and InvokeModel input formats.
Args:
request_data: The request payload
Raises:
ValueError: If the request is invalid
"""
if not request_data.get("model"):
raise ValueError("model parameter is required")
input_type = self._detect_input_type(request_data)
if input_type == "converse":
# Validate Converse format (messages-based)
messages = request_data.get("messages", [])
if not messages:
raise ValueError("messages parameter is required for Converse input")
if not isinstance(messages, list):
raise ValueError("messages must be a list")
for i, message in enumerate(messages):
if not isinstance(message, dict):
raise ValueError(f"Message {i} must be a dictionary")
if "role" not in message:
raise ValueError(f"Message {i} must have a 'role' field")
if "content" not in message:
raise ValueError(f"Message {i} must have a 'content' field")
else:
# For InvokeModel format, we need at least some content to count tokens
# The content structure varies by model, so we do minimal validation
if len(request_data) <= 1: # Only has 'model' field
raise ValueError("Request must contain content to count tokens")

View File

@@ -0,0 +1,361 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Nova /invoke and /async-invoke format.
Why separate file? Make it easy to see how transformation works
Supports:
- Synchronous embeddings (SINGLE_EMBEDDING)
- Asynchronous embeddings with segmentation (SEGMENTED_EMBEDDING)
- Multimodal inputs: text, image, video, audio
- Multiple embedding purposes and dimensions
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/nova-embed.html
"""
from typing import List, Optional
from litellm.types.utils import (
Embedding,
EmbeddingResponse,
PromptTokensDetailsWrapper,
Usage,
)
class AmazonNovaEmbeddingConfig:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/nova-embed.html
Amazon Nova Multimodal Embeddings supports:
- Text, image, video, and audio inputs
- Synchronous (InvokeModel) and asynchronous (StartAsyncInvoke) APIs
- Multiple embedding purposes and dimensions
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self) -> List[str]:
return [
"dimensions",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
"""Map OpenAI-style parameters to Nova parameters."""
for k, v in non_default_params.items():
if k == "dimensions":
# Map OpenAI dimensions to Nova embedding_dimension
optional_params["embedding_dimension"] = v
elif k in self.get_supported_openai_params():
optional_params[k] = v
return optional_params
def _parse_data_url(self, data_url: str) -> tuple:
"""
Parse a data URL to extract the media type and base64 data.
Args:
data_url: Data URL in format: data:image/jpeg;base64,/9j/4AAQ...
Returns:
tuple: (media_type, base64_data)
media_type: e.g., "image/jpeg", "video/mp4", "audio/mpeg"
base64_data: The base64-encoded data without the prefix
"""
if not data_url.startswith("data:"):
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
# Split by comma to separate metadata from data
# Format: data:image/jpeg;base64,<base64_data>
if "," not in data_url:
raise ValueError(
f"Invalid data URL format (missing comma): {data_url[:50]}..."
)
metadata, base64_data = data_url.split(",", 1)
# Extract media type from metadata
# Remove 'data:' prefix and ';base64' suffix
metadata = metadata[5:] # Remove 'data:'
if ";" in metadata:
media_type = metadata.split(";")[0]
else:
media_type = metadata
return media_type, base64_data
def _transform_request(
self,
input: str,
inference_params: dict,
async_invoke_route: bool = False,
model_id: Optional[str] = None,
output_s3_uri: Optional[str] = None,
) -> dict:
"""
Transform OpenAI-style input to Nova format.
Only handles OpenAI params (dimensions). All other Nova-specific params
should be passed via inference_params and will be passed through as-is.
Args:
input: The input text or media reference
inference_params: Additional parameters (will be passed through)
async_invoke_route: Whether this is for async invoke
model_id: Model ID (for async invoke)
output_s3_uri: S3 URI for output (for async invoke)
Returns:
dict: Nova embedding request
"""
# Determine task type
task_type = "SEGMENTED_EMBEDDING" if async_invoke_route else "SINGLE_EMBEDDING"
# Build the base request structure
request: dict = {
"schemaVersion": "nova-multimodal-embed-v1",
"taskType": task_type,
}
# Start with inference_params (user-provided params)
embedding_params = inference_params.copy()
embedding_params.pop("output_s3_uri", None)
# Map OpenAI dimensions to embeddingDimension if provided
if "dimensions" in embedding_params:
embedding_params["embeddingDimension"] = embedding_params.pop("dimensions")
elif "embedding_dimension" in embedding_params:
embedding_params["embeddingDimension"] = embedding_params.pop(
"embedding_dimension"
)
# Add required embeddingPurpose if not provided (required by Nova API)
if "embeddingPurpose" not in embedding_params:
embedding_params["embeddingPurpose"] = "GENERIC_INDEX"
# Add required embeddingDimension if not provided (required by Nova API)
if "embeddingDimension" not in embedding_params:
embedding_params["embeddingDimension"] = 3072
# For text/media input, add basic structure if user hasn't provided text/image/video/audio
if (
"text" not in embedding_params
and "image" not in embedding_params
and "video" not in embedding_params
and "audio" not in embedding_params
):
# Check if input is a data URL (e.g., data:image/jpeg;base64,...)
if input.startswith("data:"):
# Parse the data URL to extract media type and base64 data
media_type, base64_data = self._parse_data_url(input)
if media_type.startswith("image/"):
# Extract image format from MIME type (e.g., image/jpeg -> jpeg)
image_format = media_type.split("/")[1].lower()
# Nova API expects specific formats
if image_format == "jpg":
image_format = "jpeg"
embedding_params["image"] = {
"format": image_format,
"source": {"bytes": base64_data},
}
elif media_type.startswith("video/"):
# Handle video data URLs
video_format = media_type.split("/")[1].lower()
embedding_params["video"] = {
"format": video_format,
"source": {"bytes": base64_data},
}
elif media_type.startswith("audio/"):
# Handle audio data URLs
audio_format = media_type.split("/")[1].lower()
embedding_params["audio"] = {
"format": audio_format,
"source": {"bytes": base64_data},
}
else:
# Fallback to text for unknown types
embedding_params["text"] = {"value": input, "truncationMode": "END"}
elif input.startswith("s3://"):
# S3 URL - default to text for now, user should specify modality
embedding_params["text"] = {
"source": {"s3Location": {"uri": input}},
"truncationMode": "END", # Required by Nova API
}
else:
# Plain text input
embedding_params["text"] = {
"value": input,
"truncationMode": "END", # Required by Nova API
}
# Set the embedding params in the request
if task_type == "SINGLE_EMBEDDING":
request["singleEmbeddingParams"] = embedding_params
else:
request["segmentedEmbeddingParams"] = embedding_params
# For async invoke, wrap in the async invoke format
if async_invoke_route and model_id:
return self._wrap_async_invoke_request(
model_input=request,
model_id=model_id,
output_s3_uri=output_s3_uri,
)
return request
def _wrap_async_invoke_request(
self,
model_input: dict,
model_id: str,
output_s3_uri: Optional[str] = None,
) -> dict:
"""
Wrap the transformed request in the AWS Bedrock async invoke format.
Args:
model_input: The transformed Nova embedding request
model_id: The model identifier (without async_invoke prefix)
output_s3_uri: S3 URI for output data config
Returns:
dict: The wrapped async invoke request
"""
import urllib.parse
# Clean the model ID
unquoted_model_id = urllib.parse.unquote(model_id)
if unquoted_model_id.startswith("async_invoke/"):
unquoted_model_id = unquoted_model_id.replace("async_invoke/", "")
# Validate that the S3 URI is not empty
if not output_s3_uri or output_s3_uri.strip() == "":
raise ValueError("output_s3_uri is required for async invoke requests")
return {
"modelId": unquoted_model_id,
"modelInput": model_input,
"outputDataConfig": {"s3OutputDataConfig": {"s3Uri": output_s3_uri}},
}
def _transform_response(
self,
response_list: List[dict],
model: str,
batch_data: Optional[List[dict]] = None,
) -> EmbeddingResponse:
"""
Transform Nova response to OpenAI format.
Nova response format:
{
"embeddings": [
{
"embeddingType": "TEXT" | "IMAGE" | "VIDEO" | "AUDIO" | "AUDIO_VIDEO_COMBINED",
"embedding": [0.1, 0.2, ...],
"truncatedCharLength": 100 # Optional, only for text
}
]
}
"""
embeddings: List[Embedding] = []
total_tokens = 0
for response in response_list:
# Nova response has an "embeddings" array
if "embeddings" in response and isinstance(response["embeddings"], list):
for item in response["embeddings"]:
if "embedding" in item:
embedding = Embedding(
embedding=item["embedding"],
index=len(embeddings),
object="embedding",
)
embeddings.append(embedding)
# Estimate token count
# For text, use truncatedCharLength if available
if "truncatedCharLength" in item:
total_tokens += item["truncatedCharLength"] // 4
else:
# Rough estimate based on embedding dimension
total_tokens += len(item["embedding"]) // 4
elif "embedding" in response:
# Direct embedding response (fallback)
embedding = Embedding(
embedding=response["embedding"],
index=len(embeddings),
object="embedding",
)
embeddings.append(embedding)
total_tokens += len(response["embedding"]) // 4
# Count images from original requests for cost calculation
image_count = 0
if batch_data:
for request_data in batch_data:
# Nova wraps params in singleEmbeddingParams or segmentedEmbeddingParams
params = request_data.get(
"singleEmbeddingParams",
request_data.get("segmentedEmbeddingParams", {}),
)
if "image" in params:
image_count += 1
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if image_count > 0:
prompt_tokens_details = PromptTokensDetailsWrapper(
image_count=image_count,
)
usage = Usage(
prompt_tokens=total_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
)
return EmbeddingResponse(data=embeddings, model=model, usage=usage)
def _transform_async_invoke_response(
self, response: dict, model: str
) -> EmbeddingResponse:
"""
Transform async invoke response (invocation ARN) to OpenAI format.
AWS async invoke returns:
{
"invocationArn": "arn:aws:bedrock:us-east-1:123456789012:async-invoke/abc123"
}
We transform this to a job-like embedding response with the ARN in hidden params.
"""
invocation_arn = response.get("invocationArn", "")
# Create a placeholder embedding object for the job
embedding = Embedding(
embedding=[], # Empty embedding for async jobs
index=0,
object="embedding",
)
# Create usage object (empty for async jobs)
usage = Usage(prompt_tokens=0, total_tokens=0)
# Create hidden params with job ID
from litellm.types.llms.base import HiddenParams
hidden_params = HiddenParams()
setattr(hidden_params, "_invocation_arn", invocation_arn)
return EmbeddingResponse(
data=[embedding],
model=model,
usage=usage,
hidden_params=hidden_params,
)

View File

@@ -0,0 +1,88 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
Why separate file? Make it easy to see how transformation works
Convers
- G1 request format
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
import types
from typing import List
from litellm.types.llms.bedrock import (
AmazonTitanG1EmbeddingRequest,
AmazonTitanG1EmbeddingResponse,
)
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
class AmazonTitanG1Config:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
def __init__(
self,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self) -> List[str]:
return []
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
return optional_params
def _transform_request(
self, input: str, inference_params: dict
) -> AmazonTitanG1EmbeddingRequest:
return AmazonTitanG1EmbeddingRequest(inputText=input)
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):
_parsed_response = AmazonTitanG1EmbeddingResponse(**response) # type: ignore
transformed_responses.append(
Embedding(
embedding=_parsed_response["embedding"],
index=index,
object="embedding",
)
)
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
usage = Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=0,
total_tokens=total_prompt_tokens,
)
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)

View File

@@ -0,0 +1,101 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan multimodal /invoke format.
Why separate file? Make it easy to see how transformation works
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
"""
from typing import List, Optional
from litellm.types.llms.bedrock import (
AmazonTitanMultimodalEmbeddingConfig,
AmazonTitanMultimodalEmbeddingRequest,
AmazonTitanMultimodalEmbeddingResponse,
)
from litellm.types.utils import (
Embedding,
EmbeddingResponse,
PromptTokensDetailsWrapper,
Usage,
)
from litellm.utils import get_base64_str, is_base64_encoded
class AmazonTitanMultimodalEmbeddingG1Config:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self) -> List[str]:
return ["dimensions"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for k, v in non_default_params.items():
if k == "dimensions":
optional_params[
"embeddingConfig"
] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
return optional_params
def _transform_request(
self, input: str, inference_params: dict
) -> AmazonTitanMultimodalEmbeddingRequest:
## check if b64 encoded str or not ##
is_encoded = is_base64_encoded(input)
if is_encoded: # check if string is b64 encoded image or not
b64_str = get_base64_str(input)
transformed_request = AmazonTitanMultimodalEmbeddingRequest(
inputImage=b64_str
)
else:
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
for k, v in inference_params.items():
transformed_request[k] = v # type: ignore
return transformed_request
def _transform_response(
self,
response_list: List[dict],
model: str,
batch_data: Optional[List[dict]] = None,
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
transformed_responses.append(
Embedding(
embedding=_parsed_response["embedding"],
index=index,
object="embedding",
)
)
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
# Count images from original requests for cost calculation
image_count = 0
if batch_data:
for request_data in batch_data:
if "inputImage" in request_data:
image_count += 1
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if image_count > 0:
prompt_tokens_details = PromptTokensDetailsWrapper(
image_count=image_count,
)
usage = Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=0,
total_tokens=total_prompt_tokens,
prompt_tokens_details=prompt_tokens_details,
)
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)

View File

@@ -0,0 +1,131 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan V2 /invoke format.
Why separate file? Make it easy to see how transformation works
Convers
- v2 request format
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
import types
from typing import List, Optional, Union
from litellm.types.llms.bedrock import (
AmazonTitanV2EmbeddingRequest,
AmazonTitanV2EmbeddingResponse,
)
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
class AmazonTitanV2Config:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
"""
normalize: Optional[bool] = None
dimensions: Optional[int] = None
def __init__(
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self) -> List[str]:
return ["dimensions", "encoding_format"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for k, v in non_default_params.items():
if k == "dimensions":
optional_params["dimensions"] = v
elif k == "encoding_format":
# Map OpenAI encoding_format to AWS embeddingTypes
if v == "float":
optional_params["embeddingTypes"] = ["float"]
elif v == "base64":
# base64 maps to binary format in AWS
optional_params["embeddingTypes"] = ["binary"]
else:
# For any other encoding format, default to float
optional_params["embeddingTypes"] = ["float"]
return optional_params
def _transform_request(
self, input: str, inference_params: dict
) -> AmazonTitanV2EmbeddingRequest:
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
# According to AWS docs, embeddingsByType is always present
# If binary was requested (encoding_format="base64"), use binary data
# Otherwise, use float data from embeddingsByType or fallback to embedding field
embedding_data: Union[List[float], List[int]]
if (
"embeddingsByType" in _parsed_response
and "binary" in _parsed_response["embeddingsByType"]
):
# Use binary data if available (for encoding_format="base64")
embedding_data = _parsed_response["embeddingsByType"]["binary"]
elif (
"embeddingsByType" in _parsed_response
and "float" in _parsed_response["embeddingsByType"]
):
# Use float data from embeddingsByType
embedding_data = _parsed_response["embeddingsByType"]["float"]
elif "embedding" in _parsed_response:
# Fallback to legacy embedding field
embedding_data = _parsed_response["embedding"]
else:
raise ValueError(f"No embedding data found in response: {response}")
transformed_responses.append(
Embedding(
embedding=embedding_data,
index=index,
object="embedding",
)
)
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
usage = Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=0,
total_tokens=total_prompt_tokens,
)
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)

View File

@@ -0,0 +1,47 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
Why separate file? Make it easy to see how transformation works
"""
from typing import List
from litellm.llms.cohere.embed.transformation import CohereEmbeddingConfig
from litellm.types.llms.bedrock import CohereEmbeddingRequest
class BedrockCohereEmbeddingConfig:
def __init__(self) -> None:
pass
def get_supported_openai_params(self) -> List[str]:
return ["encoding_format", "dimensions"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for k, v in non_default_params.items():
if k == "encoding_format":
optional_params["embedding_types"] = v
elif k == "dimensions":
optional_params["output_dimension"] = v
return optional_params
def _is_v3_model(self, model: str) -> bool:
return "3" in model
def _transform_request(
self, model: str, input: List[str], inference_params: dict
) -> CohereEmbeddingRequest:
transformed_request = CohereEmbeddingConfig()._transform_request(
model, input, inference_params
)
new_transformed_request = CohereEmbeddingRequest(
input_type=transformed_request["input_type"],
)
for k in CohereEmbeddingRequest.__annotations__.keys():
if k in transformed_request:
new_transformed_request[k] = transformed_request[k] # type: ignore
return new_transformed_request

View File

@@ -0,0 +1,699 @@
"""
Handles embedding calls to Bedrock's `/invoke` endpoint
"""
import copy
import json
import urllib.parse
from typing import Any, Callable, List, Optional, Tuple, Union, get_args
import httpx
import litellm
from litellm.constants import BEDROCK_EMBEDDING_PROVIDERS_LITERAL
from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.secret_managers.main import get_secret
from litellm.types.llms.bedrock import (
AmazonEmbeddingRequest,
CohereEmbeddingRequest,
)
from litellm.types.utils import EmbeddingResponse, LlmProviders
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
from .amazon_nova_transformation import AmazonNovaEmbeddingConfig
from .amazon_titan_g1_transformation import AmazonTitanG1Config
from .amazon_titan_multimodal_transformation import (
AmazonTitanMultimodalEmbeddingG1Config,
)
from .amazon_titan_v2_transformation import AmazonTitanV2Config
from .cohere_transformation import BedrockCohereEmbeddingConfig
from .twelvelabs_marengo_transformation import TwelveLabsMarengoEmbeddingConfig
class BedrockEmbedding(BaseAWSLLM):
def _load_credentials(
self,
optional_params: dict,
) -> Tuple[Any, str]:
try:
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials( # type: ignore
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return credentials, aws_region_name
async def async_embeddings(self):
pass
def _make_sync_call(
self,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
api_base: str,
headers: dict,
data: dict,
) -> dict:
if client is None or not isinstance(client, HTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_httpx_client(_params) # type: ignore
else:
client = client
try:
response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return response.json()
async def _make_async_call(
self,
client: Optional[AsyncHTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
api_base: str,
headers: dict,
data: dict,
) -> dict:
if client is None or not isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = get_async_httpx_client(
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
)
else:
client = client
try:
response = await client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return response.json()
def _transform_response(
self,
response_list: List[dict],
model: str,
provider: BEDROCK_EMBEDDING_PROVIDERS_LITERAL,
is_async_invoke: Optional[bool] = False,
batch_data: Optional[List[dict]] = None,
) -> Optional[EmbeddingResponse]:
"""
Transforms the response from the Bedrock embedding provider to the OpenAI format.
"""
returned_response: Optional[EmbeddingResponse] = None
# Handle async invoke responses (single response with invocationArn)
if (
is_async_invoke
and len(response_list) == 1
and "invocationArn" in response_list[0]
):
if provider == "twelvelabs":
returned_response = (
TwelveLabsMarengoEmbeddingConfig()._transform_async_invoke_response(
response=response_list[0], model=model
)
)
elif provider == "nova":
returned_response = (
AmazonNovaEmbeddingConfig()._transform_async_invoke_response(
response=response_list[0], model=model
)
)
else:
# For other providers, create a generic async response
invocation_arn = response_list[0].get("invocationArn", "")
from litellm.types.utils import Embedding, Usage
embedding = Embedding(
embedding=[],
index=0,
object="embedding", # Must be literal "embedding"
)
usage = Usage(prompt_tokens=0, total_tokens=0)
# Create hidden params with job ID
from litellm.types.llms.base import HiddenParams
hidden_params = HiddenParams()
setattr(hidden_params, "_invocation_arn", invocation_arn)
returned_response = EmbeddingResponse(
data=[embedding],
model=model,
usage=usage,
hidden_params=hidden_params,
)
else:
# Handle regular invoke responses
if model == "amazon.titan-embed-image-v1":
returned_response = (
AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
response_list=response_list, model=model, batch_data=batch_data
)
)
elif model == "amazon.titan-embed-text-v1":
returned_response = AmazonTitanG1Config()._transform_response(
response_list=response_list, model=model
)
elif model == "amazon.titan-embed-text-v2:0":
returned_response = AmazonTitanV2Config()._transform_response(
response_list=response_list, model=model
)
elif provider == "twelvelabs":
returned_response = (
TwelveLabsMarengoEmbeddingConfig()._transform_response(
response_list=response_list, model=model
)
)
elif provider == "nova":
returned_response = AmazonNovaEmbeddingConfig()._transform_response(
response_list=response_list, model=model, batch_data=batch_data
)
##########################################################
# Validate returned response
##########################################################
if returned_response is None:
raise Exception(
"Unable to map model response to known provider format. model={}".format(
model
)
)
return returned_response
def _single_func_embeddings(
self,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
batch_data: List[dict],
credentials: Any,
extra_headers: Optional[dict],
endpoint_url: str,
aws_region_name: str,
model: str,
logging_obj: Any,
provider: BEDROCK_EMBEDDING_PROVIDERS_LITERAL,
api_key: Optional[str] = None,
is_async_invoke: Optional[bool] = False,
):
responses: List[dict] = []
for data in batch_data:
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
prepped = self.get_request_headers( # type: ignore # type: ignore
credentials=credentials,
aws_region_name=aws_region_name,
extra_headers=extra_headers,
endpoint_url=endpoint_url,
data=json.dumps(data),
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=data,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
headers_for_request = (
dict(prepped.headers) if hasattr(prepped, "headers") else {}
)
response = self._make_sync_call(
client=client,
timeout=timeout,
api_base=prepped.url,
headers=headers_for_request,
data=data,
)
## LOGGING
logging_obj.post_call(
input=data,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
responses.append(response)
return self._transform_response(
response_list=responses,
model=model,
provider=provider,
is_async_invoke=is_async_invoke,
batch_data=batch_data,
)
async def _async_single_func_embeddings(
self,
client: Optional[AsyncHTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
batch_data: List[dict],
credentials: Any,
extra_headers: Optional[dict],
endpoint_url: str,
aws_region_name: str,
model: str,
logging_obj: Any,
provider: BEDROCK_EMBEDDING_PROVIDERS_LITERAL,
api_key: Optional[str] = None,
is_async_invoke: Optional[bool] = False,
):
responses: List[dict] = []
for data in batch_data:
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
prepped = self.get_request_headers( # type: ignore # type: ignore
credentials=credentials,
aws_region_name=aws_region_name,
extra_headers=extra_headers,
endpoint_url=endpoint_url,
data=json.dumps(data),
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=data,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
# Convert CaseInsensitiveDict to regular dict for httpx compatibility
# This ensures custom headers are properly forwarded, especially with IAM roles and custom api_base
headers_for_request = (
dict(prepped.headers) if hasattr(prepped, "headers") else {}
)
response = await self._make_async_call(
client=client,
timeout=timeout,
api_base=prepped.url,
headers=headers_for_request,
data=data,
)
## LOGGING
logging_obj.post_call(
input=data,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
responses.append(response)
## TRANSFORM RESPONSE ##
return self._transform_response(
response_list=responses,
model=model,
provider=provider,
is_async_invoke=is_async_invoke,
batch_data=batch_data,
)
def embeddings( # noqa: PLR0915
self,
model: str,
input: List[str],
api_base: Optional[str],
model_response: EmbeddingResponse,
print_verbose: Callable,
encoding,
logging_obj,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
timeout: Optional[Union[float, httpx.Timeout]],
aembedding: Optional[bool],
extra_headers: Optional[dict],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
) -> EmbeddingResponse:
credentials, aws_region_name = self._load_credentials(optional_params)
### TRANSFORMATION ###
unencoded_model_id = (
optional_params.pop("model_id", None) or model
) # default to model if not passed
modelId = urllib.parse.quote(unencoded_model_id, safe="")
aws_region_name = self._get_aws_region_name(
optional_params={"aws_region_name": aws_region_name},
model=model,
model_id=unencoded_model_id,
)
# Check async invoke needs to be used
has_async_invoke = "async_invoke/" in model
if has_async_invoke:
model = model.replace("async_invoke/", "", 1)
provider = self.get_bedrock_embedding_provider(model)
if provider is None:
raise Exception(
f"Unable to determine bedrock embedding provider for model: {model}. "
f"Supported providers: {list(get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL))}"
)
inference_params = copy.deepcopy(optional_params)
inference_params = {
k: v
for k, v in inference_params.items()
if k.lower() not in self.aws_authentication_params
}
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data: Optional[CohereEmbeddingRequest] = None
batch_data: Optional[List] = None
if provider == "cohere":
data = BedrockCohereEmbeddingConfig()._transform_request(
model=model, input=input, inference_params=inference_params
)
elif provider == "amazon" and model in [
"amazon.titan-embed-image-v1",
"amazon.titan-embed-text-v1",
"amazon.titan-embed-text-v2:0",
]:
batch_data = []
for i in input:
if model == "amazon.titan-embed-image-v1":
transformed_request: (
AmazonEmbeddingRequest
) = AmazonTitanMultimodalEmbeddingG1Config()._transform_request(
input=i, inference_params=inference_params
)
elif model == "amazon.titan-embed-text-v1":
transformed_request = AmazonTitanG1Config()._transform_request(
input=i, inference_params=inference_params
)
elif model == "amazon.titan-embed-text-v2:0":
transformed_request = AmazonTitanV2Config()._transform_request(
input=i, inference_params=inference_params
)
else:
raise Exception(
"Unmapped model. Received={}. Expected={}".format(
model,
[
"amazon.titan-embed-image-v1",
"amazon.titan-embed-text-v1",
"amazon.titan-embed-text-v2:0",
],
)
)
batch_data.append(transformed_request)
elif provider == "twelvelabs":
batch_data = []
for i in input:
twelvelabs_request = (
TwelveLabsMarengoEmbeddingConfig()._transform_request(
input=i,
inference_params=inference_params,
async_invoke_route=has_async_invoke,
model_id=modelId,
output_s3_uri=inference_params.get("output_s3_uri"),
)
)
batch_data.append(twelvelabs_request)
elif provider == "nova":
batch_data = []
for i in input:
nova_request = AmazonNovaEmbeddingConfig()._transform_request(
input=i,
inference_params=inference_params,
async_invoke_route=has_async_invoke,
model_id=modelId,
output_s3_uri=inference_params.get("output_s3_uri"),
)
batch_data.append(nova_request)
### SET RUNTIME ENDPOINT ###
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=optional_params.pop(
"aws_bedrock_runtime_endpoint", None
),
aws_region_name=aws_region_name,
)
if has_async_invoke:
endpoint_url = f"{endpoint_url}/async-invoke"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
if batch_data is not None:
if aembedding:
return self._async_single_func_embeddings( # type: ignore
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
timeout=timeout,
batch_data=batch_data,
credentials=credentials,
extra_headers=extra_headers,
endpoint_url=endpoint_url,
aws_region_name=aws_region_name,
model=model,
logging_obj=logging_obj,
api_key=api_key,
provider=provider,
is_async_invoke=has_async_invoke,
)
returned_response = self._single_func_embeddings(
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
timeout=timeout,
batch_data=batch_data,
credentials=credentials,
extra_headers=extra_headers,
endpoint_url=endpoint_url,
aws_region_name=aws_region_name,
model=model,
logging_obj=logging_obj,
api_key=api_key,
provider=provider,
is_async_invoke=has_async_invoke,
)
if returned_response is None:
raise Exception("Unable to map Bedrock request to provider")
return returned_response
elif data is None:
raise Exception("Unable to map Bedrock request to provider")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
prepped = self.get_request_headers( # type: ignore
credentials=credentials,
aws_region_name=aws_region_name,
extra_headers=extra_headers,
endpoint_url=endpoint_url,
data=json.dumps(data),
headers=headers,
api_key=api_key,
)
## ROUTING ##
# Convert CaseInsensitiveDict to regular dict for httpx compatibility
headers_for_request = (
dict(prepped.headers) if hasattr(prepped, "headers") else {}
)
return cohere_embedding(
model=model,
input=input,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
encoding=encoding,
data=data, # type: ignore
complete_api_base=prepped.url,
api_key=None,
aembedding=aembedding,
timeout=timeout,
client=client,
headers=headers_for_request,
)
async def _get_async_invoke_status(
self, invocation_arn: str, aws_region_name: str, logging_obj=None, **kwargs
) -> dict:
"""
Get the status of an async invoke job using the GetAsyncInvoke operation.
Args:
invocation_arn: The invocation ARN from the async invoke response
aws_region_name: AWS region name
**kwargs: Additional parameters (credentials, etc.)
Returns:
dict: Status response from AWS Bedrock
"""
# Get AWS credentials using the same method as other Bedrock methods
credentials, _ = self._load_credentials(kwargs)
# Get the runtime endpoint
endpoint_url, _ = self.get_runtime_endpoint(
api_base=None,
aws_bedrock_runtime_endpoint=kwargs.get("aws_bedrock_runtime_endpoint"),
aws_region_name=aws_region_name,
)
from urllib.parse import quote
# Encode the ARN for use in URL path
encoded_arn = quote(invocation_arn, safe="")
status_url = f"{endpoint_url.rstrip('/')}/async-invoke/{encoded_arn}"
# Prepare headers for GET request
headers = {"Content-Type": "application/json"}
# Use AWSRequest directly for GET requests (get_request_headers hardcodes POST)
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
# Create AWSRequest with GET method and encoded URL
request = AWSRequest(
method="GET",
url=status_url,
data=None, # GET request, no body
headers=headers,
)
# Sign the request - SigV4Auth will create canonical string from request URL
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
sigv4.add_auth(request)
# Prepare the request
prepped = request.prepare()
# LOGGING
if logging_obj is not None:
# Create custom curl command for GET request
masked_headers = logging_obj._get_masked_headers(prepped.headers)
formatted_headers = " ".join(
[f"-H '{k}: {v}'" for k, v in masked_headers.items()]
)
custom_curl = "\n\nGET Request Sent from LiteLLM:\n"
custom_curl += "curl -X GET \\\n"
custom_curl += f"{prepped.url} \\\n"
custom_curl += f"{formatted_headers}\n"
logging_obj.pre_call(
input=invocation_arn,
api_key="",
additional_args={
"complete_input_dict": {"invocation_arn": invocation_arn},
"api_base": prepped.url,
"headers": prepped.headers,
"request_str": custom_curl, # Override with custom GET curl command
},
)
# Make the GET request
client = get_async_httpx_client(llm_provider=LlmProviders.BEDROCK)
response = await client.get(
url=prepped.url,
headers=prepped.headers,
)
# LOGGING
if logging_obj is not None:
logging_obj.post_call(
input=invocation_arn,
api_key="",
original_response=response,
additional_args={
"complete_input_dict": {"invocation_arn": invocation_arn}
},
)
# Parse response
if response.status_code == 200:
return response.json()
else:
raise Exception(
f"Failed to get async invoke status: {response.status_code} - {response.text}"
)

View File

@@ -0,0 +1,304 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock TwelveLabs Marengo /invoke and /async-invoke format.
Why separate file? Make it easy to see how transformation works
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-marengo.html
"""
from typing import List, Optional, Union, cast
from litellm.types.llms.bedrock import (
TWELVELABS_EMBEDDING_INPUT_TYPES,
TwelveLabsAsyncInvokeRequest,
TwelveLabsMarengoEmbeddingRequest,
TwelveLabsOutputDataConfig,
TwelveLabsS3Location,
TwelveLabsS3OutputDataConfig,
)
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
class TwelveLabsMarengoEmbeddingConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-marengo.html
Supports text, image, video, and audio inputs.
- InvokeModel: text and image inputs
- StartAsyncInvoke: video, audio, image, and text inputs
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self) -> List[str]:
return [
"encoding_format",
"textTruncate",
"embeddingOption",
"startSec",
"lengthSec",
"useFixedLengthSec",
"minClipSec",
"input_type",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for k, v in non_default_params.items():
if k == "encoding_format":
# TwelveLabs doesn't have encoding_format, but we can map it to embeddingOption
if v == "float":
optional_params["embeddingOption"] = ["visual-text", "visual-image"]
elif k == "textTruncate":
optional_params["textTruncate"] = v
elif k == "embeddingOption":
optional_params["embeddingOption"] = v
elif k == "input_type":
# Map input_type to inputType for Bedrock
optional_params["inputType"] = v
elif k in ["startSec", "lengthSec", "useFixedLengthSec", "minClipSec"]:
optional_params[k] = v
return optional_params
def _extract_bucket_owner_from_params(self, inference_params: dict) -> str:
"""
Extract bucket owner from inference parameters.
"""
return inference_params.get("bucketOwner", "")
def _is_s3_url(self, input: str) -> bool:
"""Check if input is an S3 URL."""
return input.startswith("s3://")
def _transform_request(
self,
input: str,
inference_params: dict,
async_invoke_route: bool = False,
model_id: Optional[str] = None,
output_s3_uri: Optional[str] = None,
) -> Union[TwelveLabsMarengoEmbeddingRequest, TwelveLabsAsyncInvokeRequest]:
"""
Transform OpenAI-style input to TwelveLabs Marengo format/async-invoke format.
Supports:
- Text inputs (for both invoke and async-invoke)
- Image inputs (for both invoke and async-invoke)
- Video inputs (async-invoke only)
- Audio inputs (async-invoke only)
- S3 URLs for all media types (async-invoke only)
"""
# Get input_type or default to "text"
input_type = cast(
TWELVELABS_EMBEDDING_INPUT_TYPES,
inference_params.get("inputType")
or inference_params.get("input_type")
or "text",
)
# Validate that async-invoke is used for video/audio
if input_type in ["video", "audio"] and not async_invoke_route:
raise ValueError(
f"Input type '{input_type}' requires async_invoke route. "
f"Use model format: 'bedrock/async_invoke/model_id'"
)
transformed_request: TwelveLabsMarengoEmbeddingRequest = {
"inputType": input_type
}
if input_type == "text":
transformed_request["inputText"] = input
# Set default textTruncate if not specified
if "textTruncate" not in inference_params:
transformed_request["textTruncate"] = "end"
elif input_type in ["image", "video", "audio"]:
if self._is_s3_url(input):
# S3 URL input
s3_location: TwelveLabsS3Location = {"uri": input}
bucket_owner = self._extract_bucket_owner_from_params(inference_params)
if bucket_owner:
s3_location["bucketOwner"] = bucket_owner
transformed_request["mediaSource"] = {"s3Location": s3_location}
else:
# Base64 encoded input
if input.startswith("data:"):
# Extract base64 data from data URL
b64_str = input.split(",", 1)[1] if "," in input else input
else:
# Direct base64 string
from litellm.utils import get_base64_str
b64_str = get_base64_str(input)
transformed_request["mediaSource"] = {"base64String": b64_str}
# Apply any additional inference parameters
for k, v in inference_params.items():
if k not in [
"inputType",
"input_type", # Exclude both camelCase and snake_case
"inputText",
"mediaSource",
"bucketOwner", # Don't include bucketOwner in the request
]: # Don't override core fields
transformed_request[k] = v # type: ignore
# If async invoke route, wrap in the async invoke format
if async_invoke_route and model_id:
return self._wrap_async_invoke_request(
model_input=transformed_request,
model_id=model_id,
output_s3_uri=output_s3_uri,
)
return transformed_request
def _wrap_async_invoke_request(
self,
model_input: TwelveLabsMarengoEmbeddingRequest,
model_id: str,
output_s3_uri: Optional[str] = None,
) -> TwelveLabsAsyncInvokeRequest:
"""
Wrap the transformed request in the correct AWS Bedrock async invoke format.
Args:
model_input: The transformed TwelveLabs Marengo embedding request
model_id: The model identifier (without async_invoke prefix)
output_s3_uri: Optional S3 URI for output data config
Returns:
TwelveLabsAsyncInvokeRequest: The wrapped async invoke request
"""
import urllib.parse
# Clean the model ID
unquoted_model_id = urllib.parse.unquote(model_id)
if unquoted_model_id.startswith("async_invoke/"):
unquoted_model_id = unquoted_model_id.replace("async_invoke/", "")
# Validate that the S3 URI is not empty
if not output_s3_uri or output_s3_uri.strip() == "":
raise ValueError("output_s3_uri cannot be empty for async invoke requests")
return TwelveLabsAsyncInvokeRequest(
modelId=unquoted_model_id,
modelInput=model_input,
outputDataConfig=TwelveLabsOutputDataConfig(
s3OutputDataConfig=TwelveLabsS3OutputDataConfig(s3Uri=output_s3_uri)
),
)
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
"""
Transform TwelveLabs response to OpenAI format.
Handles the actual TwelveLabs response format: {"data": [{"embedding": [...]}]}
"""
embeddings: List[Embedding] = []
total_tokens = 0
for response in response_list:
# TwelveLabs response format has a "data" field containing the embeddings
if "data" in response and isinstance(response["data"], list):
for item in response["data"]:
if "embedding" in item:
# Single embedding response
embedding = Embedding(
embedding=item["embedding"],
index=len(embeddings),
object="embedding",
)
embeddings.append(embedding)
# Estimate token count (rough approximation)
if "inputTextTokenCount" in item:
total_tokens += item["inputTextTokenCount"]
else:
# Rough estimate: 1 token per 4 characters for text, or use embedding size
total_tokens += len(item["embedding"]) // 4
elif "embedding" in response:
# Direct embedding response (fallback for other formats)
embedding = Embedding(
embedding=response["embedding"],
index=len(embeddings),
object="embedding",
)
embeddings.append(embedding)
# Estimate token count (rough approximation)
if "inputTextTokenCount" in response:
total_tokens += response["inputTextTokenCount"]
else:
# Rough estimate: 1 token per 4 characters for text
total_tokens += len(response.get("inputText", "")) // 4
elif "embeddings" in response:
# Multiple embeddings response (from video/audio)
for i, emb in enumerate(response["embeddings"]):
embedding = Embedding(
embedding=emb["embedding"],
index=len(embeddings),
object="embedding",
)
embeddings.append(embedding)
total_tokens += len(emb["embedding"]) // 4 # Rough estimate
usage = Usage(prompt_tokens=total_tokens, total_tokens=total_tokens)
return EmbeddingResponse(data=embeddings, model=model, usage=usage)
def _transform_async_invoke_response(
self, response: dict, model: str
) -> EmbeddingResponse:
"""
Transform async invoke response (invocation ARN) to OpenAI format.
AWS async invoke returns:
{
"invocationArn": "arn:aws:bedrock:us-east-1:123456789012:async-invoke/abc123"
}
We transform this to a job-like embedding response:
{
"object": "list",
"data": [
{
"object": "embedding_job_id:1234567890",
"embedding": [],
"index": 0
}
],
"model": "model",
"usage": {}
}
"""
invocation_arn = response.get("invocationArn", "")
# Create a placeholder embedding object for the job
embedding = Embedding(
embedding=[], # Empty embedding for async jobs
index=0,
object="embedding",
)
# Create usage object (empty for async jobs)
usage = Usage(prompt_tokens=0, total_tokens=0)
# Create hidden params with job ID
from litellm.types.llms.base import HiddenParams
hidden_params = HiddenParams()
setattr(hidden_params, "_invocation_arn", invocation_arn)
return EmbeddingResponse(
data=[embedding],
model=model,
usage=usage,
hidden_params=hidden_params,
)

View File

@@ -0,0 +1,210 @@
import asyncio
import base64
from typing import Any, Coroutine, Optional, Tuple, Union
import httpx
from litellm import LlmProviders
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import (
FileContentRequest,
HttpxBinaryResponseContent,
)
from litellm.types.utils import SpecialEnums
from ..base_aws_llm import BaseAWSLLM
class BedrockFilesHandler(BaseAWSLLM):
"""
Handles downloading files from S3 for Bedrock batch processing.
This implementation downloads files from S3 buckets where Bedrock
stores batch output files.
"""
def __init__(self):
super().__init__()
self.async_httpx_client = get_async_httpx_client(
llm_provider=LlmProviders.BEDROCK,
)
def _extract_s3_uri_from_file_id(self, file_id: str) -> str:
"""
Extract S3 URI from encoded file ID.
The file ID can be in two formats:
1. Base64-encoded unified file ID containing: llm_output_file_id,s3://bucket/path
2. Direct S3 URI: s3://bucket/path
Args:
file_id: Encoded file ID or direct S3 URI
Returns:
S3 URI (e.g., "s3://bucket-name/path/to/file")
"""
# First, try to decode if it's a base64-encoded unified file ID
try:
# Add padding if needed
padded = file_id + "=" * (-len(file_id) % 4)
decoded = base64.urlsafe_b64decode(padded).decode()
# Check if it's a unified file ID format
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
# Extract llm_output_file_id from the decoded string
if "llm_output_file_id," in decoded:
s3_uri = decoded.split("llm_output_file_id,")[1].split(";")[0]
return s3_uri
except Exception:
pass
# If not base64 encoded or doesn't contain llm_output_file_id, assume it's already an S3 URI
if file_id.startswith("s3://"):
return file_id
# If it doesn't start with s3://, assume it's a direct S3 URI and add the prefix
return f"s3://{file_id}"
def _parse_s3_uri(self, s3_uri: str) -> Tuple[str, str]:
"""
Parse S3 URI to extract bucket name and object key.
Args:
s3_uri: S3 URI (e.g., "s3://bucket-name/path/to/file")
Returns:
Tuple of (bucket_name, object_key)
"""
if not s3_uri.startswith("s3://"):
raise ValueError(
f"Invalid S3 URI format: {s3_uri}. Expected format: s3://bucket-name/path/to/file"
)
# Remove 's3://' prefix
path = s3_uri[5:]
if "/" in path:
bucket_name, object_key = path.split("/", 1)
else:
bucket_name = path
object_key = ""
return bucket_name, object_key
async def afile_content(
self,
file_content_request: FileContentRequest,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> HttpxBinaryResponseContent:
"""
Download file content from S3 bucket for Bedrock files.
Args:
file_content_request: Contains file_id (encoded or S3 URI)
optional_params: Optional parameters containing AWS credentials
timeout: Request timeout
max_retries: Max retry attempts
Returns:
HttpxBinaryResponseContent: Binary content wrapped in compatible response format
"""
import boto3
from botocore.credentials import Credentials
file_id = file_content_request.get("file_id")
if not file_id:
raise ValueError("file_id is required in file_content_request")
# Extract S3 URI from file ID
s3_uri = self._extract_s3_uri_from_file_id(file_id)
bucket_name, object_key = self._parse_s3_uri(s3_uri)
# Get AWS credentials
aws_region_name = self._get_aws_region_name(
optional_params=optional_params, model=""
)
credentials: Credentials = self.get_credentials(
aws_access_key_id=optional_params.get("aws_access_key_id"),
aws_secret_access_key=optional_params.get("aws_secret_access_key"),
aws_session_token=optional_params.get("aws_session_token"),
aws_region_name=aws_region_name,
aws_session_name=optional_params.get("aws_session_name"),
aws_profile_name=optional_params.get("aws_profile_name"),
aws_role_name=optional_params.get("aws_role_name"),
aws_web_identity_token=optional_params.get("aws_web_identity_token"),
aws_sts_endpoint=optional_params.get("aws_sts_endpoint"),
)
# Create S3 client
s3_client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
region_name=aws_region_name,
verify=self._get_ssl_verify(),
)
# Download file from S3
try:
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
file_content = response["Body"].read()
except Exception as e:
raise ValueError(
f"Failed to download file from S3: {s3_uri}. Error: {str(e)}"
)
# Create mock HTTP response
mock_response = httpx.Response(
status_code=200,
content=file_content,
headers={"content-type": "application/octet-stream"},
request=httpx.Request(method="GET", url=s3_uri),
)
return HttpxBinaryResponseContent(response=mock_response)
def file_content(
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: Optional[str],
optional_params: dict,
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
"""
Download file content from S3 bucket for Bedrock files.
Supports both sync and async operations.
Args:
_is_async: Whether to run asynchronously
file_content_request: Contains file_id (encoded or S3 URI)
api_base: API base (unused for S3 operations)
optional_params: Optional parameters containing AWS credentials
timeout: Request timeout
max_retries: Max retry attempts
Returns:
HttpxBinaryResponseContent or Coroutine: Binary content wrapped in compatible response format
"""
if _is_async:
return self.afile_content(
file_content_request=file_content_request,
optional_params=optional_params,
timeout=timeout,
max_retries=max_retries,
)
else:
return asyncio.run(
self.afile_content(
file_content_request=file_content_request,
optional_params=optional_params,
timeout=timeout,
max_retries=max_retries,
)
)

View File

@@ -0,0 +1,772 @@
import json
import os
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import httpx
from httpx import Headers, Response
from openai.types.file_deleted import FileDeleted
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.files.utils import FilesAPIUtils
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import (
AllMessageValues,
CreateFileRequest,
FileTypes,
HttpxBinaryResponseContent,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
PathLike,
)
from litellm.types.utils import ExtractedFileData, LlmProviders
from litellm.utils import get_llm_provider
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
"""
Config for Bedrock Files - handles S3 uploads for Bedrock batch processing
"""
def __init__(self):
self.jsonl_transformation = BedrockJsonlFilesTransformation()
super().__init__()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.BEDROCK
@property
def file_upload_http_method(self) -> str:
"""
Bedrock files are uploaded to S3, which requires PUT requests
"""
return "PUT"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
# No additional headers needed for S3 uploads - AWS credentials handled by BaseAWSLLM
return headers
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def _get_s3_object_name_from_batch_jsonl(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique S3 object name for the Bedrock batch processing job
named as: litellm-bedrock-files/{model}/{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
# Remove bedrock/ prefix if present
if _model.startswith("bedrock/"):
_model = _model[8:]
# Replace colons with hyphens for Bedrock S3 URI compliance
_model = _model.replace(":", "-")
object_name = f"litellm-bedrock-files-{_model}-{uuid.uuid4()}.jsonl"
return object_name
def get_object_name(
self, extracted_file_data: ExtractedFileData, purpose: str
) -> str:
"""
Get the object name for the request
"""
extracted_file_data_content = extracted_file_data.get("content")
if extracted_file_data_content is None:
raise ValueError("file content is required")
if purpose == "batch":
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
if len(openai_jsonl_content) > 0:
return self._get_s3_object_name_from_batch_jsonl(openai_jsonl_content)
## 2. If not jsonl, return the filename
filename = extracted_file_data.get("filename")
if filename:
return filename
## 3. If no file name, return timestamp
return str(int(time.time()))
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateFileRequest,
) -> str:
"""
Get the complete S3 URL for the file upload request
"""
bucket_name = litellm_params.get("s3_bucket_name") or os.getenv(
"AWS_S3_BUCKET_NAME"
)
if not bucket_name:
raise ValueError(
"S3 bucket_name is required. Set 's3_bucket_name' in litellm_params or AWS_S3_BUCKET_NAME env var"
)
aws_region_name = self._get_aws_region_name(optional_params, model)
file_data = data.get("file")
purpose = data.get("purpose")
if file_data is None:
raise ValueError("file is required")
if purpose is None:
raise ValueError("purpose is required")
extracted_file_data = extract_file_data(file_data)
object_name = self.get_object_name(extracted_file_data, purpose)
# S3 endpoint URL format
s3_endpoint_url = (
optional_params.get("s3_endpoint_url")
or f"https://s3.{aws_region_name}.amazonaws.com"
)
return f"{s3_endpoint_url}/{bucket_name}/{object_name}"
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
# Providers whose InvokeModel body uses the Converse API format
# (messages + inferenceConfig + image blocks). Nova is the primary
# example; add others here as they adopt the same schema.
CONVERSE_INVOKE_PROVIDERS = ("nova",)
def _map_openai_to_bedrock_params(
self,
openai_request_body: Dict[str, Any],
provider: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform OpenAI request body to Bedrock-compatible modelInput
parameters using existing transformation logic.
Routes to the correct per-provider transformation so that the
resulting dict matches the InvokeModel body that Bedrock expects
for batch inference.
"""
from litellm.types.utils import LlmProviders
_model = openai_request_body.get("model", "")
messages = openai_request_body.get("messages", [])
optional_params = {
k: v
for k, v in openai_request_body.items()
if k not in ["model", "messages"]
}
# --- Anthropic: use existing AmazonAnthropicClaudeConfig ---
if provider == LlmProviders.ANTHROPIC:
from litellm.llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation import (
AmazonAnthropicClaudeConfig,
)
config = AmazonAnthropicClaudeConfig()
mapped_params = config.map_openai_params(
non_default_params={},
optional_params=optional_params,
model=_model,
drop_params=False,
)
return config.transform_request(
model=_model,
messages=messages,
optional_params=mapped_params,
litellm_params={},
headers={},
)
# --- Converse API providers (e.g. Nova): use AmazonConverseConfig
# to correctly convert image_url blocks to Bedrock image format
# and wrap inference params inside inferenceConfig. ---
if provider in self.CONVERSE_INVOKE_PROVIDERS:
from litellm.llms.bedrock.chat.converse_transformation import (
AmazonConverseConfig,
)
converse_config = AmazonConverseConfig()
mapped_params = converse_config.map_openai_params(
non_default_params=optional_params,
optional_params={},
model=_model,
drop_params=False,
)
return converse_config.transform_request(
model=_model,
messages=messages,
optional_params=mapped_params,
litellm_params={},
headers={},
)
# --- All other providers: passthrough (OpenAI-compatible models
# like openai.gpt-oss-*, qwen, deepseek, etc.) ---
return {
"messages": messages,
**optional_params,
}
def _transform_openai_jsonl_content_to_bedrock_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Transforms OpenAI JSONL content to Bedrock batch format
Bedrock batch format: { "recordId": "alphanumeric string", "modelInput": {JSON body} }
Example:
{
"recordId": "CALL0000001",
"modelInput": {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}
]
}
}
"""
bedrock_jsonl_content = []
for idx, _openai_jsonl_content in enumerate(openai_jsonl_content):
# Extract the request body from OpenAI format
openai_body = _openai_jsonl_content.get("body", {})
model = openai_body.get("model", "")
try:
model, _, _, _ = get_llm_provider(
model=model,
custom_llm_provider=None,
)
except Exception as e:
verbose_logger.exception(
f"litellm.llms.bedrock.files.transformation.py::_transform_openai_jsonl_content_to_bedrock_jsonl_content() - Error inferring custom_llm_provider - {str(e)}"
)
# Determine provider from model name
provider = self.get_bedrock_invoke_provider(model)
# Transform to Bedrock modelInput format
model_input = self._map_openai_to_bedrock_params(
openai_request_body=openai_body, provider=provider
)
# Create Bedrock batch record
record_id = _openai_jsonl_content.get(
"custom_id", f"CALL{str(idx).zfill(7)}"
)
bedrock_record = {"recordId": record_id, "modelInput": model_input}
bedrock_jsonl_content.append(bedrock_record)
return bedrock_jsonl_content
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, dict]:
"""
Transform file request and return a pre-signed request for S3.
This keeps the HTTP handler clean by doing all the signing here.
"""
file_data = create_file_data.get("file")
if file_data is None:
raise ValueError("file is required")
extracted_file_data = extract_file_data(file_data)
extracted_file_data_content = extracted_file_data.get("content")
if extracted_file_data_content is None:
raise ValueError("file content is required")
# Get and transform the file content
if FilesAPIUtils.is_batch_jsonl_file(
create_file_data=create_file_data,
extracted_file_data=extracted_file_data,
):
## Transform JSONL content to Bedrock format
original_file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
openai_jsonl_content = [
json.loads(line)
for line in original_file_content.splitlines()
if line.strip()
]
bedrock_jsonl_content = (
self._transform_openai_jsonl_content_to_bedrock_jsonl_content(
openai_jsonl_content
)
)
file_content = "\n".join(json.dumps(item) for item in bedrock_jsonl_content)
elif isinstance(extracted_file_data_content, bytes):
file_content = extracted_file_data_content.decode("utf-8")
elif isinstance(extracted_file_data_content, str):
file_content = extracted_file_data_content
else:
raise ValueError("Unsupported file content type")
# Get the S3 URL for upload
api_base = self.get_complete_file_url(
api_base=None,
api_key=None,
model=model,
optional_params=optional_params,
litellm_params=litellm_params,
data=create_file_data,
)
# Sign the request and return a pre-signed request object
signed_headers, signed_body = self._sign_s3_request(
content=file_content,
api_base=api_base,
optional_params=optional_params,
)
litellm_params["upload_url"] = api_base
# Return a dict that tells the HTTP handler exactly what to do
return {
"method": "PUT",
"url": api_base,
"headers": signed_headers,
"data": signed_body or file_content,
}
def _sign_s3_request(
self,
content: str,
api_base: str,
optional_params: dict,
) -> Tuple[dict, str]:
"""
Sign S3 PUT request using the same proven logic as S3Logger.
Reuses the exact pattern from litellm/integrations/s3_v2.py
"""
try:
import hashlib
import requests
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
# Get AWS credentials using existing methods
aws_region_name = self._get_aws_region_name(
optional_params=optional_params, model=""
)
credentials = self.get_credentials(
aws_access_key_id=optional_params.get("aws_access_key_id"),
aws_secret_access_key=optional_params.get("aws_secret_access_key"),
aws_session_token=optional_params.get("aws_session_token"),
aws_region_name=aws_region_name,
aws_session_name=optional_params.get("aws_session_name"),
aws_profile_name=optional_params.get("aws_profile_name"),
aws_role_name=optional_params.get("aws_role_name"),
aws_web_identity_token=optional_params.get("aws_web_identity_token"),
aws_sts_endpoint=optional_params.get("aws_sts_endpoint"),
)
# Calculate SHA256 hash of the content (REQUIRED for S3)
content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
# Prepare headers with required S3 headers (same as s3_v2.py)
request_headers = {
"Content-Type": "application/json", # JSONL files are JSON content
"x-amz-content-sha256": content_hash, # REQUIRED by S3
"Content-Language": "en",
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
}
# Use requests.Request to prepare the request (same pattern as s3_v2.py)
req = requests.Request("PUT", api_base, data=content, headers=request_headers)
prepped = req.prepare()
# Sign the request with S3 service
aws_request = AWSRequest(
method=prepped.method,
url=prepped.url,
data=prepped.body,
headers=prepped.headers,
)
# Get region name for non-LLM API calls (same as s3_v2.py)
signing_region = self.get_aws_region_name_for_non_llm_api_calls(
aws_region_name=aws_region_name
)
SigV4Auth(credentials, "s3", signing_region).add_auth(aws_request)
# Return signed headers and body
signed_body = aws_request.body
if isinstance(signed_body, bytes):
signed_body = signed_body.decode("utf-8")
elif signed_body is None:
signed_body = content # Fallback to original content
return dict(aws_request.headers), signed_body
def _convert_https_url_to_s3_uri(self, https_url: str) -> tuple[str, str]:
"""
Convert HTTPS S3 URL to s3:// URI format.
Args:
https_url: HTTPS S3 URL (e.g., "https://s3.us-west-2.amazonaws.com/bucket/key")
Returns:
Tuple of (s3_uri, filename)
Example:
Input: "https://s3.us-west-2.amazonaws.com/litellm-proxy/file.jsonl"
Output: ("s3://litellm-proxy/file.jsonl", "file.jsonl")
"""
import re
# Match HTTPS S3 URL patterns
# Pattern 1: https://s3.region.amazonaws.com/bucket/key
# Pattern 2: https://bucket.s3.region.amazonaws.com/key
pattern1 = r"https://s3\.([^.]+)\.amazonaws\.com/([^/]+)/(.+)"
pattern2 = r"https://([^.]+)\.s3\.([^.]+)\.amazonaws\.com/(.+)"
match1 = re.match(pattern1, https_url)
match2 = re.match(pattern2, https_url)
if match1:
# Pattern: https://s3.region.amazonaws.com/bucket/key
region, bucket, key = match1.groups()
s3_uri = f"s3://{bucket}/{key}"
elif match2:
# Pattern: https://bucket.s3.region.amazonaws.com/key
bucket, region, key = match2.groups()
s3_uri = f"s3://{bucket}/{key}"
else:
# Fallback: try to extract bucket and key from URL path
from urllib.parse import urlparse
parsed = urlparse(https_url)
path_parts = parsed.path.lstrip("/").split("/", 1)
if len(path_parts) >= 2:
bucket, key = path_parts[0], path_parts[1]
s3_uri = f"s3://{bucket}/{key}"
else:
raise ValueError(f"Unable to parse S3 URL: {https_url}")
# Extract filename from key
filename = key.split("/")[-1] if "/" in key else key
return s3_uri, filename
def transform_create_file_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""
Transform S3 File upload response into OpenAI-style FileObject
"""
# For S3 uploads, we typically get an ETag and other metadata
response_headers = raw_response.headers
# Extract S3 object information from the response
# S3 PUT object returns ETag and other metadata in headers
content_length = response_headers.get("Content-Length", "0")
# Use the actual upload URL that was used for the S3 upload
upload_url = litellm_params.get("upload_url")
file_id: str = ""
filename: str = ""
if upload_url:
# Convert HTTPS S3 URL to s3:// URI format
file_id, filename = self._convert_https_url_to_s3_uri(upload_url)
return OpenAIFileObject(
purpose="batch", # Default purpose for Bedrock files
id=file_id,
filename=filename,
created_at=int(time.time()), # Current timestamp
status="uploaded",
bytes=int(content_length) if content_length.isdigit() else 0,
object="file",
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
return BedrockError(
status_code=status_code, message=error_message, headers=headers
)
def transform_retrieve_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
raise NotImplementedError("BedrockFilesConfig does not support file retrieval")
def transform_retrieve_file_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
raise NotImplementedError("BedrockFilesConfig does not support file retrieval")
def transform_delete_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
raise NotImplementedError("BedrockFilesConfig does not support file deletion")
def transform_delete_file_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> FileDeleted:
raise NotImplementedError("BedrockFilesConfig does not support file deletion")
def transform_list_files_request(
self,
purpose: Optional[str],
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
raise NotImplementedError("BedrockFilesConfig does not support file listing")
def transform_list_files_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> List[OpenAIFileObject]:
raise NotImplementedError("BedrockFilesConfig does not support file listing")
def transform_file_content_request(
self,
file_content_request,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
raise NotImplementedError(
"BedrockFilesConfig does not support file content retrieval"
)
def transform_file_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> HttpxBinaryResponseContent:
raise NotImplementedError(
"BedrockFilesConfig does not support file content retrieval"
)
class BedrockJsonlFilesTransformation:
"""
Transforms OpenAI /v1/files/* requests to Bedrock S3 file uploads for batch processing
"""
def transform_openai_file_content_to_bedrock_file_content(
self, openai_file_content: Optional[FileTypes] = None
) -> Tuple[str, str]:
"""
Transforms OpenAI FileContentRequest to Bedrock S3 file format
"""
if openai_file_content is None:
raise ValueError("contents of file are None")
# Read the content of the file
file_content = self._get_content_from_openai_file(openai_file_content)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
bedrock_jsonl_content = (
self._transform_openai_jsonl_content_to_bedrock_jsonl_content(
openai_jsonl_content
)
)
bedrock_jsonl_string = "\n".join(
json.dumps(item) for item in bedrock_jsonl_content
)
object_name = self._get_s3_object_name(
openai_jsonl_content=openai_jsonl_content
)
return bedrock_jsonl_string, object_name
def _transform_openai_jsonl_content_to_bedrock_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
):
"""
Delegate to the main BedrockFilesConfig transformation method
"""
config = BedrockFilesConfig()
return config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
openai_jsonl_content
)
def _get_s3_object_name(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique S3 object name for the Bedrock batch processing job
named as: litellm-bedrock-files-{model}-{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
# Remove bedrock/ prefix if present
if _model.startswith("bedrock/"):
_model = _model[8:]
object_name = f"litellm-bedrock-files-{_model}-{uuid.uuid4()}.jsonl"
return object_name
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def transform_s3_bucket_response_to_openai_file_object(
self, create_file_data: CreateFileRequest, s3_upload_response: Dict[str, Any]
) -> OpenAIFileObject:
"""
Transforms S3 Bucket upload file response to OpenAI FileObject
"""
# S3 response typically contains ETag, key, etc.
object_key = s3_upload_response.get("Key", "")
bucket_name = s3_upload_response.get("Bucket", "")
# Extract filename from object key
filename = object_key.split("/")[-1] if "/" in object_key else object_key
return OpenAIFileObject(
purpose=create_file_data.get("purpose", "batch"),
id=f"s3://{bucket_name}/{object_key}",
filename=filename,
created_at=int(time.time()), # Current timestamp
status="uploaded",
bytes=s3_upload_response.get("ContentLength", 0),
object="file",
)

View File

@@ -0,0 +1,9 @@
"""
Bedrock Image Edit Module
Handles image edit operations for Bedrock stability models.
"""
from .handler import BedrockImageEdit
__all__ = ["BedrockImageEdit"]

View File

@@ -0,0 +1,309 @@
"""
Bedrock Image Edit Handler
Handles image edit requests for Bedrock stability models.
"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Optional, Union
import httpx
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.bedrock.image_edit.stability_transformation import (
BedrockStabilityImageEditConfig,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import ImageResponse
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockImageEditPreparedRequest(BaseModel):
"""
Internal/Helper class for preparing the request for bedrock image edit
"""
endpoint_url: str
prepped: AWSPreparedRequest
body: bytes
data: dict
class BedrockImageEdit(BaseAWSLLM):
"""
Bedrock Image Edit handler
"""
@classmethod
def get_config_class(cls, model: str | None):
if BedrockStabilityImageEditConfig._is_stability_edit_model(model):
return BedrockStabilityImageEditConfig
else:
raise ValueError(f"Unsupported model for bedrock image edit: {model}")
def image_edit(
self,
model: str,
image: list,
prompt: Optional[str],
model_response: ImageResponse,
optional_params: dict,
logging_obj: LitellmLogging,
timeout: Optional[Union[float, httpx.Timeout]],
aimage_edit: bool = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
api_key: Optional[str] = None,
):
prepared_request = self._prepare_request(
model=model,
image=image,
prompt=prompt,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
logging_obj=logging_obj,
api_key=api_key,
)
if aimage_edit is True:
return self.async_image_edit(
prepared_request=prepared_request,
timeout=timeout,
model=model,
logging_obj=logging_obj,
prompt=prompt,
model_response=model_response,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try:
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model_response=model_response,
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
)
return model_response
async def async_image_edit(
self,
prepared_request: BedrockImageEditPreparedRequest,
timeout: Optional[Union[float, httpx.Timeout]],
model: str,
logging_obj: LitellmLogging,
prompt: Optional[str],
model_response: ImageResponse,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
"""
Asynchronous handler for bedrock image edit
"""
async_client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK,
params={"timeout": timeout},
)
try:
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
model_response=model_response,
)
return model_response
def _prepare_request(
self,
model: str,
image: list,
prompt: Optional[str],
optional_params: dict,
api_base: Optional[str],
extra_headers: Optional[dict],
logging_obj: LitellmLogging,
api_key: Optional[str],
) -> BedrockImageEditPreparedRequest:
"""
Prepare the request body, headers, and endpoint URL for the Bedrock Image Edit API
Args:
model (str): The model to use for the image edit
image (list): The images to edit
prompt (Optional[str]): The prompt for the edit
optional_params (dict): The optional parameters for the image edit
api_base (Optional[str]): The base URL for the Bedrock API
extra_headers (Optional[dict]): The extra headers to include in the request
logging_obj (LitellmLogging): The logging object to use for logging
api_key (Optional[str]): The API key to use
Returns:
BedrockImageEditPreparedRequest: The prepared request object
"""
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params, model
)
# Use the existing ARN-aware provider detection method
bedrock_provider = self.get_bedrock_invoke_provider(model)
### SET RUNTIME ENDPOINT ###
modelId = self.get_bedrock_model_id(
model=model,
provider=bedrock_provider,
optional_params=optional_params,
)
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
data = self._get_request_body(
model=model,
image=image,
prompt=prompt,
optional_params=optional_params,
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
prepped = self.get_request_headers(
credentials=boto3_credentials_info.credentials,
aws_region_name=boto3_credentials_info.aws_region_name,
extra_headers=extra_headers,
endpoint_url=proxy_endpoint_url,
data=body,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
return BedrockImageEditPreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)
def _get_request_body(
self,
model: str,
image: list,
prompt: Optional[str],
optional_params: dict,
) -> dict:
"""
Get the request body for the Bedrock Image Edit API
Checks the model/provider and transforms the request body accordingly
Returns:
dict: The request body to use for the Bedrock Image Edit API
"""
config_class = self.get_config_class(model=model)
config_instance = config_class()
request_body, _ = config_instance.transform_image_edit_request(
model=model,
prompt=prompt,
image=image[0] if image else None,
image_edit_optional_request_params=optional_params,
litellm_params={},
headers={},
)
return dict(request_body)
def _transform_response_dict_to_openai_response(
self,
model_response: ImageResponse,
model: str,
logging_obj: LitellmLogging,
prompt: Optional[str],
response: httpx.Response,
data: dict,
) -> ImageResponse:
"""
Transforms the Image Edit response from Bedrock to OpenAI format
"""
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
verbose_logger.debug("raw model_response: %s", response.text)
response_dict = response.json()
if response_dict is None:
raise ValueError("Error in response object format, got None")
config_class = self.get_config_class(model=model)
config_instance = config_class()
model_response = config_instance.transform_image_edit_response(
model=model,
raw_response=response,
logging_obj=logging_obj,
)
return model_response

View File

@@ -0,0 +1,399 @@
"""
Bedrock Stability AI Image Edit Transformation
Handles transformation between OpenAI-compatible format and Bedrock Stability AI Image Edit API format.
Supported models:
- stability.stable-conservative-upscale-v1:0
- stability.stable-creative-upscale-v1:0
- stability.stable-fast-upscale-v1:0
- stability.stable-outpaint-v1:0
- stability.stable-image-control-sketch-v1:0
- stability.stable-image-control-structure-v1:0
- stability.stable-image-erase-object-v1:0
- stability.stable-image-inpaint-v1:0
- stability.stable-image-remove-background-v1:0
- stability.stable-image-search-recolor-v1:0
- stability.stable-image-search-replace-v1:0
- stability.stable-image-style-guide-v1:0
- stability.stable-style-transfer-v1:0
API Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
"""
import base64
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import httpx
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.llms.stability import (
OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes, ImageObject, ImageResponse
from litellm.utils import get_model_info
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BedrockStabilityImageEditConfig(BaseImageEditConfig):
"""
Configuration for Bedrock Stability AI image edit.
Supports all Stability image edit operations through Bedrock.
"""
@classmethod
def _is_stability_edit_model(cls, model: Optional[str] = None) -> bool:
"""
Returns True if the model is a Bedrock Stability edit model.
Bedrock Stability edit models follow this pattern:
stability.stable-conservative-upscale-v1:0
stability.stable-creative-upscale-v1:0
stability.stable-fast-upscale-v1:0
stability.stable-outpaint-v1:0
stability.stable-image-inpaint-v1:0
stability.stable-image-erase-object-v1:0
etc.
"""
if model:
model_lower = model.lower()
if "stability." in model_lower and any(
[
"upscale" in model_lower,
"outpaint" in model_lower,
"inpaint" in model_lower,
"erase" in model_lower,
"remove-background" in model_lower,
"search-recolor" in model_lower,
"search-replace" in model_lower,
"control-sketch" in model_lower,
"control-structure" in model_lower,
"style-guide" in model_lower,
"style-transfer" in model_lower,
]
):
return True
return False
def get_supported_openai_params(self, model: str) -> list:
"""
Return list of OpenAI params supported by Bedrock Stability.
"""
return [
"n", # Number of images (Stability always returns 1, we can loop)
"size", # Maps to aspect_ratio
"response_format", # b64_json or url (Stability only returns b64)
"mask",
]
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""
Map OpenAI parameters to Bedrock Stability parameters.
OpenAI -> Stability mappings:
- size -> aspect_ratio
- n -> (handled separately, Stability returns 1 image per request)
"""
supported_params = self.get_supported_openai_params(model)
# Define mapping from OpenAI params to Stability params
param_mapping = {
"size": "aspect_ratio",
# "n" and "response_format" are handled separately
}
# Create a copy to not mutate original - convert TypedDict to regular dict
mapped_params: Dict[str, Any] = dict(image_edit_optional_params)
for k, v in image_edit_optional_params.items():
if k in param_mapping:
# Map param if mapping exists and value is valid
if k == "size" and v in OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO:
mapped_params[param_mapping[k]] = OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO[v] # type: ignore
# Don't copy "size" itself to final dict
elif k == "n":
# Store for logic but do not add to outgoing params
mapped_params["_n"] = v
elif k == "response_format":
# Only b64 supported at Stability; store for postprocessing
mapped_params["_response_format"] = v
elif k not in supported_params:
if not drop_params:
raise ValueError(
f"Parameter {k} is not supported for model {model}. "
f"Supported parameters are {supported_params}. "
f"Set drop_params=True to drop unsupported parameters."
)
# Otherwise, param will simply be dropped
else:
# param is supported and not mapped, keep as-is
continue
# Remove OpenAI params that have been mapped unless they're in stability
for mapped in ["size", "n", "response_format"]:
if mapped in mapped_params:
del mapped_params[mapped]
return mapped_params
def transform_image_edit_request( # noqa: PLR0915
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, Any]:
"""
Transform OpenAI-style request to Bedrock Stability request format.
Returns the request body dict that will be JSON-encoded by the handler.
"""
# Build Bedrock Stability request
data: Dict[str, Any] = {
"output_format": "png", # Default to PNG
}
# Add prompt only if provided (some models don't require it)
if prompt is not None and prompt != "":
data["prompt"] = prompt
# Convert image to base64 if provided
if image is not None:
image_b64: str
if hasattr(image, "read") and callable(getattr(image, "read", None)):
# File-like object (e.g., BufferedReader from open())
image_bytes = image.read() # type: ignore
image_b64 = base64.b64encode(image_bytes).decode("utf-8") # type: ignore
elif isinstance(image, bytes):
# Raw bytes
image_b64 = base64.b64encode(image).decode("utf-8")
elif isinstance(image, str):
# Already a base64 string
image_b64 = image
else:
# Try to handle as bytes
image_b64 = base64.b64encode(bytes(image)).decode("utf-8") # type: ignore
# For style-transfer models, map image to init_image
model_lower = model.lower()
if "style-transfer" in model_lower:
data["init_image"] = image_b64
else:
data["image"] = image_b64
# Add optional params (already mapped in map_openai_params)
for key, value in image_edit_optional_request_params.items(): # type: ignore
# Skip internal params (prefixed with _)
if key.startswith("_") or value is None:
continue
# File-like optional params (mask, init_image, style_image, etc.)
if key in ["mask", "init_image", "style_image"]:
# Handle case where value might be in a list
file_value = value
if isinstance(value, list) and len(value) > 0:
file_value = value[0]
if hasattr(file_value, "read") and callable(
getattr(file_value, "read", None)
):
file_bytes = file_value.read() # type: ignore
elif isinstance(file_value, bytes):
file_bytes = file_value
elif isinstance(file_value, str):
# Already a base64 string
data[key] = file_value
continue
else:
file_bytes = file_value # type: ignore
if isinstance(file_bytes, bytes):
file_b64 = base64.b64encode(file_bytes).decode("utf-8")
else:
file_b64 = str(file_bytes)
data[key] = file_b64
continue
# Numeric fields that need to be converted to int/float
numeric_int_fields = ["left", "right", "up", "down", "seed"]
numeric_float_fields = [
"strength",
"creativity",
"control_strength",
"grow_mask",
"fidelity",
"composition_fidelity",
"style_strength",
"change_strength",
]
if key in numeric_int_fields:
# Convert to int (these are pixel values for outpaint)
try:
data[key] = int(value) # type: ignore
except (ValueError, TypeError):
data[key] = value # type: ignore
elif key in numeric_float_fields:
# Convert to float
try:
data[key] = float(value) # type: ignore
except (ValueError, TypeError):
data[key] = value # type: ignore
# Supported text fields
elif key in [
"negative_prompt",
"aspect_ratio",
"output_format",
"model",
"mode",
"style_preset",
"select_prompt",
"search_prompt",
]:
data[key] = value # type: ignore
return data, {}
def transform_image_edit_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
"""
Transform Bedrock Stability response to OpenAI-compatible ImageResponse.
Bedrock returns: {"images": ["base64..."], "finish_reasons": [null], "seeds": [123]}
OpenAI expects: {"data": [{"b64_json": "base64..."}], "created": timestamp}
"""
try:
response_data = raw_response.json()
except Exception as e:
raise self.get_error_class(
error_message=f"Error parsing Bedrock Stability response: {e}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
# Check for errors in response
if "errors" in response_data:
raise self.get_error_class(
error_message=f"Bedrock Stability error: {response_data['errors']}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
# Check finish_reasons
finish_reasons = response_data.get("finish_reasons", [])
if finish_reasons and finish_reasons[0]:
raise self.get_error_class(
error_message=f"Bedrock Stability error: {finish_reasons[0]}",
status_code=400,
headers=raw_response.headers,
)
model_response = ImageResponse()
if not model_response.data:
model_response.data = []
# Extract images from response
images = response_data.get("images", [])
if images:
for image_b64 in images:
if image_b64:
model_response.data.append(
ImageObject(
b64_json=image_b64,
url=None,
revised_prompt=None,
)
)
if not hasattr(model_response, "_hidden_params"):
model_response._hidden_params = {}
if "additional_headers" not in model_response._hidden_params:
model_response._hidden_params["additional_headers"] = {}
# Set cost based on model
model_info = get_model_info(model, custom_llm_provider="bedrock")
cost_per_image = model_info.get("output_cost_per_image", 0)
if cost_per_image is not None:
model_response._hidden_params["additional_headers"][
"llm_provider-x-litellm-response-cost"
] = float(cost_per_image)
return model_response
def use_multipart_form_data(self) -> bool:
"""
Bedrock Stability uses JSON format, not multipart/form-data.
"""
return False
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for the Bedrock Image Edit API.
For Bedrock, this is handled by the handler which constructs the endpoint URL
based on the model ID and AWS region. This method is required by the base class
but the actual URL construction happens in BedrockImageEdit.image_edit().
Returns a placeholder - the real endpoint is constructed in the handler.
"""
# Bedrock URLs are constructed in the handler using boto3
# This is a placeholder for the abstract method requirement
return "bedrock://image-edit"
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
"""
Validate environment for Bedrock Stability image edit.
For Bedrock, AWS credentials are managed by the BaseAWSLLM class.
This method validates that headers are properly set up.
Args:
headers: The request headers to validate/update
model: The model name being used
api_key: Optional API key (not used for Bedrock, which uses AWS credentials)
Returns:
Updated headers dict
"""
if headers is None:
headers = {}
# Bedrock uses AWS credentials, not API keys
# Headers are set up by the handler's get_request_headers() method
# This just ensures basic headers are present
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return headers

View File

@@ -0,0 +1,220 @@
import types
from typing import Any, Dict, List, Optional
from openai.types.image import Image
from litellm.types.llms.bedrock import (
AmazonNovaCanvasColorGuidedGenerationParams,
AmazonNovaCanvasColorGuidedRequest,
AmazonNovaCanvasImageGenerationConfig,
AmazonNovaCanvasInpaintingParams,
AmazonNovaCanvasInpaintingRequest,
AmazonNovaCanvasRequestBase,
AmazonNovaCanvasTextToImageParams,
AmazonNovaCanvasTextToImageRequest,
AmazonNovaCanvasTextToImageResponse,
)
from litellm.llms.bedrock.common_utils import get_cached_model_info
from litellm.types.utils import ImageResponse
class AmazonNovaCanvasConfig:
"""
Reference: https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/model-catalog/serverless/amazon.nova-canvas-v1:0
"""
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
""" """
return ["n", "size", "quality"]
@classmethod
def _is_nova_model(cls, model: Optional[str] = None) -> bool:
"""
Returns True if the model is a Nova Canvas model
Nova models follow this pattern:
"""
if model and "amazon.nova-canvas" in model:
return True
return False
@classmethod
def transform_request_body(
cls, text: str, optional_params: dict
) -> AmazonNovaCanvasRequestBase:
"""
Transform the request body for Amazon Nova Canvas model
"""
task_type = optional_params.pop("taskType", "TEXT_IMAGE")
image_generation_config = optional_params.pop("imageGenerationConfig", {})
# Extract model_id parameter to prevent "extraneous key" error from Bedrock API
# Following the same pattern as chat completions and embeddings
unencoded_model_id = optional_params.pop("model_id", None) # noqa: F841
image_generation_config = {**image_generation_config, **optional_params}
if task_type == "TEXT_IMAGE":
text_to_image_params: Dict[str, Any] = image_generation_config.pop(
"textToImageParams", {}
)
text_to_image_params = {"text": text, **text_to_image_params}
try:
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
**text_to_image_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasTextToImageRequest(
textToImageParams=text_to_image_params_typed,
taskType=task_type,
imageGenerationConfig=image_generation_config_typed,
)
if task_type == "COLOR_GUIDED_GENERATION":
color_guided_generation_params: Dict[
str, Any
] = image_generation_config.pop("colorGuidedGenerationParams", {})
color_guided_generation_params = {
"text": text,
**color_guided_generation_params,
}
try:
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
**color_guided_generation_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasColorGuidedRequest(
taskType=task_type,
colorGuidedGenerationParams=color_guided_generation_params_typed,
imageGenerationConfig=image_generation_config_typed,
)
if task_type == "INPAINTING":
inpainting_params: Dict[str, Any] = image_generation_config.pop(
"inpaintingParams", {}
)
inpainting_params = {"text": text, **inpainting_params}
try:
inpainting_params_typed = AmazonNovaCanvasInpaintingParams(
**inpainting_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming inpainting params: {e}. Got params: {inpainting_params}, Expected params: {AmazonNovaCanvasInpaintingParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasInpaintingRequest(
taskType=task_type,
inpaintingParams=inpainting_params_typed,
imageGenerationConfig=image_generation_config_typed,
)
raise NotImplementedError(f"Task type {task_type} is not supported")
@classmethod
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
"""
Map the OpenAI params to the Bedrock params
"""
_size = non_default_params.get("size")
if _size is not None:
width, height = _size.split("x")
optional_params["width"], optional_params["height"] = int(width), int(
height
)
if non_default_params.get("n") is not None:
optional_params["numberOfImages"] = non_default_params.get("n")
if non_default_params.get("quality") is not None:
if non_default_params.get("quality") in ("hd", "premium"):
optional_params["quality"] = "premium"
if non_default_params.get("quality") == "standard":
optional_params["quality"] = "standard"
return optional_params
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
"""
Transform the response dict to the OpenAI response
"""
nova_response = AmazonNovaCanvasTextToImageResponse(**response_dict)
openai_images: List[Image] = []
for _img in nova_response.get("images", []):
openai_images.append(Image(b64_json=_img))
model_response.data = openai_images
return model_response
@classmethod
def cost_calculator(
cls,
model: str,
image_response: ImageResponse,
size: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> float:
get_model_info = get_cached_model_info()
model_info = get_model_info(
model=model,
custom_llm_provider="bedrock",
)
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
num_images: int = 0
if image_response.data:
num_images = len(image_response.data)
return output_cost_per_image * num_images

View File

@@ -0,0 +1,164 @@
import copy
import os
import types
from typing import List, Optional
from openai.types.image import Image
from litellm.llms.bedrock.common_utils import get_cached_model_info
from litellm.types.utils import ImageResponse
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
return ["size"]
@classmethod
def map_openai_params(
cls,
non_default_params: dict,
optional_params: dict,
):
_size = non_default_params.get("size")
if _size is not None:
width, height = _size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
return optional_params
@classmethod
def transform_request_body(
cls,
text: str,
optional_params: dict,
) -> dict:
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
prompt = text.replace(os.linesep, " ")
## LOAD CONFIG
config = cls.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
return {
"text_prompts": [{"text": prompt, "weight": 1}],
**inference_params,
}
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
image_list: List[Image] = []
for artifact in response_dict["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response
@classmethod
def cost_calculator(
cls,
model: str,
image_response: ImageResponse,
size: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> float:
optional_params = optional_params or {}
# see model_prices_and_context_window.json for details on how steps is used
# Reference pricing by steps for stability 1: https://aws.amazon.com/bedrock/pricing/
_steps = optional_params.get("steps", 50)
steps = "max-steps" if _steps > 50 else "50-steps"
# size is stored in model_prices_and_context_window.json as 1024-x-1024
# current size has 1024x1024
size = size or "1024-x-1024"
model = f"{size}/{steps}/{model}"
get_model_info = get_cached_model_info()
model_info = get_model_info(
model=model,
custom_llm_provider="bedrock",
)
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
num_images: int = 0
if image_response.data:
num_images = len(image_response.data)
return output_cost_per_image * num_images

View File

@@ -0,0 +1,128 @@
import types
from typing import List, Optional
from openai.types.image import Image
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.types.llms.bedrock import (
AmazonStability3TextToImageRequest,
AmazonStability3TextToImageResponse,
)
from litellm.llms.bedrock.common_utils import get_cached_model_info
from litellm.types.utils import ImageResponse
class AmazonStability3Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
"""
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
"""
No additional OpenAI params are mapped for stability 3
"""
return []
@classmethod
def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
"""
Returns True if the model is a Stability 3 model
Stability 3 models follow this pattern:
sd3-large
sd3-large-turbo
sd3-medium
sd3.5-large
sd3.5-large-turbo
Stability ultra models
stable-image-ultra-v1
"""
if model:
if "sd3" in model or "sd3.5" in model:
return True
if "stable-image" in model:
return True
return False
@classmethod
def transform_request_body(
cls, text: str, optional_params: dict
) -> AmazonStability3TextToImageRequest:
"""
Transform the request body for the Stability 3 models
"""
data = AmazonStability3TextToImageRequest(prompt=text, **optional_params)
return data
@classmethod
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
"""
Map the OpenAI params to the Bedrock params
No OpenAI params are mapped for Stability 3, so directly return the optional_params
"""
return optional_params
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
"""
Transform the response dict to the OpenAI response
"""
stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
finish_reasons = stability_3_response.get("finish_reasons", [])
finish_reasons = [reason for reason in finish_reasons if reason]
if len(finish_reasons) > 0:
raise BedrockError(status_code=400, message="; ".join(finish_reasons))
openai_images: List[Image] = []
for _img in stability_3_response.get("images", []):
openai_images.append(Image(b64_json=_img))
model_response.data = openai_images
return model_response
@classmethod
def cost_calculator(
cls,
model: str,
image_response: ImageResponse,
size: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> float:
get_model_info = get_cached_model_info()
model_info = get_model_info(
model=model,
custom_llm_provider="bedrock",
)
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
num_images: int = 0
if image_response.data:
num_images = len(image_response.data)
return output_cost_per_image * num_images

View File

@@ -0,0 +1,160 @@
"""
Transformation logic for Amazon Titan Image Generation.
"""
import types
from typing import List, Optional
from openai.types.image import Image
from litellm.utils import get_model_info
from litellm.types.llms.bedrock import (
AmazonNovaCanvasImageGenerationConfig,
AmazonTitanImageGenerationRequestBody,
AmazonTitanTextToImageParams,
)
from litellm.types.utils import ImageResponse
class AmazonTitanImageGenerationConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@classmethod
def _is_titan_model(cls, model: Optional[str] = None) -> bool:
"""
Returns True if the model is a Titan model
Titan models follow this pattern:
"""
if model and "amazon.titan" in model:
return True
return False
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
return ["size", "n", "quality"]
@classmethod
def map_openai_params(
cls,
non_default_params: dict,
optional_params: dict,
):
from typing import Any, Dict
image_generation_config: Dict[str, Any] = {}
for k, v in non_default_params.items():
if k == "size" and v is not None:
width, height = v.split("x")
image_generation_config["width"] = int(width)
image_generation_config["height"] = int(height)
elif k == "n" and v is not None:
image_generation_config["numberOfImages"] = v
elif (
k == "quality" and v is not None
): # 'auto', 'hd', 'standard', 'high', 'medium', 'low'
if v in ("hd", "premium", "high"):
image_generation_config["quality"] = "premium"
elif v in ("standard", "medium", "low"):
image_generation_config["quality"] = "standard"
if image_generation_config:
optional_params["imageGenerationConfig"] = image_generation_config
return optional_params
@classmethod
def transform_request_body(
cls,
text: str,
optional_params: dict,
) -> AmazonTitanImageGenerationRequestBody:
from typing import Any, Dict
image_generation_config = optional_params.pop("imageGenerationConfig", {})
negative_text = optional_params.pop("negativeText", None)
text_to_image_params: Dict[str, Any] = {"text": text}
if negative_text:
text_to_image_params["negativeText"] = negative_text
task_type = optional_params.pop("taskType", "TEXT_IMAGE")
user_specified_image_generation_config = optional_params.pop(
"imageGenerationConfig", {}
)
image_generation_config = {
**image_generation_config,
**user_specified_image_generation_config,
}
return AmazonTitanImageGenerationRequestBody(
taskType=task_type,
textToImageParams=AmazonTitanTextToImageParams(**text_to_image_params), # type: ignore
imageGenerationConfig=AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
),
)
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
image_list: List[Image] = []
for image in response_dict["images"]:
_image = Image(b64_json=image)
image_list.append(_image)
model_response.data = image_list
return model_response
@classmethod
def cost_calculator(
cls,
model: str,
image_response: ImageResponse,
size: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> float:
model_info = get_model_info(model=model)
output_cost_per_image = model_info.get("output_cost_per_image") or 0.0
if not image_response.data:
return 0.0
num_images = len(image_response.data)
return output_cost_per_image * num_images

View File

@@ -0,0 +1,24 @@
from typing import Optional
from litellm.llms.bedrock.image_generation.image_handler import BedrockImageGeneration
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: ImageResponse,
size: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> float:
"""
Bedrock image generation cost calculator
Handles both Stability 1 and Stability 3 models
"""
config_class = BedrockImageGeneration.get_config_class(model=model)
return config_class.cost_calculator(
model=model,
image_response=image_response,
size=size,
optional_params=optional_params,
)

View File

@@ -0,0 +1,333 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Optional, Union
import httpx
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.bedrock.image_generation.amazon_nova_canvas_transformation import (
AmazonNovaCanvasConfig,
)
from litellm.llms.bedrock.image_generation.amazon_stability1_transformation import (
AmazonStabilityConfig,
)
from litellm.llms.bedrock.image_generation.amazon_stability3_transformation import (
AmazonStability3Config,
)
from litellm.llms.bedrock.image_generation.amazon_titan_transformation import (
AmazonTitanImageGenerationConfig,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import ImageResponse
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockImagePreparedRequest(BaseModel):
"""
Internal/Helper class for preparing the request for bedrock image generation
"""
endpoint_url: str
prepped: AWSPreparedRequest
body: bytes
data: dict
BedrockImageConfigClass = Union[
type[AmazonTitanImageGenerationConfig],
type[AmazonNovaCanvasConfig],
type[AmazonStability3Config],
type[AmazonStabilityConfig],
]
class BedrockImageGeneration(BaseAWSLLM):
"""
Bedrock Image Generation handler
"""
@classmethod
def get_config_class(cls, model: str | None) -> BedrockImageConfigClass:
if AmazonTitanImageGenerationConfig._is_titan_model(model):
return AmazonTitanImageGenerationConfig
elif AmazonNovaCanvasConfig._is_nova_model(model):
return AmazonNovaCanvasConfig
elif AmazonStability3Config._is_stability_3_model(model):
return AmazonStability3Config
else:
return litellm.AmazonStabilityConfig
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: LitellmLogging,
timeout: Optional[Union[float, httpx.Timeout]],
aimg_generation: bool = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
api_key: Optional[str] = None,
):
prepared_request = self._prepare_request(
model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
logging_obj=logging_obj,
prompt=prompt,
api_key=api_key,
)
if aimg_generation is True:
return self.async_image_generation(
prepared_request=prepared_request,
timeout=timeout,
model=model,
logging_obj=logging_obj,
prompt=prompt,
model_response=model_response,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try:
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model_response=model_response,
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
)
return model_response
async def async_image_generation(
self,
prepared_request: BedrockImagePreparedRequest,
timeout: Optional[Union[float, httpx.Timeout]],
model: str,
logging_obj: LitellmLogging,
prompt: str,
model_response: ImageResponse,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
"""
Asynchronous handler for bedrock image generation
Awaits the response from the bedrock image generation endpoint
"""
async_client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK,
params={"timeout": timeout},
)
try:
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
model_response=model_response,
)
return model_response
def _extract_headers_from_optional_params(self, optional_params: dict) -> dict:
"""
Extract guardrail parameters from optional_params and convert them to headers.
"""
headers = {}
guardrail_identifier = optional_params.pop("guardrailIdentifier", None)
guardrail_version = optional_params.pop("guardrailVersion", None)
if guardrail_identifier is not None:
headers["x-amz-bedrock-guardrail-identifier"] = guardrail_identifier
if guardrail_version is not None:
headers["x-amz-bedrock-guardrail-version"] = guardrail_version
return headers
def _prepare_request(
self,
model: str,
optional_params: dict,
api_base: Optional[str],
extra_headers: Optional[dict],
logging_obj: LitellmLogging,
prompt: str,
api_key: Optional[str],
) -> BedrockImagePreparedRequest:
"""
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
Args:
model (str): The model to use for the image generation
optional_params (dict): The optional parameters for the image generation
api_base (Optional[str]): The base URL for the Bedrock API
extra_headers (Optional[dict]): The extra headers to include in the request
logging_obj (LitellmLogging): The logging object to use for logging
prompt (str): The prompt to use for the image generation
Returns:
BedrockImagePreparedRequest: The prepared request object
The BedrockImagePreparedRequest contains:
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
prepped (httpx.Request): The prepared request object
body (bytes): The request body
"""
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params, model
)
# Use the existing ARN-aware provider detection method
bedrock_provider = self.get_bedrock_invoke_provider(model)
### SET RUNTIME ENDPOINT ###
modelId = self.get_bedrock_model_id(
model=model,
provider=bedrock_provider,
optional_params=optional_params,
)
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
data = self._get_request_body(
model=model,
prompt=prompt,
optional_params=optional_params,
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
# Extract guardrail parameters and add them as headers
guardrail_headers = self._extract_headers_from_optional_params(optional_params)
headers.update(guardrail_headers)
prepped = self.get_request_headers(
credentials=boto3_credentials_info.credentials,
aws_region_name=boto3_credentials_info.aws_region_name,
extra_headers=extra_headers,
endpoint_url=proxy_endpoint_url,
data=body,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
return BedrockImagePreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)
def _get_request_body(
self,
model: str,
prompt: str,
optional_params: dict,
) -> dict:
"""
Get the request body for the Bedrock Image Generation API
Checks the model/provider and transforms the request body accordingly
Returns:
dict: The request body to use for the Bedrock Image Generation API
"""
config_class = self.get_config_class(model=model)
request_body = config_class.transform_request_body(
text=prompt, optional_params=optional_params
)
return dict(request_body)
def _transform_response_dict_to_openai_response(
self,
model_response: ImageResponse,
model: str,
logging_obj: LitellmLogging,
prompt: str,
response: httpx.Response,
data: dict,
) -> ImageResponse:
"""
Transforms the Image Generation response from Bedrock to OpenAI format
"""
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
verbose_logger.debug("raw model_response: %s", response.text)
response_dict = response.json()
if response_dict is None:
raise ValueError("Error in response object format, got None")
config_class = self.get_config_class(model=model)
config_class.transform_response_dict_to_openai_response(
model_response=model_response,
response_dict=response_dict,
)
return model_response

View File

@@ -0,0 +1,550 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
import httpx
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
AnthropicMessagesConfig,
)
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
)
from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import (
get_anthropic_beta_from_headers,
is_claude_4_5_on_bedrock,
remove_custom_field_from_tools,
)
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import ModelResponseStream
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaudeMessagesConfig(
AnthropicMessagesConfig,
AmazonInvokeConfig,
):
"""
Call Claude model family in the /v1/messages API spec
Supports anthropic_beta parameter for beta features.
"""
DEFAULT_BEDROCK_ANTHROPIC_API_VERSION = "bedrock-2023-05-31"
# Beta header patterns that are not supported by Bedrock Invoke API
# These will be filtered out to prevent 400 "invalid beta flag" errors
def __init__(self, **kwargs):
BaseAnthropicMessagesConfig.__init__(self, **kwargs)
AmazonInvokeConfig.__init__(self, **kwargs)
def validate_anthropic_messages_environment(
self,
headers: dict,
model: str,
messages: List[Any],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[dict, Optional[str]]:
return headers, api_base
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
return AmazonInvokeConfig.sign_request(
self=self,
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
api_key=api_key,
model=model,
stream=stream,
fake_stream=fake_stream,
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
return AmazonInvokeConfig.get_complete_url(
self=self,
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params,
stream=stream,
)
def _remove_ttl_from_cache_control(
self, anthropic_messages_request: Dict, model: Optional[str] = None
) -> None:
"""
Remove unsupported fields from cache_control for Bedrock.
Bedrock only supports `type` and `ttl` in cache_control. It does NOT support:
- `scope` (e.g., "global") - always removed
- `ttl` - removed for older models; Claude 4.5+ supports "5m" and "1h"
Processes both `system` and `messages` content blocks.
Args:
anthropic_messages_request: The request dictionary to modify in-place
model: The model name to check if it supports ttl
"""
is_claude_4_5 = False
if model:
is_claude_4_5 = self._is_claude_4_5_on_bedrock(model)
def _sanitize_cache_control(cache_control: dict) -> None:
if not isinstance(cache_control, dict):
return
# Bedrock doesn't support scope (e.g., "global" for cross-request caching)
cache_control.pop("scope", None)
# Remove ttl for models that don't support it
if "ttl" in cache_control:
ttl = cache_control["ttl"]
if is_claude_4_5 and ttl in ["5m", "1h"]:
return
cache_control.pop("ttl", None)
def _process_content_list(content: list) -> None:
for item in content:
if isinstance(item, dict) and "cache_control" in item:
_sanitize_cache_control(item["cache_control"])
# Process system (list of content blocks)
if "system" in anthropic_messages_request:
system = anthropic_messages_request["system"]
if isinstance(system, list):
_process_content_list(system)
# Process messages
if "messages" in anthropic_messages_request:
for message in anthropic_messages_request["messages"]:
if isinstance(message, dict) and "content" in message:
content = message["content"]
if isinstance(content, list):
_process_content_list(content)
def _supports_extended_thinking_on_bedrock(self, model: str) -> bool:
"""
Check if the model supports extended thinking beta headers on Bedrock.
On 3rd-party platforms (e.g., Amazon Bedrock), extended thinking is only
supported on: Claude Opus 4.5, Claude Opus 4.1, Opus 4, or Sonnet 4.
Ref: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking
Args:
model: The model name
Returns:
True if the model supports extended thinking on Bedrock
"""
model_lower = model.lower()
# Supported models on Bedrock for extended thinking
supported_patterns = [
"opus-4.5",
"opus_4.5",
"opus-4-5",
"opus_4_5", # Opus 4.5
"opus-4.1",
"opus_4.1",
"opus-4-1",
"opus_4_1", # Opus 4.1
"opus-4",
"opus_4", # Opus 4
"sonnet-4",
"sonnet_4", # Sonnet 4
"sonnet-4.6",
"sonnet_4.6",
"sonnet-4-6",
"sonnet_4_6",
"opus-4.6",
"opus_4.6",
"opus-4-6",
"opus_4_6",
]
return any(pattern in model_lower for pattern in supported_patterns)
def _is_claude_opus_4_5(self, model: str) -> bool:
"""
Check if the model is Claude Opus 4.5.
Args:
model: The model name
Returns:
True if the model is Claude Opus 4.5
"""
model_lower = model.lower()
opus_4_5_patterns = [
"opus-4.5",
"opus_4.5",
"opus-4-5",
"opus_4_5",
]
return any(pattern in model_lower for pattern in opus_4_5_patterns)
def _is_claude_4_5_on_bedrock(self, model: str) -> bool:
"""
Check if the model is Claude 4.5 on Bedrock.
Claude Sonnet 4.5, Haiku 4.5, and Opus 4.5 support 1-hour prompt caching.
Args:
model: The model name
Returns:
True if the model is Claude 4.5
"""
return is_claude_4_5_on_bedrock(model)
def _supports_tool_search_on_bedrock(self, model: str) -> bool:
"""
Check if the model supports tool search on Bedrock.
On Amazon Bedrock, server-side tool search is supported on Claude Opus 4.5
and Claude Sonnet 4.5 with the tool-search-tool-2025-10-19 beta header.
Ref: https://platform.claude.com/docs/en/agents-and-tools/tool-use/tool-search-tool
Args:
model: The model name
Returns:
True if the model supports tool search on Bedrock
"""
model_lower = model.lower()
# Supported models for tool search on Bedrock
supported_patterns = [
# Opus 4.5
"opus-4.5",
"opus_4.5",
"opus-4-5",
"opus_4_5",
# Sonnet 4.5
"sonnet-4.5",
"sonnet_4.5",
"sonnet-4-5",
"sonnet_4_5",
# Opus 4.6
"opus-4.6",
"opus_4.6",
"opus-4-6",
"opus_4_6",
# sonnet 4.6
"sonnet-4.6",
"sonnet_4.6",
"sonnet-4-6",
"sonnet_4_6",
]
return any(pattern in model_lower for pattern in supported_patterns)
def _get_tool_search_beta_header_for_bedrock(
self,
model: str,
tool_search_used: bool,
programmatic_tool_calling_used: bool,
input_examples_used: bool,
beta_set: set,
) -> None:
"""
Adjust tool search beta header for Bedrock.
Bedrock requires a different beta header for tool search on Opus 4 models
when tool search is used without programmatic tool calling or input examples.
Note: On Amazon Bedrock, server-side tool search is only supported on Claude Opus 4
with the `tool-search-tool-2025-10-19` beta header.
Ref: https://platform.claude.com/docs/en/agents-and-tools/tool-use/tool-search-tool
Args:
model: The model name
tool_search_used: Whether tool search is used
programmatic_tool_calling_used: Whether programmatic tool calling is used
input_examples_used: Whether input examples are used
beta_set: The set of beta headers to modify in-place
"""
if tool_search_used and not (
programmatic_tool_calling_used or input_examples_used
):
beta_set.discard(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
if self._supports_tool_search_on_bedrock(model):
beta_set.add("tool-search-tool-2025-10-19")
def _convert_output_format_to_inline_schema(
self,
output_format: Dict,
anthropic_messages_request: Dict,
) -> None:
"""
Convert Anthropic output_format to inline schema in message content.
Bedrock Invoke doesn't support the output_format parameter, so we embed
the schema directly into the user message content as text instructions.
This approach adds the schema to the last user message, instructing the model
to respond in the specified JSON format.
Args:
output_format: The output_format dict with 'type' and 'schema'
anthropic_messages_request: The request dict to modify in-place
Ref: https://aws.amazon.com/blogs/machine-learning/structured-data-response-with-amazon-bedrock-prompt-engineering-and-tool-use/
"""
import json
# Extract schema from output_format
schema = output_format.get("schema")
if not schema:
return
# Get messages from the request
messages = anthropic_messages_request.get("messages", [])
if not messages:
return
# Find the last user message
last_user_message_idx = None
for idx in range(len(messages) - 1, -1, -1):
if messages[idx].get("role") == "user":
last_user_message_idx = idx
break
if last_user_message_idx is None:
return
last_user_message = messages[last_user_message_idx]
content = last_user_message.get("content", [])
# Ensure content is a list
if isinstance(content, str):
content = [{"type": "text", "text": content}]
last_user_message["content"] = content
# Add schema as text content to the message
schema_text = {"type": "text", "text": json.dumps(schema)}
content.append(schema_text)
def transform_anthropic_messages_request(
self,
model: str,
messages: List[Dict],
anthropic_messages_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
anthropic_messages_request = AnthropicMessagesConfig.transform_anthropic_messages_request(
self=self,
model=model,
messages=messages,
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
litellm_params=litellm_params,
headers=headers,
)
#########################################################
############## BEDROCK Invoke SPECIFIC TRANSFORMATION ###
#########################################################
# 1. anthropic_version is required for all claude models
if "anthropic_version" not in anthropic_messages_request:
anthropic_messages_request[
"anthropic_version"
] = self.DEFAULT_BEDROCK_ANTHROPIC_API_VERSION
# 2. `stream` is not allowed in request body for bedrock invoke
if "stream" in anthropic_messages_request:
anthropic_messages_request.pop("stream", None)
# 3. `model` is not allowed in request body for bedrock invoke
if "model" in anthropic_messages_request:
anthropic_messages_request.pop("model", None)
# 4. Remove `ttl` field from cache_control in messages (Bedrock doesn't support it for older models)
self._remove_ttl_from_cache_control(
anthropic_messages_request=anthropic_messages_request, model=model
)
# 5. Convert `output_format` to inline schema (Bedrock invoke doesn't support output_format)
output_format = anthropic_messages_request.pop("output_format", None)
if output_format:
self._convert_output_format_to_inline_schema(
output_format=output_format,
anthropic_messages_request=anthropic_messages_request,
)
# 5b. Strip `output_config` — Bedrock Invoke doesn't support it
# Fixes: https://github.com/BerriAI/litellm/issues/22797
anthropic_messages_request.pop("output_config", None)
# 5a. Remove `custom` field from tools (Bedrock doesn't support it)
# Claude Code sends `custom: {defer_loading: true}` on tool definitions,
# which causes Bedrock to reject the request with "Extra inputs are not permitted"
# Ref: https://github.com/BerriAI/litellm/issues/22847
remove_custom_field_from_tools(anthropic_messages_request)
# 6. AUTO-INJECT beta headers based on features used
anthropic_model_info = AnthropicModelInfo()
tools = anthropic_messages_optional_request_params.get("tools")
messages_typed = cast(List[AllMessageValues], messages)
tool_search_used = anthropic_model_info.is_tool_search_used(tools)
programmatic_tool_calling_used = (
anthropic_model_info.is_programmatic_tool_calling_used(tools)
)
input_examples_used = anthropic_model_info.is_input_examples_used(tools)
beta_set = set(get_anthropic_beta_from_headers(headers))
auto_betas = anthropic_model_info.get_anthropic_beta_list(
model=model,
optional_params=anthropic_messages_optional_request_params,
computer_tool_used=anthropic_model_info.is_computer_tool_used(tools),
prompt_caching_set=False,
file_id_used=anthropic_model_info.is_file_id_used(messages_typed),
mcp_server_used=anthropic_model_info.is_mcp_server_used(
anthropic_messages_optional_request_params.get("mcp_servers")
),
)
beta_set.update(auto_betas)
self._get_tool_search_beta_header_for_bedrock(
model=model,
tool_search_used=tool_search_used,
programmatic_tool_calling_used=programmatic_tool_calling_used,
input_examples_used=input_examples_used,
beta_set=beta_set,
)
if "tool-search-tool-2025-10-19" in beta_set:
beta_set.add("tool-examples-2025-10-29")
if beta_set:
anthropic_messages_request["anthropic_beta"] = list(beta_set)
return anthropic_messages_request
def get_async_streaming_response_iterator(
self,
model: str,
httpx_response: httpx.Response,
request_body: dict,
litellm_logging_obj: LiteLLMLoggingObj,
) -> AsyncIterator:
aws_decoder = AmazonAnthropicClaudeMessagesStreamDecoder(
model=model,
)
completion_stream = aws_decoder.aiter_bytes(
httpx_response.aiter_bytes(chunk_size=aws_decoder.DEFAULT_CHUNK_SIZE)
)
# Convert decoded Bedrock events to Server-Sent Events expected by Anthropic clients.
return self.bedrock_sse_wrapper(
completion_stream=completion_stream,
litellm_logging_obj=litellm_logging_obj,
request_body=request_body,
)
async def bedrock_sse_wrapper(
self,
completion_stream: AsyncIterator[
Union[bytes, GenericStreamingChunk, ModelResponseStream, dict]
],
litellm_logging_obj: LiteLLMLoggingObj,
request_body: dict,
):
"""
Bedrock invoke does not return SSE formatted data. This function is a wrapper to ensure litellm chunks are SSE formatted.
"""
from litellm.llms.anthropic.experimental_pass_through.messages.streaming_iterator import (
BaseAnthropicMessagesStreamingIterator,
)
handler = BaseAnthropicMessagesStreamingIterator(
litellm_logging_obj=litellm_logging_obj,
request_body=request_body,
)
async for chunk in handler.async_sse_wrapper(completion_stream):
yield chunk
class AmazonAnthropicClaudeMessagesStreamDecoder(AWSEventStreamDecoder):
def __init__(
self,
model: str,
) -> None:
"""
Iterator to return Bedrock invoke response in anthropic /messages format
"""
super().__init__(model=model)
self.DEFAULT_CHUNK_SIZE = 1024
def _chunk_parser(
self, chunk_data: dict
) -> Union[GChunk, ModelResponseStream, dict]:
"""
Parse the chunk data into anthropic /messages format
Bedrock returns usage metrics using camelCase keys. Convert these to
the Anthropic `/v1/messages` specification so callers receive a
consistent response shape when streaming.
"""
amazon_bedrock_invocation_metrics = chunk_data.pop(
"amazon-bedrock-invocationMetrics", {}
)
if amazon_bedrock_invocation_metrics:
anthropic_usage = {}
if "inputTokenCount" in amazon_bedrock_invocation_metrics:
anthropic_usage["input_tokens"] = amazon_bedrock_invocation_metrics[
"inputTokenCount"
]
if "outputTokenCount" in amazon_bedrock_invocation_metrics:
anthropic_usage["output_tokens"] = amazon_bedrock_invocation_metrics[
"outputTokenCount"
]
chunk_data["usage"] = anthropic_usage
return chunk_data

View File

@@ -0,0 +1,3 @@
# /v1/messages
This folder contains transformation logic for calling bedrock models in the Anthropic /v1/messages API spec.

View File

@@ -0,0 +1,249 @@
import json
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from httpx import Response
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockEventStreamDecoderBase, BedrockModelInfo
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import CostResponseTypes
if TYPE_CHECKING:
from httpx import URL
class BedrockPassthroughConfig(
BaseAWSLLM, BedrockModelInfo, BedrockEventStreamDecoderBase, BasePassthroughConfig
):
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
return "stream" in endpoint
def _encode_model_id_for_endpoint(self, model_id: str) -> str:
"""
Encode model_id (especially ARNs) for use in Bedrock endpoints.
ARNs contain special characters like colons and slashes that need to be
properly URL-encoded when used in HTTP request paths. For example:
arn:aws:bedrock:us-east-1:123:application-inference-profile/abc123
becomes:
arn:aws:bedrock:us-east-1:123:application-inference-profile%2Fabc123
Args:
model_id: The model ID or ARN to encode
Returns:
The encoded model_id suitable for use in endpoint URLs
"""
from litellm.passthrough.utils import CommonUtils
import re
# Create a temporary endpoint with the model_id to check if encoding is needed
temp_endpoint = f"/model/{model_id}/converse"
encoded_temp_endpoint = CommonUtils.encode_bedrock_runtime_modelid_arn(
temp_endpoint
)
# Extract the encoded model_id from the temporary endpoint
encoded_model_id_match = re.search(r"/model/([^/]+)/", encoded_temp_endpoint)
if encoded_model_id_match:
return encoded_model_id_match.group(1)
else:
# Fallback to original model_id if extraction fails
return model_id
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
endpoint: str,
request_query_params: Optional[dict],
litellm_params: dict,
) -> Tuple["URL", str]:
optional_params = litellm_params.copy()
model_id = optional_params.get("model_id", None)
aws_region_name = self._get_aws_region_name(
optional_params=optional_params,
model=model,
model_id=model_id,
)
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint"
)
endpoint_url, _ = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
endpoint_type="runtime",
)
# If model_id is provided (e.g., Application Inference Profile ARN), use it in the endpoint
# instead of the translated model name
if model_id is not None:
import re
# Encode the model_id if it's an ARN to properly handle special characters
encoded_model_id = self._encode_model_id_for_endpoint(model_id)
# Replace the model name in the endpoint with the encoded model_id
endpoint = re.sub(r"model/[^/]+/", f"model/{encoded_model_id}/", endpoint)
return (
self.format_url(endpoint, endpoint_url, request_query_params or {}),
endpoint_url,
)
def sign_request(
self,
headers: dict,
litellm_params: dict,
request_data: Optional[dict],
api_base: str,
model: Optional[str] = None,
) -> Tuple[dict, Optional[bytes]]:
optional_params = litellm_params.copy()
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data or {},
api_base=api_base,
model=model,
)
def logging_non_streaming_response(
self,
model: str,
custom_llm_provider: str,
httpx_response: Response,
request_data: dict,
logging_obj: Logging,
endpoint: str,
) -> Optional["CostResponseTypes"]:
from litellm import encoding
from litellm.types.utils import LlmProviders, ModelResponse
from litellm.utils import ProviderConfigManager
if "invoke" in endpoint:
chat_config_model = "invoke/" + model
elif "converse" in endpoint:
chat_config_model = "converse/" + model
else:
return None
provider_chat_config = ProviderConfigManager.get_provider_chat_config(
provider=LlmProviders(custom_llm_provider),
model=chat_config_model,
)
if provider_chat_config is None:
raise ValueError(f"No provider config found for model: {model}")
litellm_model_response: ModelResponse = provider_chat_config.transform_response(
model=model,
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
raw_response=httpx_response,
model_response=ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
request_data=request_data,
encoding=encoding,
)
return litellm_model_response
def _convert_raw_bytes_to_str_lines(self, raw_bytes: List[bytes]) -> List[str]:
from botocore.eventstream import EventStreamBuffer
all_chunks = []
event_stream_buffer = EventStreamBuffer()
for chunk in raw_bytes:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message is not None:
all_chunks.append(message)
return all_chunks
def handle_logging_collected_chunks(
self,
all_chunks: List[str],
litellm_logging_obj: "LiteLLMLoggingObj",
model: str,
custom_llm_provider: str,
endpoint: str,
) -> Optional["CostResponseTypes"]:
"""
1. Convert all_chunks to a ModelResponseStream
2. combine model_response_stream to model_response
3. Return the model_response
"""
from litellm.litellm_core_utils.streaming_handler import (
convert_generic_chunk_to_model_response_stream,
generic_chunk_has_all_required_fields,
)
from litellm.llms.bedrock.chat import get_bedrock_event_stream_decoder
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.main import stream_chunk_builder
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
all_translated_chunks = []
if "invoke" in endpoint:
invoke_provider = AmazonInvokeConfig.get_bedrock_invoke_provider(model)
if invoke_provider is None:
raise ValueError(
f"Invalid invoke provider: {invoke_provider}, for model: {model}"
)
obj = get_bedrock_event_stream_decoder(
invoke_provider=invoke_provider,
model=model,
sync_stream=True,
json_mode=False,
)
elif "converse" in endpoint:
obj = get_bedrock_event_stream_decoder(
invoke_provider=None,
model=model,
sync_stream=True,
json_mode=False,
)
else:
return None
for chunk in all_chunks:
message = json.loads(chunk)
translated_chunk = obj._chunk_parser(chunk_data=message)
if isinstance(
translated_chunk, dict
) and generic_chunk_has_all_required_fields(cast(dict, translated_chunk)):
chunk_obj = convert_generic_chunk_to_model_response_stream(
cast(GenericStreamingChunk, translated_chunk)
)
elif isinstance(translated_chunk, ModelResponseStream):
chunk_obj = translated_chunk
else:
continue
all_translated_chunks.append(chunk_obj)
if len(all_translated_chunks) > 0:
model_response = stream_chunk_builder(
chunks=all_translated_chunks,
logging_obj=litellm_logging_obj,
)
return model_response
return None

View File

@@ -0,0 +1,307 @@
"""
This file contains the handler for AWS Bedrock Nova Sonic realtime API.
This uses aws_sdk_bedrock_runtime for bidirectional streaming with Nova Sonic.
"""
import asyncio
import json
from typing import Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..base_aws_llm import BaseAWSLLM
from .transformation import BedrockRealtimeConfig
class BedrockRealtime(BaseAWSLLM):
"""Handler for Bedrock Nova Sonic realtime speech-to-speech API."""
def __init__(self):
super().__init__()
async def async_realtime(
self,
model: str,
websocket: Any,
logging_obj: LiteLLMLogging,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
timeout: Optional[float] = None,
aws_region_name: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str] = None,
aws_external_id: Optional[str] = None,
**kwargs,
):
"""
Establish bidirectional streaming connection with Bedrock Nova Sonic.
Args:
model: Model ID (e.g., 'amazon.nova-sonic-v1:0')
websocket: Client WebSocket connection
logging_obj: LiteLLM logging object
aws_region_name: AWS region
Various AWS authentication parameters
"""
try:
from aws_sdk_bedrock_runtime.client import (
BedrockRuntimeClient,
InvokeModelWithBidirectionalStreamOperationInput,
)
from aws_sdk_bedrock_runtime.config import Config
from smithy_aws_core.identity.environment import (
EnvironmentCredentialsResolver,
)
except ImportError:
raise ImportError(
"Missing aws_sdk_bedrock_runtime. Install with: pip install aws-sdk-bedrock-runtime"
)
# Get AWS region
if aws_region_name is None:
optional_params = {
"aws_region_name": aws_region_name,
}
aws_region_name = self._get_aws_region_name(optional_params, model)
# Get endpoint URL
if api_base is not None:
endpoint_uri = api_base
elif aws_bedrock_runtime_endpoint is not None:
endpoint_uri = aws_bedrock_runtime_endpoint
else:
endpoint_uri = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
verbose_proxy_logger.debug(
f"Bedrock Realtime: Connecting to {endpoint_uri} with model {model}"
)
# Initialize Bedrock client with aws_sdk_bedrock_runtime
config = Config(
endpoint_uri=endpoint_uri,
region=aws_region_name,
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
)
bedrock_client = BedrockRuntimeClient(config=config)
transformation_config = BedrockRealtimeConfig()
try:
# Initialize the bidirectional stream
bedrock_stream = (
await bedrock_client.invoke_model_with_bidirectional_stream(
InvokeModelWithBidirectionalStreamOperationInput(model_id=model)
)
)
verbose_proxy_logger.debug(
"Bedrock Realtime: Bidirectional stream established"
)
# Track state for transformation
session_state = {
"current_output_item_id": None,
"current_response_id": None,
"current_conversation_id": None,
"current_delta_chunks": None,
"current_item_chunks": None,
"current_delta_type": None,
"session_configuration_request": None,
}
# Create tasks for bidirectional forwarding
client_to_bedrock_task = asyncio.create_task(
self._forward_client_to_bedrock(
websocket,
bedrock_stream,
transformation_config,
model,
session_state,
)
)
bedrock_to_client_task = asyncio.create_task(
self._forward_bedrock_to_client(
bedrock_stream,
websocket,
transformation_config,
model,
logging_obj,
session_state,
)
)
# Wait for both tasks to complete
await asyncio.gather(
client_to_bedrock_task,
bedrock_to_client_task,
return_exceptions=True,
)
except Exception as e:
verbose_proxy_logger.exception(
f"Error in BedrockRealtime.async_realtime: {e}"
)
try:
await websocket.close(code=1011, reason=f"Internal error: {str(e)}")
except Exception:
pass
raise
async def _forward_client_to_bedrock(
self,
client_ws: Any,
bedrock_stream: Any,
transformation_config: BedrockRealtimeConfig,
model: str,
session_state: dict,
):
"""Forward messages from client WebSocket to Bedrock stream."""
try:
from aws_sdk_bedrock_runtime.models import (
BidirectionalInputPayloadPart,
InvokeModelWithBidirectionalStreamInputChunk,
)
while True:
# Receive message from client
message = await client_ws.receive_text()
verbose_proxy_logger.debug(
f"Bedrock Realtime: Received from client: {message[:200]}"
)
# Transform OpenAI format to Bedrock format
transformed_messages = transformation_config.transform_realtime_request(
message=message,
model=model,
session_configuration_request=session_state.get(
"session_configuration_request"
),
)
# Send transformed messages to Bedrock
for bedrock_message in transformed_messages:
event = InvokeModelWithBidirectionalStreamInputChunk(
value=BidirectionalInputPayloadPart(
bytes_=bedrock_message.encode("utf-8")
)
)
await bedrock_stream.input_stream.send(event)
verbose_proxy_logger.debug(
f"Bedrock Realtime: Sent to Bedrock: {bedrock_message[:200]}"
)
except Exception as e:
verbose_proxy_logger.debug(
f"Client to Bedrock forwarding ended: {e}", exc_info=True
)
# Close the Bedrock stream input
try:
await bedrock_stream.input_stream.close()
except Exception:
pass
async def _forward_bedrock_to_client(
self,
bedrock_stream: Any,
client_ws: Any,
transformation_config: BedrockRealtimeConfig,
model: str,
logging_obj: LiteLLMLogging,
session_state: dict,
):
"""Forward messages from Bedrock stream to client WebSocket."""
try:
while True:
# Receive from Bedrock
output = await bedrock_stream.await_output()
result = await output[1].receive()
if result.value and result.value.bytes_:
bedrock_response = result.value.bytes_.decode("utf-8")
verbose_proxy_logger.debug(
f"Bedrock Realtime: Received from Bedrock: {bedrock_response[:200]}"
)
# Transform Bedrock format to OpenAI format
from litellm.types.realtime import RealtimeResponseTransformInput
realtime_response_transform_input: RealtimeResponseTransformInput = {
"current_output_item_id": session_state.get(
"current_output_item_id"
),
"current_response_id": session_state.get("current_response_id"),
"current_conversation_id": session_state.get(
"current_conversation_id"
),
"current_delta_chunks": session_state.get(
"current_delta_chunks"
),
"current_item_chunks": session_state.get("current_item_chunks"),
"current_delta_type": session_state.get("current_delta_type"),
"session_configuration_request": session_state.get(
"session_configuration_request"
),
}
transformed_response = transformation_config.transform_realtime_response(
message=bedrock_response,
model=model,
logging_obj=logging_obj,
realtime_response_transform_input=realtime_response_transform_input,
)
# Update session state
session_state.update(
{
"current_output_item_id": transformed_response.get(
"current_output_item_id"
),
"current_response_id": transformed_response.get(
"current_response_id"
),
"current_conversation_id": transformed_response.get(
"current_conversation_id"
),
"current_delta_chunks": transformed_response.get(
"current_delta_chunks"
),
"current_item_chunks": transformed_response.get(
"current_item_chunks"
),
"current_delta_type": transformed_response.get(
"current_delta_type"
),
"session_configuration_request": transformed_response.get(
"session_configuration_request"
),
}
)
# Send transformed messages to client
openai_messages = transformed_response.get("response", [])
for openai_message in openai_messages:
message_json = json.dumps(openai_message)
await client_ws.send_text(message_json)
verbose_proxy_logger.debug(
f"Bedrock Realtime: Sent to client: {message_json[:200]}"
)
except Exception as e:
verbose_proxy_logger.debug(
f"Bedrock to client forwarding ended: {e}", exc_info=True
)
# Close the client WebSocket
try:
await client_ws.close()
except Exception:
pass

View File

@@ -0,0 +1,179 @@
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import BedrockPreparedRequest
from litellm.types.rerank import RerankRequest
from litellm.types.utils import RerankResponse
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
from .transformation import BedrockRerankConfig
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockRerankHandler(BaseAWSLLM):
async def arerank(
self,
prepared_request: BedrockPreparedRequest,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
):
if client is None:
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
try:
response = await client.post(
url=prepared_request["endpoint_url"],
headers=dict(prepared_request["prepped"].headers),
data=prepared_request["body"],
timeout=timeout,
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return BedrockRerankConfig()._transform_response(response.json())
def rerank(
self,
model: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
optional_params: dict,
logging_obj: LitellmLogging,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
timeout: Optional[Union[float, httpx.Timeout]] = None,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
request_data = RerankRequest(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
)
data = BedrockRerankConfig()._transform_request(request_data)
prepared_request = self._prepare_request(
model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
data=cast(dict, data),
)
logging_obj.pre_call(
input=data,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepared_request["endpoint_url"],
"headers": dict(prepared_request["prepped"].headers),
},
)
if _is_async:
return self.arerank(prepared_request, timeout=timeout, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try:
response = client.post(
url=prepared_request["endpoint_url"],
headers=dict(prepared_request["prepped"].headers),
data=prepared_request["body"],
timeout=timeout,
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
logging_obj.post_call(
original_response=response.text,
api_key="",
)
response_json = response.json()
return BedrockRerankConfig()._transform_response(response_json)
def _prepare_request(
self,
model: str,
api_base: Optional[str],
extra_headers: Optional[dict],
data: dict,
optional_params: dict,
) -> BedrockPreparedRequest:
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params, model
)
### SET RUNTIME ENDPOINT ###
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = proxy_endpoint_url.replace(
"bedrock-runtime", "bedrock-agent-runtime"
)
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
sigv4 = SigV4Auth(
boto3_credentials_info.credentials,
"bedrock",
boto3_credentials_info.aws_region_name,
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=proxy_endpoint_url, data=body, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
return BedrockPreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)

View File

@@ -0,0 +1,118 @@
"""
Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
Why separate file? Make it easy to see how transformation works
"""
from litellm._uuid import uuid
from typing import List, Optional, Union
from litellm.types.llms.bedrock import (
BedrockRerankBedrockRerankingConfiguration,
BedrockRerankConfiguration,
BedrockRerankInlineDocumentSource,
BedrockRerankModelConfiguration,
BedrockRerankQuery,
BedrockRerankRequest,
BedrockRerankSource,
BedrockRerankTextDocument,
BedrockRerankTextQuery,
)
from litellm.types.rerank import (
RerankBilledUnits,
RerankRequest,
RerankResponse,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)
class BedrockRerankConfig:
def _transform_sources(
self, documents: List[Union[str, dict]]
) -> List[BedrockRerankSource]:
"""
Transform the sources from RerankRequest format to Bedrock format.
"""
_sources = []
for document in documents:
if isinstance(document, str):
_sources.append(
BedrockRerankSource(
inlineDocumentSource=BedrockRerankInlineDocumentSource(
textDocument=BedrockRerankTextDocument(text=document),
type="TEXT",
),
type="INLINE",
)
)
else:
_sources.append(
BedrockRerankSource(
inlineDocumentSource=BedrockRerankInlineDocumentSource(
jsonDocument=document, type="JSON"
),
type="INLINE",
)
)
return _sources
def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
"""
Transform the request from RerankRequest format to Bedrock format.
"""
_sources = self._transform_sources(request_data.documents)
return BedrockRerankRequest(
queries=[
BedrockRerankQuery(
textQuery=BedrockRerankTextQuery(text=request_data.query),
type="TEXT",
)
],
rerankingConfiguration=BedrockRerankConfiguration(
bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
modelConfiguration=BedrockRerankModelConfiguration(
modelArn=request_data.model
),
numberOfResults=request_data.top_n or len(request_data.documents),
),
type="BEDROCK_RERANKING_MODEL",
),
sources=_sources,
)
def _transform_response(self, response: dict) -> RerankResponse:
"""
Transform the response from Bedrock into the RerankResponse format.
example input:
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
"""
_billed_units = RerankBilledUnits(
**response.get("usage", {"search_units": 1})
) # by default 1 search unit
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[RerankResponseResult]] = None
bedrock_results = response.get("results")
if bedrock_results:
_results = [
RerankResponseResult(
index=result.get("index"),
relevance_score=result.get("relevanceScore"),
)
for result in bedrock_results
]
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View File

@@ -0,0 +1,356 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
import httpx
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.types.integrations.rag.bedrock_knowledgebase import (
BedrockKBContent,
BedrockKBResponse,
BedrockKBRetrievalConfiguration,
BedrockKBRetrievalQuery,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
BaseVectorStoreAuthCredentials,
VectorStoreIndexEndpoints,
VECTOR_STORE_OPENAI_PARAMS,
VectorStoreResultContent,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BedrockVectorStoreConfig(BaseVectorStoreConfig, BaseAWSLLM):
"""Vector store configuration for AWS Bedrock Knowledge Bases."""
def __init__(self) -> None:
BaseVectorStoreConfig.__init__(self)
BaseAWSLLM.__init__(self)
def get_auth_credentials(
self, litellm_params: dict
) -> BaseVectorStoreAuthCredentials:
return {}
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
return {
"read": [("POST", "/knowledgebases/{knowledge_base_id}/retrieve")],
"write": [],
}
def get_supported_openai_params(
self, model: str
) -> List[VECTOR_STORE_OPENAI_PARAMS]:
return ["filters", "max_num_results", "ranking_options"]
def _map_operator_to_aws(self, operator: str) -> str:
"""
Map OpenAI-style operators to AWS Bedrock operator names.
OpenAI uses: eq, ne, gt, gte, lt, lte, in, nin
AWS uses: equals, notEquals, greaterThan, greaterThanOrEquals, lessThan, lessThanOrEquals, in, notIn, startsWith, listContains, stringContains
"""
operator_mapping = {
"eq": "equals",
"ne": "notEquals",
"gt": "greaterThan",
"gte": "greaterThanOrEquals",
"lt": "lessThan",
"lte": "lessThanOrEquals",
"in": "in",
"nin": "notIn",
# AWS-specific operators (pass through)
"equals": "equals",
"notEquals": "notEquals",
"greaterThan": "greaterThan",
"greaterThanOrEquals": "greaterThanOrEquals",
"lessThan": "lessThan",
"lessThanOrEquals": "lessThanOrEquals",
"notIn": "notIn",
"startsWith": "startsWith",
"listContains": "listContains",
"stringContains": "stringContains",
}
return operator_mapping.get(operator, operator)
def _map_operator_filter(self, filter_dict: dict) -> dict:
"""
Map a single OpenAI operator filter to AWS KB format.
OpenAI format: {"key": <key>, "value": <value>, "operator": <operator>}
AWS KB format: {"operator": {"key": <key>, "value": <value>}}
"""
aws_operator = self._map_operator_to_aws(filter_dict["operator"])
return {
aws_operator: {
"key": filter_dict["key"],
"value": filter_dict["value"],
}
}
def _map_and_or_filters(self, value: dict) -> dict:
"""
Map OpenAI and/or filters to AWS KB format.
OpenAI format: {"and" | "or": [{"key": <key>, "value": <value>, "operator": <operator>}]}
AWS KB format: {"andAll" | "orAll": [{"operator": {"key": <key>, "value": <value>}}]}
Note: AWS requires andAll/orAll to have at least 2 elements.
For single filters, unwrap and return just the operator.
"""
aws_filters = {}
if "and" in value:
and_filters = value["and"]
# If only 1 filter, return just the operator (AWS requires andAll to have >=2 elements)
if len(and_filters) == 1:
return self._map_operator_filter(and_filters[0])
aws_filters["andAll"] = [
{
self._map_operator_to_aws(and_filters[i]["operator"]): {
"key": and_filters[i]["key"],
"value": and_filters[i]["value"],
}
}
for i in range(len(and_filters))
]
if "or" in value:
or_filters = value["or"]
# If only 1 filter, return just the operator (AWS requires orAll to have >=2 elements)
if len(or_filters) == 1:
return self._map_operator_filter(or_filters[0])
aws_filters["orAll"] = [
{
self._map_operator_to_aws(or_filters[i]["operator"]): {
"key": or_filters[i]["key"],
"value": or_filters[i]["value"],
}
}
for i in range(len(or_filters))
]
return aws_filters
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_num_results":
optional_params["numberOfResults"] = value
elif param == "filters" and value is not None:
# map the openai filters to the aws kb filters format
# openai filters = {"key": <key>, "value": <value>, "operator": <operator>} OR {"and" | "or": [{"key": <key>, "value": <value>, "operator": <operator>}]}
# aws kb filters = {"operator": {"<key>": <value>}} OR {"andAll | orAll": [{"operator": {"<key>": <value>}}]}
# 1. check if filter is in openai format
# 2. if it is, map it to the aws kb filters format
# 3. if it is not, assume it is in aws kb filters format and add it to the optional_params
aws_filters: Optional[Dict] = None
if isinstance(value, dict):
if "operator" in value.keys():
# Single operator - map directly (no wrapping needed)
aws_filters = self._map_operator_filter(value)
elif "and" in value.keys() or "or" in value.keys():
aws_filters = self._map_and_or_filters(value)
else:
# Assume it's already in AWS KB format
aws_filters = value
optional_params["filters"] = aws_filters
return optional_params
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
headers = headers or {}
headers.setdefault("Content-Type", "application/json")
return headers
def get_complete_url(self, api_base: Optional[str], litellm_params: dict) -> str:
aws_region_name = litellm_params.get("aws_region_name")
endpoint_url, _ = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=litellm_params.get(
"aws_bedrock_runtime_endpoint"
),
aws_region_name=self.get_aws_region_name_for_non_llm_api_calls(
aws_region_name=aws_region_name
),
endpoint_type="agent",
)
return f"{endpoint_url}/knowledgebases"
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict]:
if isinstance(query, list):
query = " ".join(query)
url = f"{api_base}/{vector_store_id}/retrieve"
request_body: Dict[str, Any] = {
"retrievalQuery": BedrockKBRetrievalQuery(text=query),
}
retrieval_config: Dict[str, Any] = {}
max_results = vector_store_search_optional_params.get("max_num_results")
if max_results is not None:
retrieval_config.setdefault("vectorSearchConfiguration", {})[
"numberOfResults"
] = max_results
filters = vector_store_search_optional_params.get("filters")
if filters is not None:
retrieval_config.setdefault("vectorSearchConfiguration", {})[
"filter"
] = filters
if retrieval_config:
# Create a properly typed retrieval configuration
typed_retrieval_config: BedrockKBRetrievalConfiguration = {}
if "vectorSearchConfiguration" in retrieval_config:
typed_retrieval_config["vectorSearchConfiguration"] = retrieval_config[
"vectorSearchConfiguration"
]
request_body["retrievalConfiguration"] = typed_retrieval_config
litellm_logging_obj.model_call_details["query"] = query
return url, request_body
def sign_request(
self,
headers: dict,
optional_params: Dict,
request_data: Dict,
api_base: str,
api_key: Optional[str] = None,
) -> Tuple[dict, Optional[bytes]]:
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
api_key=api_key,
)
def _get_file_id_from_metadata(self, metadata: Dict[str, Any]) -> str:
"""
Extract file_id from Bedrock KB metadata.
Uses source URI if available, otherwise generates a fallback ID.
"""
source_uri = metadata.get("x-amz-bedrock-kb-source-uri", "") if metadata else ""
if source_uri:
return source_uri
chunk_id = (
metadata.get("x-amz-bedrock-kb-chunk-id", "unknown")
if metadata
else "unknown"
)
return f"bedrock-kb-{chunk_id}"
def _get_filename_from_metadata(self, metadata: Dict[str, Any]) -> str:
"""
Extract filename from Bedrock KB metadata.
Tries to extract filename from source URI, falls back to domain name or data source ID.
"""
source_uri = metadata.get("x-amz-bedrock-kb-source-uri", "") if metadata else ""
if source_uri:
try:
parsed_uri = urlparse(source_uri)
filename = (
parsed_uri.path.split("/")[-1]
if parsed_uri.path and parsed_uri.path != "/"
else parsed_uri.netloc
)
if not filename or filename == "/":
filename = parsed_uri.netloc
return filename
except Exception:
return source_uri
data_source_id = (
metadata.get("x-amz-bedrock-kb-data-source-id", "unknown")
if metadata
else "unknown"
)
return f"bedrock-kb-document-{data_source_id}"
def _get_attributes_from_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract all attributes from Bedrock KB metadata.
Returns a copy of the metadata dictionary.
"""
if not metadata:
return {}
return dict(metadata)
def transform_search_vector_store_response(
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
) -> VectorStoreSearchResponse:
try:
response_data = BedrockKBResponse(**response.json())
results: List[VectorStoreSearchResult] = []
for item in response_data.get("retrievalResults", []) or []:
content: Optional[BedrockKBContent] = item.get("content")
text = content.get("text") if content else None
if text is None:
continue
# Extract metadata and use helper functions
metadata = item.get("metadata", {}) or {}
file_id = self._get_file_id_from_metadata(metadata)
filename = self._get_filename_from_metadata(metadata)
attributes = self._get_attributes_from_metadata(metadata)
results.append(
VectorStoreSearchResult(
score=item.get("score"),
content=[VectorStoreResultContent(text=text, type="text")],
file_id=file_id,
filename=filename,
attributes=attributes,
)
)
return VectorStoreSearchResponse(
object="vector_store.search_results.page",
search_query=litellm_logging_obj.model_call_details.get("query", ""),
data=results,
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)
# Vector store creation is not yet implemented
def transform_create_vector_store_request(
self,
vector_store_create_optional_params,
api_base: str,
) -> Tuple[str, Dict]:
raise NotImplementedError
def transform_create_vector_store_response(self, response: httpx.Response):
raise NotImplementedError