chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,118 @@
|
||||
This makes it easier to pass through requests to the LLM APIs.
|
||||
|
||||
E.g. Route to VLLM's `/classify` endpoint:
|
||||
|
||||
|
||||
## SDK (Basic)
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
|
||||
response = litellm.llm_passthrough_route(
|
||||
model="hosted_vllm/papluca/xlm-roberta-base-language-detection",
|
||||
method="POST",
|
||||
endpoint="classify",
|
||||
api_base="http://localhost:8090",
|
||||
api_key=None,
|
||||
json={
|
||||
"model": "swapped-for-litellm-model",
|
||||
"input": "Hello, world!",
|
||||
}
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
## SDK (Router)
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from litellm import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "roberta-base-language-detection",
|
||||
"litellm_params": {
|
||||
"model": "hosted_vllm/papluca/xlm-roberta-base-language-detection",
|
||||
"api_base": "http://localhost:8090",
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"model": "roberta-base-language-detection",
|
||||
"method": "POST",
|
||||
"endpoint": "classify",
|
||||
"api_base": "http://localhost:8090",
|
||||
"api_key": None,
|
||||
"json": {
|
||||
"model": "roberta-base-language-detection",
|
||||
"input": "Hello, world!",
|
||||
}
|
||||
}
|
||||
|
||||
async def main():
|
||||
response = await router.allm_passthrough_route(**request_data)
|
||||
print(response)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## PROXY
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: roberta-base-language-detection
|
||||
litellm_params:
|
||||
model: hosted_vllm/papluca/xlm-roberta-base-language-detection
|
||||
api_base: http://localhost:8090
|
||||
```
|
||||
|
||||
2. Run the proxy
|
||||
|
||||
```bash
|
||||
litellm proxy --config config.yaml
|
||||
|
||||
# RUNNING on http://localhost:4000
|
||||
```
|
||||
|
||||
3. Use the proxy
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/vllm/classify \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer <your-api-key>" \
|
||||
-d '{"model": "roberta-base-language-detection", "input": "Hello, world!"}' \
|
||||
```
|
||||
|
||||
# How to add a provider for passthrough
|
||||
|
||||
See [VLLMModelInfo](https://github.com/BerriAI/litellm/blob/main/litellm/llms/vllm/common_utils.py) for an example.
|
||||
|
||||
1. Inherit from BaseModelInfo
|
||||
|
||||
```python
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
|
||||
class VLLMModelInfo(BaseLLMModelInfo):
|
||||
pass
|
||||
```
|
||||
|
||||
2. Register the provider in the ProviderConfigManager.get_provider_model_info
|
||||
|
||||
```python
|
||||
from litellm.utils import ProviderConfigManager
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
model="my-test-model", provider=LlmProviders.VLLM
|
||||
)
|
||||
|
||||
print(provider_config)
|
||||
```
|
||||
@@ -0,0 +1,8 @@
|
||||
from .main import allm_passthrough_route, llm_passthrough_route
|
||||
from .utils import BasePassthroughUtils
|
||||
|
||||
__all__ = [
|
||||
"allm_passthrough_route",
|
||||
"llm_passthrough_route",
|
||||
"BasePassthroughUtils",
|
||||
]
|
||||
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
This module is used to pass through requests to the LLM APIs.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Coroutine,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from httpx._types import CookieTypes, QueryParamTypes, RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.passthrough.utils import CommonUtils
|
||||
from litellm.utils import client
|
||||
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
from .utils import BasePassthroughUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
|
||||
|
||||
|
||||
@client
|
||||
async def allm_passthrough_route(
|
||||
*,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
request_query_params: Optional[dict] = None,
|
||||
request_headers: Optional[dict] = None,
|
||||
content: Optional[Any] = None,
|
||||
data: Optional[dict] = None,
|
||||
files: Optional[RequestFiles] = None,
|
||||
json: Optional[Any] = None,
|
||||
params: Optional[QueryParamTypes] = None,
|
||||
cookies: Optional[CookieTypes] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Union[httpx.Response, AsyncGenerator[Any, Any]]:
|
||||
"""
|
||||
Async: Reranks a list of documents based on their relevance to the query
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["allm_passthrough_route"] = True
|
||||
|
||||
model, custom_llm_provider, api_key, api_base = get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
provider_config = cast(
|
||||
Optional["BasePassthroughConfig"], kwargs.get("provider_config")
|
||||
) or ProviderConfigManager.get_provider_passthrough_config(
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
model=model,
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise Exception(f"Provider {custom_llm_provider} not found")
|
||||
|
||||
func = partial(
|
||||
llm_passthrough_route,
|
||||
method=method,
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
request_query_params=request_query_params,
|
||||
request_headers=request_headers,
|
||||
content=content,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
cookies=cookies,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
# Since allm_passthrough_route=True, we always get a coroutine from _async_passthrough_request
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
|
||||
# Only call raise_for_status if it's a Response object (not a generator)
|
||||
if isinstance(response, httpx.Response):
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
else:
|
||||
# This shouldn't happen when allm_passthrough_route=True, but handle it for type safety
|
||||
raise Exception("Expected coroutine from async passthrough route")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
# For HTTP errors, re-raise as-is to preserve the original error details
|
||||
# The caller (e.g., proxy layer) can handle conversion to appropriate response format
|
||||
raise e
|
||||
except Exception as e:
|
||||
# For other exceptions, use provider-specific error handling
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
# Get the provider using the same logic as llm_passthrough_route
|
||||
_, resolved_custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Get provider config if available
|
||||
provider_config = None
|
||||
if resolved_custom_llm_provider:
|
||||
try:
|
||||
provider_config = cast(
|
||||
Optional["BasePassthroughConfig"], kwargs.get("provider_config")
|
||||
) or ProviderConfigManager.get_provider_passthrough_config(
|
||||
provider=LlmProviders(resolved_custom_llm_provider),
|
||||
model=model,
|
||||
)
|
||||
except Exception:
|
||||
# If we can't get provider config, pass None
|
||||
pass
|
||||
|
||||
if provider_config is None:
|
||||
# If no provider config available, raise the original exception
|
||||
raise e
|
||||
|
||||
raise base_llm_http_handler._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def llm_passthrough_route(
|
||||
*,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
request_query_params: Optional[dict] = None,
|
||||
request_headers: Optional[dict] = None,
|
||||
allm_passthrough_route: bool = False,
|
||||
content: Optional[Any] = None,
|
||||
data: Optional[dict] = None,
|
||||
files: Optional[RequestFiles] = None,
|
||||
json: Optional[Any] = None,
|
||||
params: Optional[QueryParamTypes] = None,
|
||||
cookies: Optional[CookieTypes] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
httpx.Response,
|
||||
Coroutine[Any, Any, httpx.Response],
|
||||
Coroutine[Any, Any, Union[httpx.Response, AsyncGenerator[Any, Any]]],
|
||||
Generator[Any, Any, Any],
|
||||
AsyncGenerator[Any, Any],
|
||||
]:
|
||||
"""
|
||||
Pass through requests to the LLM APIs.
|
||||
|
||||
Step 1. Build the request
|
||||
Step 2. Send the request
|
||||
Step 3. Return the response
|
||||
"""
|
||||
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
_is_async = allm_passthrough_route
|
||||
|
||||
if client is None:
|
||||
if _is_async:
|
||||
client = litellm.module_level_aclient
|
||||
else:
|
||||
client = litellm.module_level_client
|
||||
|
||||
litellm_logging_obj = cast("LiteLLMLoggingObj", kwargs.get("litellm_logging_obj"))
|
||||
|
||||
model, custom_llm_provider, api_key, api_base = get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
# Add model_id to litellm_params if present in kwargs (for Bedrock Application Inference Profiles)
|
||||
if "model_id" in kwargs:
|
||||
litellm_params_dict["model_id"] = kwargs["model_id"]
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
litellm_params=litellm_params_dict,
|
||||
optional_params={},
|
||||
endpoint=endpoint,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
request_data=data if data else json,
|
||||
)
|
||||
|
||||
provider_config = cast(
|
||||
Optional["BasePassthroughConfig"], kwargs.get("provider_config")
|
||||
) or ProviderConfigManager.get_provider_passthrough_config(
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
model=model,
|
||||
)
|
||||
if provider_config is None:
|
||||
raise Exception(f"Provider {custom_llm_provider} not found")
|
||||
|
||||
updated_url, base_target_url = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
endpoint=endpoint,
|
||||
request_query_params=request_query_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# [TODO: Refactor to bedrockpassthroughconfig] need to encode the id of application-inference-profile for bedrock
|
||||
if custom_llm_provider == "bedrock" and "application-inference-profile" in endpoint:
|
||||
encoded_url_str = CommonUtils.encode_bedrock_runtime_modelid_arn(
|
||||
str(updated_url)
|
||||
)
|
||||
updated_url = httpx.URL(encoded_url_str)
|
||||
|
||||
# Add or update query parameters
|
||||
provider_api_key = provider_config.get_api_key(api_key)
|
||||
|
||||
auth_headers = provider_config.validate_environment(
|
||||
headers={},
|
||||
model=model,
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params=litellm_params_dict,
|
||||
api_key=provider_api_key,
|
||||
api_base=base_target_url,
|
||||
)
|
||||
|
||||
headers = BasePassthroughUtils.forward_headers_from_request(
|
||||
request_headers=request_headers or {},
|
||||
headers=auth_headers,
|
||||
forward_headers=False,
|
||||
)
|
||||
|
||||
headers, signed_json_body = provider_config.sign_request(
|
||||
headers=headers,
|
||||
litellm_params=litellm_params_dict,
|
||||
request_data=data if data else json,
|
||||
api_base=str(updated_url),
|
||||
model=model,
|
||||
)
|
||||
|
||||
## SWAP MODEL IN JSON BODY [TODO: REFACTOR TO A provider_config.transform_request method]
|
||||
if json and isinstance(json, dict) and "model" in json:
|
||||
json["model"] = model
|
||||
|
||||
request = client.client.build_request(
|
||||
method=method,
|
||||
url=updated_url,
|
||||
content=signed_json_body if signed_json_body is not None else content,
|
||||
data=data if (signed_json_body is None and content is None) else None,
|
||||
files=files,
|
||||
json=json if (signed_json_body is None and content is None) else None,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
## IS STREAMING REQUEST
|
||||
is_streaming_request = provider_config.is_streaming_request(
|
||||
endpoint=endpoint,
|
||||
request_data=data or json or {},
|
||||
)
|
||||
|
||||
# Update logging object with streaming status
|
||||
litellm_logging_obj.stream = is_streaming_request
|
||||
|
||||
## LOGGING PRE-CALL
|
||||
request_data = data if data else json
|
||||
litellm_logging_obj.pre_call(
|
||||
input=request_data,
|
||||
api_key=provider_api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": str(updated_url),
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
if _is_async:
|
||||
# Return the coroutine to be awaited by the caller
|
||||
return _async_passthrough_request(
|
||||
client=client,
|
||||
request=request,
|
||||
is_streaming_request=is_streaming_request,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
else:
|
||||
# Sync path - client.client.send returns Response directly
|
||||
response: httpx.Response = client.client.send(request=request, stream=is_streaming_request) # type: ignore
|
||||
response.raise_for_status()
|
||||
|
||||
if (
|
||||
hasattr(response, "iter_bytes") and is_streaming_request
|
||||
): # yield the chunk, so we can store it in the logging object
|
||||
return _sync_streaming(response, litellm_logging_obj, provider_config)
|
||||
else:
|
||||
# For non-streaming responses, yield the entire response
|
||||
return response
|
||||
except Exception as e:
|
||||
if provider_config is None:
|
||||
raise e
|
||||
raise base_llm_http_handler._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
|
||||
async def _async_passthrough_request(
|
||||
client: Union[HTTPHandler, AsyncHTTPHandler],
|
||||
request: httpx.Request,
|
||||
is_streaming_request: bool,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
provider_config: "BasePassthroughConfig",
|
||||
) -> Union[httpx.Response, AsyncGenerator[Any, Any]]:
|
||||
"""
|
||||
Handle async passthrough requests.
|
||||
Uses async client to send request and properly handles streaming.
|
||||
"""
|
||||
# client.client.send returns a coroutine for async clients
|
||||
response_result = client.client.send(request=request, stream=is_streaming_request)
|
||||
|
||||
# Check if it's a coroutine and await it
|
||||
if asyncio.iscoroutine(response_result):
|
||||
if is_streaming_request:
|
||||
# Pass the coroutine to _async_streaming which will await it
|
||||
return _async_streaming(
|
||||
response=response_result,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
else:
|
||||
response = await response_result
|
||||
await response.aread()
|
||||
response.raise_for_status()
|
||||
return response
|
||||
else:
|
||||
# Fallback for sync-like behavior (shouldn't happen in async path)
|
||||
raise Exception("Expected coroutine from async client")
|
||||
|
||||
|
||||
def _sync_streaming(
|
||||
response: httpx.Response,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
provider_config: "BasePassthroughConfig",
|
||||
):
|
||||
from litellm.utils import executor
|
||||
|
||||
try:
|
||||
raw_bytes: List[bytes] = []
|
||||
for chunk in response.iter_bytes(): # type: ignore
|
||||
raw_bytes.append(chunk)
|
||||
yield chunk
|
||||
|
||||
executor.submit(
|
||||
litellm_logging_obj.flush_passthrough_collected_chunks,
|
||||
raw_bytes=raw_bytes,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def _async_streaming(
|
||||
response: Coroutine[Any, Any, httpx.Response],
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
provider_config: "BasePassthroughConfig",
|
||||
):
|
||||
iter_response = await response
|
||||
try:
|
||||
iter_response.raise_for_status()
|
||||
raw_bytes: List[bytes] = []
|
||||
|
||||
async for chunk in iter_response.aiter_bytes(): # type: ignore
|
||||
raw_bytes.append(chunk)
|
||||
yield chunk
|
||||
|
||||
asyncio.create_task(
|
||||
litellm_logging_obj.async_flush_passthrough_collected_chunks(
|
||||
raw_bytes=raw_bytes,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
await iter_response.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
@@ -0,0 +1,119 @@
|
||||
from typing import Dict, List, Mapping, Optional, Union
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.constants import PASS_THROUGH_HEADER_PREFIX
|
||||
|
||||
|
||||
class BasePassthroughUtils:
|
||||
@staticmethod
|
||||
def get_merged_query_parameters(
|
||||
existing_url: httpx.URL,
|
||||
request_query_params: Mapping[str, Union[str, list]],
|
||||
default_query_params: Optional[Dict[str, Union[str, list]]] = None,
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
# Get the existing query params from the target URL
|
||||
existing_query_string = existing_url.query.decode("utf-8")
|
||||
existing_query_params = parse_qs(existing_query_string)
|
||||
|
||||
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||
updated_existing_query_params = {
|
||||
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||
}
|
||||
|
||||
# Start with default query params (lowest priority)
|
||||
merged_params = {}
|
||||
if default_query_params:
|
||||
merged_params.update(default_query_params)
|
||||
|
||||
# Override with existing URL query params (medium priority)
|
||||
merged_params.update(updated_existing_query_params)
|
||||
|
||||
# Override with request query params (highest priority - client can override anything)
|
||||
merged_params.update(request_query_params)
|
||||
|
||||
return merged_params
|
||||
|
||||
@staticmethod
|
||||
def forward_headers_from_request(
|
||||
request_headers: dict,
|
||||
headers: dict,
|
||||
forward_headers: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Helper to forward headers from original request.
|
||||
|
||||
Also handles 'x-pass-' prefixed headers which are always forwarded
|
||||
with the prefix stripped, regardless of forward_headers setting.
|
||||
e.g., 'x-pass-anthropic-beta: value' becomes 'anthropic-beta: value'
|
||||
"""
|
||||
if forward_headers is True:
|
||||
# Header We Should NOT forward
|
||||
request_headers.pop("content-length", None)
|
||||
request_headers.pop("host", None)
|
||||
|
||||
# Combine request headers with custom headers
|
||||
headers = {**request_headers, **headers}
|
||||
|
||||
# Always process x-pass- prefixed headers (strip prefix and forward)
|
||||
for header_name, header_value in request_headers.items():
|
||||
if header_name.lower().startswith(PASS_THROUGH_HEADER_PREFIX):
|
||||
# Strip the 'x-pass-' prefix to get the actual header name
|
||||
actual_header_name = header_name[len(PASS_THROUGH_HEADER_PREFIX) :]
|
||||
headers[actual_header_name] = header_value
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class CommonUtils:
|
||||
@staticmethod
|
||||
def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str:
|
||||
"""
|
||||
Encodes any "/" found in the modelId of an AWS Bedrock Runtime Endpoint when arns are passed in.
|
||||
- modelID value can be an ARN which contains slashes that SHOULD NOT be treated as path separators.
|
||||
e.g endpoint: /model/<modelId>/invoke
|
||||
<modelId> containing arns with slashes need to be encoded from
|
||||
arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile/abdefg12334 =>
|
||||
arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile%2Fabdefg12334
|
||||
so that it is treated as one part of the path.
|
||||
Otherwise, the encoded endpoint will return 500 error when passed to Bedrock endpoint.
|
||||
|
||||
See the apis in https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Operations_Amazon_Bedrock_Runtime.html
|
||||
for more details on the regex patterns of modelId which we use in the regex logic below.
|
||||
|
||||
Args:
|
||||
endpoint (str): The original endpoint string which may contain ARNs that contain slashes.
|
||||
|
||||
Returns:
|
||||
str: The endpoint with properly encoded ARN slashes
|
||||
"""
|
||||
import re
|
||||
|
||||
# Early exit: if no ARN detected, return unchanged
|
||||
if "arn:aws:" not in endpoint:
|
||||
return endpoint
|
||||
|
||||
# Handle all patterns in one go - more efficient and cleaner
|
||||
patterns = [
|
||||
# Custom model with 2 slashes (order matters - do this first)
|
||||
(r"(custom-model)/([a-z0-9.-]+)/([a-z0-9]+)", r"\1%2F\2%2F\3"),
|
||||
# All other resource types with 1 slash
|
||||
(r"(:application-inference-profile)/", r"\1%2F"),
|
||||
(r"(:inference-profile)/", r"\1%2F"),
|
||||
(r"(:foundation-model)/", r"\1%2F"),
|
||||
(r"(:imported-model)/", r"\1%2F"),
|
||||
(r"(:provisioned-model)/", r"\1%2F"),
|
||||
(r"(:prompt)/", r"\1%2F"),
|
||||
(r"(:endpoint)/", r"\1%2F"),
|
||||
(r"(:prompt-router)/", r"\1%2F"),
|
||||
(r"(:default-prompt-router)/", r"\1%2F"),
|
||||
]
|
||||
|
||||
for pattern, replacement in patterns:
|
||||
# Check if pattern exists before applying regex (early exit optimization)
|
||||
if re.search(pattern, endpoint):
|
||||
endpoint = re.sub(pattern, replacement, endpoint)
|
||||
break # Exit after first match since each ARN has only one resource type
|
||||
|
||||
return endpoint
|
||||
Reference in New Issue
Block a user