import hashlib import json import os import urllib.parse from datetime import datetime from typing import ( TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast, get_args, ) import httpx from pydantic import BaseModel from litellm._logging import verbose_logger from litellm.caching.caching import DualCache from litellm.constants import ( BEDROCK_EMBEDDING_PROVIDERS_LITERAL, BEDROCK_INVOKE_PROVIDERS_LITERAL, BEDROCK_MAX_POLICY_SIZE, ) from litellm.litellm_core_utils.dd_tracing import tracer from litellm.secret_managers.main import get_secret, get_secret_str if TYPE_CHECKING: from botocore.awsrequest import AWSPreparedRequest from botocore.credentials import Credentials else: Credentials = Any AWSPreparedRequest = Any class Boto3CredentialsInfo(BaseModel): credentials: Credentials aws_region_name: str aws_bedrock_runtime_endpoint: Optional[str] class AwsAuthError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message self.request = httpx.Request( method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs class BaseAWSLLM: def __init__(self) -> None: self.iam_cache = DualCache() super().__init__() self.aws_authentication_params = [ "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_session_name", "aws_profile_name", "aws_role_name", "aws_web_identity_token", "aws_sts_endpoint", "aws_bedrock_runtime_endpoint", "aws_external_id", ] def _get_ssl_verify(self, ssl_verify: Optional[Union[bool, str]] = None): """ Get SSL verification setting for boto3 clients. This ensures that custom CA certificates are properly used for all AWS API calls, including STS and Bedrock services. Returns: Union[bool, str]: SSL verification setting - False to disable, True to enable, or a string path to a CA bundle file """ from litellm.llms.custom_httpx.http_handler import get_ssl_verify return get_ssl_verify(ssl_verify=ssl_verify) def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str: """ Generate a unique cache key based on the credential arguments. """ # Convert credential arguments to a JSON string and hash it to create a unique key credential_str = json.dumps(credential_args, sort_keys=True) return hashlib.sha256(credential_str.encode()).hexdigest() @tracer.wrap() def get_credentials( self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, aws_role_name: Optional[str] = None, aws_web_identity_token: Optional[str] = None, aws_sts_endpoint: Optional[str] = None, aws_external_id: Optional[str] = None, ssl_verify: Optional[Union[bool, str]] = None, ): """ Return a boto3.Credentials object """ ## CHECK IS 'os.environ/' passed in params_to_check: List[Optional[str]] = [ aws_access_key_id, aws_secret_access_key, aws_session_token, aws_region_name, aws_session_name, aws_profile_name, aws_role_name, aws_web_identity_token, aws_sts_endpoint, aws_external_id, ] # Iterate over parameters and update if needed for i, param in enumerate(params_to_check): if param and param.startswith("os.environ/"): _v = get_secret(param) if _v is not None and isinstance(_v, str): params_to_check[i] = _v elif param is None: # check if uppercase value in env key = self.aws_authentication_params[i] if key.upper() in os.environ: params_to_check[i] = os.getenv(key.upper()) # Assign updated values back to parameters ( aws_access_key_id, aws_secret_access_key, aws_session_token, aws_region_name, aws_session_name, aws_profile_name, aws_role_name, aws_web_identity_token, aws_sts_endpoint, aws_external_id, ) = params_to_check verbose_logger.debug( "in get credentials\n" "aws_access_key_id=%s\n" "aws_secret_access_key=%s\n" "aws_session_token=%s\n" "aws_region_name=%s\n" "aws_session_name=%s\n" "aws_profile_name=%s\n" "aws_role_name=%s\n" "aws_web_identity_token=%s\n" "aws_sts_endpoint=%s\n" "aws_external_id=%s", aws_access_key_id, aws_secret_access_key, aws_session_token, aws_region_name, aws_session_name, aws_profile_name, aws_role_name, aws_web_identity_token, aws_sts_endpoint, aws_external_id, ) # create cache key for non-expiring auth flows args = { k: v for k, v in locals().items() if k.startswith("aws_") or k == "ssl_verify" } cache_key = self.get_cache_key(args) _cached_credentials = self.iam_cache.get_cache(cache_key) if _cached_credentials: return _cached_credentials ######################################################### # Handle diff boto3 auth flows # for each helper # Return: # Credentials - boto3.Credentials # cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time. ######################################################### if ( aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None ): credentials, _cache_ttl = self._auth_with_web_identity_token( aws_web_identity_token=aws_web_identity_token, aws_role_name=aws_role_name, aws_session_name=aws_session_name, aws_region_name=aws_region_name, aws_sts_endpoint=aws_sts_endpoint, aws_external_id=aws_external_id, ) elif aws_role_name is not None: # Check if we're already running as the target role and can skip assumption # This handles IRSA (EKS), ECS task roles, and EC2 instance profiles if self._is_already_running_as_role(aws_role_name, ssl_verify=ssl_verify): verbose_logger.debug( "Already running as target role %s, using ambient credentials", aws_role_name, ) credentials, _cache_ttl = self._auth_with_env_vars() else: verbose_logger.debug( "Using role assumption: calling _auth_with_aws_role" ) # If aws_session_name is not provided, generate a default one if aws_session_name is None: aws_session_name = ( f"litellm-session-{int(datetime.now().timestamp())}" ) credentials, _cache_ttl = self._auth_with_aws_role( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_role_name=aws_role_name, aws_session_name=aws_session_name, aws_region_name=aws_region_name, aws_sts_endpoint=aws_sts_endpoint, aws_external_id=aws_external_id, ssl_verify=ssl_verify, ) elif aws_profile_name is not None: ### CHECK SESSION ### credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name) elif ( aws_access_key_id is not None and aws_secret_access_key is not None and aws_session_token is not None ): credentials, _cache_ttl = self._auth_with_aws_session_token( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, ) elif ( aws_access_key_id is not None and aws_secret_access_key is not None and aws_region_name is not None ): credentials, _cache_ttl = self._auth_with_access_key_and_secret_key( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, ) else: credentials, _cache_ttl = self._auth_with_env_vars() self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) return credentials def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]: try: # First check if the string contains the expected prefix if not isinstance(model, str) or "arn:aws:bedrock" not in model: return None # Split the ARN and check if we have enough parts parts = model.split(":") if len(parts) < 4: return None # Get the region from the correct position region = parts[3] if not region: # Check if region is empty return None return region except Exception: # Catch any unexpected errors and return None return None @staticmethod def _get_provider_from_model_path( model_path: str, ) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]: """ Helper function to get the provider from a model path with format: provider/model-name Args: model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name') Returns: Optional[str]: The provider name, or None if no valid provider found """ parts = model_path.split("/") if len(parts) >= 1: provider = parts[0] if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider) return None @staticmethod def get_bedrock_invoke_provider( model: str, ) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]: """ Helper function to get the bedrock provider from the model handles 3 scenarions: 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` 3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` 4. model=us.amazon.nova-pro-v1:0 -> Returns `nova` """ if model.startswith("invoke/"): model = model.replace("invoke/", "", 1) # Special case: Check for "nova" in model name first (before "amazon") # This handles amazon.nova-* models which would otherwise match "amazon" (Titan) if "nova" in model.lower(): if "nova" in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, "nova") _split_model = model.split(".")[0] if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) # If not a known provider, check for pattern with two slashes provider = BaseAWSLLM._get_provider_from_model_path(model) if provider is not None: return provider for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): if provider in model: return provider return None @staticmethod def get_bedrock_model_id( optional_params: dict, provider: Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL], model: str, ) -> str: model_id = optional_params.pop("model_id", None) if model_id is not None: model_id = BaseAWSLLM.encode_model_id(model_id=model_id) else: model_id = model model_id = model_id.replace("invoke/", "", 1) if provider == "llama" and "llama/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="llama" ) elif provider == "deepseek_r1" and "deepseek_r1/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="deepseek_r1" ) elif provider == "openai" and "openai/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="openai" ) elif provider == "qwen2" and "qwen2/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="qwen2" ) elif provider == "qwen3" and "qwen3/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="qwen3" ) elif provider == "stability" and "stability/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="stability" ) elif provider == "moonshot" and "moonshot/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="moonshot" ) elif "nova-2/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="nova-2" ) elif "nova/" in model_id: model_id = BaseAWSLLM._get_model_id_from_model_with_spec( model_id, spec="nova" ) return model_id @staticmethod def _get_model_id_from_model_with_spec( model: str, spec: str, ) -> str: """ Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models """ model_id = model.replace(spec + "/", "") return BaseAWSLLM.encode_model_id(model_id=model_id) @staticmethod def encode_model_id(model_id: str) -> str: """ Double encode the model ID to ensure it matches the expected double-encoded format. Args: model_id (str): The model ID to encode. Returns: str: The double-encoded model ID. """ return urllib.parse.quote(model_id, safe="") @staticmethod def get_bedrock_embedding_provider( model: str, ) -> Optional[BEDROCK_EMBEDDING_PROVIDERS_LITERAL]: """ Helper function to get the bedrock embedding provider from the model Handles scenarios like: 1. model=cohere.embed-english-v3:0 -> Returns `cohere` 2. model=amazon.titan-embed-text-v1 -> Returns `amazon` 3. model=amazon.nova-2-multimodal-embeddings-v1:0 -> Returns `nova` 4. model=us.twelvelabs.marengo-embed-2-7-v1:0 -> Returns `twelvelabs` 5. model=twelvelabs.marengo-embed-2-7-v1:0 -> Returns `twelvelabs` """ # Special case: Check for "nova" in model name first (before "amazon") # This handles amazon.nova-* models if "nova" in model.lower(): if "nova" in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL): return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, "nova") # Handle regional models like us.twelvelabs.marengo-embed-2-7-v1:0 if "." in model: parts = model.split(".") # Check if the second part (after potential region) is a known provider if len(parts) >= 2: potential_provider = parts[ 1 ] # e.g., "twelvelabs" from "us.twelvelabs.marengo-embed-2-7-v1:0" if potential_provider in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL): return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, potential_provider) # Check if the first part is a known provider (standard format) potential_provider = parts[ 0 ] # e.g., "cohere" from "cohere.embed-english-v3:0" if potential_provider in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL): return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, potential_provider) # Fallback: check if any provider name appears in the model string for provider in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL): if provider in model: return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, provider) return None def _get_aws_region_name( self, optional_params: dict, model: Optional[str] = None, model_id: Optional[str] = None, ) -> str: """ Get the AWS region name from the environment variables. Parameters: optional_params (dict): Optional parameters for the model call model (str): The model name model_id (str): The model ID. This is the ARN of the model, if passed in as a separate param. Returns: str: The AWS region name """ aws_region_name = optional_params.get("aws_region_name", None) ### SET REGION NAME ### if aws_region_name is None: # check model arn # if model_id is not None: aws_region_name = self._get_aws_region_from_model_arn(model_id) else: aws_region_name = self._get_aws_region_from_model_arn(model) # check env # litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) if ( aws_region_name is None and litellm_aws_region_name is not None and isinstance(litellm_aws_region_name, str) ): aws_region_name = litellm_aws_region_name standard_aws_region_name = get_secret("AWS_REGION", None) if ( aws_region_name is None and standard_aws_region_name is not None and isinstance(standard_aws_region_name, str) ): aws_region_name = standard_aws_region_name if aws_region_name is None: try: import boto3 with tracer.trace("boto3.Session()"): session = boto3.Session() configured_region = session.region_name if configured_region: aws_region_name = configured_region else: aws_region_name = "us-west-2" except Exception: aws_region_name = "us-west-2" return aws_region_name def get_aws_region_name_for_non_llm_api_calls( self, aws_region_name: Optional[str] = None, ): """ Get the AWS region name for non-llm api calls. LLM API calls check the model arn and end up using that as the region name. For non-llm api calls eg. Guardrails, Vector Stores we just need to check the dynamic param or env vars. """ if aws_region_name is None: # check env # litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) if litellm_aws_region_name is not None and isinstance( litellm_aws_region_name, str ): aws_region_name = litellm_aws_region_name standard_aws_region_name = get_secret("AWS_REGION", None) if standard_aws_region_name is not None and isinstance( standard_aws_region_name, str ): aws_region_name = standard_aws_region_name if aws_region_name is None: aws_region_name = "us-west-2" return aws_region_name @staticmethod def _parse_arn_account_and_role_name( arn: str, ) -> Optional[Tuple[str, str, str]]: """ Parse an ARN and return (partition, account_id, role_name). Handles: - arn:aws:iam::123456789012:role/MyRole - arn:aws:iam::123456789012:role/path/to/MyRole - arn:aws:sts::123456789012:assumed-role/MyRole/session-name Returns None if the ARN cannot be parsed. """ # ARN format: arn:PARTITION:SERVICE:REGION:ACCOUNT:RESOURCE parts = arn.split(":") if len(parts) < 6 or parts[0] != "arn": return None partition = parts[1] # e.g. "aws", "aws-cn", "aws-us-gov" account_id = parts[4] resource = ":".join(parts[5:]) # rejoin in case resource contains colons if resource.startswith("role/"): # arn:aws:iam::ACCOUNT:role/[path/]ROLE_NAME role_name = resource.split("/")[-1] elif resource.startswith("assumed-role/"): # arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION role_parts = resource.split("/") if len(role_parts) >= 2: role_name = role_parts[1] else: return None else: return None return partition, account_id, role_name def _is_already_running_as_role( self, aws_role_name: str, ssl_verify: Optional[Union[bool, str]] = None, ) -> bool: """ Check if the current environment is already running as the target IAM role. This handles multiple AWS environments: - IRSA (EKS): AWS_ROLE_ARN + AWS_WEB_IDENTITY_TOKEN_FILE are set - ECS task roles: Uses sts:GetCallerIdentity to check current role ARN - EC2 instance profiles: Uses sts:GetCallerIdentity to check current role ARN Compares partition, account ID, and role name to avoid cross-account false matches. Returns True if the current identity matches the target role, meaning we can skip sts:AssumeRole and use ambient credentials directly. """ target_parsed = self._parse_arn_account_and_role_name(aws_role_name) if target_parsed is None: return False target_partition, target_account, target_role = target_parsed # Fast path: IRSA environment check (no API call needed) current_role_arn = os.getenv("AWS_ROLE_ARN") web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE") if current_role_arn and web_identity_token_file: return current_role_arn == aws_role_name # For ECS/EC2: call sts:GetCallerIdentity to check if already running as the role try: import boto3 with tracer.trace("boto3.client(sts).get_caller_identity"): sts_client = boto3.client( "sts", verify=self._get_ssl_verify(ssl_verify) ) identity = sts_client.get_caller_identity() caller_arn = identity.get("Arn", "") caller_parsed = self._parse_arn_account_and_role_name(caller_arn) if caller_parsed is not None: caller_partition, caller_account, caller_role = caller_parsed if ( caller_partition == target_partition and caller_account == target_account and caller_role == target_role ): verbose_logger.debug( "Current identity already matches target role: %s", aws_role_name, ) return True except Exception as e: verbose_logger.debug( "Could not determine current role identity: %s", str(e) ) return False @tracer.wrap() def _auth_with_web_identity_token( self, aws_web_identity_token: str, aws_role_name: str, aws_session_name: str, aws_region_name: Optional[str], aws_sts_endpoint: Optional[str], aws_external_id: Optional[str] = None, ssl_verify: Optional[Union[bool, str]] = None, ) -> Tuple[Credentials, Optional[int]]: """ Authenticate with AWS Web Identity Token """ import boto3 verbose_logger.debug( f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}" ) if aws_sts_endpoint is None: sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com" else: sts_endpoint = aws_sts_endpoint oidc_token = get_secret(aws_web_identity_token) if oidc_token is None: raise AwsAuthError( message="OIDC token could not be retrieved from secret manager.", status_code=401, ) with tracer.trace("boto3.client(sts)"): sts_client = boto3.client( "sts", region_name=aws_region_name, endpoint_url=sts_endpoint, verify=self._get_ssl_verify(ssl_verify), ) # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html assume_role_params = { "RoleArn": aws_role_name, "RoleSessionName": aws_session_name, "WebIdentityToken": oidc_token, "DurationSeconds": 3600, "Policy": '{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}', } # Add ExternalId parameter if provided if aws_external_id is not None: assume_role_params["ExternalId"] = aws_external_id sts_response = sts_client.assume_role_with_web_identity(**assume_role_params) iam_creds_dict = { "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"], "aws_session_token": sts_response["Credentials"]["SessionToken"], "region_name": aws_region_name, } if sts_response["PackedPolicySize"] > BEDROCK_MAX_POLICY_SIZE: verbose_logger.warning( f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}" ) with tracer.trace("boto3.Session(**iam_creds_dict)"): session = boto3.Session(**iam_creds_dict) iam_creds = session.get_credentials() return iam_creds, self._get_default_ttl_for_boto3_credentials() def _handle_irsa_cross_account( self, irsa_role_arn: str, aws_role_name: str, aws_session_name: str, region: str, web_identity_token_file: str, aws_external_id: Optional[str] = None, aws_sts_endpoint: Optional[str] = None, ssl_verify: Optional[Union[bool, str]] = None, ) -> dict: """Handle cross-account role assumption for IRSA.""" import boto3 verbose_logger.debug("Cross-account role assumption detected") # Read the web identity token with open(web_identity_token_file, "r") as f: web_identity_token = f.read().strip() irsa_sts_kwargs: dict = { "region_name": region, "verify": self._get_ssl_verify(ssl_verify), } if aws_sts_endpoint is not None: irsa_sts_kwargs["endpoint_url"] = aws_sts_endpoint # Create an STS client without credentials with tracer.trace("boto3.client(sts) for manual IRSA"): sts_client = boto3.client("sts", **irsa_sts_kwargs) # Manually assume the IRSA role with the session name verbose_logger.debug( f"Manually assuming IRSA role {irsa_role_arn} with session {aws_session_name}" ) irsa_response = sts_client.assume_role_with_web_identity( RoleArn=irsa_role_arn, RoleSessionName=aws_session_name, WebIdentityToken=web_identity_token, ) # Extract the credentials from the IRSA assumption irsa_creds = irsa_response["Credentials"] # Create a new STS client with the IRSA credentials with tracer.trace("boto3.client(sts) with manual IRSA credentials"): sts_client_with_creds = boto3.client( "sts", aws_access_key_id=irsa_creds["AccessKeyId"], aws_secret_access_key=irsa_creds["SecretAccessKey"], aws_session_token=irsa_creds["SessionToken"], **irsa_sts_kwargs, ) # Get current caller identity for debugging try: caller_identity = sts_client_with_creds.get_caller_identity() verbose_logger.debug( f"Current identity after manual IRSA assumption: {caller_identity.get('Arn', 'unknown')}" ) except Exception as e: verbose_logger.debug(f"Failed to get caller identity: {e}") # Now assume the target role verbose_logger.debug( f"Attempting to assume target role: {aws_role_name} with session: {aws_session_name}" ) assume_role_params = { "RoleArn": aws_role_name, "RoleSessionName": aws_session_name, } # Add ExternalId parameter if provided if aws_external_id is not None: assume_role_params["ExternalId"] = aws_external_id return sts_client_with_creds.assume_role(**assume_role_params) def _handle_irsa_same_account( self, aws_role_name: str, aws_session_name: str, region: str, aws_external_id: Optional[str] = None, aws_sts_endpoint: Optional[str] = None, ssl_verify: Optional[Union[bool, str]] = None, ) -> dict: """Handle same-account role assumption for IRSA.""" import boto3 irsa_sts_kwargs: dict = { "region_name": region, "verify": self._get_ssl_verify(ssl_verify), } if aws_sts_endpoint is not None: irsa_sts_kwargs["endpoint_url"] = aws_sts_endpoint verbose_logger.debug("Same account role assumption, using automatic IRSA") with tracer.trace("boto3.client(sts) with automatic IRSA"): sts_client = boto3.client("sts", **irsa_sts_kwargs) # Get current caller identity for debugging try: caller_identity = sts_client.get_caller_identity() verbose_logger.debug( f"Current IRSA identity: {caller_identity.get('Arn', 'unknown')}" ) except Exception as e: verbose_logger.debug(f"Failed to get caller identity: {e}") # Assume the role verbose_logger.debug( f"Attempting to assume role: {aws_role_name} with session: {aws_session_name}" ) assume_role_params = { "RoleArn": aws_role_name, "RoleSessionName": aws_session_name, } # Add ExternalId parameter if provided if aws_external_id is not None: assume_role_params["ExternalId"] = aws_external_id return sts_client.assume_role(**assume_role_params) def _extract_credentials_and_ttl( self, sts_response: dict ) -> Tuple[Credentials, Optional[int]]: """Extract credentials and TTL from STS response.""" from botocore.credentials import Credentials sts_credentials = sts_response["Credentials"] credentials = Credentials( access_key=sts_credentials["AccessKeyId"], secret_key=sts_credentials["SecretAccessKey"], token=sts_credentials["SessionToken"], ) expiration_time = sts_credentials["Expiration"] ttl = int( (expiration_time - datetime.now(expiration_time.tzinfo)).total_seconds() ) return credentials, ttl @tracer.wrap() def _auth_with_aws_role( self, aws_access_key_id: Optional[str], aws_secret_access_key: Optional[str], aws_session_token: Optional[str], aws_role_name: str, aws_session_name: str, aws_region_name: Optional[str] = None, aws_sts_endpoint: Optional[str] = None, aws_external_id: Optional[str] = None, ssl_verify: Optional[Union[bool, str]] = None, ) -> Tuple[Credentials, Optional[int]]: """ Authenticate with AWS Role """ import boto3 from botocore.credentials import Credentials # Check if we're in an EKS/IRSA environment web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE") irsa_role_arn = os.getenv("AWS_ROLE_ARN") region = ( aws_region_name or os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") ) # If we have IRSA environment variables and no explicit credentials, # we need to use the web identity token flow if ( web_identity_token_file and irsa_role_arn and aws_access_key_id is None and aws_secret_access_key is None ): # For cross-account role assumption with specific session names, # we need to manually assume the IRSA role first with the correct session name verbose_logger.debug( f"IRSA detected: using web identity token from {web_identity_token_file}" ) try: # Use passed-in region when set, else env, else default (align with AssumeRole path) region = region or "us-east-1" # Check if we need to do cross-account role assumption if aws_role_name != irsa_role_arn: sts_response = self._handle_irsa_cross_account( irsa_role_arn, aws_role_name, aws_session_name, region, web_identity_token_file, aws_external_id, aws_sts_endpoint=aws_sts_endpoint, ssl_verify=ssl_verify, ) else: sts_response = self._handle_irsa_same_account( aws_role_name, aws_session_name, region, aws_external_id, aws_sts_endpoint=aws_sts_endpoint, ssl_verify=ssl_verify, ) return self._extract_credentials_and_ttl(sts_response) except Exception as e: verbose_logger.debug(f"Failed to assume role via IRSA: {e}") if "AccessDenied" in str( e ) and "is not authorized to perform: sts:AssumeRole" in str(e): # Provide a more helpful error message for trust policy issues verbose_logger.error( f"Access denied when trying to assume role {aws_role_name}. " f"Please ensure the trust policy of {aws_role_name} allows " f"the current role to assume it. Current identity: check logs with verbose mode." ) # Re-raise the exception instead of falling through raise # In EKS/IRSA environments, use ambient credentials (no explicit keys needed) # This allows the web identity token to work automatically sts_client_kwargs: dict = {"verify": self._get_ssl_verify(ssl_verify)} if region is not None: sts_client_kwargs["region_name"] = region if aws_sts_endpoint is not None: sts_client_kwargs["endpoint_url"] = aws_sts_endpoint if aws_access_key_id is None and aws_secret_access_key is None: with tracer.trace("boto3.client(sts)"): sts_client = boto3.client("sts", **sts_client_kwargs) else: with tracer.trace("boto3.client(sts)"): sts_client = boto3.client( "sts", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, **sts_client_kwargs, ) assume_role_params = { "RoleArn": aws_role_name, "RoleSessionName": aws_session_name, } # Add ExternalId parameter if provided if aws_external_id is not None: assume_role_params["ExternalId"] = aws_external_id try: sts_response = sts_client.assume_role(**assume_role_params) except Exception as e: error_str = str(e) if "AccessDenied" in error_str: # Only fall back to ambient credentials if we can positively # confirm the caller is already the target role (same account, # partition, and role name). This avoids silently using the # wrong identity when there is a genuine trust-policy or # permission misconfiguration. if self._is_already_running_as_role( aws_role_name, ssl_verify=ssl_verify ): verbose_logger.warning( "AssumeRole failed for %s (%s). " "Caller is already running as this role; " "falling back to ambient credentials.", aws_role_name, error_str, ) return self._auth_with_env_vars() # Genuine permission error — re-raise verbose_logger.error( "AssumeRole AccessDenied for %s and caller is NOT " "the same role. Re-raising. Error: %s", aws_role_name, error_str, ) raise # Extract the credentials from the response and convert to Session Credentials sts_credentials = sts_response["Credentials"] credentials = Credentials( access_key=sts_credentials["AccessKeyId"], secret_key=sts_credentials["SecretAccessKey"], token=sts_credentials["SessionToken"], ) sts_expiry = sts_credentials["Expiration"] # Convert to timezone-aware datetime for comparison current_time = datetime.now(sts_expiry.tzinfo) sts_ttl = (sts_expiry - current_time).total_seconds() - 60 return credentials, sts_ttl @tracer.wrap() def _auth_with_aws_profile( self, aws_profile_name: str ) -> Tuple[Credentials, Optional[int]]: """ Authenticate with AWS profile """ import boto3 # uses auth values from AWS profile usually stored in ~/.aws/credentials with tracer.trace("boto3.Session(profile_name=aws_profile_name)"): client = boto3.Session(profile_name=aws_profile_name) return client.get_credentials(), None @tracer.wrap() def _auth_with_aws_session_token( self, aws_access_key_id: str, aws_secret_access_key: str, aws_session_token: str, ) -> Tuple[Credentials, Optional[int]]: """ Authenticate with AWS Session Token """ ### CHECK FOR AWS SESSION TOKEN ### from botocore.credentials import Credentials credentials = Credentials( access_key=aws_access_key_id, secret_key=aws_secret_access_key, token=aws_session_token, ) return credentials, None @tracer.wrap() def _auth_with_access_key_and_secret_key( self, aws_access_key_id: str, aws_secret_access_key: str, aws_region_name: Optional[str], ) -> Tuple[Credentials, Optional[int]]: """ Authenticate with AWS Access Key and Secret Key """ import boto3 # Check if credentials are already in cache. These credentials have no expiry time. with tracer.trace( "boto3.Session(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name)" ): session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name, ) credentials = session.get_credentials() return credentials, self._get_default_ttl_for_boto3_credentials() @tracer.wrap() def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]: """ Authenticate with AWS Environment Variables """ import boto3 with tracer.trace("boto3.Session()"): session = boto3.Session() credentials = session.get_credentials() return credentials, None @tracer.wrap() def _get_default_ttl_for_boto3_credentials(self) -> int: """ Get the default TTL for boto3 credentials Returns `3600-60` which is 59 minutes """ return 3600 - 60 def get_runtime_endpoint( self, api_base: Optional[str], aws_bedrock_runtime_endpoint: Optional[str], aws_region_name: str, endpoint_type: Optional[Literal["runtime", "agent", "agentcore"]] = "runtime", ) -> Tuple[str, str]: env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") if api_base is not None: endpoint_url = api_base elif aws_bedrock_runtime_endpoint is not None and isinstance( aws_bedrock_runtime_endpoint, str ): endpoint_url = aws_bedrock_runtime_endpoint elif env_aws_bedrock_runtime_endpoint and isinstance( env_aws_bedrock_runtime_endpoint, str ): endpoint_url = env_aws_bedrock_runtime_endpoint else: endpoint_url = self._select_default_endpoint_url( endpoint_type=endpoint_type, aws_region_name=aws_region_name, ) # Determine proxy_endpoint_url if aws_bedrock_runtime_endpoint is not None and isinstance( aws_bedrock_runtime_endpoint, str ): proxy_endpoint_url = aws_bedrock_runtime_endpoint elif env_aws_bedrock_runtime_endpoint and isinstance( env_aws_bedrock_runtime_endpoint, str ): proxy_endpoint_url = env_aws_bedrock_runtime_endpoint else: proxy_endpoint_url = endpoint_url return endpoint_url, proxy_endpoint_url def _select_default_endpoint_url( self, endpoint_type: Optional[Literal["runtime", "agent", "agentcore"]], aws_region_name: str, ) -> str: """ Select the default endpoint url based on the endpoint type Default endpoint url is https://bedrock-runtime.{aws_region_name}.amazonaws.com """ if endpoint_type == "agent": return f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com" elif endpoint_type == "agentcore": return f"https://bedrock-agentcore.{aws_region_name}.amazonaws.com" else: return f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" def _get_boto_credentials_from_optional_params( self, optional_params: dict, model: Optional[str] = None ) -> Boto3CredentialsInfo: """ Get boto3 credentials from optional params Args: optional_params (dict): Optional parameters for the model call Returns: Credentials: Boto3 credentials object """ try: from botocore.credentials import Credentials except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = self._get_aws_region_name(optional_params, model) optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None) aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) # https://bedrock-runtime.{region_name}.amazonaws.com aws_external_id = optional_params.pop("aws_external_id", None) credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, aws_role_name=aws_role_name, aws_web_identity_token=aws_web_identity_token, aws_sts_endpoint=aws_sts_endpoint, aws_external_id=aws_external_id, ) return Boto3CredentialsInfo( credentials=credentials, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, ) @tracer.wrap() def get_request_headers( self, credentials: Credentials, aws_region_name: str, extra_headers: Optional[dict], endpoint_url: str, data: Union[str, bytes], headers: dict, api_key: Optional[str] = None, ) -> AWSPreparedRequest: if api_key is not None: aws_bearer_token: Optional[str] = api_key else: aws_bearer_token = get_secret_str("AWS_BEARER_TOKEN_BEDROCK") if aws_bearer_token: try: from botocore.awsrequest import AWSRequest except ImportError: raise ImportError( "Missing boto3 to call bedrock. Run 'pip install boto3'." ) headers["Authorization"] = f"Bearer {aws_bearer_token}" request = AWSRequest( method="POST", url=endpoint_url, data=data, headers=headers ) else: try: from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest except ImportError: raise ImportError( "Missing boto3 to call bedrock. Run 'pip install boto3'." ) # Filter headers for AWS signature calculation # AWS SigV4 only includes specific headers in signature calculation aws_signature_headers = self._filter_headers_for_aws_signature(headers) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) request = AWSRequest( method="POST", url=endpoint_url, data=data, headers=aws_signature_headers, ) sigv4.add_auth(request) # Add back all original headers (including forwarded ones) after signature calculation for header_name, header_value in headers.items(): request.headers[header_name] = header_value if ( extra_headers is not None and "Authorization" in extra_headers ): # prevent sigv4 from overwriting the auth header request.headers["Authorization"] = extra_headers["Authorization"] prepped = request.prepare() return prepped def _filter_headers_for_aws_signature(self, headers: dict) -> dict: """ Filter headers to only include those that AWS SigV4 includes in signature calculation. This Fixes forwarded client headers from breaking the signature calculation. """ aws_signature_headers = {} aws_headers = { "host", "content-type", "date", "x-amz-date", "x-amz-security-token", "x-amz-content-sha256", "x-amz-algorithm", "x-amz-credential", "x-amz-signedheaders", "x-amz-signature", } for header_name, header_value in headers.items(): header_lower = header_name.lower() if ( header_lower in aws_headers or header_lower.startswith("x-amz-") or header_lower.startswith("x-amzn-") ): aws_signature_headers[header_name] = header_value return aws_signature_headers def _sign_request( self, service_name: Literal["bedrock", "sagemaker", "bedrock-agentcore", "s3vectors"], headers: dict, optional_params: dict, request_data: dict, api_base: str, model: Optional[str] = None, stream: Optional[bool] = None, fake_stream: Optional[bool] = None, api_key: Optional[str] = None, ) -> Tuple[dict, Optional[bytes]]: """ Sign a request for Bedrock or Sagemaker Returns: Tuple[dict, Optional[str]]: A tuple containing the headers and the json str body of the request """ if api_key is not None: aws_bearer_token: Optional[str] = api_key else: aws_bearer_token = get_secret_str("AWS_BEARER_TOKEN_BEDROCK") # If aws bearer token is set, use it directly in the header if aws_bearer_token: headers = headers or {} headers["Content-Type"] = "application/json" headers["Authorization"] = f"Bearer {aws_bearer_token}" return headers, json.dumps(request_data).encode() # If no bearer token is set, proceed with the existing SigV4 authentication try: from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.get("aws_secret_access_key", None) aws_access_key_id = optional_params.get("aws_access_key_id", None) aws_session_token = optional_params.get("aws_session_token", None) aws_role_name = optional_params.get("aws_role_name", None) aws_session_name = optional_params.get("aws_session_name", None) aws_profile_name = optional_params.get("aws_profile_name", None) aws_web_identity_token = optional_params.get("aws_web_identity_token", None) aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None) aws_external_id = optional_params.get("aws_external_id", None) aws_region_name = self._get_aws_region_name( optional_params=optional_params, model=model ) credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, aws_role_name=aws_role_name, aws_web_identity_token=aws_web_identity_token, aws_sts_endpoint=aws_sts_endpoint, aws_external_id=aws_external_id, ) sigv4 = SigV4Auth(credentials, service_name, aws_region_name) if headers is not None: headers = {"Content-Type": "application/json", **headers} else: headers = {"Content-Type": "application/json"} aws_signature_headers = self._filter_headers_for_aws_signature(headers) request = AWSRequest( method="POST", url=api_base, data=json.dumps(request_data), headers=aws_signature_headers, ) sigv4.add_auth(request) request_headers_dict = dict(request.headers) # Add back original headers after signing. Only headers in SignedHeaders # are integrity-protected; forwarded headers (x-forwarded-*) must remain unsigned. for header_name, header_value in headers.items(): request_headers_dict[header_name] = header_value if ( headers is not None and "Authorization" in headers ): # prevent sigv4 from overwriting the auth header request_headers_dict["Authorization"] = headers["Authorization"] return request_headers_dict, request.body