import os import time from typing import Any, Dict, List, Literal, Optional, Union, cast from httpx import Headers, Response from litellm.llms.base_llm.batches.transformation import BaseBatchesConfig from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.secret_managers.main import get_secret_str from litellm.types.llms.bedrock import ( BedrockCreateBatchRequest, BedrockCreateBatchResponse, BedrockInputDataConfig, BedrockOutputDataConfig, BedrockS3InputDataConfig, BedrockS3OutputDataConfig, ) from litellm.types.llms.openai import ( AllMessageValues, CreateBatchRequest, ) from litellm.types.utils import LiteLLMBatch, LlmProviders from ..base_aws_llm import BaseAWSLLM from ..common_utils import CommonBatchFilesUtils class BedrockBatchesConfig(BaseAWSLLM, BaseBatchesConfig): """ Config for Bedrock Batches - handles batch job creation and management for Bedrock """ def __init__(self): super().__init__() self.common_utils = CommonBatchFilesUtils() @property def custom_llm_provider(self) -> LlmProviders: return LlmProviders.BEDROCK def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: """ Validate and prepare environment for Bedrock batch requests. AWS credentials are handled by BaseAWSLLM. """ # Add any Bedrock-specific headers if needed return headers def get_complete_batch_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: Dict, litellm_params: Dict, data: CreateBatchRequest, ) -> str: """ Get the complete URL for Bedrock batch creation. Bedrock batch jobs are created via the model invocation job API. """ aws_region_name = self._get_aws_region_name(optional_params, model) # Bedrock model invocation job endpoint # Format: https://bedrock.{region}.amazonaws.com/model-invocation-job bedrock_endpoint = ( f"https://bedrock.{aws_region_name}.amazonaws.com/model-invocation-job" ) return bedrock_endpoint def transform_create_batch_request( self, model: str, create_batch_data: CreateBatchRequest, optional_params: dict, litellm_params: dict, ) -> Dict[str, Any]: """ Transform the batch creation request to Bedrock format. Bedrock batch inference requires: - modelId: The Bedrock model ID - jobName: Unique name for the batch job - inputDataConfig: Configuration for input data (S3 location) - outputDataConfig: Configuration for output data (S3 location) - roleArn: IAM role ARN for the batch job """ # Get required parameters input_file_id = create_batch_data.get("input_file_id") if not input_file_id: raise ValueError("input_file_id is required for Bedrock batch creation") # Extract S3 information from file ID using common utility input_bucket, input_key = self.common_utils.parse_s3_uri(input_file_id) # Get output S3 configuration output_bucket = litellm_params.get("s3_output_bucket_name") or os.getenv( "AWS_S3_OUTPUT_BUCKET_NAME" ) if not output_bucket: # Use same bucket as input if no output bucket specified output_bucket = input_bucket # Get IAM role ARN role_arn = ( litellm_params.get("aws_batch_role_arn") or optional_params.get("aws_batch_role_arn") or os.getenv("AWS_BATCH_ROLE_ARN") ) if not role_arn: raise ValueError( "AWS IAM role ARN is required for Bedrock batch jobs. " "Set 'aws_batch_role_arn' in litellm_params or AWS_BATCH_ROLE_ARN env var" ) if not model: raise ValueError( "Could not determine Bedrock model ID. Please pass `model` in your request body." ) # Generate job name with the correct model ID using common utility job_name = self.common_utils.generate_unique_job_name(model, prefix="litellm") output_key = f"litellm-batch-outputs/{job_name}/" # Build input data config input_data_config: BedrockInputDataConfig = { "s3InputDataConfig": BedrockS3InputDataConfig( s3Uri=f"s3://{input_bucket}/{input_key}" ) } # Build output data config s3_output_config: BedrockS3OutputDataConfig = BedrockS3OutputDataConfig( s3Uri=f"s3://{output_bucket}/{output_key}" ) # Add optional KMS encryption key ID if provided s3_encryption_key_id = litellm_params.get( "s3_encryption_key_id" ) or get_secret_str("AWS_S3_ENCRYPTION_KEY_ID") if s3_encryption_key_id: s3_output_config["s3EncryptionKeyId"] = s3_encryption_key_id output_data_config: BedrockOutputDataConfig = { "s3OutputDataConfig": s3_output_config } # Create Bedrock batch request with proper typing bedrock_request: BedrockCreateBatchRequest = { "modelId": model, "jobName": job_name, "inputDataConfig": input_data_config, "outputDataConfig": output_data_config, "roleArn": role_arn, } # Add optional parameters if provided completion_window = create_batch_data.get("completion_window") if completion_window: # Map OpenAI completion window to Bedrock timeout # OpenAI uses "24h", Bedrock expects timeout in hours if completion_window == "24h": bedrock_request["timeoutDurationInHours"] = 24 # For Bedrock, we need to return a pre-signed request with AWS auth headers # Use common utility for AWS signing endpoint_url = f"https://bedrock.{self._get_aws_region_name(optional_params, model)}.amazonaws.com/model-invocation-job" signed_headers, signed_data = self.common_utils.sign_aws_request( service_name="bedrock", data=bedrock_request, endpoint_url=endpoint_url, optional_params=optional_params, method="POST", ) # Return a pre-signed request format that the HTTP handler can use return { "method": "POST", "url": endpoint_url, "headers": signed_headers, "data": signed_data.decode("utf-8"), } def transform_create_batch_response( self, model: Optional[str], raw_response: Response, logging_obj: Any, litellm_params: dict, ) -> LiteLLMBatch: """ Transform Bedrock batch creation response to LiteLLM format. """ try: response_data: BedrockCreateBatchResponse = raw_response.json() except Exception as e: raise ValueError(f"Failed to parse Bedrock batch response: {e}") # Extract information from typed Bedrock response job_arn = response_data.get("jobArn", "") status_str: str = str(response_data.get("status", "Submitted")) # Map Bedrock status to OpenAI-compatible status status_mapping: Dict[str, str] = { "Submitted": "validating", "Validating": "validating", "Scheduled": "in_progress", "InProgress": "in_progress", "PartiallyCompleted": "completed", "Completed": "completed", "Failed": "failed", "Stopping": "cancelling", "Stopped": "cancelled", "Expired": "expired", } openai_status = cast( Literal[ "validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled", ], status_mapping.get(status_str, "validating"), ) # Get original request data from litellm_params if available original_request = litellm_params.get("original_batch_request", {}) # Create LiteLLM batch object return LiteLLMBatch( id=job_arn, # Use ARN as the batch ID object="batch", endpoint=original_request.get("endpoint", "/v1/chat/completions"), errors=None, input_file_id=original_request.get("input_file_id", ""), completion_window=original_request.get("completion_window", "24h"), status=openai_status, output_file_id=None, # Will be populated when job completes error_file_id=None, created_at=int(time.time()), in_progress_at=int(time.time()) if status_str == "InProgress" else None, expires_at=None, finalizing_at=None, completed_at=None, failed_at=None, expired_at=None, cancelling_at=None, cancelled_at=None, request_counts=None, metadata=original_request.get("metadata", {}), ) def transform_retrieve_batch_request( self, batch_id: str, optional_params: dict, litellm_params: dict, ) -> Dict[str, Any]: """ Transform batch retrieval request for Bedrock. Args: batch_id: Bedrock job ARN optional_params: Optional parameters litellm_params: LiteLLM parameters Returns: Transformed request data for Bedrock GetModelInvocationJob API """ # For Bedrock, batch_id should be the full job ARN # The GetModelInvocationJob API expects the full ARN as the identifier if not batch_id.startswith("arn:aws:bedrock:"): raise ValueError(f"Invalid batch_id format. Expected ARN, got: {batch_id}") # Extract the job identifier from the ARN - use the full ARN path part # ARN format: arn:aws:bedrock:region:account:model-invocation-job/job-name arn_parts = batch_id.split(":") if len(arn_parts) < 6: raise ValueError(f"Invalid ARN format: {batch_id}") region = arn_parts[3] # arn_parts[5] contains "model-invocation-job/{jobId}" # Build the endpoint URL for GetModelInvocationJob # AWS API format: GET /model-invocation-job/{jobIdentifier} # Use the FULL ARN as jobIdentifier and URL-encode it (includes ':' and '/') import urllib.parse as _ul encoded_arn = _ul.quote(batch_id, safe="") endpoint_url = ( f"https://bedrock.{region}.amazonaws.com/model-invocation-job/{encoded_arn}" ) # Use common utility for AWS signing signed_headers, _ = self.common_utils.sign_aws_request( service_name="bedrock", data={}, # GET request has no body endpoint_url=endpoint_url, optional_params=optional_params, method="GET", ) # Return pre-signed request format return { "method": "GET", "url": endpoint_url, "headers": signed_headers, "data": None, } def _parse_timestamps_and_status(self, response_data, status_str: str): """Helper to parse timestamps based on status.""" import datetime def parse_timestamp(ts_str: Optional[str]) -> Optional[int]: if not ts_str: return None try: dt = datetime.datetime.fromisoformat(ts_str.replace("Z", "+00:00")) return int(dt.timestamp()) except Exception: return None created_at = parse_timestamp( str(response_data.get("submitTime")) if response_data.get("submitTime") is not None else None ) in_progress_states = {"InProgress", "Validating", "Scheduled"} in_progress_at = ( parse_timestamp( str(response_data.get("lastModifiedTime")) if response_data.get("lastModifiedTime") is not None else None ) if status_str in in_progress_states else None ) completed_at = ( parse_timestamp( str(response_data.get("endTime")) if response_data.get("endTime") is not None else None ) if status_str in {"Completed", "PartiallyCompleted"} else None ) failed_at = ( parse_timestamp( str(response_data.get("endTime")) if response_data.get("endTime") is not None else None ) if status_str == "Failed" else None ) cancelled_at = ( parse_timestamp( str(response_data.get("endTime")) if response_data.get("endTime") is not None else None ) if status_str == "Stopped" else None ) expires_at = parse_timestamp( str(response_data.get("jobExpirationTime")) if response_data.get("jobExpirationTime") is not None else None ) return ( created_at, in_progress_at, completed_at, failed_at, cancelled_at, expires_at, ) def _extract_file_configs(self, response_data): """Helper to extract input and output file configurations.""" # Extract input file ID input_file_id = "" input_data_config = response_data.get("inputDataConfig", {}) if isinstance(input_data_config, dict): s3_input_config = input_data_config.get("s3InputDataConfig", {}) if isinstance(s3_input_config, dict): input_file_id = s3_input_config.get("s3Uri", "") # Extract output file ID output_file_id = None output_data_config = response_data.get("outputDataConfig", {}) if isinstance(output_data_config, dict): s3_output_config = output_data_config.get("s3OutputDataConfig", {}) if isinstance(s3_output_config, dict): output_file_id = s3_output_config.get("s3Uri", "") return input_file_id, output_file_id def _extract_errors_and_metadata(self, response_data, raw_response): """Helper to extract errors and enriched metadata.""" # Extract errors message = response_data.get("message") errors = None if message: from openai.types.batch import Errors from openai.types.batch_error import BatchError errors = Errors( data=[BatchError(message=message, code=str(raw_response.status_code))], object="list", ) # Enrich metadata with useful Bedrock fields enriched_metadata_raw: Dict[str, Any] = { "jobName": response_data.get("jobName"), "clientRequestToken": response_data.get("clientRequestToken"), "modelId": response_data.get("modelId"), "roleArn": response_data.get("roleArn"), "timeoutDurationInHours": response_data.get("timeoutDurationInHours"), "vpcConfig": response_data.get("vpcConfig"), } import json as _json enriched_metadata: Dict[str, str] = {} for _k, _v in enriched_metadata_raw.items(): if _v is None: continue if isinstance(_v, (dict, list)): try: enriched_metadata[_k] = _json.dumps(_v) except Exception: enriched_metadata[_k] = str(_v) else: enriched_metadata[_k] = str(_v) return errors, enriched_metadata def transform_retrieve_batch_response( self, model: Optional[str], raw_response: Response, logging_obj: Any, litellm_params: dict, ) -> LiteLLMBatch: """ Transform Bedrock batch retrieval response to LiteLLM format. """ from litellm.types.llms.bedrock import BedrockGetBatchResponse try: response_data: BedrockGetBatchResponse = raw_response.json() except Exception as e: raise ValueError(f"Failed to parse Bedrock batch response: {e}") job_arn = response_data.get("jobArn", "") status_str: str = str(response_data.get("status", "Submitted")) # Map Bedrock status to OpenAI-compatible status status_mapping: Dict[str, str] = { "Submitted": "validating", "Validating": "validating", "Scheduled": "in_progress", "InProgress": "in_progress", "PartiallyCompleted": "completed", "Completed": "completed", "Failed": "failed", "Stopping": "cancelling", "Stopped": "cancelled", "Expired": "expired", } openai_status = cast( Literal[ "validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled", ], status_mapping.get(status_str, "validating"), ) # Parse timestamps ( created_at, in_progress_at, completed_at, failed_at, cancelled_at, expires_at, ) = self._parse_timestamps_and_status(response_data, status_str) # Extract file configurations input_file_id, output_file_id = self._extract_file_configs(response_data) # Extract errors and metadata errors, enriched_metadata = self._extract_errors_and_metadata( response_data, raw_response ) return LiteLLMBatch( id=job_arn, object="batch", endpoint="/v1/chat/completions", errors=errors, input_file_id=input_file_id, completion_window="24h", status=openai_status, output_file_id=output_file_id, error_file_id=None, created_at=created_at or int(time.time()), in_progress_at=in_progress_at, expires_at=expires_at, finalizing_at=None, completed_at=completed_at, failed_at=failed_at, expired_at=None, cancelling_at=None, cancelled_at=cancelled_at, request_counts=None, metadata=enriched_metadata, ) def get_error_class( self, error_message: str, status_code: int, headers: Union[Dict, Headers] ) -> BaseLLMException: """ Get Bedrock-specific error class using common utility. """ return self.common_utils.get_error_class(error_message, status_code, headers)