Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/bedrock/passthrough/transformation.py
2026-03-26 20:06:14 +08:00

250 lines
8.7 KiB
Python

import json
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from httpx import Response
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockEventStreamDecoderBase, BedrockModelInfo
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import CostResponseTypes
if TYPE_CHECKING:
from httpx import URL
class BedrockPassthroughConfig(
BaseAWSLLM, BedrockModelInfo, BedrockEventStreamDecoderBase, BasePassthroughConfig
):
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
return "stream" in endpoint
def _encode_model_id_for_endpoint(self, model_id: str) -> str:
"""
Encode model_id (especially ARNs) for use in Bedrock endpoints.
ARNs contain special characters like colons and slashes that need to be
properly URL-encoded when used in HTTP request paths. For example:
arn:aws:bedrock:us-east-1:123:application-inference-profile/abc123
becomes:
arn:aws:bedrock:us-east-1:123:application-inference-profile%2Fabc123
Args:
model_id: The model ID or ARN to encode
Returns:
The encoded model_id suitable for use in endpoint URLs
"""
from litellm.passthrough.utils import CommonUtils
import re
# Create a temporary endpoint with the model_id to check if encoding is needed
temp_endpoint = f"/model/{model_id}/converse"
encoded_temp_endpoint = CommonUtils.encode_bedrock_runtime_modelid_arn(
temp_endpoint
)
# Extract the encoded model_id from the temporary endpoint
encoded_model_id_match = re.search(r"/model/([^/]+)/", encoded_temp_endpoint)
if encoded_model_id_match:
return encoded_model_id_match.group(1)
else:
# Fallback to original model_id if extraction fails
return model_id
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
endpoint: str,
request_query_params: Optional[dict],
litellm_params: dict,
) -> Tuple["URL", str]:
optional_params = litellm_params.copy()
model_id = optional_params.get("model_id", None)
aws_region_name = self._get_aws_region_name(
optional_params=optional_params,
model=model,
model_id=model_id,
)
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint"
)
endpoint_url, _ = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
endpoint_type="runtime",
)
# If model_id is provided (e.g., Application Inference Profile ARN), use it in the endpoint
# instead of the translated model name
if model_id is not None:
import re
# Encode the model_id if it's an ARN to properly handle special characters
encoded_model_id = self._encode_model_id_for_endpoint(model_id)
# Replace the model name in the endpoint with the encoded model_id
endpoint = re.sub(r"model/[^/]+/", f"model/{encoded_model_id}/", endpoint)
return (
self.format_url(endpoint, endpoint_url, request_query_params or {}),
endpoint_url,
)
def sign_request(
self,
headers: dict,
litellm_params: dict,
request_data: Optional[dict],
api_base: str,
model: Optional[str] = None,
) -> Tuple[dict, Optional[bytes]]:
optional_params = litellm_params.copy()
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data or {},
api_base=api_base,
model=model,
)
def logging_non_streaming_response(
self,
model: str,
custom_llm_provider: str,
httpx_response: Response,
request_data: dict,
logging_obj: Logging,
endpoint: str,
) -> Optional["CostResponseTypes"]:
from litellm import encoding
from litellm.types.utils import LlmProviders, ModelResponse
from litellm.utils import ProviderConfigManager
if "invoke" in endpoint:
chat_config_model = "invoke/" + model
elif "converse" in endpoint:
chat_config_model = "converse/" + model
else:
return None
provider_chat_config = ProviderConfigManager.get_provider_chat_config(
provider=LlmProviders(custom_llm_provider),
model=chat_config_model,
)
if provider_chat_config is None:
raise ValueError(f"No provider config found for model: {model}")
litellm_model_response: ModelResponse = provider_chat_config.transform_response(
model=model,
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
raw_response=httpx_response,
model_response=ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
request_data=request_data,
encoding=encoding,
)
return litellm_model_response
def _convert_raw_bytes_to_str_lines(self, raw_bytes: List[bytes]) -> List[str]:
from botocore.eventstream import EventStreamBuffer
all_chunks = []
event_stream_buffer = EventStreamBuffer()
for chunk in raw_bytes:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message is not None:
all_chunks.append(message)
return all_chunks
def handle_logging_collected_chunks(
self,
all_chunks: List[str],
litellm_logging_obj: "LiteLLMLoggingObj",
model: str,
custom_llm_provider: str,
endpoint: str,
) -> Optional["CostResponseTypes"]:
"""
1. Convert all_chunks to a ModelResponseStream
2. combine model_response_stream to model_response
3. Return the model_response
"""
from litellm.litellm_core_utils.streaming_handler import (
convert_generic_chunk_to_model_response_stream,
generic_chunk_has_all_required_fields,
)
from litellm.llms.bedrock.chat import get_bedrock_event_stream_decoder
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.main import stream_chunk_builder
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
all_translated_chunks = []
if "invoke" in endpoint:
invoke_provider = AmazonInvokeConfig.get_bedrock_invoke_provider(model)
if invoke_provider is None:
raise ValueError(
f"Invalid invoke provider: {invoke_provider}, for model: {model}"
)
obj = get_bedrock_event_stream_decoder(
invoke_provider=invoke_provider,
model=model,
sync_stream=True,
json_mode=False,
)
elif "converse" in endpoint:
obj = get_bedrock_event_stream_decoder(
invoke_provider=None,
model=model,
sync_stream=True,
json_mode=False,
)
else:
return None
for chunk in all_chunks:
message = json.loads(chunk)
translated_chunk = obj._chunk_parser(chunk_data=message)
if isinstance(
translated_chunk, dict
) and generic_chunk_has_all_required_fields(cast(dict, translated_chunk)):
chunk_obj = convert_generic_chunk_to_model_response_stream(
cast(GenericStreamingChunk, translated_chunk)
)
elif isinstance(translated_chunk, ModelResponseStream):
chunk_obj = translated_chunk
else:
continue
all_translated_chunks.append(chunk_obj)
if len(all_translated_chunks) > 0:
model_response = stream_chunk_builder(
chunks=all_translated_chunks,
logging_obj=litellm_logging_obj,
)
return model_response
return None