chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Generic container file handler for LiteLLM.
|
||||
|
||||
This module provides a single generic handler that can process any container file
|
||||
endpoint defined in endpoints.json, eliminating the need for individual handler methods.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Type, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.containers.main import (
|
||||
ContainerFileListResponse,
|
||||
ContainerFileObject,
|
||||
DeleteContainerFileResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
|
||||
|
||||
# Response type mapping
|
||||
RESPONSE_TYPES: Dict[str, Type] = {
|
||||
"ContainerFileListResponse": ContainerFileListResponse,
|
||||
"ContainerFileObject": ContainerFileObject,
|
||||
"DeleteContainerFileResponse": DeleteContainerFileResponse,
|
||||
}
|
||||
|
||||
|
||||
def _load_endpoints_config() -> Dict:
|
||||
"""Load the endpoints configuration from JSON file."""
|
||||
config_path = Path(__file__).parent.parent.parent / "containers" / "endpoints.json"
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _get_endpoint_config(endpoint_name: str) -> Optional[Dict]:
|
||||
"""Get config for a specific endpoint by name."""
|
||||
config = _load_endpoints_config()
|
||||
for endpoint in config["endpoints"]:
|
||||
if endpoint["name"] == endpoint_name or endpoint["async_name"] == endpoint_name:
|
||||
return endpoint
|
||||
return None
|
||||
|
||||
|
||||
def _build_url(
|
||||
api_base: str,
|
||||
path_template: str,
|
||||
path_params: Dict[str, str],
|
||||
) -> str:
|
||||
"""Build the full URL by substituting path parameters.
|
||||
|
||||
The api_base from get_complete_url already includes /containers,
|
||||
so we need to strip that prefix from the path_template.
|
||||
"""
|
||||
# api_base ends with /containers, path_template starts with /containers
|
||||
# So we need to strip /containers from the path
|
||||
if path_template.startswith("/containers"):
|
||||
path_template = path_template[len("/containers") :]
|
||||
|
||||
url = f"{api_base.rstrip('/')}{path_template}"
|
||||
for param, value in path_params.items():
|
||||
url = url.replace(f"{{{param}}}", value)
|
||||
return url
|
||||
|
||||
|
||||
def _build_query_params(
|
||||
query_param_names: list,
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Dict[str, str]:
|
||||
"""Build query parameters from kwargs."""
|
||||
params = {}
|
||||
for param_name in query_param_names:
|
||||
value = kwargs.get(param_name)
|
||||
if value is not None:
|
||||
params[param_name] = str(value) if not isinstance(value, str) else value
|
||||
return params
|
||||
|
||||
|
||||
def _prepare_multipart_file_upload(
|
||||
file: Any,
|
||||
headers: Dict[str, Any],
|
||||
) -> tuple:
|
||||
"""
|
||||
Prepare file and headers for multipart upload.
|
||||
|
||||
Returns:
|
||||
Tuple of (files_dict, headers_without_content_type)
|
||||
"""
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_file_data,
|
||||
)
|
||||
|
||||
extracted = extract_file_data(file)
|
||||
filename = extracted.get("filename") or "file"
|
||||
content = extracted.get("content") or b""
|
||||
content_type = extracted.get("content_type") or "application/octet-stream"
|
||||
files = {"file": (filename, content, content_type)}
|
||||
|
||||
# Remove content-type header - httpx will set it automatically for multipart
|
||||
headers_copy = headers.copy()
|
||||
headers_copy.pop("content-type", None)
|
||||
headers_copy.pop("Content-Type", None)
|
||||
|
||||
return files, headers_copy
|
||||
|
||||
|
||||
class GenericContainerHandler:
|
||||
"""
|
||||
Generic handler for container file API endpoints.
|
||||
|
||||
This single handler can process any endpoint defined in endpoints.json,
|
||||
eliminating the need for individual handler methods per endpoint.
|
||||
"""
|
||||
|
||||
def handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
container_provider_config: "BaseContainerConfig",
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600,
|
||||
_is_async: bool = False,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Union[Any, Coroutine[Any, Any, Any]]:
|
||||
"""
|
||||
Generic handler for any container file endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the endpoint (e.g., "list_container_files")
|
||||
container_provider_config: Provider-specific configuration
|
||||
litellm_params: LiteLLM parameters including api_key, api_base
|
||||
logging_obj: Logging object for request logging
|
||||
extra_headers: Additional HTTP headers
|
||||
extra_query: Additional query parameters
|
||||
timeout: Request timeout
|
||||
_is_async: Whether to make async request
|
||||
client: Optional HTTP client
|
||||
**kwargs: Path params and query params (e.g., container_id, file_id, after, limit)
|
||||
"""
|
||||
if _is_async:
|
||||
return self._async_handle(
|
||||
endpoint_name=endpoint_name,
|
||||
container_provider_config=container_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._sync_handle(
|
||||
endpoint_name=endpoint_name,
|
||||
container_provider_config=container_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _sync_handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
container_provider_config: "BaseContainerConfig",
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Synchronous request handler."""
|
||||
endpoint_config = _get_endpoint_config(endpoint_name)
|
||||
if not endpoint_config:
|
||||
raise ValueError(f"Unknown endpoint: {endpoint_name}")
|
||||
|
||||
# Get HTTP client
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
http_client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
else:
|
||||
http_client = client
|
||||
|
||||
# Build request
|
||||
headers = container_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
api_key=litellm_params.get("api_key", None),
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = container_provider_config.get_complete_url(
|
||||
api_base=litellm_params.get("api_base", None),
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
# Build URL with path params
|
||||
path_params = {
|
||||
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
|
||||
}
|
||||
url = _build_url(api_base, endpoint_config["path"], path_params)
|
||||
|
||||
# Build query params
|
||||
query_params = _build_query_params(
|
||||
endpoint_config.get("query_params", []), kwargs
|
||||
)
|
||||
if extra_query:
|
||||
query_params.update(extra_query)
|
||||
|
||||
# Log request
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
"params": query_params,
|
||||
},
|
||||
)
|
||||
|
||||
# Make request
|
||||
method = endpoint_config["method"].upper()
|
||||
returns_binary = endpoint_config.get("returns_binary", False)
|
||||
is_multipart = endpoint_config.get("is_multipart", False)
|
||||
|
||||
try:
|
||||
if method == "GET":
|
||||
response = http_client.get(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "DELETE":
|
||||
response = http_client.delete(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "POST":
|
||||
if is_multipart and "file" in kwargs:
|
||||
files, headers = _prepare_multipart_file_upload(
|
||||
kwargs["file"], headers
|
||||
)
|
||||
response = http_client.post(
|
||||
url=url, headers=headers, params=query_params, files=files
|
||||
)
|
||||
else:
|
||||
response = http_client.post(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# For binary responses, return raw content
|
||||
if returns_binary:
|
||||
return response.content
|
||||
|
||||
# Check for error response
|
||||
response_json = response.json()
|
||||
if "error" in response_json:
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=error_msg,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
# Parse response
|
||||
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
|
||||
if response_type:
|
||||
return response_type(**response_json)
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _async_handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
container_provider_config: "BaseContainerConfig",
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Asynchronous request handler."""
|
||||
endpoint_config = _get_endpoint_config(endpoint_name)
|
||||
if not endpoint_config:
|
||||
raise ValueError(f"Unknown endpoint: {endpoint_name}")
|
||||
|
||||
# Get HTTP client
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
http_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.OPENAI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
else:
|
||||
http_client = client
|
||||
|
||||
# Build request
|
||||
headers = container_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
api_key=litellm_params.get("api_key", None),
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = container_provider_config.get_complete_url(
|
||||
api_base=litellm_params.get("api_base", None),
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
# Build URL with path params
|
||||
path_params = {
|
||||
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
|
||||
}
|
||||
url = _build_url(api_base, endpoint_config["path"], path_params)
|
||||
|
||||
# Build query params
|
||||
query_params = _build_query_params(
|
||||
endpoint_config.get("query_params", []), kwargs
|
||||
)
|
||||
if extra_query:
|
||||
query_params.update(extra_query)
|
||||
|
||||
# Log request
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
"params": query_params,
|
||||
},
|
||||
)
|
||||
|
||||
# Make request
|
||||
method = endpoint_config["method"].upper()
|
||||
returns_binary = endpoint_config.get("returns_binary", False)
|
||||
is_multipart = endpoint_config.get("is_multipart", False)
|
||||
|
||||
try:
|
||||
if method == "GET":
|
||||
response = await http_client.get(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "DELETE":
|
||||
response = await http_client.delete(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "POST":
|
||||
if is_multipart and "file" in kwargs:
|
||||
files, headers = _prepare_multipart_file_upload(
|
||||
kwargs["file"], headers
|
||||
)
|
||||
response = await http_client.post(
|
||||
url=url, headers=headers, params=query_params, files=files
|
||||
)
|
||||
else:
|
||||
response = await http_client.post(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# For binary responses, return raw content
|
||||
if returns_binary:
|
||||
return response.content
|
||||
|
||||
# Check for error response
|
||||
response_json = response.json()
|
||||
if "error" in response_json:
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=error_msg,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
# Parse response
|
||||
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
|
||||
if response_type:
|
||||
return response_type(**response_json)
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Singleton instance
|
||||
generic_container_handler = GenericContainerHandler()
|
||||
Reference in New Issue
Block a user