chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
from .anthropic_messages.transformation import BaseAnthropicMessagesConfig
from .audio_transcription.transformation import BaseAudioTranscriptionConfig
from .batches.transformation import BaseBatchesConfig
from .chat.transformation import BaseConfig
from .embedding.transformation import BaseEmbeddingConfig
from .image_edit.transformation import BaseImageEditConfig
from .image_generation.transformation import BaseImageGenerationConfig
__all__ = [
"BaseImageGenerationConfig",
"BaseConfig",
"BaseAudioTranscriptionConfig",
"BaseAnthropicMessagesConfig",
"BaseEmbeddingConfig",
"BaseImageEditConfig",
"BaseBatchesConfig",
]

View File

@@ -0,0 +1,122 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Tuple, Union
import httpx
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseAnthropicMessagesConfig(ABC):
@abstractmethod
def validate_anthropic_messages_environment( # use different name because return type is different from base config's validate_environment
self,
headers: dict,
model: str,
messages: List[Any],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[dict, Optional[str]]:
"""
OPTIONAL
Validate the environment for the request
Returns:
- headers: dict
- api_base: Optional[str] - If the provider needs to update the api_base, return it here. Otherwise, return None.
"""
return headers, api_base
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def get_supported_anthropic_messages_params(self, model: str) -> list:
pass
@abstractmethod
def transform_anthropic_messages_request(
self,
model: str,
messages: List[Dict],
anthropic_messages_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
pass
@abstractmethod
def transform_anthropic_messages_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> AnthropicMessagesResponse:
pass
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
"""
OPTIONAL
Sign the request, providers like Bedrock need to sign the request before sending it to the API
For all other providers, this is a no-op and we just return the headers
"""
return headers, None
def get_async_streaming_response_iterator(
self,
model: str,
httpx_response: httpx.Response,
request_body: dict,
litellm_logging_obj: LiteLLMLoggingObj,
) -> AsyncIterator:
raise NotImplementedError("Subclasses must implement this method")
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> "BaseLLMException":
from litellm.llms.base_llm.chat.transformation import BaseLLMException
return BaseLLMException(
message=error_message, status_code=status_code, headers=headers
)

View File

@@ -0,0 +1,172 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import FileTypes, ModelResponse, TranscriptionResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
@dataclass
class AudioTranscriptionRequestData:
"""
Structured data for audio transcription requests.
Attributes:
data: The request data (form data for multipart, json data for regular requests)
files: Optional files dict for multipart form data
content_type: Optional content type override
"""
data: Union[dict, bytes]
files: Optional[dict] = None
content_type: Optional[str] = None
class BaseAudioTranscriptionConfig(BaseConfig, ABC):
@abstractmethod
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
pass
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> AudioTranscriptionRequestData:
raise NotImplementedError(
"AudioTranscriptionConfig needs a request transformation for audio transcription models"
)
def transform_audio_transcription_response(
self,
raw_response: httpx.Response,
) -> TranscriptionResponse:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
)
def get_provider_specific_params(
self,
model: str,
optional_params: dict,
openai_params: List[OpenAIAudioTranscriptionOptionalParams],
) -> dict:
"""
Get provider specific parameters that are not OpenAI compatible
eg. if user passes `diarize=True`, we need to pass `diarize` to the provider
but `diarize` is not an OpenAI parameter, so we need to handle it here
"""
provider_specific_params = {}
for key, value in optional_params.items():
# Skip None values
if value is None:
continue
# Skip excluded parameters
if self._should_exclude_param(
param_name=key,
model=model,
):
continue
# Add the parameter to the provider specific params
provider_specific_params[key] = value
return provider_specific_params
def _should_exclude_param(
self,
param_name: str,
model: str,
) -> bool:
"""
Determines if a parameter should be excluded from the query string.
Args:
param_name: Parameter name
model: Model name
Returns:
True if the parameter should be excluded
"""
# Parameters that are handled elsewhere or not relevant to Deepgram API
excluded_params = {
"model", # Already in the URL path
"OPENAI_TRANSCRIPTION_PARAMS", # Internal litellm parameter
}
# Skip if it's an excluded parameter
if param_name in excluded_params:
return True
# Skip if it's an OpenAI-specific parameter that we handle separately
if param_name in self.get_supported_openai_params(model):
return True
return False

View File

@@ -0,0 +1,264 @@
import json
from abc import abstractmethod
from typing import List, Optional, Union, cast
import litellm
from litellm.types.utils import (
Choices,
Delta,
GenericStreamingChunk,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
def convert_model_response_to_streaming(
model_response: ModelResponse,
) -> ModelResponseStream:
"""
Convert a ModelResponse to ModelResponseStream.
This function transforms a standard completion response into a streaming chunk format
by converting 'message' fields to 'delta' fields.
Args:
model_response: The ModelResponse to convert
Returns:
ModelResponseStream: A streaming chunk version of the response
Raises:
ValueError: If the conversion fails
"""
try:
streaming_choices: List[StreamingChoices] = []
for choice in model_response.choices:
streaming_choices.append(
StreamingChoices(
index=choice.index,
delta=Delta(
**cast(Choices, choice).message.model_dump(),
),
finish_reason=choice.finish_reason,
)
)
processed_chunk = ModelResponseStream(
id=model_response.id,
object="chat.completion.chunk",
created=model_response.created,
model=model_response.model,
choices=streaming_choices,
)
return processed_chunk
except Exception as e:
raise ValueError(
f"Failed to convert ModelResponse to ModelResponseStream: {model_response}. Error: {e}"
)
class BaseModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.json_mode = json_mode
def chunk_parser(
self, chunk: dict
) -> Union[GenericStreamingChunk, ModelResponseStream]:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
# Sync iterator
def __iter__(self):
return self
@staticmethod
def _string_to_dict_parser(str_line: str) -> Optional[dict]:
stripped_json_chunk: Optional[dict] = None
stripped_chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(
str_line
)
try:
if stripped_chunk is not None:
stripped_json_chunk = json.loads(stripped_chunk)
else:
stripped_json_chunk = None
except json.JSONDecodeError:
stripped_json_chunk = None
return stripped_json_chunk
def _handle_string_chunk(
self, str_line: str
) -> Union[GenericStreamingChunk, ModelResponseStream]:
# chunk is a str at this point
stripped_json_chunk = BaseModelResponseIterator._string_to_dict_parser(
str_line=str_line
)
if "[DONE]" in str_line:
return GenericStreamingChunk(
text="",
is_finished=True,
finish_reason="stop",
usage=None,
index=0,
tool_use=None,
)
elif stripped_json_chunk:
return self.chunk_parser(chunk=stripped_json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def __next__(self):
while True:
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
# Skip empty lines (common in SSE streams between events).
# Only apply to str chunks — non-string objects (e.g. Pydantic
# BaseModel events from the Responses API) must pass through.
if isinstance(str_line, str) and (
not str_line or not str_line.strip()
):
continue
# chunk is a str at this point
return self._handle_string_chunk(str_line=str_line)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(
f"Error parsing chunk: {e},\nReceived chunk: {chunk}"
)
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
while True:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
# Skip empty lines (common in SSE streams between events).
# Only apply to str chunks — non-string objects (e.g. Pydantic
# BaseModel events from the Responses API) must pass through.
if isinstance(str_line, str) and (
not str_line or not str_line.strip()
):
continue
# chunk is a str at this point
chunk = self._handle_string_chunk(str_line=str_line)
return chunk
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(
f"Error parsing chunk: {e},\nReceived chunk: {chunk}"
)
class MockResponseIterator: # for returning ai21 streaming responses
def __init__(
self, model_response: ModelResponse, json_mode: Optional[bool] = False
):
self.model_response = model_response
self.json_mode = json_mode
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def _chunk_parser(self, chunk_data: ModelResponse) -> ModelResponseStream:
return convert_model_response_to_streaming(chunk_data)
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self._chunk_parser(self.model_response)
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self._chunk_parser(self.model_response)
class FakeStreamResponseIterator:
def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response
self.json_mode = json_mode
self.is_done = False
# Sync iterator
def __iter__(self):
return self
@abstractmethod
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
pass
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.chunk_parser(self.model_response)
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.chunk_parser(self.model_response)

View File

@@ -0,0 +1,227 @@
"""
Utility functions for base LLM classes.
"""
import copy
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Union
from openai.lib import _parsing, _pydantic
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
from litellm.types.utils import Message, ProviderSpecificModelInfo, TokenCountResponse
class BaseTokenCounter(ABC):
@abstractmethod
async def count_tokens(
self,
model_to_use: str,
messages: Optional[List[Dict[str, Any]]],
contents: Optional[List[Dict[str, Any]]],
deployment: Optional[Dict[str, Any]] = None,
request_model: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Optional[TokenCountResponse]:
pass
@abstractmethod
def should_use_token_counting_api(
self,
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Returns True if we should the this API for token counting for the selected `custom_llm_provider`
"""
return False
class BaseLLMModelInfo(ABC):
def get_provider_info(
self,
model: str,
) -> Optional[ProviderSpecificModelInfo]:
"""
Default values all models of this provider support.
"""
return None
@abstractmethod
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Returns a list of models supported by this provider.
"""
return []
@staticmethod
@abstractmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
pass
@staticmethod
@abstractmethod
def get_api_base(
api_base: Optional[str] = None,
) -> Optional[str]:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
pass
@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass
def get_token_counter(self) -> Optional[BaseTokenCounter]:
"""
Factory method to create a token counter for this provider.
Returns:
Optional TokenCounterInterface implementation for this provider,
or None if token counting is not supported.
"""
return None
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[Message]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get("arguments")
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if isinstance(args, dict) and (values := args.get("values")) is not None:
_message = Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return Message(content=json_mode_content_str)
return None
def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None
) -> dict:
if ref_template is not None and response_format.get("type") == "json_schema":
# Deep copy to avoid modifying original
modified_format = copy.deepcopy(response_format)
schema = modified_format["json_schema"]["schema"]
# Update all $ref values in the schema
def update_refs(schema):
stack = [(schema, [])]
visited = set()
while stack:
obj, path = stack.pop()
obj_id = id(obj)
if obj_id in visited:
continue
visited.add(obj_id)
if isinstance(obj, dict):
if "$ref" in obj:
ref_path = obj["$ref"]
model_name = ref_path.split("/")[-1]
obj["$ref"] = ref_template.format(model=model_name)
for k, v in obj.items():
if isinstance(v, (dict, list)):
stack.append((v, path + [k]))
elif isinstance(obj, list):
for i, item in enumerate(obj):
if isinstance(item, (dict, list)):
stack.append((item, path + [i]))
update_refs(schema)
return modified_format
return response_format
def type_to_response_format_param(
response_format: Optional[Union[Type[BaseModel], dict]],
ref_template: Optional[str] = None,
) -> Optional[dict]:
"""
Re-implementation of openai's 'type_to_response_format_param' function
Used for converting pydantic object to api schema.
"""
if response_format is None:
return None
if isinstance(response_format, dict):
return _dict_to_response_format_helper(response_format, ref_template)
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
if not _parsing._completions.is_basemodel_type(response_format):
raise TypeError(f"Unsupported response_format type - {response_format}")
if ref_template is not None:
schema = response_format.model_json_schema(ref_template=ref_template)
else:
schema = _pydantic.to_strict_json_schema(response_format)
return {
"type": "json_schema",
"json_schema": {
"schema": schema,
"name": response_format.__name__,
"strict": True,
},
}
def map_developer_role_to_system_role(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
"""
new_messages: List[AllMessageValues] = []
for m in messages:
if m["role"] == "developer":
verbose_logger.debug(
"Translating developer role to system role for non-OpenAI providers."
) # ensure user knows what's happening with their input.
new_messages.append({"role": "system", "content": m["content"]})
else:
new_messages.append(m)
return new_messages

View File

@@ -0,0 +1,218 @@
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from httpx import Headers
from litellm.types.llms.openai import (
AllMessageValues,
CreateBatchRequest,
)
from litellm.types.utils import LiteLLMBatch, LlmProviders
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class BaseBatchesConfig(ABC):
"""
Abstract base class for batch processing configurations across different LLM providers.
This class defines the interface that all provider-specific batch configurations
must implement to work with LiteLLM's unified batch processing system.
"""
def __init__(self):
pass
@property
@abstractmethod
def custom_llm_provider(self) -> LlmProviders:
"""Return the LLM provider type for this configuration."""
pass
@classmethod
def get_config(cls):
"""Get configuration dictionary for this class."""
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate and prepare environment-specific headers and parameters.
Args:
headers: HTTP headers dictionary
model: Model name
messages: List of messages
optional_params: Optional parameters
litellm_params: LiteLLM parameters
api_key: API key
api_base: API base URL
Returns:
Updated headers dictionary
"""
pass
@abstractmethod
def get_complete_batch_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateBatchRequest,
) -> str:
"""
Get the complete URL for batch creation request.
Args:
api_base: Base API URL
api_key: API key
model: Model name
optional_params: Optional parameters
litellm_params: LiteLLM parameters
data: Batch creation request data
Returns:
Complete URL for the batch request
"""
pass
@abstractmethod
def transform_create_batch_request(
self,
model: str,
create_batch_data: CreateBatchRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, Dict[str, Any]]:
"""
Transform the batch creation request to provider-specific format.
Args:
model: Model name
create_batch_data: Batch creation request data
optional_params: Optional parameters
litellm_params: LiteLLM parameters
Returns:
Transformed request data
"""
pass
@abstractmethod
def transform_create_batch_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> LiteLLMBatch:
"""
Transform provider-specific batch response to LiteLLM format.
Args:
model: Model name
raw_response: Raw HTTP response
logging_obj: Logging object
litellm_params: LiteLLM parameters
Returns:
LiteLLM batch object
"""
pass
@abstractmethod
def transform_retrieve_batch_request(
self,
batch_id: str,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, Dict[str, Any]]:
"""
Transform the batch retrieval request to provider-specific format.
Args:
batch_id: Batch ID to retrieve
optional_params: Optional parameters
litellm_params: LiteLLM parameters
Returns:
Transformed request data
"""
pass
@abstractmethod
def transform_retrieve_batch_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> LiteLLMBatch:
"""
Transform provider-specific batch retrieval response to LiteLLM format.
Args:
model: Model name
raw_response: Raw HTTP response
logging_obj: Logging object
litellm_params: LiteLLM parameters
Returns:
LiteLLM batch object
"""
pass
@abstractmethod
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> "BaseLLMException":
"""
Get the appropriate error class for this provider.
Args:
error_message: Error message
status_code: HTTP status code
headers: Response headers
Returns:
Provider-specific exception class
"""
pass

View File

@@ -0,0 +1,55 @@
"""
Bridge for transforming API requests to another API requests
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
if TYPE_CHECKING:
from pydantic import BaseModel
from litellm import LiteLLMLoggingObj, ModelResponse
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.llms.openai import AllMessageValues
class CompletionTransformationBridge(ABC):
@abstractmethod
def transform_request(
self,
model: str,
messages: List["AllMessageValues"],
optional_params: dict,
litellm_params: dict,
headers: dict,
litellm_logging_obj: "LiteLLMLoggingObj",
) -> dict:
"""Transform /chat/completions api request to another request"""
pass
@abstractmethod
def transform_response(
self,
model: str,
raw_response: "BaseModel", # the response from the other API
model_response: "ModelResponse",
logging_obj: "LiteLLMLoggingObj",
request_data: dict,
messages: List["AllMessageValues"],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> "ModelResponse":
"""Transform another response to /chat/completions api response"""
pass
@abstractmethod
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], "ModelResponse"],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> "BaseModelResponseIterator":
pass

View File

@@ -0,0 +1,466 @@
"""
Common base config for all LLM providers
"""
import types
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import httpx
from pydantic import BaseModel
from litellm.constants import DEFAULT_MAX_TOKENS, RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.utils import ModelResponse
from ..base_utils import (
map_developer_role_to_system_role,
type_to_response_format_param,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseLLMException(Exception):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
body: Optional[dict] = None,
):
self.status_code = status_code
self.message: str = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(
method="POST", url="https://docs.litellm.ai/docs"
)
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
self.body = body
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class BaseConfig(ABC):
def __init__(self):
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not k.startswith("_is_base_class")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
property,
),
)
and v is not None
and not callable(v) # Filter out any callable objects including mocks
}
def get_json_schema_from_pydantic_object(
self, response_format: Optional[Union[Type[BaseModel], dict]]
) -> Optional[dict]:
return type_to_response_format_param(response_format=response_format)
def is_thinking_enabled(self, non_default_params: dict) -> bool:
return (
non_default_params.get("thinking", {}).get("type") == "enabled"
or non_default_params.get("reasoning_effort") is not None
)
def is_max_tokens_in_request(self, non_default_params: dict) -> bool:
"""
OpenAI spec allows max_tokens or max_completion_tokens to be specified.
"""
return (
"max_tokens" in non_default_params
or "max_completion_tokens" in non_default_params
)
def update_optional_params_with_thinking_tokens(
self, non_default_params: dict, optional_params: dict
):
"""
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.
Checks 'non_default_params' for 'thinking' and 'max_tokens'
if 'thinking' is enabled and 'max_tokens' or 'max_completion_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
"""
is_thinking_enabled = self.is_thinking_enabled(optional_params)
if is_thinking_enabled and (
"max_tokens" not in non_default_params
and "max_completion_tokens" not in non_default_params
):
thinking_token_budget = cast(dict, optional_params["thinking"]).get(
"budget_tokens", None
)
if thinking_token_budget is not None:
optional_params["max_tokens"] = (
thinking_token_budget + DEFAULT_MAX_TOKENS
)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Returns True if the model/provider should fake stream
"""
return False
def _add_tools_to_optional_params(self, optional_params: dict, tools: List) -> dict:
"""
Helper util to add tools to optional_params.
"""
if "tools" not in optional_params:
optional_params["tools"] = tools
else:
optional_params["tools"] = [
*optional_params["tools"],
*tools,
]
return optional_params
def translate_developer_role_to_system_role(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
Overriden by OpenAI/Azure
"""
return map_developer_role_to_system_role(messages=messages)
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
"""
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
Overriden by azure ai - where different models support different parameters
"""
return False
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
"""
Transform the request data on UnprocessableEntityError
"""
return request_data
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
"""
Returns the max retry count for UnprocessableEntityError
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
"""
return 0
@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
def _add_response_format_to_tools(
self,
optional_params: dict,
value: dict,
is_response_format_supported: bool,
enforce_tool_choice: bool = True,
) -> dict:
"""
Follow similar approach to anthropic - translate to a single tool call.
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
Add response format to tools
This is used to translate response_format to a tool call, for models/APIs that don't support response_format directly.
"""
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=RESPONSE_FORMAT_TOOL_NAME
),
)
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=RESPONSE_FORMAT_TOOL_NAME, parameters=json_schema
),
)
optional_params.setdefault("tools", [])
optional_params["tools"].append(_tool)
if enforce_tool_choice:
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
elif is_response_format_supported:
optional_params["response_format"] = value
return optional_params
@abstractmethod
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
pass
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
"""
Some providers like Bedrock require signing the request. The sign request funtion needs access to `request_data` and `complete_url`
Args:
headers: dict
optional_params: dict
request_data: dict - the request body being sent in http request
api_base: str - the complete url being sent in http request
Returns:
dict - the signed headers
Update the headers with the signed headers in this function. The return values will be sent as headers in the http request.
"""
return headers, None
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
pass
async def async_transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Override to allow for http requests on async calls - e.g. converting url to base64
Currently only used by openai.py
"""
return self.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
@abstractmethod
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: "ModelResponse",
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> "ModelResponse":
pass
@abstractmethod
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
pass
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], "ModelResponse"],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
pass
async def get_async_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> "CustomStreamWrapper":
raise NotImplementedError
def get_sync_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> "CustomStreamWrapper":
raise NotImplementedError
@property
def custom_llm_provider(self) -> Optional[str]:
return None
@property
def has_custom_stream_wrapper(self) -> bool:
return False
@property
def supports_stream_param_in_request_body(self) -> bool:
"""
Some providers like Bedrock invoke do not support the stream parameter in the request body.
By default, this is true for almost all providers.
"""
return True
def post_stream_processing(self, stream: Any) -> Any:
"""Hook for providers to post-process streaming responses. Default: pass-through."""
return stream
def calculate_additional_costs(
self, model: str, prompt_tokens: int, completion_tokens: int
) -> Optional[dict]:
"""
Calculate any additional costs beyond standard token costs.
This is used for provider-specific infrastructure costs, routing fees, etc.
Args:
model: The model name
prompt_tokens: Number of prompt tokens
completion_tokens: Number of completion tokens
Returns:
Optional dictionary with cost names and amounts, e.g.:
{"Infrastructure Fee": 0.001, "Routing Cost": 0.0005}
Returns None if no additional costs apply.
"""
return None

View File

@@ -0,0 +1,75 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseTextCompletionConfig(BaseConfig, ABC):
@abstractmethod
def transform_text_completion_request(
self,
model: str,
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
optional_params: dict,
headers: dict,
) -> dict:
return {}
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
)

View File

@@ -0,0 +1,268 @@
from __future__ import annotations
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import httpx
from litellm.types.containers.main import ContainerCreateOptionalRequestParams
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.types.containers.main import (
ContainerFileListResponse as _ContainerFileListResponse,
)
from litellm.types.containers.main import (
ContainerListResponse as _ContainerListResponse,
)
from litellm.types.containers.main import ContainerObject as _ContainerObject
from litellm.types.containers.main import (
DeleteContainerResult as _DeleteContainerResult,
)
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
ContainerObject = _ContainerObject
DeleteContainerResult = _DeleteContainerResult
ContainerListResponse = _ContainerListResponse
ContainerFileListResponse = _ContainerFileListResponse
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
ContainerObject = Any
DeleteContainerResult = Any
ContainerListResponse = Any
ContainerFileListResponse = Any
class BaseContainerConfig(ABC):
def __init__(self):
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def get_supported_openai_params(self) -> list:
pass
@abstractmethod
def map_openai_params(
self,
container_create_optional_params: ContainerCreateOptionalRequestParams,
drop_params: bool,
) -> dict:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
api_key: str | None = None,
) -> dict:
return {}
@abstractmethod
def get_complete_url(
self,
api_base: str | None,
litellm_params: dict,
) -> str:
"""Get the complete url for the request.
OPTIONAL - Some providers need `model` in `api_base`.
"""
if api_base is None:
msg = "api_base is required"
raise ValueError(msg)
return api_base
@abstractmethod
def transform_container_create_request(
self,
name: str,
container_create_optional_request_params: dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> dict:
"""Transform the container creation request.
Returns:
dict: Request data for container creation.
"""
...
@abstractmethod
def transform_container_create_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerObject:
"""Transform the container creation response."""
...
@abstractmethod
def transform_container_list_request(
self,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: str | None = None,
limit: int | None = None,
order: str | None = None,
extra_query: dict[str, Any] | None = None,
) -> tuple[str, dict]:
"""Transform the container list request into a URL and params.
Returns:
tuple[str, dict]: (url, params) for the container list request.
"""
...
@abstractmethod
def transform_container_list_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerListResponse:
"""Transform the container list response."""
...
@abstractmethod
def transform_container_retrieve_request(
self,
container_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> tuple[str, dict]:
"""Transform the container retrieve request into a URL and data/params.
Returns:
tuple[str, dict]: (url, params) for the container retrieve request.
"""
...
@abstractmethod
def transform_container_retrieve_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerObject:
"""Transform the container retrieve response."""
...
@abstractmethod
def transform_container_delete_request(
self,
container_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> tuple[str, dict]:
"""Transform the container delete request into a URL and data.
Returns:
tuple[str, dict]: (url, data) for the container delete request.
"""
...
@abstractmethod
def transform_container_delete_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteContainerResult:
"""Transform the container delete response."""
...
@abstractmethod
def transform_container_file_list_request(
self,
container_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: str | None = None,
limit: int | None = None,
order: str | None = None,
extra_query: dict[str, Any] | None = None,
) -> tuple[str, dict]:
"""Transform the container file list request into a URL and params.
Returns:
tuple[str, dict]: (url, params) for the container file list request.
"""
...
@abstractmethod
def transform_container_file_list_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerFileListResponse:
"""Transform the container file list response."""
...
@abstractmethod
def transform_container_file_content_request(
self,
container_id: str,
file_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> tuple[str, dict]:
"""Transform the container file content request into a URL and params.
Returns:
tuple[str, dict]: (url, params) for the container file content request.
"""
...
@abstractmethod
def transform_container_file_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> bytes:
"""Transform the container file content response.
Returns:
bytes: The raw file content.
"""
...
def get_error_class(
self,
error_message: str,
status_code: int,
headers: dict | httpx.Headers,
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseEmbeddingConfig(BaseConfig, ABC):
@abstractmethod
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
return {}
@abstractmethod
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
return model_response
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"EmbeddingConfig does not need a request transformation for chat models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"EmbeddingConfig does not need a response transformation for chat models"
)

View File

@@ -0,0 +1,7 @@
"""
Base configuration for Evals API
"""
from .transformation import BaseEvalsAPIConfig
__all__ = ["BaseEvalsAPIConfig"]

View File

@@ -0,0 +1,542 @@
"""
Base configuration class for Evals API
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai_evals import (
CancelEvalResponse,
CancelRunResponse,
CreateEvalRequest,
CreateRunRequest,
DeleteEvalResponse,
Eval,
ListEvalsParams,
ListEvalsResponse,
ListRunsParams,
ListRunsResponse,
Run,
RunDeleteResponse,
UpdateEvalRequest,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseEvalsAPIConfig(ABC):
"""Base configuration for Evals API providers"""
def __init__(self):
pass
@property
@abstractmethod
def custom_llm_provider(self) -> LlmProviders:
pass
@abstractmethod
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""
Validate and update headers with provider-specific requirements
Args:
headers: Base headers dictionary
litellm_params: LiteLLM parameters
Returns:
Updated headers dictionary
"""
return headers
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
endpoint: str,
eval_id: Optional[str] = None,
) -> str:
"""
Get the complete URL for the API request
Args:
api_base: Base API URL
endpoint: API endpoint (e.g., 'evals', 'evals/{id}')
eval_id: Optional eval ID for specific eval operations
Returns:
Complete URL
"""
if api_base is None:
raise ValueError("api_base is required")
return f"{api_base}/v1/{endpoint}"
@abstractmethod
def transform_create_eval_request(
self,
create_request: CreateEvalRequest,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""
Transform create eval request to provider-specific format
Args:
create_request: Eval creation parameters
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Provider-specific request body
"""
pass
@abstractmethod
def transform_create_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Eval:
"""
Transform provider response to Eval object
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Eval object
"""
pass
@abstractmethod
def transform_list_evals_request(
self,
list_params: ListEvalsParams,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform list evals request parameters
Args:
list_params: List parameters (pagination, filters)
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, query_params)
"""
pass
@abstractmethod
def transform_list_evals_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ListEvalsResponse:
"""
Transform provider response to ListEvalsResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
ListEvalsResponse object
"""
pass
@abstractmethod
def transform_get_eval_request(
self,
eval_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform get eval request
Args:
eval_id: Eval ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers)
"""
pass
@abstractmethod
def transform_get_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Eval:
"""
Transform provider response to Eval object
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Eval object
"""
pass
@abstractmethod
def transform_update_eval_request(
self,
eval_id: str,
update_request: UpdateEvalRequest,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""
Transform update eval request
Args:
eval_id: Eval ID
update_request: Update parameters
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers, body)
"""
pass
@abstractmethod
def transform_update_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Eval:
"""
Transform provider response to Eval object
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Eval object
"""
pass
@abstractmethod
def transform_delete_eval_request(
self,
eval_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform delete eval request
Args:
eval_id: Eval ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers)
"""
pass
@abstractmethod
def transform_delete_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteEvalResponse:
"""
Transform provider response to DeleteEvalResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
DeleteEvalResponse object
"""
pass
@abstractmethod
def transform_cancel_eval_request(
self,
eval_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""
Transform cancel eval request
Args:
eval_id: Eval ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers, body)
"""
pass
@abstractmethod
def transform_cancel_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> CancelEvalResponse:
"""
Transform provider response to CancelEvalResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
CancelEvalResponse object
"""
pass
# Run API Transformations
@abstractmethod
def transform_create_run_request(
self,
eval_id: str,
create_request: CreateRunRequest,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform create run request to provider-specific format
Args:
eval_id: Eval ID
create_request: Run creation parameters
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, request_body)
"""
pass
@abstractmethod
def transform_create_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Run:
"""
Transform provider response to Run object
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Run object
"""
pass
@abstractmethod
def transform_list_runs_request(
self,
eval_id: str,
list_params: ListRunsParams,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform list runs request parameters
Args:
eval_id: Eval ID
list_params: List parameters (pagination, filters)
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, query_params)
"""
pass
@abstractmethod
def transform_list_runs_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ListRunsResponse:
"""
Transform provider response to ListRunsResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
ListRunsResponse object
"""
pass
@abstractmethod
def transform_get_run_request(
self,
eval_id: str,
run_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform get run request
Args:
eval_id: Eval ID
run_id: Run ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers)
"""
pass
@abstractmethod
def transform_get_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Run:
"""
Transform provider response to Run object
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Run object
"""
pass
@abstractmethod
def transform_cancel_run_request(
self,
eval_id: str,
run_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""
Transform cancel run request
Args:
eval_id: Eval ID
run_id: Run ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers, body)
"""
pass
@abstractmethod
def transform_cancel_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> CancelRunResponse:
"""
Transform provider response to CancelRunResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
CancelRunResponse object
"""
pass
@abstractmethod
def transform_delete_run_request(
self,
eval_id: str,
run_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""
Transform delete run request
Args:
eval_id: Eval ID
run_id: Run ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers, body)
"""
pass
@abstractmethod
def transform_delete_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> "RunDeleteResponse":
"""
Transform provider response to RunDeleteResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
RunDeleteResponse object
"""
pass
def get_error_class(
self,
error_message: str,
status_code: int,
headers: dict,
) -> Exception:
"""Get appropriate error class for the provider."""
return BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,325 @@
"""
Azure Blob Storage backend implementation for file storage.
This module implements the Azure Blob Storage backend for storing files
in Azure Data Lake Storage Gen2. It inherits from AzureBlobStorageLogger
to reuse all authentication and Azure Storage operations.
"""
import time
from typing import Optional
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from .storage_backend import BaseFileStorageBackend
from litellm.integrations.azure_storage.azure_storage import AzureBlobStorageLogger
class AzureBlobStorageBackend(BaseFileStorageBackend, AzureBlobStorageLogger):
"""
Azure Blob Storage backend implementation.
Inherits from AzureBlobStorageLogger to reuse:
- Authentication (account key and Azure AD)
- Service client management
- Token management
- All Azure Storage helper methods
Reads configuration from the same environment variables as AzureBlobStorageLogger.
"""
def __init__(self, **kwargs):
"""
Initialize Azure Blob Storage backend.
Inherits all functionality from AzureBlobStorageLogger which handles:
- Reading environment variables
- Authentication (account key and Azure AD)
- Service client management
- Token management
Environment variables (same as AzureBlobStorageLogger):
- AZURE_STORAGE_ACCOUNT_NAME (required)
- AZURE_STORAGE_FILE_SYSTEM (required)
- AZURE_STORAGE_ACCOUNT_KEY (optional, if using account key auth)
- AZURE_STORAGE_TENANT_ID (optional, if using Azure AD)
- AZURE_STORAGE_CLIENT_ID (optional, if using Azure AD)
- AZURE_STORAGE_CLIENT_SECRET (optional, if using Azure AD)
Note: We skip periodic_flush since we're not using this as a logger.
"""
# Initialize AzureBlobStorageLogger (handles all auth and config)
AzureBlobStorageLogger.__init__(self, **kwargs)
# Disable logging functionality - we're only using this for file storage
# The periodic_flush task will be created but will do nothing since we override it
async def periodic_flush(self):
"""
Override to do nothing - we're not using this as a logger.
This prevents the periodic flush task from doing any work.
"""
# Do nothing - this class is used for file storage, not logging
return
async def async_log_success_event(self, *args, **kwargs):
"""
Override to do nothing - we're not using this as a logger.
"""
# Do nothing - this class is used for file storage, not logging
pass
async def async_log_failure_event(self, *args, **kwargs):
"""
Override to do nothing - we're not using this as a logger.
"""
# Do nothing - this class is used for file storage, not logging
pass
def _generate_file_name(
self, original_filename: str, file_naming_strategy: str
) -> str:
"""Generate file name based on naming strategy."""
if file_naming_strategy == "original_filename":
# Use original filename, but sanitize it
return quote(original_filename, safe="")
elif file_naming_strategy == "timestamp":
# Use timestamp
extension = (
original_filename.split(".")[-1] if "." in original_filename else ""
)
timestamp = int(time.time() * 1000) # milliseconds
return f"{timestamp}.{extension}" if extension else str(timestamp)
else: # default to "uuid"
# Use UUID
extension = (
original_filename.split(".")[-1] if "." in original_filename else ""
)
file_uuid = str(uuid.uuid4())
return f"{file_uuid}.{extension}" if extension else file_uuid
async def upload_file(
self,
file_content: bytes,
filename: str,
content_type: str,
path_prefix: Optional[str] = None,
file_naming_strategy: str = "uuid",
) -> str:
"""
Upload a file to Azure Blob Storage.
Returns the blob URL in format: https://{account}.blob.core.windows.net/{container}/{path}
"""
try:
# Generate file name
file_name = self._generate_file_name(filename, file_naming_strategy)
# Build full path
if path_prefix:
# Remove leading/trailing slashes and normalize
prefix = path_prefix.strip("/")
full_path = f"{prefix}/{file_name}"
else:
full_path = file_name
if self.azure_storage_account_key:
# Use Azure SDK with account key (reuse logger's method)
storage_url = await self._upload_file_with_account_key(
file_content=file_content,
full_path=full_path,
)
else:
# Use REST API with Azure AD token (reuse logger's methods)
storage_url = await self._upload_file_with_azure_ad(
file_content=file_content,
full_path=full_path,
)
verbose_logger.debug(
f"Successfully uploaded file to Azure Blob Storage: {storage_url}"
)
return storage_url
except Exception as e:
verbose_logger.exception(
f"Error uploading file to Azure Blob Storage: {str(e)}"
)
raise
async def _upload_file_with_account_key(
self, file_content: bytes, full_path: str
) -> str:
"""Upload file using Azure SDK with account key authentication."""
# Reuse the logger's service client method
service_client = await self.get_service_client()
file_system_client = service_client.get_file_system_client(
file_system=self.azure_storage_file_system
)
# Create filesystem (container) if it doesn't exist
if not await file_system_client.exists():
await file_system_client.create_file_system()
verbose_logger.debug(
f"Created filesystem: {self.azure_storage_file_system}"
)
# Extract directory and filename (similar to logger's pattern)
path_parts = full_path.split("/")
if len(path_parts) > 1:
directory_path = "/".join(path_parts[:-1])
file_name = path_parts[-1]
# Create directory if needed (like logger does)
directory_client = file_system_client.get_directory_client(directory_path)
if not await directory_client.exists():
await directory_client.create_directory()
verbose_logger.debug(f"Created directory: {directory_path}")
# Get file client from directory (same pattern as logger)
file_client = directory_client.get_file_client(file_name)
else:
# No directory, create file directly in root
file_client = file_system_client.get_file_client(full_path)
# Create, append, and flush (same pattern as logger's upload_to_azure_data_lake_with_azure_account_key)
await file_client.create_file()
await file_client.append_data(
data=file_content, offset=0, length=len(file_content)
)
await file_client.flush_data(position=len(file_content), offset=0)
# Return blob URL (not DFS URL)
blob_url = f"https://{self.azure_storage_account_name}.blob.core.windows.net/{self.azure_storage_file_system}/{full_path}"
return blob_url
async def _upload_file_with_azure_ad(
self, file_content: bytes, full_path: str
) -> str:
"""Upload file using REST API with Azure AD authentication."""
# Reuse the logger's token management
await self.set_valid_azure_ad_token()
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
# Use DFS endpoint for upload
base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{full_path}"
# Execute 3-step upload process: create, append, flush
# Reuse the logger's helper methods
await self._create_file(async_client, base_url)
# Append data - logger's _append_data expects string, so we create our own for bytes
await self._append_data_bytes(async_client, base_url, file_content)
await self._flush_data(async_client, base_url, len(file_content))
# Return blob URL (not DFS URL)
blob_url = f"https://{self.azure_storage_account_name}.blob.core.windows.net/{self.azure_storage_file_system}/{full_path}"
return blob_url
async def _append_data_bytes(self, client, base_url: str, file_content: bytes):
"""Append binary data to file using REST API."""
from litellm.constants import AZURE_STORAGE_MSFT_VERSION
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Content-Type": "application/octet-stream",
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await client.patch(
f"{base_url}?action=append&position=0",
headers=headers,
content=file_content,
)
response.raise_for_status()
async def download_file(self, storage_url: str) -> bytes:
"""
Download a file from Azure Blob Storage.
Args:
storage_url: Blob URL in format: https://{account}.blob.core.windows.net/{container}/{path}
Returns:
bytes: File content
"""
try:
# Parse blob URL to extract path
# URL format: https://{account}.blob.core.windows.net/{container}/{path}
if ".blob.core.windows.net/" not in storage_url:
raise ValueError(f"Invalid Azure Blob Storage URL: {storage_url}")
# Extract path after container name
container_and_path = storage_url.split(".blob.core.windows.net/", 1)[1]
path_parts = container_and_path.split("/", 1)
if len(path_parts) < 2:
raise ValueError(
f"Invalid Azure Blob Storage URL format: {storage_url}"
)
file_path = path_parts[1] # Path after container name
if self.azure_storage_account_key:
# Use Azure SDK (reuse logger's service client)
return await self._download_file_with_account_key(file_path)
else:
# Use REST API (reuse logger's token management)
return await self._download_file_with_azure_ad(file_path)
except Exception as e:
verbose_logger.exception(
f"Error downloading file from Azure Blob Storage: {str(e)}"
)
raise
async def _download_file_with_account_key(self, file_path: str) -> bytes:
"""Download file using Azure SDK with account key."""
# Reuse the logger's service client method
service_client = await self.get_service_client()
file_system_client = service_client.get_file_system_client(
file_system=self.azure_storage_file_system
)
# Ensure filesystem exists (should already exist, but check for safety)
if not await file_system_client.exists():
raise ValueError(
f"Filesystem {self.azure_storage_file_system} does not exist"
)
file_client = file_system_client.get_file_client(file_path)
# Download file
download_response = await file_client.download_file()
file_content = await download_response.readall()
return file_content
async def _download_file_with_azure_ad(self, file_path: str) -> bytes:
"""Download file using REST API with Azure AD token."""
# Reuse the logger's token management
await self.set_valid_azure_ad_token()
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.constants import AZURE_STORAGE_MSFT_VERSION
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
# Use blob endpoint for download (simpler than DFS)
blob_url = f"https://{self.azure_storage_account_name}.blob.core.windows.net/{self.azure_storage_file_system}/{file_path}"
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await async_client.get(blob_url, headers=headers)
response.raise_for_status()
return response.content

View File

@@ -0,0 +1,78 @@
"""
Base storage backend interface for file storage backends.
This module defines the abstract base class that all file storage backends
(e.g., Azure Blob Storage, S3, GCS) must implement.
"""
from abc import ABC, abstractmethod
from typing import Optional
class BaseFileStorageBackend(ABC):
"""
Abstract base class for file storage backends.
All storage backends (Azure Blob Storage, S3, GCS, etc.) must implement
these methods to provide a consistent interface for file operations.
"""
@abstractmethod
async def upload_file(
self,
file_content: bytes,
filename: str,
content_type: str,
path_prefix: Optional[str] = None,
file_naming_strategy: str = "uuid",
) -> str:
"""
Upload a file to the storage backend.
Args:
file_content: The file content as bytes
filename: Original filename (may be used for naming strategy)
content_type: MIME type of the file
path_prefix: Optional path prefix for organizing files
file_naming_strategy: Strategy for naming files ("uuid", "timestamp", "original_filename")
Returns:
str: The storage URL where the file can be accessed/downloaded
Raises:
Exception: If upload fails
"""
pass
@abstractmethod
async def download_file(self, storage_url: str) -> bytes:
"""
Download a file from the storage backend.
Args:
storage_url: The storage URL returned from upload_file
Returns:
bytes: The file content
Raises:
Exception: If download fails
"""
pass
async def delete_file(self, storage_url: str) -> None:
"""
Delete a file from the storage backend.
This is optional and can be overridden by backends that support deletion.
Default implementation does nothing.
Args:
storage_url: The storage URL of the file to delete
Raises:
Exception: If deletion fails
"""
# Default implementation: no-op
# Backends can override if they support deletion
pass

View File

@@ -0,0 +1,40 @@
"""
Factory for creating storage backend instances.
This module provides a factory function to instantiate the correct storage backend
based on the backend type. Backends use the same configuration as their corresponding
callbacks (e.g., azure_storage uses the same env vars as AzureBlobStorageLogger).
"""
from litellm._logging import verbose_logger
from .azure_blob_storage_backend import AzureBlobStorageBackend
from .storage_backend import BaseFileStorageBackend
def get_storage_backend(backend_type: str) -> BaseFileStorageBackend:
"""
Factory function to create a storage backend instance.
Backends are configured using the same environment variables as their
corresponding callbacks. For example, "azure_storage" uses the same
env vars as AzureBlobStorageLogger.
Args:
backend_type: Backend type identifier (e.g., "azure_storage")
Returns:
BaseFileStorageBackend: Instance of the appropriate storage backend
Raises:
ValueError: If backend_type is not supported
"""
verbose_logger.debug(f"Creating storage backend: type={backend_type}")
if backend_type == "azure_storage":
return AzureBlobStorageBackend()
else:
raise ValueError(
f"Unsupported storage backend type: {backend_type}. "
f"Supported types: azure_storage"
)

View File

@@ -0,0 +1,261 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from openai.types.file_deleted import FileDeleted
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.files import TwoStepFileUploadConfig
from litellm.types.llms.openai import (
AllMessageValues,
CreateFileRequest,
FileContentRequest,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import LlmProviders, ModelResponse
from ..chat.transformation import BaseConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.router import Router as _Router
from litellm.types.llms.openai import HttpxBinaryResponseContent
LiteLLMLoggingObj = _LiteLLMLoggingObj
Span = Any
Router = _Router
else:
LiteLLMLoggingObj = Any
Span = Any
Router = Any
class BaseFilesConfig(BaseConfig):
@property
@abstractmethod
def custom_llm_provider(self) -> LlmProviders:
pass
@property
def file_upload_http_method(self) -> str:
"""
HTTP method to use for file uploads.
Override this in provider configs if they need different methods.
Default is POST (used by most providers like OpenAI, Anthropic).
S3-based providers like Bedrock should return "PUT".
"""
return "POST"
@abstractmethod
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
pass
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
data: CreateFileRequest,
):
return self.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params,
)
@abstractmethod
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[dict, str, bytes, "TwoStepFileUploadConfig"]:
"""
Transform OpenAI-style file creation request into provider-specific format.
Returns:
- dict: For pre-signed single-step uploads (e.g., Bedrock S3)
- str/bytes: For traditional file uploads
- TwoStepFileUploadConfig: For two-step upload process (e.g., Manus, GCS)
"""
pass
@abstractmethod
def transform_create_file_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
pass
@abstractmethod
def transform_retrieve_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
"""Transform file retrieve request into provider-specific format."""
pass
@abstractmethod
def transform_retrieve_file_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""Transform file retrieve response into OpenAI format."""
pass
@abstractmethod
def transform_delete_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
"""Transform file delete request into provider-specific format."""
pass
@abstractmethod
def transform_delete_file_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> "FileDeleted":
"""Transform file delete response into OpenAI format."""
pass
@abstractmethod
def transform_list_files_request(
self,
purpose: Optional[str],
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
"""Transform file list request into provider-specific format."""
pass
@abstractmethod
def transform_list_files_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> List[OpenAIFileObject]:
"""Transform file list response into OpenAI format."""
pass
@abstractmethod
def transform_file_content_request(
self,
file_content_request: "FileContentRequest",
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
"""Transform file content request into provider-specific format."""
pass
@abstractmethod
def transform_file_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> "HttpxBinaryResponseContent":
"""Transform file content response into OpenAI format."""
pass
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
)
class BaseFileEndpoints(ABC):
@abstractmethod
async def acreate_file(
self,
create_file_request: CreateFileRequest,
llm_router: Router,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
user_api_key_dict: UserAPIKeyAuth,
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_retrieve(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Optional[Router] = None,
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_list(
self,
purpose: Optional[OpenAIFilesPurpose],
litellm_parent_otel_span: Optional[Span],
**data: Dict,
) -> List[OpenAIFileObject]:
pass
@abstractmethod
async def afile_delete(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_content(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> "HttpxBinaryResponseContent":
pass

View File

@@ -0,0 +1,211 @@
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.google_genai.main import (
GenerateContentConfigDict,
GenerateContentContentListUnionDict,
GenerateContentResponse,
ToolConfigDict,
)
else:
GenerateContentConfigDict = Any
GenerateContentContentListUnionDict = Any
GenerateContentResponse = Any
LiteLLMLoggingObj = Any
ToolConfigDict = Any
from litellm.types.router import GenericLiteLLMParams
class BaseGoogleGenAIGenerateContentConfig(ABC):
"""Base configuration class for Google GenAI generate_content functionality"""
def __init__(self):
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def get_supported_generate_content_optional_params(self, model: str) -> List[str]:
"""
Get the list of supported Google GenAI parameters for the model.
Args:
model: The model name
Returns:
List of supported parameter names
"""
raise NotImplementedError(
"get_supported_generate_content_optional_params is not implemented"
)
@abstractmethod
def map_generate_content_optional_params(
self,
generate_content_config_dict: GenerateContentConfigDict,
model: str,
) -> Dict[str, Any]:
"""
Map Google GenAI parameters to provider-specific format.
Args:
generate_content_optional_params: Optional parameters for generate content
model: The model name
Returns:
Mapped parameters for the provider
"""
raise NotImplementedError(
"map_generate_content_optional_params is not implemented"
)
@abstractmethod
def validate_environment(
self,
api_key: Optional[str],
headers: Optional[dict],
model: str,
litellm_params: Optional[Union[GenericLiteLLMParams, dict]],
) -> dict:
"""
Validate the environment and return headers for the request.
Args:
api_key: API key
headers: Existing headers
model: The model name
litellm_params: LiteLLM parameters
Returns:
Updated headers
"""
raise NotImplementedError("validate_environment is not implemented")
def sync_get_auth_token_and_url(
self,
api_base: Optional[str],
model: str,
litellm_params: dict,
stream: bool,
) -> Tuple[dict, str]:
"""
Sync version of get_auth_token_and_url.
Args:
api_base: Base API URL
model: The model name
litellm_params: LiteLLM parameters
stream: Whether this is a streaming call
Returns:
Tuple of headers and API base
"""
raise NotImplementedError("sync_get_auth_token_and_url is not implemented")
async def get_auth_token_and_url(
self,
api_base: Optional[str],
model: str,
litellm_params: dict,
stream: bool,
) -> Tuple[dict, str]:
"""
Get the complete URL for the request.
Args:
api_base: Base API URL
model: The model name
litellm_params: LiteLLM parameters
Returns:
Tuple of headers and API base
"""
raise NotImplementedError("get_auth_token_and_url is not implemented")
@abstractmethod
def transform_generate_content_request(
self,
model: str,
contents: GenerateContentContentListUnionDict,
tools: Optional[ToolConfigDict],
generate_content_config_dict: Dict,
system_instruction: Optional[Any] = None,
) -> dict:
"""
Transform the request parameters for the generate content API.
Args:
model: The model name
contents: Input contents
tools: Tools
generate_content_config_dict: Generation config parameters
system_instruction: Optional system instruction
Returns:
Transformed request data
"""
pass
@abstractmethod
def transform_generate_content_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> GenerateContentResponse:
"""
Transform the raw response from the generate content API.
Args:
model: The model name
raw_response: Raw HTTP response
Returns:
Transformed response data
"""
pass
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> Exception:
"""
Get the appropriate exception class for the error.
Args:
error_message: Error message
status_code: HTTP status code
headers: Response headers
Returns:
Exception instance
"""
from litellm.llms.base_llm.chat.transformation import BaseLLMException
return BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,107 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import UserAPIKeyAuth
class BaseTranslation(ABC):
@staticmethod
def transform_user_api_key_dict_to_metadata(
user_api_key_dict: Optional[Any],
) -> Dict[str, Any]:
"""
Transform user_api_key_dict to a metadata dict with prefixed keys.
Converts keys like 'user_id' to 'user_api_key_user_id' to clearly indicate
the source of the metadata.
Args:
user_api_key_dict: UserAPIKeyAuth object or dict with user information
Returns:
Dict with keys prefixed with 'user_api_key_'
"""
if user_api_key_dict is None:
return {}
# Convert to dict if it's a Pydantic object
user_dict = (
user_api_key_dict.model_dump()
if hasattr(user_api_key_dict, "model_dump")
else user_api_key_dict
)
if not isinstance(user_dict, dict):
return {}
# Transform keys to be prefixed with 'user_api_key_'
transformed = {}
for key, value in user_dict.items():
# Skip None values and internal fields
if value is None or key.startswith("_"):
continue
# If key already has the prefix, use as-is, otherwise add prefix
if key.startswith("user_api_key_"):
transformed[key] = value
else:
transformed[f"user_api_key_{key}"] = value
return transformed
@abstractmethod
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> Any:
"""
Process input messages with guardrails.
Note: user_api_key_dict metadata should be available in the data dict.
"""
pass
@abstractmethod
async def process_output_response(
self,
response: Any,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
) -> Any:
"""
Process output response with guardrails.
Args:
response: The response object from the LLM
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata (passed separately since response doesn't contain it)
"""
pass
async def process_output_streaming_response(
self,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
) -> Any:
"""
Process output streaming response with guardrails.
Optional to override in subclasses.
"""
return responses_so_far
def extract_request_tool_names(self, data: dict) -> List[str]:
"""
Extract tool names from the request body for allowlist/policy checks.
Override in tool-capable handlers; default returns [].
"""
return []

View File

@@ -0,0 +1,130 @@
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import httpx
from httpx._types import RequestFiles
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.utils import ImageResponse as _ImageResponse
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
ImageResponse = _ImageResponse
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
ImageResponse = Any
class BaseImageEditConfig(ABC):
def __init__(self):
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
@abstractmethod
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
return {}
@abstractmethod
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_image_edit_request(
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles]:
pass
@abstractmethod
def transform_image_edit_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ImageResponse:
pass
def use_multipart_form_data(self) -> bool:
"""
Return True if the provider uses multipart/form-data for image edit requests.
Return False if the provider uses JSON requests.
Default is True for backwards compatibility with OpenAI-style providers.
"""
return True
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,112 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageGenerationOptionalParams,
)
from litellm.types.utils import ImageResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseImageGenerationConfig(ABC):
@abstractmethod
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
pass
@abstractmethod
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
pass
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {}
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"ImageVariationConfig implementa 'transform_request_image_variation' for image variation models"
)
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
raise NotImplementedError(
"ImageVariationConfig implements 'transform_response_image_variation' for image variation models"
)
def use_multipart_form_data(self) -> bool:
"""
Returns True if this provider requires multipart/form-data instead of JSON.
Override this method in subclasses that need form-data (e.g., Stability AI).
"""
return False

View File

@@ -0,0 +1,134 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from aiohttp import ClientResponse
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageVariationOptionalParams,
)
from litellm.types.utils import (
FileTypes,
HttpHandlerRequestFields,
ImageResponse,
ModelResponse,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseImageVariationConfig(BaseConfig, ABC):
@abstractmethod
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageVariationOptionalParams]:
pass
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def transform_request_image_variation(
self,
model: Optional[str],
image: FileTypes,
optional_params: dict,
headers: dict,
) -> HttpHandlerRequestFields:
pass
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {}
@abstractmethod
async def async_transform_response_image_variation(
self,
model: Optional[str],
raw_response: ClientResponse,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
pass
@abstractmethod
def transform_response_image_variation(
self,
model: Optional[str],
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
pass
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"ImageVariationConfig implementa 'transform_request_image_variation' for image variation models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"ImageVariationConfig implements 'transform_response_image_variation' for image variation models"
)

View File

@@ -0,0 +1,5 @@
"""Base classes for Interactions API implementations."""
from litellm.llms.base_llm.interactions.transformation import BaseInteractionsAPIConfig
__all__ = ["BaseInteractionsAPIConfig"]

View File

@@ -0,0 +1,310 @@
"""
Base transformation class for Interactions API implementations.
This follows the same pattern as BaseResponsesAPIConfig for the Responses API.
Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json):
- Create: POST /{api_version}/interactions
- Get: GET /{api_version}/interactions/{interaction_id}
- Delete: DELETE /{api_version}/interactions/{interaction_id}
"""
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm.types.interactions import (
CancelInteractionResult,
DeleteInteractionResult,
InteractionInput,
InteractionsAPIOptionalRequestParams,
InteractionsAPIResponse,
InteractionsAPIStreamingResponse,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class BaseInteractionsAPIConfig(ABC):
"""
Base configuration class for Google Interactions API implementations.
Per OpenAPI spec, the Interactions API supports two types of interactions:
- Model interactions (with model parameter)
- Agent interactions (with agent parameter)
Implementations should override the abstract methods to provide
provider-specific transformations for requests and responses.
"""
def __init__(self):
pass
@property
@abstractmethod
def custom_llm_provider(self) -> LlmProviders:
"""Return the LLM provider identifier."""
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def get_supported_params(self, model: str) -> List[str]:
"""
Return the list of supported parameters for the given model.
"""
pass
@abstractmethod
def validate_environment(
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""
Validate and prepare environment settings including headers.
"""
return {}
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
model: Optional[str],
agent: Optional[str] = None,
litellm_params: Optional[dict] = None,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the interaction request.
Per OpenAPI spec: POST /{api_version}/interactions
Args:
api_base: Base URL for the API
model: The model name (for model interactions)
agent: The agent name (for agent interactions)
litellm_params: LiteLLM parameters
stream: Whether this is a streaming request
Returns:
The complete URL for the request
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_request(
self,
model: Optional[str],
agent: Optional[str],
input: Optional[InteractionInput],
optional_params: InteractionsAPIOptionalRequestParams,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""
Transform the input request into the provider's expected format.
Per OpenAPI spec, the request body should be either:
- CreateModelInteractionParams (with model)
- CreateAgentInteractionParams (with agent)
Args:
model: The model name (for model interactions)
agent: The agent name (for agent interactions)
input: The input content (string, content object, or list)
optional_params: Optional parameters for the request
litellm_params: LiteLLM-specific parameters
headers: Request headers
Returns:
The transformed request body as a dictionary
"""
pass
@abstractmethod
def transform_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> InteractionsAPIResponse:
"""
Transform the raw HTTP response into an InteractionsAPIResponse.
Per OpenAPI spec, the response is an Interaction object.
"""
pass
@abstractmethod
def transform_streaming_response(
self,
model: Optional[str],
parsed_chunk: dict,
logging_obj: LiteLLMLoggingObj,
) -> InteractionsAPIStreamingResponse:
"""
Transform a parsed streaming response chunk into an InteractionsAPIStreamingResponse.
Per OpenAPI spec, streaming uses SSE with various event types.
"""
pass
# =========================================================
# GET INTERACTION TRANSFORMATION
# =========================================================
@abstractmethod
def transform_get_interaction_request(
self,
interaction_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the get interaction request into URL and query params.
Per OpenAPI spec: GET /{api_version}/interactions/{interaction_id}
Returns:
Tuple of (URL, query_params)
"""
pass
@abstractmethod
def transform_get_interaction_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> InteractionsAPIResponse:
"""
Transform the get interaction response.
"""
pass
# =========================================================
# DELETE INTERACTION TRANSFORMATION
# =========================================================
@abstractmethod
def transform_delete_interaction_request(
self,
interaction_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the delete interaction request into URL and body.
Per OpenAPI spec: DELETE /{api_version}/interactions/{interaction_id}
Returns:
Tuple of (URL, request_body)
"""
pass
@abstractmethod
def transform_delete_interaction_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
interaction_id: str,
) -> DeleteInteractionResult:
"""
Transform the delete interaction response.
"""
pass
# =========================================================
# CANCEL INTERACTION TRANSFORMATION
# =========================================================
@abstractmethod
def transform_cancel_interaction_request(
self,
interaction_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the cancel interaction request into URL and body.
Returns:
Tuple of (URL, request_body)
"""
pass
@abstractmethod
def transform_cancel_interaction_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> CancelInteractionResult:
"""
Transform the cancel interaction response.
"""
pass
# =========================================================
# ERROR HANDLING
# =========================================================
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
"""
Get the appropriate exception class for an error.
"""
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Returns True if litellm should fake a stream for the given model.
Override in subclasses if the provider doesn't support native streaming.
"""
return False

View File

@@ -0,0 +1,41 @@
"""
Managed Resources Module
This module provides base classes and utilities for managing resources
(files, vector stores, etc.) with target_model_names support.
The BaseManagedResource class provides common functionality for:
- Storing unified resource IDs with model mappings
- Retrieving resources by unified ID
- Deleting resources across multiple models
- Creating resources for multiple models
- Filtering deployments based on model mappings
"""
from .base_managed_resource import BaseManagedResource
from .utils import (
decode_unified_id,
encode_unified_id,
extract_model_id_from_unified_id,
extract_provider_resource_id_from_unified_id,
extract_resource_type_from_unified_id,
extract_target_model_names_from_unified_id,
extract_unified_uuid_from_unified_id,
generate_unified_id_string,
is_base64_encoded_unified_id,
parse_unified_id,
)
__all__ = [
"BaseManagedResource",
"is_base64_encoded_unified_id",
"extract_target_model_names_from_unified_id",
"extract_resource_type_from_unified_id",
"extract_unified_uuid_from_unified_id",
"extract_model_id_from_unified_id",
"extract_provider_resource_id_from_unified_id",
"generate_unified_id_string",
"encode_unified_id",
"decode_unified_id",
"parse_unified_id",
]

View File

@@ -0,0 +1,607 @@
# What is this?
## Base class for managing resources (files, vector stores, etc.) with target_model_names support
## This provides common functionality for creating, retrieving, and managing resources across multiple models
import base64
import json
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
cast,
)
from litellm import verbose_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
from litellm.router import Router as _Router
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
Router = _Router
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
Router = Any
# Generic type for resource objects
ResourceObjectType = TypeVar("ResourceObjectType")
class BaseManagedResource(ABC, Generic[ResourceObjectType]):
"""
Base class for managing resources with target_model_names support.
This class provides common functionality for:
- Storing unified resource IDs with model mappings
- Retrieving resources by unified ID
- Deleting resources across multiple models
- Creating resources for multiple models
- Filtering deployments based on model mappings
Subclasses should implement:
- resource_type: str property
- table_name: str property
- create_resource_for_model: method to create resource on a specific model
- get_unified_resource_id_format: method to generate unified ID format
"""
def __init__(
self,
internal_usage_cache: InternalUsageCache,
prisma_client: PrismaClient,
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
# ============================================================================
# ABSTRACT METHODS
# ============================================================================
@property
@abstractmethod
def resource_type(self) -> str:
"""
Return the resource type identifier (e.g., 'file', 'vector_store', 'vector_store_file').
Used for logging and unified ID generation.
"""
pass
@property
@abstractmethod
def table_name(self) -> str:
"""
Return the database table name for this resource type.
Example: 'litellm_managedfiletable', 'litellm_managedvectorstoretable'
"""
pass
@abstractmethod
def get_unified_resource_id_format(
self,
resource_object: ResourceObjectType,
target_model_names_list: List[str],
) -> str:
"""
Generate the format string for the unified resource ID.
This should return a string that will be base64 encoded.
Example for files:
"litellm_proxy:application/json;unified_id,{uuid};target_model_names,{models};..."
Args:
resource_object: The resource object returned from the provider
target_model_names_list: List of target model names
Returns:
Format string to be base64 encoded
"""
pass
@abstractmethod
async def create_resource_for_model(
self,
llm_router: Router,
model: str,
request_data: Dict[str, Any],
litellm_parent_otel_span: Span,
) -> ResourceObjectType:
"""
Create a resource for a specific model.
Args:
llm_router: LiteLLM router instance
model: Model name to create resource for
request_data: Request data for resource creation
litellm_parent_otel_span: OpenTelemetry span for tracing
Returns:
Resource object from the provider
"""
pass
# ============================================================================
# COMMON STORAGE OPERATIONS
# ============================================================================
async def store_unified_resource_id(
self,
unified_resource_id: str,
resource_object: Optional[ResourceObjectType],
litellm_parent_otel_span: Optional[Span],
model_mappings: Dict[str, str],
user_api_key_dict: UserAPIKeyAuth,
additional_db_fields: Optional[Dict[str, Any]] = None,
) -> None:
"""
Store unified resource ID with model mappings in cache and database.
Args:
unified_resource_id: The unified resource ID (base64 encoded)
resource_object: The resource object to store (can be None)
litellm_parent_otel_span: OpenTelemetry span for tracing
model_mappings: Dictionary mapping model_id -> provider_resource_id
user_api_key_dict: User API key authentication details
additional_db_fields: Additional fields to store in database
"""
verbose_logger.info(
f"Storing LiteLLM Managed {self.resource_type} with id={unified_resource_id} in cache"
)
# Prepare cache data
cache_data = {
"unified_resource_id": unified_resource_id,
"resource_object": resource_object,
"model_mappings": model_mappings,
"flat_model_resource_ids": list(model_mappings.values()),
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
}
# Add additional fields if provided
if additional_db_fields:
cache_data.update(additional_db_fields)
# Store in cache
if resource_object is not None:
await self.internal_usage_cache.async_set_cache(
key=unified_resource_id,
value=cache_data,
litellm_parent_otel_span=litellm_parent_otel_span,
)
# Prepare database data
db_data = {
"unified_resource_id": unified_resource_id,
"model_mappings": json.dumps(model_mappings),
"flat_model_resource_ids": list(model_mappings.values()),
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
}
# Add resource object if available
if resource_object is not None:
# Handle both dict and Pydantic models
if hasattr(resource_object, "model_dump_json"):
db_data["resource_object"] = resource_object.model_dump_json() # type: ignore
elif isinstance(resource_object, dict):
db_data["resource_object"] = json.dumps(resource_object)
# Extract storage metadata from hidden params if present
hidden_params = getattr(resource_object, "_hidden_params", {}) or {}
if "storage_backend" in hidden_params:
db_data["storage_backend"] = hidden_params["storage_backend"]
if "storage_url" in hidden_params:
db_data["storage_url"] = hidden_params["storage_url"]
# Add additional fields to database
if additional_db_fields:
db_data.update(additional_db_fields)
# Store in database
table = getattr(self.prisma_client.db, self.table_name)
result = await table.create(data=db_data)
verbose_logger.debug(
f"LiteLLM Managed {self.resource_type} with id={unified_resource_id} stored in db: {result}"
)
async def get_unified_resource_id(
self,
unified_resource_id: str,
litellm_parent_otel_span: Optional[Span] = None,
) -> Optional[Dict[str, Any]]:
"""
Retrieve unified resource by ID from cache or database.
Args:
unified_resource_id: The unified resource ID to retrieve
litellm_parent_otel_span: OpenTelemetry span for tracing
Returns:
Dictionary containing resource data or None if not found
"""
# Check cache first
result = cast(
Optional[dict],
await self.internal_usage_cache.async_get_cache(
key=unified_resource_id,
litellm_parent_otel_span=litellm_parent_otel_span,
),
)
if result:
return result
# Check database
table = getattr(self.prisma_client.db, self.table_name)
db_object = await table.find_first(
where={"unified_resource_id": unified_resource_id}
)
if db_object:
return db_object.model_dump()
return None
async def delete_unified_resource_id(
self,
unified_resource_id: str,
litellm_parent_otel_span: Optional[Span] = None,
) -> Optional[ResourceObjectType]:
"""
Delete unified resource from cache and database.
Args:
unified_resource_id: The unified resource ID to delete
litellm_parent_otel_span: OpenTelemetry span for tracing
Returns:
The deleted resource object or None if not found
"""
# Get old value from database
table = getattr(self.prisma_client.db, self.table_name)
initial_value = await table.find_first(
where={"unified_resource_id": unified_resource_id}
)
if initial_value is None:
raise Exception(
f"LiteLLM Managed {self.resource_type} with id={unified_resource_id} not found"
)
# Delete from cache
await self.internal_usage_cache.async_set_cache(
key=unified_resource_id,
value=None,
litellm_parent_otel_span=litellm_parent_otel_span,
)
# Delete from database
await table.delete(where={"unified_resource_id": unified_resource_id})
return initial_value.resource_object
async def can_user_access_unified_resource_id(
self,
unified_resource_id: str,
user_api_key_dict: UserAPIKeyAuth,
litellm_parent_otel_span: Optional[Span] = None,
) -> bool:
"""
Check if user has access to the unified resource ID.
Uses get_unified_resource_id() which checks cache first before hitting the database,
avoiding direct DB queries in the critical request path.
Args:
unified_resource_id: The unified resource ID to check
user_api_key_dict: User API key authentication details
litellm_parent_otel_span: OpenTelemetry span for tracing
Returns:
True if user has access, False otherwise
"""
user_id = user_api_key_dict.user_id
# Use cached method instead of direct DB query
resource = await self.get_unified_resource_id(
unified_resource_id, litellm_parent_otel_span
)
if resource:
return resource.get("created_by") == user_id
return False
# ============================================================================
# MODEL MAPPING OPERATIONS
# ============================================================================
async def get_model_resource_id_mapping(
self,
resource_ids: List[str],
litellm_parent_otel_span: Span,
) -> Dict[str, Dict[str, str]]:
"""
Get model-specific resource IDs for a list of unified resource IDs.
Args:
resource_ids: List of unified resource IDs
litellm_parent_otel_span: OpenTelemetry span for tracing
Returns:
Dictionary mapping unified_resource_id -> model_id -> provider_resource_id
Example:
{
"unified_resource_id_1": {
"model_id_1": "provider_resource_id_1",
"model_id_2": "provider_resource_id_2"
}
}
"""
resource_id_mapping: Dict[str, Dict[str, str]] = {}
for resource_id in resource_ids:
# Get unified resource from cache/db
unified_resource_object = await self.get_unified_resource_id(
resource_id, litellm_parent_otel_span
)
if unified_resource_object:
model_mappings = unified_resource_object.get("model_mappings", {})
# Handle both JSON string and dict
if isinstance(model_mappings, str):
model_mappings = json.loads(model_mappings)
resource_id_mapping[resource_id] = model_mappings
return resource_id_mapping
# ============================================================================
# RESOURCE CREATION OPERATIONS
# ============================================================================
async def create_resource_for_each_model(
self,
llm_router: Router,
request_data: Dict[str, Any],
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> List[ResourceObjectType]:
"""
Create a resource for each model in the target list.
Args:
llm_router: LiteLLM router instance
request_data: Request data for resource creation
target_model_names_list: List of target model names
litellm_parent_otel_span: OpenTelemetry span for tracing
Returns:
List of resource objects created for each model
"""
if llm_router is None:
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
responses = []
for model in target_model_names_list:
individual_response = await self.create_resource_for_model(
llm_router=llm_router,
model=model,
request_data=request_data,
litellm_parent_otel_span=litellm_parent_otel_span,
)
responses.append(individual_response)
return responses
def generate_unified_resource_id(
self,
resource_objects: List[ResourceObjectType],
target_model_names_list: List[str],
) -> str:
"""
Generate a unified resource ID from multiple resource objects.
Args:
resource_objects: List of resource objects from different models
target_model_names_list: List of target model names
Returns:
Base64 encoded unified resource ID
"""
# Use the first resource object to generate the format
unified_id_format = self.get_unified_resource_id_format(
resource_object=resource_objects[0],
target_model_names_list=target_model_names_list,
)
# Convert to URL-safe base64 and strip padding
base64_unified_id = (
base64.urlsafe_b64encode(unified_id_format.encode()).decode().rstrip("=")
)
return base64_unified_id
def extract_model_mappings_from_responses(
self,
resource_objects: List[ResourceObjectType],
) -> Dict[str, str]:
"""
Extract model mappings from resource objects.
Args:
resource_objects: List of resource objects from different models
Returns:
Dictionary mapping model_id -> provider_resource_id
"""
model_mappings: Dict[str, str] = {}
for resource_object in resource_objects:
# Get hidden params if available
hidden_params = getattr(resource_object, "_hidden_params", {}) or {}
model_resource_id_mapping = hidden_params.get("model_resource_id_mapping")
if model_resource_id_mapping and isinstance(
model_resource_id_mapping, dict
):
model_mappings.update(model_resource_id_mapping)
return model_mappings
# ============================================================================
# DEPLOYMENT FILTERING
# ============================================================================
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
request_kwargs: Optional[Dict] = None,
parent_otel_span: Optional[Span] = None,
resource_id_key: str = "resource_id",
) -> List[Dict]:
"""
Filter deployments based on model mappings for a resource.
This is used by the router to select only deployments that have
the resource available.
Args:
model: Model name
healthy_deployments: List of healthy deployments
request_kwargs: Request kwargs containing resource_id and mappings
parent_otel_span: OpenTelemetry span for tracing
resource_id_key: Key to use for resource ID in request_kwargs
Returns:
Filtered list of deployments
"""
if request_kwargs is None:
return healthy_deployments
resource_id = cast(Optional[str], request_kwargs.get(resource_id_key))
model_resource_id_mapping = cast(
Optional[Dict[str, Dict[str, str]]],
request_kwargs.get("model_resource_id_mapping"),
)
allowed_model_ids = []
if resource_id and model_resource_id_mapping:
model_id_dict = model_resource_id_mapping.get(resource_id, {})
allowed_model_ids = list(model_id_dict.keys())
if len(allowed_model_ids) == 0:
return healthy_deployments
return [
deployment
for deployment in healthy_deployments
if deployment.get("model_info", {}).get("id") in allowed_model_ids
]
# ============================================================================
# UTILITY METHODS
# ============================================================================
def get_unified_id_prefix(self) -> str:
"""
Get the prefix for unified IDs for this resource type.
Returns:
Prefix string (e.g., "litellm_proxy:")
"""
return SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
async def list_user_resources(
self,
user_api_key_dict: UserAPIKeyAuth,
limit: Optional[int] = None,
after: Optional[str] = None,
additional_filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
List resources created by a user.
Args:
user_api_key_dict: User API key authentication details
limit: Maximum number of resources to return
after: Cursor for pagination
additional_filters: Additional filters to apply
Returns:
Dictionary with list of resources and pagination info
"""
where_clause: Dict[str, Any] = {}
# Filter by user who created the resource
if user_api_key_dict.user_id:
where_clause["created_by"] = user_api_key_dict.user_id
if after:
where_clause["id"] = {"gt": after}
# Add additional filters
if additional_filters:
where_clause.update(additional_filters)
# Fetch resources
fetch_limit = limit or 20
table = getattr(self.prisma_client.db, self.table_name)
resources = await table.find_many(
where=where_clause,
take=fetch_limit,
order={"created_at": "desc"},
)
resource_objects: List[Any] = []
for resource in resources:
try:
# Stop once we have enough
if len(resource_objects) >= (limit or 20):
break
# Parse resource object
resource_data = resource.resource_object
if isinstance(resource_data, str):
resource_data = json.loads(resource_data)
# Set unified ID
if hasattr(resource_data, "id"):
resource_data.id = resource.unified_resource_id
elif isinstance(resource_data, dict):
resource_data["id"] = resource.unified_resource_id
resource_objects.append(resource_data)
except Exception as e:
verbose_logger.warning(
f"Failed to parse {self.resource_type} object "
f"{resource.unified_resource_id}: {e}"
)
continue
return {
"object": "list",
"data": resource_objects,
"first_id": resource_objects[0].id if resource_objects else None,
"last_id": resource_objects[-1].id if resource_objects else None,
"has_more": len(resource_objects) == (limit or 20),
}

View File

@@ -0,0 +1,364 @@
"""
Utility functions for managed resources.
This module provides common utility functions that can be used across
different managed resource types (files, vector stores, etc.).
"""
import base64
import re
from typing import List, Optional, Union, Literal
def is_base64_encoded_unified_id(
resource_id: str,
prefix: str = "litellm_proxy:",
) -> Union[str, Literal[False]]:
"""
Check if a resource ID is a base64 encoded unified ID.
Args:
resource_id: The resource ID to check
prefix: The expected prefix for unified IDs
Returns:
Decoded string if valid unified ID, False otherwise
"""
# Ensure resource_id is a string
if not isinstance(resource_id, str):
return False
# Add padding back if needed
padded = resource_id + "=" * (-len(resource_id) % 4)
# Decode from base64
try:
decoded = base64.urlsafe_b64decode(padded).decode()
if decoded.startswith(prefix):
return decoded
else:
return False
except Exception:
return False
def extract_target_model_names_from_unified_id(
unified_id: str,
) -> List[str]:
"""
Extract target model names from a unified resource ID.
Args:
unified_id: The unified resource ID (decoded or encoded)
Returns:
List of target model names
Example:
unified_id = "litellm_proxy:vector_store;unified_id,uuid;target_model_names,gpt-4,gemini-2.0"
returns: ["gpt-4", "gemini-2.0"]
"""
try:
# Ensure unified_id is a string
if not isinstance(unified_id, str):
return []
# Decode if it's base64 encoded
decoded_id = is_base64_encoded_unified_id(unified_id)
if decoded_id:
unified_id = decoded_id
# Extract model names using regex
match = re.search(r"target_model_names,([^;]+)", unified_id)
if match:
# Split on comma and strip whitespace from each model name
return [model.strip() for model in match.group(1).split(",")]
return []
except Exception:
return []
def extract_resource_type_from_unified_id(
unified_id: str,
) -> Optional[str]:
"""
Extract resource type from a unified resource ID.
Args:
unified_id: The unified resource ID (decoded or encoded)
Returns:
Resource type string or None
Example:
unified_id = "litellm_proxy:vector_store;unified_id,uuid;..."
returns: "vector_store"
"""
try:
# Ensure unified_id is a string
if not isinstance(unified_id, str):
return None
# Decode if it's base64 encoded
decoded_id = is_base64_encoded_unified_id(unified_id)
if decoded_id:
unified_id = decoded_id
# Extract resource type (comes after prefix and before first semicolon)
match = re.search(r"litellm_proxy:([^;]+)", unified_id)
if match:
return match.group(1).strip()
return None
except Exception:
return None
def extract_unified_uuid_from_unified_id(
unified_id: str,
) -> Optional[str]:
"""
Extract the UUID from a unified resource ID.
Args:
unified_id: The unified resource ID (decoded or encoded)
Returns:
UUID string or None
Example:
unified_id = "litellm_proxy:vector_store;unified_id,abc-123;..."
returns: "abc-123"
"""
try:
# Ensure unified_id is a string
if not isinstance(unified_id, str):
return None
# Decode if it's base64 encoded
decoded_id = is_base64_encoded_unified_id(unified_id)
if decoded_id:
unified_id = decoded_id
# Extract UUID
match = re.search(r"unified_id,([^;]+)", unified_id)
if match:
return match.group(1).strip()
return None
except Exception:
return None
def extract_model_id_from_unified_id(
unified_id: str,
) -> Optional[str]:
"""
Extract model ID from a unified resource ID.
Args:
unified_id: The unified resource ID (decoded or encoded)
Returns:
Model ID string or None
Example:
unified_id = "litellm_proxy:vector_store;...;model_id,gpt-4-model-id;..."
returns: "gpt-4-model-id"
"""
try:
# Ensure unified_id is a string
if not isinstance(unified_id, str):
return None
# Decode if it's base64 encoded
decoded_id = is_base64_encoded_unified_id(unified_id)
if decoded_id:
unified_id = decoded_id
# Extract model ID
match = re.search(r"model_id,([^;]+)", unified_id)
if match:
return match.group(1).strip()
return None
except Exception:
return None
def extract_provider_resource_id_from_unified_id(
unified_id: str,
) -> Optional[str]:
"""
Extract provider resource ID from a unified resource ID.
Args:
unified_id: The unified resource ID (decoded or encoded)
Returns:
Provider resource ID string or None
Example:
unified_id = "litellm_proxy:vector_store;...;resource_id,vs_abc123;..."
returns: "vs_abc123"
"""
try:
# Ensure unified_id is a string
if not isinstance(unified_id, str):
return None
# Decode if it's base64 encoded
decoded_id = is_base64_encoded_unified_id(unified_id)
if decoded_id:
unified_id = decoded_id
# Extract resource ID (try multiple patterns for different resource types)
patterns = [
r"resource_id,([^;]+)",
r"vector_store_id,([^;]+)",
r"file_id,([^;]+)",
]
for pattern in patterns:
match = re.search(pattern, unified_id)
if match:
return match.group(1).strip()
return None
except Exception:
return None
def generate_unified_id_string(
resource_type: str,
unified_uuid: str,
target_model_names: List[str],
provider_resource_id: str,
model_id: str,
additional_fields: Optional[dict] = None,
) -> str:
"""
Generate a unified ID string (before base64 encoding).
Args:
resource_type: Type of resource (e.g., "vector_store", "file")
unified_uuid: UUID for this unified resource
target_model_names: List of target model names
provider_resource_id: Resource ID from the provider
model_id: Model ID from the router
additional_fields: Additional fields to include in the ID
Returns:
Unified ID string (not yet base64 encoded)
Example:
generate_unified_id_string(
resource_type="vector_store",
unified_uuid="abc-123",
target_model_names=["gpt-4", "gemini"],
provider_resource_id="vs_xyz",
model_id="model-id-123",
)
returns: "litellm_proxy:vector_store;unified_id,abc-123;target_model_names,gpt-4,gemini;resource_id,vs_xyz;model_id,model-id-123"
"""
# Build the unified ID string
parts = [
f"litellm_proxy:{resource_type}",
f"unified_id,{unified_uuid}",
f"target_model_names,{','.join(target_model_names)}",
f"resource_id,{provider_resource_id}",
f"model_id,{model_id}",
]
# Add additional fields if provided
if additional_fields:
for key, value in additional_fields.items():
parts.append(f"{key},{value}")
return ";".join(parts)
def encode_unified_id(unified_id_string: str) -> str:
"""
Encode a unified ID string to base64.
Args:
unified_id_string: The unified ID string to encode
Returns:
Base64 encoded unified ID (URL-safe, padding stripped)
"""
return base64.urlsafe_b64encode(unified_id_string.encode()).decode().rstrip("=")
def decode_unified_id(encoded_unified_id: str) -> Optional[str]:
"""
Decode a base64 encoded unified ID.
Args:
encoded_unified_id: The base64 encoded unified ID
Returns:
Decoded unified ID string or None if invalid
"""
try:
# Add padding back if needed
padded = encoded_unified_id + "=" * (-len(encoded_unified_id) % 4)
# Decode from base64
decoded = base64.urlsafe_b64decode(padded).decode()
# Verify it starts with the expected prefix
if decoded.startswith("litellm_proxy:"):
return decoded
return None
except Exception:
return None
def parse_unified_id(
unified_id: str,
) -> Optional[dict]:
"""
Parse a unified ID into its components.
Args:
unified_id: The unified ID (encoded or decoded)
Returns:
Dictionary with parsed components or None if invalid
Example:
{
"resource_type": "vector_store",
"unified_uuid": "abc-123",
"target_model_names": ["gpt-4", "gemini"],
"provider_resource_id": "vs_xyz",
"model_id": "model-id-123"
}
"""
try:
# Decode if needed
decoded_id = decode_unified_id(unified_id)
if not decoded_id:
# Maybe it's already decoded
if unified_id.startswith("litellm_proxy:"):
decoded_id = unified_id
else:
return None
return {
"resource_type": extract_resource_type_from_unified_id(decoded_id),
"unified_uuid": extract_unified_uuid_from_unified_id(decoded_id),
"target_model_names": extract_target_model_names_from_unified_id(
decoded_id
),
"provider_resource_id": extract_provider_resource_id_from_unified_id(
decoded_id
),
"model_id": extract_model_id_from_unified_id(decoded_id),
}
except Exception:
return None

View File

@@ -0,0 +1,22 @@
"""Base OCR transformation module."""
from .transformation import (
BaseOCRConfig,
DocumentType,
OCRPage,
OCRPageDimensions,
OCRPageImage,
OCRRequestData,
OCRResponse,
OCRUsageInfo,
)
__all__ = [
"BaseOCRConfig",
"DocumentType",
"OCRResponse",
"OCRPage",
"OCRPageDimensions",
"OCRPageImage",
"OCRUsageInfo",
"OCRRequestData",
]

View File

@@ -0,0 +1,258 @@
"""
Base OCR transformation configuration.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from pydantic import PrivateAttr
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.base import LiteLLMPydanticObjectBase
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
# DocumentType for OCR - providers always receive a dict with
# type="document_url" or type="image_url" (str values only).
# File-type inputs are preprocessed to this format in litellm/ocr/main.py.
DocumentType = Dict[str, str]
class OCRPageDimensions(LiteLLMPydanticObjectBase):
"""Page dimensions from OCR response."""
dpi: Optional[int] = None
height: Optional[int] = None
width: Optional[int] = None
class OCRPageImage(LiteLLMPydanticObjectBase):
"""Image extracted from OCR page."""
image_base64: Optional[str] = None
bbox: Optional[Dict[str, Any]] = None
model_config = {"extra": "allow"}
class OCRPage(LiteLLMPydanticObjectBase):
"""Single page from OCR response."""
index: int
markdown: str
images: Optional[List[OCRPageImage]] = None
dimensions: Optional[OCRPageDimensions] = None
model_config = {"extra": "allow"}
class OCRUsageInfo(LiteLLMPydanticObjectBase):
"""Usage information from OCR response."""
pages_processed: Optional[int] = None
doc_size_bytes: Optional[int] = None
model_config = {"extra": "allow"}
class OCRResponse(LiteLLMPydanticObjectBase):
"""
Standard OCR response format.
Standardized to Mistral OCR format - other providers should transform to this format.
"""
pages: List[OCRPage]
model: str
document_annotation: Optional[Any] = None
usage_info: Optional[OCRUsageInfo] = None
object: str = "ocr"
model_config = {"extra": "allow"}
# Define private attributes using PrivateAttr
_hidden_params: dict = PrivateAttr(default_factory=dict)
class OCRRequestData(LiteLLMPydanticObjectBase):
"""OCR request data structure."""
data: Optional[Union[Dict, bytes]] = None
files: Optional[Dict[str, Any]] = None
class BaseOCRConfig:
"""
Base configuration for OCR transformations.
Handles provider-agnostic OCR operations.
"""
def __init__(self) -> None:
pass
def get_supported_ocr_params(self, model: str) -> list:
"""
Get supported OCR parameters for this provider.
Override this method in provider-specific implementations.
"""
return []
def map_ocr_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
) -> dict:
"""Map OCR parameters to provider-specific parameters."""
return optional_params
def validate_environment(
self,
headers: Dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
litellm_params: Optional[dict] = None,
**kwargs,
) -> Dict:
"""
Validate environment and return headers.
Override in provider-specific implementations.
"""
return headers
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: Optional[dict] = None,
**kwargs,
) -> str:
"""
Get complete URL for OCR endpoint.
Override in provider-specific implementations.
"""
raise NotImplementedError("get_complete_url must be implemented by provider")
def transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request to provider-specific format.
Override in provider-specific implementations.
Note: By the time this method is called, any file-type documents have already
been converted to document_url/image_url format with base64 data URIs by
the preprocessing in litellm/ocr/main.py.
Args:
model: Model name
document: Document to process - always a dict with type="document_url" or type="image_url"
optional_params: Optional parameters for the request
headers: Request headers
Returns:
OCRRequestData with data and files fields
"""
raise NotImplementedError(
"transform_ocr_request must be implemented by provider"
)
async def async_transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Async transform OCR request to provider-specific format.
Optional method - providers can override if they need async transformations
(e.g., Azure AI for URL-to-base64 conversion).
Default implementation falls back to sync transform_ocr_request.
Args:
model: Model name
document: Document to process (Mistral format dict, or file path, bytes, etc.)
optional_params: Optional parameters for the request
headers: Request headers
Returns:
OCRRequestData with data and files fields
"""
# Default implementation: call sync version
return self.transform_ocr_request(
model=model,
document=document,
optional_params=optional_params,
headers=headers,
**kwargs,
)
def transform_ocr_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> OCRResponse:
"""
Transform provider-specific OCR response to standard format.
Override in provider-specific implementations.
"""
raise NotImplementedError(
"transform_ocr_response must be implemented by provider"
)
async def async_transform_ocr_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> OCRResponse:
"""
Async transform provider-specific OCR response to standard format.
Optional method - providers can override if they need async transformations
(e.g., Azure Document Intelligence for async operation polling).
Default implementation falls back to sync transform_ocr_response.
Args:
model: Model name
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
OCRResponse in standard format
"""
# Default implementation: call sync version
return self.transform_ocr_response(
model=model,
raw_response=raw_response,
logging_obj=logging_obj,
**kwargs,
)
def get_error_class(
self,
error_message: str,
status_code: int,
headers: dict,
) -> Exception:
"""Get appropriate error class for the provider."""
return BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,139 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from ..base_utils import BaseLLMModelInfo
if TYPE_CHECKING:
from httpx import URL, Headers, Response
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import CostResponseTypes
from ..chat.transformation import BaseLLMException
class BasePassthroughConfig(BaseLLMModelInfo):
@abstractmethod
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
"""
Check if the request is a streaming request
"""
pass
def format_url(
self,
endpoint: str,
base_target_url: str,
request_query_params: Optional[dict],
) -> "URL":
"""
Helper function to add query params to the url
Args:
endpoint: str - the endpoint to add to the url
base_target_url: str - the base url to add the endpoint to
request_query_params: Optional[dict] - the query params to add to the url
Returns:
httpx.URL - the formatted url
"""
from urllib.parse import urlencode
import httpx
base = base_target_url.rstrip("/")
endpoint = endpoint.lstrip("/")
full_url = f"{base}/{endpoint}"
url = httpx.URL(full_url)
if request_query_params:
url = url.copy_with(query=urlencode(request_query_params).encode("ascii"))
return url
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
endpoint: str,
request_query_params: Optional[dict],
litellm_params: dict,
) -> Tuple["URL", str]:
"""
Get the complete url for the request
Returns:
- complete_url: URL - the complete url for the request
- base_target_url: str - the base url to add the endpoint to. Useful for auth headers.
"""
pass
def sign_request(
self,
headers: dict,
litellm_params: dict,
request_data: Optional[dict],
api_base: str,
model: Optional[str] = None,
) -> Tuple[dict, Optional[bytes]]:
"""
Some providers like Bedrock require signing the request. The sign request funtion needs access to `request_data` and `complete_url`
Args:
headers: dict
optional_params: dict
request_data: dict - the request body being sent in http request
api_base: str - the complete url being sent in http request
Returns:
dict - the signed headers
Update the headers with the signed headers in this function. The return values will be sent as headers in the http request.
"""
return headers, None
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, "Headers"]
) -> "BaseLLMException":
from litellm.llms.base_llm.chat.transformation import BaseLLMException
return BaseLLMException(
status_code=status_code, message=error_message, headers=headers
)
def logging_non_streaming_response(
self,
model: str,
custom_llm_provider: str,
httpx_response: "Response",
request_data: dict,
logging_obj: "LiteLLMLoggingObj",
endpoint: str,
) -> Optional["CostResponseTypes"]:
pass
def handle_logging_collected_chunks(
self,
all_chunks: List[str],
litellm_logging_obj: "LiteLLMLoggingObj",
model: str,
custom_llm_provider: str,
endpoint: str,
) -> Optional["CostResponseTypes"]:
return None
def _convert_raw_bytes_to_str_lines(self, raw_bytes: List[bytes]) -> List[str]:
"""
Converts a list of raw bytes into a list of string lines, similar to aiter_lines()
Args:
raw_bytes: List of bytes chunks from aiter.bytes()
Returns:
List of string lines, with each line being a complete data: {} chunk
"""
# Combine all bytes and decode to string
combined_str = b"".join(raw_bytes).decode("utf-8")
# Split by newlines and filter out empty lines
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
return lines

View File

@@ -0,0 +1,117 @@
"""
Base transformation class for realtime HTTP endpoints (client_secrets, realtime_calls).
These are HTTP (not WebSocket) endpoints used by the WebRTC flow:
POST /v1/realtime/client_secrets — obtains a short-lived ephemeral key
POST /v1/realtime/calls — exchanges an SDP offer using that key
"""
from abc import ABC, abstractmethod
from typing import Optional, Union
import httpx
class BaseRealtimeHTTPConfig(ABC):
"""
Abstract base for provider-specific realtime HTTP credential / URL logic.
Implement one subclass per provider (OpenAI, Azure, …).
"""
# ------------------------------------------------------------------ #
# Credential resolution #
# ------------------------------------------------------------------ #
@abstractmethod
def get_api_base(
self,
api_base: Optional[str],
**kwargs,
) -> str:
"""
Resolve the provider API base URL.
Resolution order (provider-specific):
explicit api_base → litellm.api_base → env var → hard-coded default
"""
@abstractmethod
def get_api_key(
self,
api_key: Optional[str],
**kwargs,
) -> str:
"""
Resolve the provider API key.
Resolution order (provider-specific):
explicit api_key → litellm.api_key → env var → ""
"""
# ------------------------------------------------------------------ #
# client_secrets endpoint #
# ------------------------------------------------------------------ #
@abstractmethod
def get_complete_url(
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
) -> str:
"""Return the full URL for POST /realtime/client_secrets."""
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
"""
Build and return the request headers for the client_secrets call.
Merge `headers` (caller-supplied extras) with auth / content-type
headers required by this provider.
"""
# ------------------------------------------------------------------ #
# realtime_calls endpoint #
# ------------------------------------------------------------------ #
def get_realtime_calls_url(
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
) -> str:
"""Return the full URL for POST /realtime/calls (SDP exchange)."""
base = (api_base or "").rstrip("/")
return f"{base}/v1/realtime/calls"
def get_realtime_calls_headers(self, ephemeral_key: str) -> dict:
"""
Build headers for the realtime_calls POST.
The Bearer token here is the ephemeral key obtained from
client_secrets, not the long-lived provider key.
"""
return {
"Authorization": f"Bearer {ephemeral_key}",
}
# ------------------------------------------------------------------ #
# Error handling #
# ------------------------------------------------------------------ #
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
):
"""
Map HTTP errors to LiteLLM exception types.
Default: generic exception. Override in subclasses for provider-specific
error mapping (e.g., Azure uses different error codes).
"""
from litellm.llms.base_llm.chat.transformation import BaseLLMException
return BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,83 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm.types.realtime import (
RealtimeResponseTransformInput,
RealtimeResponseTypedDict,
)
from ..chat.transformation import BaseLLMException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseRealtimeConfig(ABC):
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
pass
@abstractmethod
def get_complete_url(
self, api_base: Optional[str], model: str, api_key: Optional[str] = None
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
@abstractmethod
def transform_realtime_request(
self,
message: str,
model: str,
session_configuration_request: Optional[str] = None,
) -> List[str]:
pass
def requires_session_configuration(
self,
) -> bool: # initial configuration message sent to setup the realtime session
return False
def session_configuration_request(
self, model: str
) -> Optional[str]: # message sent to setup the realtime session
return None
@abstractmethod
def transform_realtime_response(
self,
message: Union[str, bytes],
model: str,
logging_obj: LiteLLMLoggingObj,
realtime_response_transform_input: RealtimeResponseTransformInput,
) -> RealtimeResponseTypedDict: # message sent to setup the realtime session
"""
Keep this state less - leave the state management (e.g. tracking current_output_item_id, current_response_id, current_conversation_id, current_delta_chunks) to the caller.
"""
pass

View File

@@ -0,0 +1,134 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import ModelInfo
from ..chat.transformation import BaseLLMException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseRerankConfig(ABC):
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> dict:
pass
@abstractmethod
def transform_rerank_request(
self,
model: str,
optional_rerank_params: Dict,
headers: dict,
) -> dict:
return {}
@abstractmethod
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
return model_response
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: Optional[dict] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def get_supported_cohere_rerank_params(self, model: str) -> list:
pass
@abstractmethod
def map_cohere_rerank_params(
self,
non_default_params: dict,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> Dict:
pass
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def calculate_rerank_cost(
self,
model: str,
custom_llm_provider: Optional[str] = None,
billed_units: Optional[RerankBilledUnits] = None,
model_info: Optional[ModelInfo] = None,
) -> Tuple[float, float]:
"""
Calculates the cost per query for a given rerank model.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
- num_queries: int, the number of queries to calculate the cost for
- model_info: ModelInfo, the model info for the given model
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
if (
model_info is None
or "input_cost_per_query" not in model_info
or model_info["input_cost_per_query"] is None
or billed_units is None
):
return 0.0, 0.0
search_units = billed_units.get("search_units")
if search_units is None:
return 0.0, 0.0
prompt_cost = model_info["input_cost_per_query"] * search_units
return prompt_cost, 0.0

View File

@@ -0,0 +1,283 @@
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import httpx
from litellm.types.llms.openai import (
ResponseInputParam,
ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse,
ResponsesAPIStreamingResponse,
)
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class BaseResponsesAPIConfig(ABC):
def __init__(self):
pass
@property
@abstractmethod
def custom_llm_provider(self) -> LlmProviders:
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
@abstractmethod
def map_openai_params(
self,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
pass
@abstractmethod
def validate_environment(
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
return {}
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_responses_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
pass
@abstractmethod
def transform_response_api_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
pass
@abstractmethod
def transform_streaming_response(
self,
model: str,
parsed_chunk: dict,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIStreamingResponse:
"""
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
"""
pass
#########################################################
########## DELETE RESPONSE API TRANSFORMATION ##############
#########################################################
@abstractmethod
def transform_delete_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_delete_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteResponseResult:
pass
#########################################################
########## END DELETE RESPONSE API TRANSFORMATION #######
#########################################################
#########################################################
########## GET RESPONSE API TRANSFORMATION ###############
#########################################################
@abstractmethod
def transform_get_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_get_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
pass
#########################################################
########## LIST INPUT ITEMS API TRANSFORMATION ##########
#########################################################
@abstractmethod
def transform_list_input_items_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
before: Optional[str] = None,
include: Optional[List[str]] = None,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_list_input_items_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Dict:
pass
#########################################################
########## END GET RESPONSE API TRANSFORMATION ##########
#########################################################
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""Returns True if litellm should fake a stream for the given model and stream value"""
return False
def supports_native_websocket(self) -> bool:
"""
Returns True if the provider has a native WebSocket endpoint for Responses API.
Providers with native websocket support can connect directly to wss:// endpoints.
Providers without native support will use the ManagedResponsesWebSocketHandler
which makes HTTP streaming calls and forwards events over the websocket.
Default: False (use managed websocket handler)
"""
return False
#########################################################
########## CANCEL RESPONSE API TRANSFORMATION ##########
#########################################################
@abstractmethod
def transform_cancel_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_cancel_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
pass
#########################################################
########## END CANCEL RESPONSE API TRANSFORMATION #######
#########################################################
#########################################################
########## COMPACT RESPONSE API TRANSFORMATION ##########
#########################################################
@abstractmethod
def transform_compact_response_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_compact_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
pass
#########################################################
########## END COMPACT RESPONSE API TRANSFORMATION ######
#########################################################

View File

@@ -0,0 +1,14 @@
"""
Base Search API module.
"""
from litellm.llms.base_llm.search.transformation import (
BaseSearchConfig,
SearchResponse,
SearchResult,
)
__all__ = [
"BaseSearchConfig",
"SearchResponse",
"SearchResult",
]

View File

@@ -0,0 +1,174 @@
"""
Base Search transformation configuration.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
import httpx
from pydantic import PrivateAttr
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.base import LiteLLMPydanticObjectBase
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class SearchResult(LiteLLMPydanticObjectBase):
"""Single search result."""
title: str
url: str
snippet: str
date: Optional[str] = None
last_updated: Optional[str] = None
model_config = {"extra": "allow"}
class SearchResponse(LiteLLMPydanticObjectBase):
"""
Standard Search response format.
Standardized to Perplexity Search format - other providers should transform to this format.
"""
results: List[SearchResult]
object: str = "search"
model_config = {"extra": "allow"}
# Define private attributes using PrivateAttr
_hidden_params: dict = PrivateAttr(default_factory=dict)
class BaseSearchConfig:
"""
Base configuration for Search transformations.
Handles provider-agnostic Search operations.
"""
def __init__(self) -> None:
pass
@staticmethod
def ui_friendly_name() -> str:
"""
UI-friendly name for the search provider.
Override in provider-specific implementations.
"""
return "Unknown Search Provider"
def get_http_method(self) -> Literal["GET", "POST"]:
"""
Get HTTP method for search requests.
Override in provider-specific implementations if needed.
Returns:
HTTP method ('GET' or 'POST'). Default is 'POST'.
"""
return "POST"
@staticmethod
def get_supported_perplexity_optional_params() -> set:
"""
Get the set of Perplexity unified search parameters.
These are the standard parameters that providers should transform from.
Returns:
Set of parameter names that are part of the unified spec
"""
return {
"max_results",
"search_domain_filter",
"country",
"max_tokens_per_page",
}
def validate_environment(
self,
headers: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
) -> Dict:
"""
Validate environment and return headers.
Override in provider-specific implementations.
"""
return headers
def get_complete_url(
self,
api_base: Optional[str],
optional_params: dict,
data: Optional[Union[Dict, List[Dict]]] = None,
**kwargs,
) -> str:
"""
Get complete URL for Search endpoint.
Args:
api_base: Base URL for the API
optional_params: Optional parameters for the request
data: Transformed request body from transform_search_request().
Some providers (e.g., Google PSE) use GET requests and need
the request body to construct query parameters in the URL.
Can be a dict or list of dicts depending on provider.
**kwargs: Additional keyword arguments
Returns:
Complete URL for the search endpoint
Note:
Override in provider-specific implementations.
"""
raise NotImplementedError("get_complete_url must be implemented by provider")
def transform_search_request(
self,
query: Union[str, List[str]],
optional_params: dict,
**kwargs,
) -> Union[Dict, List[Dict]]:
"""
Transform Search request to provider-specific format.
Override in provider-specific implementations.
Args:
query: Search query (string or list of strings)
optional_params: Optional parameters for the request
Returns:
Dict with request data
"""
raise NotImplementedError(
"transform_search_request must be implemented by provider"
)
def transform_search_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> SearchResponse:
"""
Transform provider-specific Search response to standard format.
Override in provider-specific implementations.
"""
raise NotImplementedError(
"transform_search_response must be implemented by provider"
)
def get_error_class(
self,
error_message: str,
status_code: int,
headers: dict,
) -> Exception:
"""Get appropriate error class for the provider."""
return BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,5 @@
"""Base Skills API configuration"""
from .transformation import BaseSkillsAPIConfig
__all__ = ["BaseSkillsAPIConfig"]

View File

@@ -0,0 +1,245 @@
"""
Base configuration class for Skills API
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.anthropic_skills import (
CreateSkillRequest,
DeleteSkillResponse,
ListSkillsParams,
ListSkillsResponse,
Skill,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseSkillsAPIConfig(ABC):
"""Base configuration for Skills API providers"""
def __init__(self):
pass
@property
@abstractmethod
def custom_llm_provider(self) -> LlmProviders:
pass
@abstractmethod
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""
Validate and update headers with provider-specific requirements
Args:
headers: Base headers dictionary
litellm_params: LiteLLM parameters
Returns:
Updated headers dictionary
"""
return headers
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
endpoint: str,
skill_id: Optional[str] = None,
) -> str:
"""
Get the complete URL for the API request
Args:
api_base: Base API URL
endpoint: API endpoint (e.g., 'skills', 'skills/{id}')
skill_id: Optional skill ID for specific skill operations
Returns:
Complete URL
"""
if api_base is None:
raise ValueError("api_base is required")
return f"{api_base}/v1/{endpoint}"
@abstractmethod
def transform_create_skill_request(
self,
create_request: CreateSkillRequest,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""
Transform create skill request to provider-specific format
Args:
create_request: Skill creation parameters
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Provider-specific request body
"""
pass
@abstractmethod
def transform_create_skill_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Skill:
"""
Transform provider response to Skill object
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Skill object
"""
pass
@abstractmethod
def transform_list_skills_request(
self,
list_params: ListSkillsParams,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform list skills request parameters
Args:
list_params: List parameters (pagination, filters)
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, query_params)
"""
pass
@abstractmethod
def transform_list_skills_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ListSkillsResponse:
"""
Transform provider response to ListSkillsResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
ListSkillsResponse object
"""
pass
@abstractmethod
def transform_get_skill_request(
self,
skill_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform get skill request
Args:
skill_id: Skill ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers)
"""
pass
@abstractmethod
def transform_get_skill_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Skill:
"""
Transform provider response to Skill object
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Skill object
"""
pass
@abstractmethod
def transform_delete_skill_request(
self,
skill_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform delete skill request
Args:
skill_id: Skill ID
api_base: Base API URL
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
Tuple of (url, headers)
"""
pass
@abstractmethod
def transform_delete_skill_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteSkillResponse:
"""
Transform provider response to DeleteSkillResponse
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
DeleteSkillResponse object
"""
pass
def get_error_class(
self,
error_message: str,
status_code: int,
headers: dict,
) -> Exception:
"""Get appropriate error class for the provider."""
return BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,149 @@
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union
import httpx
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.types.llms.openai import (
HttpxBinaryResponseContent as _HttpxBinaryResponseContent,
)
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
HttpxBinaryResponseContent = _HttpxBinaryResponseContent
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
HttpxBinaryResponseContent = Any
class TextToSpeechRequestData(TypedDict, total=False):
"""
Structured return type for text-to-speech transformations.
This ensures a consistent interface across all TTS providers.
Providers should set ONE of: dict_body, ssml_body, or text_body.
"""
dict_body: Dict[str, Any] # JSON request body (e.g., OpenAI TTS)
ssml_body: str # SSML/XML string body (e.g., Azure AVA TTS)
headers: Dict[str, str] # Provider-specific headers to merge with base headers
class BaseTextToSpeechConfig(ABC):
def __init__(self):
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
"""
Get list of OpenAI TTS parameters supported by this provider
"""
pass
@abstractmethod
def map_openai_params(
self,
model: str,
optional_params: Dict,
voice: Optional[Union[str, Dict]] = None,
drop_params: bool = False,
kwargs: Dict = {},
) -> Tuple[Optional[str], Dict]:
"""
Map OpenAI TTS parameters to provider-specific parameters
"""
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and return headers
"""
return {}
@abstractmethod
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete url for the request
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_text_to_speech_request(
self,
model: str,
input: str,
voice: Optional[str],
optional_params: Dict,
litellm_params: Dict,
headers: dict,
) -> TextToSpeechRequestData:
"""
Transform request to provider-specific format.
Returns:
TextToSpeechRequestData: A structured dict containing:
- body: The request body (JSON dict, XML string, or binary data)
- headers: Provider-specific headers to merge with base headers
"""
pass
@abstractmethod
def transform_text_to_speech_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> "HttpxBinaryResponseContent":
"""
Transform provider response to standard format
"""
pass
def get_error_class(
self, error_message: str, status_code: int, headers: Dict
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,163 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
VECTOR_STORE_OPENAI_PARAMS,
BaseVectorStoreAuthCredentials,
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreIndexEndpoints,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class BaseVectorStoreConfig:
def get_supported_openai_params(
self, model: str
) -> List[VECTOR_STORE_OPENAI_PARAMS]:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
) -> dict:
return optional_params
@abstractmethod
def get_auth_credentials(
self, litellm_params: dict
) -> BaseVectorStoreAuthCredentials:
pass
@abstractmethod
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
pass
@abstractmethod
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict]:
pass
async def atransform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict]:
"""
Optional async version of transform_search_vector_store_request.
If not implemented, the handler will fall back to the sync version.
Providers that need to make async calls (e.g., generating embeddings) should override this.
"""
# Default implementation: call the sync version
return self.transform_search_vector_store_request(
vector_store_id=vector_store_id,
query=query,
vector_store_search_optional_params=vector_store_search_optional_params,
api_base=api_base,
litellm_logging_obj=litellm_logging_obj,
litellm_params=litellm_params,
)
@abstractmethod
def transform_search_vector_store_response(
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
) -> VectorStoreSearchResponse:
pass
@abstractmethod
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_create_vector_store_response(
self, response: httpx.Response
) -> VectorStoreCreateResponse:
pass
@abstractmethod
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
return {}
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def sign_request(
self,
headers: dict,
optional_params: Dict,
request_data: Dict,
api_base: str,
api_key: Optional[str] = None,
) -> Tuple[dict, Optional[bytes]]:
"""Optionally sign or modify the request before sending.
Providers like AWS Bedrock require SigV4 signing. Providers that don't
require any signing can simply return the headers unchanged and ``None``
for the signed body.
"""
return headers, None
def calculate_vector_store_cost(
self,
response: VectorStoreSearchResponse,
) -> Tuple[float, float]:
return 0.0, 0.0

View File

@@ -0,0 +1,226 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import httpx
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_store_files import (
VectorStoreFileAuthCredentials,
VectorStoreFileChunkingStrategy,
VectorStoreFileContentResponse,
VectorStoreFileCreateRequest,
VectorStoreFileDeleteResponse,
VectorStoreFileListQueryParams,
VectorStoreFileListResponse,
VectorStoreFileObject,
VectorStoreFileUpdateRequest,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class BaseVectorStoreFilesConfig(ABC):
"""Base configuration contract for provider-specific vector store file implementations."""
def get_supported_openai_params(
self,
operation: str,
) -> Tuple[str, ...]:
"""Return the set of OpenAI params supported for the given operation."""
return tuple()
def map_openai_params(
self,
*,
operation: str,
non_default_params: Dict[str, Any],
optional_params: Dict[str, Any],
drop_params: bool,
) -> Dict[str, Any]:
"""Map non-default OpenAI params to provider-specific params."""
return optional_params
@abstractmethod
def get_auth_credentials(
self, litellm_params: Dict[str, Any]
) -> VectorStoreFileAuthCredentials:
...
@abstractmethod
def get_vector_store_file_endpoints_by_type(
self,
) -> Dict[str, Tuple[Tuple[str, str], ...]]:
...
@abstractmethod
def validate_environment(
self,
*,
headers: Dict[str, str],
litellm_params: Optional[GenericLiteLLMParams],
) -> Dict[str, str]:
return {}
@abstractmethod
def get_complete_url(
self,
*,
api_base: Optional[str],
vector_store_id: str,
litellm_params: Dict[str, Any],
) -> str:
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_create_vector_store_file_request(
self,
*,
vector_store_id: str,
create_request: VectorStoreFileCreateRequest,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
...
@abstractmethod
def transform_create_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileObject:
...
@abstractmethod
def transform_list_vector_store_files_request(
self,
*,
vector_store_id: str,
query_params: VectorStoreFileListQueryParams,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
...
@abstractmethod
def transform_list_vector_store_files_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileListResponse:
...
@abstractmethod
def transform_retrieve_vector_store_file_request(
self,
*,
vector_store_id: str,
file_id: str,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
...
@abstractmethod
def transform_retrieve_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileObject:
...
@abstractmethod
def transform_retrieve_vector_store_file_content_request(
self,
*,
vector_store_id: str,
file_id: str,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
...
@abstractmethod
def transform_retrieve_vector_store_file_content_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileContentResponse:
...
@abstractmethod
def transform_update_vector_store_file_request(
self,
*,
vector_store_id: str,
file_id: str,
update_request: VectorStoreFileUpdateRequest,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
...
@abstractmethod
def transform_update_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileObject:
...
@abstractmethod
def transform_delete_vector_store_file_request(
self,
*,
vector_store_id: str,
file_id: str,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
...
@abstractmethod
def transform_delete_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileDeleteResponse:
...
def get_error_class(
self,
*,
error_message: str,
status_code: int,
headers: Union[Dict[str, Any], httpx.Headers],
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def sign_request(
self,
*,
headers: Dict[str, str],
optional_params: Dict[str, Any],
request_data: Dict[str, Any],
api_base: str,
api_key: Optional[str] = None,
) -> Tuple[Dict[str, str], Optional[bytes]]:
return headers, None
def prepare_chunking_strategy(
self,
chunking_strategy: Optional[VectorStoreFileChunkingStrategy],
) -> Optional[VectorStoreFileChunkingStrategy]:
return chunking_strategy

View File

@@ -0,0 +1,277 @@
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import httpx
from httpx._types import RequestFiles
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.types.videos.main import VideoCreateOptionalRequestParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.types.videos.main import VideoObject as _VideoObject
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
VideoObject = _VideoObject
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
VideoObject = Any
class BaseVideoConfig(ABC):
def __init__(self):
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
@abstractmethod
def map_openai_params(
self,
video_create_optional_params: VideoCreateOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
litellm_params: Optional[GenericLiteLLMParams] = None,
) -> dict:
return {}
@abstractmethod
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_video_create_request(
self,
model: str,
prompt: str,
api_base: str,
video_create_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles, str]:
pass
@abstractmethod
def transform_video_create_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
request_data: Optional[Dict] = None,
) -> VideoObject:
pass
@abstractmethod
def transform_video_content_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
variant: Optional[str] = None,
) -> Tuple[str, Dict]:
"""
Transform the video content request into a URL and data/params
Returns:
Tuple[str, Dict]: (url, params) for the video content request
"""
pass
@abstractmethod
def transform_video_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> bytes:
pass
async def async_transform_video_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> bytes:
"""
Async transform video content download response to bytes.
Optional method - providers can override if they need async transformations
(e.g., RunwayML for downloading video from CloudFront URL).
Default implementation falls back to sync transform_video_content_response.
Args:
raw_response: Raw HTTP response
logging_obj: Logging object
Returns:
Video content as bytes
"""
# Default implementation: call sync version
return self.transform_video_content_response(
raw_response=raw_response,
logging_obj=logging_obj,
)
@abstractmethod
def transform_video_remix_request(
self,
video_id: str,
prompt: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Transform the video remix request into a URL and data
Returns:
Tuple[str, Dict]: (url, data) for the video remix request
"""
pass
@abstractmethod
def transform_video_remix_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
pass
@abstractmethod
def transform_video_list_request(
self,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
limit: Optional[int] = None,
order: Optional[str] = None,
extra_query: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Transform the video list request into a URL and params
Returns:
Tuple[str, Dict]: (url, params) for the video list request
"""
pass
@abstractmethod
def transform_video_list_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> Dict[str, str]:
pass
@abstractmethod
def transform_video_delete_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the video delete request into a URL and data
Returns:
Tuple[str, Dict]: (url, data) for the video delete request
"""
pass
@abstractmethod
def transform_video_delete_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> VideoObject:
pass
@abstractmethod
def transform_video_status_retrieve_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the video retrieve request into a URL and data/params
Returns:
Tuple[str, Dict]: (url, params) for the video retrieve request
"""
pass
@abstractmethod
def transform_video_status_retrieve_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
pass
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)