chore: initial public snapshot for github upload
This commit is contained in:
2170
llm-gateway-competitors/litellm-wheel-src/litellm/__init__.py
Normal file
2170
llm-gateway-competitors/litellm-wheel-src/litellm/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
Lazy Import System
|
||||
|
||||
This module implements lazy loading for LiteLLM attributes. Instead of importing
|
||||
everything when the module loads, we only import things when they're actually used.
|
||||
|
||||
How it works:
|
||||
1. When someone accesses `litellm.some_attribute`, Python calls __getattr__ in __init__.py
|
||||
2. __getattr__ looks up the attribute name in a registry
|
||||
3. The registry points to a handler function (like _lazy_import_utils)
|
||||
4. The handler function imports the module and returns the attribute
|
||||
5. The result is cached so we don't import it again
|
||||
|
||||
This makes importing litellm much faster because we don't load heavy dependencies
|
||||
until they're actually needed.
|
||||
"""
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Any, Optional, cast, Callable
|
||||
|
||||
# Import all the data structures that define what can be lazy-loaded
|
||||
# These are just lists of names and maps of where to find them
|
||||
from ._lazy_imports_registry import (
|
||||
# Name tuples
|
||||
COST_CALCULATOR_NAMES,
|
||||
LITELLM_LOGGING_NAMES,
|
||||
UTILS_NAMES,
|
||||
TOKEN_COUNTER_NAMES,
|
||||
LLM_CLIENT_CACHE_NAMES,
|
||||
BEDROCK_TYPES_NAMES,
|
||||
TYPES_UTILS_NAMES,
|
||||
CACHING_NAMES,
|
||||
HTTP_HANDLER_NAMES,
|
||||
DOTPROMPT_NAMES,
|
||||
LLM_CONFIG_NAMES,
|
||||
TYPES_NAMES,
|
||||
LLM_PROVIDER_LOGIC_NAMES,
|
||||
UTILS_MODULE_NAMES,
|
||||
# Import maps
|
||||
_UTILS_IMPORT_MAP,
|
||||
_COST_CALCULATOR_IMPORT_MAP,
|
||||
_TYPES_UTILS_IMPORT_MAP,
|
||||
_TOKEN_COUNTER_IMPORT_MAP,
|
||||
_BEDROCK_TYPES_IMPORT_MAP,
|
||||
_CACHING_IMPORT_MAP,
|
||||
_LITELLM_LOGGING_IMPORT_MAP,
|
||||
_DOTPROMPT_IMPORT_MAP,
|
||||
_TYPES_IMPORT_MAP,
|
||||
_LLM_CONFIGS_IMPORT_MAP,
|
||||
_LLM_PROVIDER_LOGIC_IMPORT_MAP,
|
||||
_UTILS_MODULE_IMPORT_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _get_litellm_globals() -> dict:
|
||||
"""
|
||||
Get the globals dictionary of the litellm module.
|
||||
|
||||
This is where we cache imported attributes so we don't import them twice.
|
||||
When you do `litellm.some_function`, it gets stored in this dictionary.
|
||||
"""
|
||||
return sys.modules["litellm"].__dict__
|
||||
|
||||
|
||||
def _get_utils_globals() -> dict:
|
||||
"""
|
||||
Get the globals dictionary of the utils module.
|
||||
|
||||
This is where we cache imported attributes so we don't import them twice.
|
||||
When you do `litellm.utils.some_function`, it gets stored in this dictionary.
|
||||
"""
|
||||
return sys.modules["litellm.utils"].__dict__
|
||||
|
||||
|
||||
# These are special lazy loaders for things that are used internally
|
||||
# They're separate from the main lazy import system because they have specific use cases
|
||||
|
||||
# Lazy loader for default encoding - avoids importing heavy tiktoken library at startup
|
||||
_default_encoding: Optional[Any] = None
|
||||
|
||||
|
||||
def _get_default_encoding() -> Any:
|
||||
"""
|
||||
Lazily load and cache the default OpenAI encoding.
|
||||
|
||||
This avoids importing `litellm.litellm_core_utils.default_encoding` (and thus tiktoken)
|
||||
at `litellm` import time. The encoding is cached after the first import.
|
||||
|
||||
This is used internally by utils.py functions that need the encoding but shouldn't
|
||||
trigger its import during module load.
|
||||
"""
|
||||
global _default_encoding
|
||||
if _default_encoding is None:
|
||||
from litellm.litellm_core_utils.default_encoding import encoding
|
||||
|
||||
_default_encoding = encoding
|
||||
return _default_encoding
|
||||
|
||||
|
||||
# Lazy loader for get_modified_max_tokens to avoid importing token_counter at module import time
|
||||
_get_modified_max_tokens_func: Optional[Any] = None
|
||||
|
||||
|
||||
def _get_modified_max_tokens() -> Any:
|
||||
"""
|
||||
Lazily load and cache the get_modified_max_tokens function.
|
||||
|
||||
This avoids importing `litellm.litellm_core_utils.token_counter` at `litellm` import time.
|
||||
The function is cached after the first import.
|
||||
|
||||
This is used internally by utils.py functions that need the token counter but shouldn't
|
||||
trigger its import during module load.
|
||||
"""
|
||||
global _get_modified_max_tokens_func
|
||||
if _get_modified_max_tokens_func is None:
|
||||
from litellm.litellm_core_utils.token_counter import (
|
||||
get_modified_max_tokens as _get_modified_max_tokens_imported,
|
||||
)
|
||||
|
||||
_get_modified_max_tokens_func = _get_modified_max_tokens_imported
|
||||
return _get_modified_max_tokens_func
|
||||
|
||||
|
||||
# Lazy loader for token_counter to avoid importing token_counter module at module import time
|
||||
_token_counter_new_func: Optional[Any] = None
|
||||
|
||||
|
||||
def _get_token_counter_new() -> Any:
|
||||
"""
|
||||
Lazily load and cache the token_counter function (aliased as token_counter_new).
|
||||
|
||||
This avoids importing `litellm.litellm_core_utils.token_counter` at `litellm` import time.
|
||||
The function is cached after the first import.
|
||||
|
||||
This is used internally by utils.py functions that need the token counter but shouldn't
|
||||
trigger its import during module load.
|
||||
"""
|
||||
global _token_counter_new_func
|
||||
if _token_counter_new_func is None:
|
||||
from litellm.litellm_core_utils.token_counter import (
|
||||
token_counter as _token_counter_imported,
|
||||
)
|
||||
|
||||
_token_counter_new_func = _token_counter_imported
|
||||
return _token_counter_new_func
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN LAZY IMPORT SYSTEM
|
||||
# ============================================================================
|
||||
|
||||
# This registry maps attribute names (like "ModelResponse") to handler functions
|
||||
# It's built once the first time someone accesses a lazy-loaded attribute
|
||||
# Example: {"ModelResponse": _lazy_import_utils, "Cache": _lazy_import_caching, ...}
|
||||
_LAZY_IMPORT_REGISTRY: Optional[dict[str, Callable[[str], Any]]] = None
|
||||
|
||||
|
||||
def _get_lazy_import_registry() -> dict[str, Callable[[str], Any]]:
|
||||
"""
|
||||
Build the registry that maps attribute names to their handler functions.
|
||||
|
||||
This is called once, the first time someone accesses a lazy-loaded attribute.
|
||||
After that, we just look up the handler function in this dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary like {"ModelResponse": _lazy_import_utils, ...}
|
||||
"""
|
||||
global _LAZY_IMPORT_REGISTRY
|
||||
if _LAZY_IMPORT_REGISTRY is None:
|
||||
# Build the registry by going through each category and mapping
|
||||
# all the names in that category to their handler function
|
||||
_LAZY_IMPORT_REGISTRY = {}
|
||||
# For each category, map all its names to the handler function
|
||||
# Example: All names in UTILS_NAMES get mapped to _lazy_import_utils
|
||||
for name in COST_CALCULATOR_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_cost_calculator
|
||||
for name in LITELLM_LOGGING_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_litellm_logging
|
||||
for name in UTILS_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_utils
|
||||
for name in TOKEN_COUNTER_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_token_counter
|
||||
for name in LLM_CLIENT_CACHE_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_client_cache
|
||||
for name in BEDROCK_TYPES_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_bedrock_types
|
||||
for name in TYPES_UTILS_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_types_utils
|
||||
for name in CACHING_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_caching
|
||||
for name in HTTP_HANDLER_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_http_handlers
|
||||
for name in DOTPROMPT_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_dotprompt
|
||||
for name in LLM_CONFIG_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_configs
|
||||
for name in TYPES_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_types
|
||||
for name in LLM_PROVIDER_LOGIC_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_provider_logic
|
||||
for name in UTILS_MODULE_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_utils_module
|
||||
|
||||
return _LAZY_IMPORT_REGISTRY
|
||||
|
||||
|
||||
def _generic_lazy_import(
|
||||
name: str, import_map: dict[str, tuple[str, str]], category: str
|
||||
) -> Any:
|
||||
"""
|
||||
Generic function that handles lazy importing for most attributes.
|
||||
|
||||
This is the workhorse function - it does the actual importing and caching.
|
||||
Most handler functions just call this with their specific import map.
|
||||
|
||||
Steps:
|
||||
1. Check if the name exists in the import map (if not, raise error)
|
||||
2. Check if we've already imported it (if yes, return cached value)
|
||||
3. Look up where to find it (module_path and attr_name from the map)
|
||||
4. Import the module (Python caches this automatically)
|
||||
5. Get the attribute from the module
|
||||
6. Cache it in _globals so we don't import again
|
||||
7. Return it
|
||||
|
||||
Args:
|
||||
name: The attribute name someone is trying to access (e.g., "ModelResponse")
|
||||
import_map: Dictionary telling us where to find each attribute
|
||||
Format: {"ModelResponse": (".utils", "ModelResponse")}
|
||||
category: Just for error messages (e.g., "Utils", "Cost calculator")
|
||||
"""
|
||||
# Step 1: Make sure this attribute exists in our map
|
||||
if name not in import_map:
|
||||
raise AttributeError(f"{category} lazy import: unknown attribute {name!r}")
|
||||
|
||||
# Step 2: Get the cache (where we store imported things)
|
||||
_globals = _get_litellm_globals()
|
||||
|
||||
# Step 3: If we've already imported it, just return the cached version
|
||||
if name in _globals:
|
||||
return _globals[name]
|
||||
|
||||
# Step 4: Look up where to find this attribute
|
||||
# The map tells us: (module_path, attribute_name)
|
||||
# Example: (".utils", "ModelResponse") means "look in .utils module, get ModelResponse"
|
||||
module_path, attr_name = import_map[name]
|
||||
|
||||
# Step 5: Import the module
|
||||
# Python automatically caches modules in sys.modules, so calling this twice is fast
|
||||
# If module_path starts with ".", it's a relative import (needs package="litellm")
|
||||
# Otherwise it's an absolute import (like "litellm.caching.caching")
|
||||
if module_path.startswith("."):
|
||||
module = importlib.import_module(module_path, package="litellm")
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Step 6: Get the actual attribute from the module
|
||||
# Example: getattr(utils_module, "ModelResponse") returns the ModelResponse class
|
||||
value = getattr(module, attr_name)
|
||||
|
||||
# Step 7: Cache it so we don't have to import again next time
|
||||
_globals[name] = value
|
||||
|
||||
# Step 8: Return it
|
||||
return value
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HANDLER FUNCTIONS
|
||||
# ============================================================================
|
||||
# These functions are called when someone accesses a lazy-loaded attribute.
|
||||
# Most of them just call _generic_lazy_import with their specific import map.
|
||||
# The registry (above) maps attribute names to these handler functions.
|
||||
|
||||
|
||||
def _lazy_import_utils(name: str) -> Any:
|
||||
"""Handler for utils module attributes (ModelResponse, token_counter, etc.)"""
|
||||
return _generic_lazy_import(name, _UTILS_IMPORT_MAP, "Utils")
|
||||
|
||||
|
||||
def _lazy_import_cost_calculator(name: str) -> Any:
|
||||
"""Handler for cost calculator functions (completion_cost, cost_per_token, etc.)"""
|
||||
return _generic_lazy_import(name, _COST_CALCULATOR_IMPORT_MAP, "Cost calculator")
|
||||
|
||||
|
||||
def _lazy_import_token_counter(name: str) -> Any:
|
||||
"""Handler for token counter utilities"""
|
||||
return _generic_lazy_import(name, _TOKEN_COUNTER_IMPORT_MAP, "Token counter")
|
||||
|
||||
|
||||
def _lazy_import_bedrock_types(name: str) -> Any:
|
||||
"""Handler for Bedrock type aliases"""
|
||||
return _generic_lazy_import(name, _BEDROCK_TYPES_IMPORT_MAP, "Bedrock types")
|
||||
|
||||
|
||||
def _lazy_import_types_utils(name: str) -> Any:
|
||||
"""Handler for types from litellm.types.utils (BudgetConfig, ImageObject, etc.)"""
|
||||
return _generic_lazy_import(name, _TYPES_UTILS_IMPORT_MAP, "Types utils")
|
||||
|
||||
|
||||
def _lazy_import_caching(name: str) -> Any:
|
||||
"""Handler for caching classes (Cache, DualCache, RedisCache, etc.)"""
|
||||
return _generic_lazy_import(name, _CACHING_IMPORT_MAP, "Caching")
|
||||
|
||||
|
||||
def _lazy_import_dotprompt(name: str) -> Any:
|
||||
"""Handler for dotprompt integration globals"""
|
||||
return _generic_lazy_import(name, _DOTPROMPT_IMPORT_MAP, "Dotprompt")
|
||||
|
||||
|
||||
def _lazy_import_types(name: str) -> Any:
|
||||
"""Handler for type classes (GuardrailItem, etc.)"""
|
||||
return _generic_lazy_import(name, _TYPES_IMPORT_MAP, "Types")
|
||||
|
||||
|
||||
def _lazy_import_llm_configs(name: str) -> Any:
|
||||
"""Handler for LLM config classes (AnthropicConfig, OpenAILikeChatConfig, etc.)"""
|
||||
return _generic_lazy_import(name, _LLM_CONFIGS_IMPORT_MAP, "LLM config")
|
||||
|
||||
|
||||
def _lazy_import_litellm_logging(name: str) -> Any:
|
||||
"""Handler for litellm_logging module (Logging, modify_integration)"""
|
||||
return _generic_lazy_import(name, _LITELLM_LOGGING_IMPORT_MAP, "Litellm logging")
|
||||
|
||||
|
||||
def _lazy_import_llm_provider_logic(name: str) -> Any:
|
||||
"""Handler for LLM provider logic functions (get_llm_provider, etc.)"""
|
||||
return _generic_lazy_import(
|
||||
name, _LLM_PROVIDER_LOGIC_IMPORT_MAP, "LLM provider logic"
|
||||
)
|
||||
|
||||
|
||||
def _lazy_import_utils_module(name: str) -> Any:
|
||||
"""
|
||||
Handler for utils module lazy imports.
|
||||
|
||||
This uses a custom implementation because utils module needs to use
|
||||
_get_utils_globals() instead of _get_litellm_globals() for caching.
|
||||
"""
|
||||
# Check if this attribute exists in our map
|
||||
if name not in _UTILS_MODULE_IMPORT_MAP:
|
||||
raise AttributeError(f"Utils module lazy import: unknown attribute {name!r}")
|
||||
|
||||
# Get the cache (where we store imported things) - use utils globals
|
||||
_globals = _get_utils_globals()
|
||||
|
||||
# If we've already imported it, just return the cached version
|
||||
if name in _globals:
|
||||
return _globals[name]
|
||||
|
||||
# Look up where to find this attribute
|
||||
module_path, attr_name = _UTILS_MODULE_IMPORT_MAP[name]
|
||||
|
||||
# Import the module
|
||||
if module_path.startswith("."):
|
||||
module = importlib.import_module(module_path, package="litellm")
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Get the actual attribute from the module
|
||||
value = getattr(module, attr_name)
|
||||
|
||||
# Cache it so we don't have to import again next time
|
||||
_globals[name] = value
|
||||
|
||||
# Return it
|
||||
return value
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SPECIAL HANDLERS
|
||||
# ============================================================================
|
||||
# These handlers have custom logic that doesn't fit the generic pattern
|
||||
|
||||
|
||||
def _lazy_import_llm_client_cache(name: str) -> Any:
|
||||
"""
|
||||
Handler for LLM client cache - has special logic for singleton instance.
|
||||
|
||||
This one is different because:
|
||||
- "LLMClientCache" is the class itself
|
||||
- "in_memory_llm_clients_cache" is a singleton instance of that class
|
||||
So we need custom logic to handle both cases.
|
||||
"""
|
||||
_globals = _get_litellm_globals()
|
||||
|
||||
# If already cached, return it
|
||||
if name in _globals:
|
||||
return _globals[name]
|
||||
|
||||
# Import the class
|
||||
module = importlib.import_module("litellm.caching.llm_caching_handler")
|
||||
LLMClientCache = getattr(module, "LLMClientCache")
|
||||
|
||||
# If they want the class itself, return it
|
||||
if name == "LLMClientCache":
|
||||
_globals["LLMClientCache"] = LLMClientCache
|
||||
return LLMClientCache
|
||||
|
||||
# If they want the singleton instance, create it (only once)
|
||||
if name == "in_memory_llm_clients_cache":
|
||||
instance = LLMClientCache()
|
||||
_globals["in_memory_llm_clients_cache"] = instance
|
||||
return instance
|
||||
|
||||
raise AttributeError(f"LLM client cache lazy import: unknown attribute {name!r}")
|
||||
|
||||
|
||||
def _lazy_import_http_handlers(name: str) -> Any:
|
||||
"""
|
||||
Handler for HTTP clients - has special logic for creating client instances.
|
||||
|
||||
This one is different because:
|
||||
- These aren't just imports, they're actual client instances that need to be created
|
||||
- They need configuration (timeout, etc.) from the module globals
|
||||
- They use factory functions instead of direct instantiation
|
||||
"""
|
||||
_globals = _get_litellm_globals()
|
||||
|
||||
if name == "module_level_aclient":
|
||||
# Create an async HTTP client using the factory function
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
# Get timeout from module config (if set)
|
||||
timeout = _globals.get("request_timeout")
|
||||
params = {"timeout": timeout, "client_alias": "module level aclient"}
|
||||
|
||||
# Create the client instance
|
||||
provider_id = cast(Any, "litellm_module_level_client")
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=provider_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Cache it so we don't create it again
|
||||
_globals["module_level_aclient"] = async_client
|
||||
return async_client
|
||||
|
||||
if name == "module_level_client":
|
||||
# Create a sync HTTP client
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
timeout = _globals.get("request_timeout")
|
||||
sync_client = HTTPHandler(timeout=timeout)
|
||||
|
||||
# Cache it
|
||||
_globals["module_level_client"] = sync_client
|
||||
return sync_client
|
||||
|
||||
raise AttributeError(f"HTTP handlers lazy import: unknown attribute {name!r}")
|
||||
File diff suppressed because it is too large
Load Diff
352
llm-gateway-competitors/litellm-wheel-src/litellm/_logging.py
Normal file
352
llm-gateway-competitors/litellm-wheel-src/litellm/_logging.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from logging import Formatter
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
|
||||
set_verbose = False
|
||||
|
||||
if set_verbose is True:
|
||||
logging.warning(
|
||||
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
|
||||
)
|
||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||
numeric_level: str = getattr(logging, log_level.upper())
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(numeric_level)
|
||||
|
||||
|
||||
def _try_parse_json_message(message: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Try to parse a log message as JSON. Returns parsed dict if valid, else None.
|
||||
Handles messages that are entirely valid JSON (e.g. json.dumps output).
|
||||
Uses shared safe_json_loads for consistent error handling.
|
||||
"""
|
||||
if not message or not isinstance(message, str):
|
||||
return None
|
||||
msg_stripped = message.strip()
|
||||
if not (msg_stripped.startswith("{") or msg_stripped.startswith("[")):
|
||||
return None
|
||||
parsed = safe_json_loads(message, default=None)
|
||||
if parsed is None or not isinstance(parsed, dict):
|
||||
return None
|
||||
return parsed
|
||||
|
||||
|
||||
def _try_parse_embedded_python_dict(message: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Try to find and parse a Python dict repr (e.g. str(d) or repr(d)) embedded in
|
||||
the message. Handles patterns like:
|
||||
"get_available_deployment for model: X, Selected deployment: {'model_name': '...', ...} for model: X"
|
||||
Uses ast.literal_eval for safe parsing. Returns the parsed dict or None.
|
||||
"""
|
||||
if not message or not isinstance(message, str) or "{" not in message:
|
||||
return None
|
||||
i = 0
|
||||
while i < len(message):
|
||||
start = message.find("{", i)
|
||||
if start == -1:
|
||||
break
|
||||
depth = 0
|
||||
for j in range(start, len(message)):
|
||||
c = message[j]
|
||||
if c == "{":
|
||||
depth += 1
|
||||
elif c == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
substr = message[start : j + 1]
|
||||
try:
|
||||
result = ast.literal_eval(substr)
|
||||
if isinstance(result, dict) and len(result) > 0:
|
||||
return result
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
break
|
||||
i = start + 1
|
||||
return None
|
||||
|
||||
|
||||
# Standard LogRecord attribute names - used to identify 'extra' fields.
|
||||
# Derived at runtime so we automatically include version-specific attrs (e.g. taskName).
|
||||
def _get_standard_record_attrs() -> frozenset:
|
||||
"""Standard LogRecord attribute names - excludes extra keys from logger.debug(..., extra={...})."""
|
||||
return frozenset(logging.LogRecord("", 0, "", 0, "", (), None).__dict__.keys())
|
||||
|
||||
|
||||
_STANDARD_RECORD_ATTRS = _get_standard_record_attrs()
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
def __init__(self):
|
||||
super(JsonFormatter, self).__init__()
|
||||
|
||||
def formatTime(self, record, datefmt=None):
|
||||
# Use datetime to format the timestamp in ISO 8601 format
|
||||
dt = datetime.fromtimestamp(record.created)
|
||||
return dt.isoformat()
|
||||
|
||||
def format(self, record):
|
||||
message_str = record.getMessage()
|
||||
json_record: Dict[str, Any] = {
|
||||
"message": message_str,
|
||||
"level": record.levelname,
|
||||
"timestamp": self.formatTime(record),
|
||||
}
|
||||
|
||||
# Parse embedded JSON or Python dict repr in message so sub-fields become first-class properties
|
||||
parsed = _try_parse_json_message(message_str)
|
||||
if parsed is None:
|
||||
parsed = _try_parse_embedded_python_dict(message_str)
|
||||
if parsed is not None:
|
||||
for key, value in parsed.items():
|
||||
if key not in json_record:
|
||||
json_record[key] = value
|
||||
|
||||
# Include extra attributes passed via logger.debug("msg", extra={...})
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in _STANDARD_RECORD_ATTRS and key not in json_record:
|
||||
json_record[key] = value
|
||||
|
||||
if record.exc_info:
|
||||
json_record["stacktrace"] = self.formatException(record.exc_info)
|
||||
|
||||
return safe_dumps(json_record)
|
||||
|
||||
|
||||
# Function to set up exception handlers for JSON logging
|
||||
def _setup_json_exception_handlers(formatter):
|
||||
# Create a handler with JSON formatting for exceptions
|
||||
error_handler = logging.StreamHandler()
|
||||
error_handler.setFormatter(formatter)
|
||||
|
||||
# Setup excepthook for uncaught exceptions
|
||||
def json_excepthook(exc_type, exc_value, exc_traceback):
|
||||
record = logging.LogRecord(
|
||||
name="LiteLLM",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg=str(exc_value),
|
||||
args=(),
|
||||
exc_info=(exc_type, exc_value, exc_traceback),
|
||||
)
|
||||
error_handler.handle(record)
|
||||
|
||||
sys.excepthook = json_excepthook
|
||||
|
||||
# Configure asyncio exception handler if possible
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
def async_json_exception_handler(loop, context):
|
||||
exception = context.get("exception")
|
||||
if exception:
|
||||
record = logging.LogRecord(
|
||||
name="LiteLLM",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg=str(exception),
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
error_handler.handle(record)
|
||||
else:
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Create a formatter and set it for the handler
|
||||
if json_logs:
|
||||
handler.setFormatter(JsonFormatter())
|
||||
_setup_json_exception_handlers(JsonFormatter())
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
|
||||
verbose_router_logger = logging.getLogger("LiteLLM Router")
|
||||
verbose_logger = logging.getLogger("LiteLLM")
|
||||
|
||||
# Add the handler to the logger
|
||||
verbose_router_logger.addHandler(handler)
|
||||
verbose_proxy_logger.addHandler(handler)
|
||||
verbose_logger.addHandler(handler)
|
||||
|
||||
|
||||
def _suppress_loggers():
|
||||
"""Suppress noisy loggers at INFO level"""
|
||||
# Suppress httpx request logging at INFO level
|
||||
httpx_logger = logging.getLogger("httpx")
|
||||
httpx_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Suppress APScheduler logging at INFO level
|
||||
apscheduler_executors_logger = logging.getLogger("apscheduler.executors.default")
|
||||
apscheduler_executors_logger.setLevel(logging.WARNING)
|
||||
apscheduler_scheduler_logger = logging.getLogger("apscheduler.scheduler")
|
||||
apscheduler_scheduler_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
# Call the suppression function
|
||||
_suppress_loggers()
|
||||
|
||||
ALL_LOGGERS = [
|
||||
logging.getLogger(),
|
||||
verbose_logger,
|
||||
verbose_router_logger,
|
||||
verbose_proxy_logger,
|
||||
]
|
||||
|
||||
|
||||
def _get_loggers_to_initialize():
|
||||
"""
|
||||
Get all loggers that should be initialized with the JSON handler.
|
||||
|
||||
Includes third-party integration loggers (like langfuse) if they are
|
||||
configured as callbacks.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
loggers = list(ALL_LOGGERS)
|
||||
|
||||
# Add langfuse logger if langfuse is being used as a callback
|
||||
langfuse_callbacks = {"langfuse", "langfuse_otel"}
|
||||
all_callbacks = set(litellm.success_callback + litellm.failure_callback)
|
||||
if langfuse_callbacks & all_callbacks:
|
||||
loggers.append(logging.getLogger("langfuse"))
|
||||
|
||||
return loggers
|
||||
|
||||
|
||||
def _initialize_loggers_with_handler(handler: logging.Handler):
|
||||
"""
|
||||
Initialize all loggers with a handler
|
||||
|
||||
- Adds a handler to each logger
|
||||
- Prevents bubbling to parent/root (critical to prevent duplicate JSON logs)
|
||||
"""
|
||||
for lg in _get_loggers_to_initialize():
|
||||
lg.handlers.clear() # remove any existing handlers
|
||||
lg.addHandler(handler) # add JSON formatter handler
|
||||
lg.propagate = False # prevent bubbling to parent/root
|
||||
|
||||
|
||||
def _get_uvicorn_json_log_config():
|
||||
"""
|
||||
Generate a uvicorn log_config dictionary that applies JSON formatting to all loggers.
|
||||
|
||||
This ensures that uvicorn's access logs, error logs, and all application logs
|
||||
are formatted as JSON when json_logs is enabled.
|
||||
"""
|
||||
json_formatter_class = "litellm._logging.JsonFormatter"
|
||||
|
||||
# Use the module-level log_level variable for consistency
|
||||
uvicorn_log_level = log_level.upper()
|
||||
|
||||
log_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"json": {
|
||||
"()": json_formatter_class,
|
||||
},
|
||||
"default": {
|
||||
"()": json_formatter_class,
|
||||
},
|
||||
"access": {
|
||||
"()": json_formatter_class,
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"formatter": "json",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"access": {
|
||||
"formatter": "access",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn": {
|
||||
"handlers": ["default"],
|
||||
"level": uvicorn_log_level,
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": uvicorn_log_level,
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["access"],
|
||||
"level": uvicorn_log_level,
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return log_config
|
||||
|
||||
|
||||
def _turn_on_json():
|
||||
"""
|
||||
Turn on JSON logging
|
||||
|
||||
- Adds a JSON formatter to all loggers
|
||||
"""
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JsonFormatter())
|
||||
_initialize_loggers_with_handler(handler)
|
||||
# Set up exception handlers
|
||||
_setup_json_exception_handlers(JsonFormatter())
|
||||
|
||||
|
||||
def _turn_on_debug():
|
||||
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
||||
|
||||
|
||||
def _disable_debugging():
|
||||
verbose_logger.disabled = True
|
||||
verbose_router_logger.disabled = True
|
||||
verbose_proxy_logger.disabled = True
|
||||
|
||||
|
||||
def _enable_debugging():
|
||||
verbose_logger.disabled = False
|
||||
verbose_router_logger.disabled = False
|
||||
verbose_proxy_logger.disabled = False
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _is_debugging_on() -> bool:
|
||||
"""
|
||||
Returns True if debugging is on
|
||||
"""
|
||||
return verbose_logger.isEnabledFor(logging.DEBUG) or set_verbose is True
|
||||
598
llm-gateway-competitors/litellm-wheel-src/litellm/_redis.py
Normal file
598
llm-gateway-competitors/litellm-wheel-src/litellm/_redis.py
Normal file
@@ -0,0 +1,598 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import redis # type: ignore
|
||||
import redis.asyncio as async_redis # type: ignore
|
||||
|
||||
from litellm import get_secret, get_secret_str
|
||||
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
||||
|
||||
def _get_redis_kwargs():
|
||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
include_args = [
|
||||
"url",
|
||||
"redis_connect_func",
|
||||
"gcp_service_account",
|
||||
"gcp_ssl_ca_certs",
|
||||
]
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_url_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
include_args = ["url"]
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_cluster_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.RedisCluster)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||
available_args.append("password")
|
||||
available_args.append("username")
|
||||
available_args.append("ssl")
|
||||
available_args.append("ssl_cert_reqs")
|
||||
available_args.append("ssl_check_hostname")
|
||||
available_args.append("ssl_ca_certs")
|
||||
available_args.append(
|
||||
"redis_connect_func"
|
||||
) # Needed for sync clusters and IAM detection
|
||||
available_args.append("gcp_service_account")
|
||||
available_args.append("gcp_ssl_ca_certs")
|
||||
available_args.append("max_connections")
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_env_kwarg_mapping():
|
||||
PREFIX = "REDIS_"
|
||||
|
||||
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
|
||||
|
||||
|
||||
def _redis_kwargs_from_environment():
|
||||
mapping = _get_redis_env_kwarg_mapping()
|
||||
|
||||
return_dict = {}
|
||||
for k, v in mapping.items():
|
||||
value = get_secret(k, default_value=None) # type: ignore
|
||||
if value is not None:
|
||||
return_dict[v] = value
|
||||
return return_dict
|
||||
|
||||
|
||||
def _generate_gcp_iam_access_token(service_account: str) -> str:
|
||||
"""
|
||||
Generate GCP IAM access token for Redis authentication.
|
||||
|
||||
Args:
|
||||
service_account: GCP service account in format 'projects/-/serviceAccounts/name@project.iam.gserviceaccount.com'
|
||||
|
||||
Returns:
|
||||
Access token string for GCP IAM authentication
|
||||
"""
|
||||
try:
|
||||
from google.cloud import iam_credentials_v1
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-cloud-iam is required for GCP IAM Redis authentication. "
|
||||
"Install it with: pip install google-cloud-iam"
|
||||
)
|
||||
|
||||
client = iam_credentials_v1.IAMCredentialsClient()
|
||||
request = iam_credentials_v1.GenerateAccessTokenRequest(
|
||||
name=service_account,
|
||||
scope=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
response = client.generate_access_token(request=request)
|
||||
return str(response.access_token)
|
||||
|
||||
|
||||
def create_gcp_iam_redis_connect_func(
|
||||
service_account: str,
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Creates a custom Redis connection function for GCP IAM authentication.
|
||||
|
||||
Args:
|
||||
service_account: GCP service account in format 'projects/-/serviceAccounts/name@project.iam.gserviceaccount.com'
|
||||
ssl_ca_certs: Path to SSL CA certificate file for secure connections
|
||||
|
||||
Returns:
|
||||
A connection function that can be used with Redis clients
|
||||
"""
|
||||
|
||||
def iam_connect(self):
|
||||
"""Initialize the connection and authenticate using GCP IAM"""
|
||||
from redis.exceptions import (
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
)
|
||||
from redis.utils import str_if_bytes
|
||||
|
||||
self._parser.on_connect(self)
|
||||
|
||||
auth_args = (_generate_gcp_iam_access_token(service_account),)
|
||||
self.send_command("AUTH", *auth_args, check_health=False)
|
||||
|
||||
try:
|
||||
auth_response = self.read_response()
|
||||
except AuthenticationWrongNumberOfArgsError:
|
||||
# Fallback to password auth if IAM fails
|
||||
if hasattr(self, "password") and self.password:
|
||||
self.send_command("AUTH", self.password, check_health=False)
|
||||
auth_response = self.read_response()
|
||||
else:
|
||||
raise
|
||||
|
||||
if str_if_bytes(auth_response) != "OK":
|
||||
raise AuthenticationError("GCP IAM authentication failed")
|
||||
|
||||
return iam_connect
|
||||
|
||||
|
||||
def get_redis_url_from_environment():
|
||||
if "REDIS_URL" in os.environ:
|
||||
return os.environ["REDIS_URL"]
|
||||
|
||||
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
|
||||
raise ValueError(
|
||||
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
|
||||
)
|
||||
|
||||
if "REDIS_SSL" in os.environ and os.environ["REDIS_SSL"].lower() == "true":
|
||||
redis_protocol = "rediss"
|
||||
else:
|
||||
redis_protocol = "redis"
|
||||
|
||||
# Build authentication part of URL
|
||||
auth_part = ""
|
||||
if "REDIS_USERNAME" in os.environ and "REDIS_PASSWORD" in os.environ:
|
||||
auth_part = f"{os.environ['REDIS_USERNAME']}:{os.environ['REDIS_PASSWORD']}@"
|
||||
elif "REDIS_PASSWORD" in os.environ:
|
||||
auth_part = f"{os.environ['REDIS_PASSWORD']}@"
|
||||
|
||||
return f"{redis_protocol}://{auth_part}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||
|
||||
|
||||
def _get_redis_client_logic(**env_overrides):
|
||||
"""
|
||||
Common functionality across sync + async redis client implementations
|
||||
"""
|
||||
### check if "os.environ/<key-name>" passed in
|
||||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
v = v.replace("os.environ/", "")
|
||||
value = get_secret(v) # type: ignore
|
||||
env_overrides[k] = value
|
||||
|
||||
redis_kwargs = {
|
||||
**_redis_kwargs_from_environment(),
|
||||
**env_overrides,
|
||||
}
|
||||
|
||||
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_CLUSTER_NODES"
|
||||
)
|
||||
|
||||
if _startup_nodes is not None and isinstance(_startup_nodes, str):
|
||||
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
|
||||
|
||||
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_SENTINEL_NODES"
|
||||
)
|
||||
|
||||
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
|
||||
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
|
||||
|
||||
_sentinel_password: Optional[str] = redis_kwargs.get(
|
||||
"sentinel_password", None
|
||||
) or get_secret_str("REDIS_SENTINEL_PASSWORD")
|
||||
|
||||
if _sentinel_password is not None:
|
||||
redis_kwargs["sentinel_password"] = _sentinel_password
|
||||
|
||||
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
|
||||
"REDIS_SERVICE_NAME"
|
||||
)
|
||||
|
||||
if _service_name is not None:
|
||||
redis_kwargs["service_name"] = _service_name
|
||||
|
||||
# Handle GCP IAM authentication
|
||||
_gcp_service_account = redis_kwargs.get("gcp_service_account") or get_secret_str(
|
||||
"REDIS_GCP_SERVICE_ACCOUNT"
|
||||
)
|
||||
_gcp_ssl_ca_certs = redis_kwargs.get("gcp_ssl_ca_certs") or get_secret_str(
|
||||
"REDIS_GCP_SSL_CA_CERTS"
|
||||
)
|
||||
|
||||
if _gcp_service_account is not None:
|
||||
verbose_logger.debug(
|
||||
"Setting up GCP IAM authentication for Redis with service account."
|
||||
)
|
||||
redis_kwargs["redis_connect_func"] = create_gcp_iam_redis_connect_func(
|
||||
service_account=_gcp_service_account, ssl_ca_certs=_gcp_ssl_ca_certs
|
||||
)
|
||||
# Store GCP service account in redis_connect_func for async cluster access
|
||||
redis_kwargs["redis_connect_func"]._gcp_service_account = _gcp_service_account
|
||||
|
||||
# Remove GCP-specific kwargs that shouldn't be passed to Redis client
|
||||
redis_kwargs.pop("gcp_service_account", None)
|
||||
redis_kwargs.pop("gcp_ssl_ca_certs", None)
|
||||
|
||||
# Only enable SSL if explicitly requested AND SSL CA certs are provided
|
||||
if _gcp_ssl_ca_certs and redis_kwargs.get("ssl", False):
|
||||
redis_kwargs["ssl_ca_certs"] = _gcp_ssl_ca_certs
|
||||
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop("host", None)
|
||||
redis_kwargs.pop("port", None)
|
||||
redis_kwargs.pop("db", None)
|
||||
redis_kwargs.pop("password", None)
|
||||
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
|
||||
pass
|
||||
elif (
|
||||
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
|
||||
):
|
||||
pass
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
|
||||
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis_kwargs
|
||||
|
||||
|
||||
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
||||
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
|
||||
if _redis_cluster_nodes_in_env is not None:
|
||||
try:
|
||||
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
cluster_kwargs.pop("startup_nodes", None)
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
|
||||
|
||||
|
||||
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
sentinel_password = redis_kwargs.get("sentinel_password")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = redis.Sentinel(
|
||||
sentinel_nodes,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
password=sentinel_password,
|
||||
)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
sentinel_password = redis_kwargs.get("sentinel_password")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = async_redis.Sentinel(
|
||||
sentinel_nodes,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
password=sentinel_password,
|
||||
)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
args = _get_redis_url_kwargs()
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
return redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
|
||||
return init_redis_cluster(redis_kwargs)
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_redis_sentinel(redis_kwargs)
|
||||
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
def get_redis_async_client(
|
||||
connection_pool: Optional[async_redis.BlockingConnectionPool] = None,
|
||||
**env_overrides,
|
||||
) -> Union[async_redis.Redis, async_redis.RedisCluster]:
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
if connection_pool is not None:
|
||||
return async_redis.Redis(connection_pool=connection_pool)
|
||||
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
|
||||
arg
|
||||
)
|
||||
)
|
||||
return async_redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs:
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
# Handle GCP IAM authentication for async clusters
|
||||
redis_connect_func = cluster_kwargs.pop("redis_connect_func", None)
|
||||
from litellm import get_secret_str
|
||||
|
||||
# Get GCP service account - first try from redis_connect_func, then from environment
|
||||
gcp_service_account = None
|
||||
if redis_connect_func and hasattr(redis_connect_func, "_gcp_service_account"):
|
||||
gcp_service_account = redis_connect_func._gcp_service_account
|
||||
else:
|
||||
gcp_service_account = redis_kwargs.get(
|
||||
"gcp_service_account"
|
||||
) or get_secret_str("REDIS_GCP_SERVICE_ACCOUNT")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"DEBUG: Redis cluster kwargs: redis_connect_func={redis_connect_func is not None}, gcp_service_account_provided={gcp_service_account is not None}"
|
||||
)
|
||||
|
||||
# If GCP IAM is configured (indicated by redis_connect_func), generate access token and use as password
|
||||
if redis_connect_func and gcp_service_account:
|
||||
verbose_logger.debug(
|
||||
"DEBUG: Generating IAM token for service account (value not logged for security reasons)"
|
||||
)
|
||||
try:
|
||||
# Generate IAM access token using the helper function
|
||||
access_token = _generate_gcp_iam_access_token(gcp_service_account)
|
||||
cluster_kwargs["password"] = access_token
|
||||
verbose_logger.debug(
|
||||
"DEBUG: Successfully generated GCP IAM access token for async Redis cluster"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to generate GCP IAM access token: {e}")
|
||||
from redis.exceptions import AuthenticationError
|
||||
|
||||
raise AuthenticationError("Failed to generate GCP IAM access token")
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"DEBUG: Not using GCP IAM auth - redis_connect_func={redis_connect_func is not None}, gcp_service_account_provided={gcp_service_account is not None}"
|
||||
)
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
cluster_kwargs.pop("startup_nodes", None)
|
||||
|
||||
# Create async RedisCluster with IAM token as password if available
|
||||
cluster_client = async_redis.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
||||
)
|
||||
|
||||
return cluster_client
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_async_redis_sentinel(redis_kwargs)
|
||||
_pretty_print_redis_config(redis_kwargs=redis_kwargs)
|
||||
|
||||
if connection_pool is not None:
|
||||
redis_kwargs["connection_pool"] = connection_pool
|
||||
|
||||
return async_redis.Redis(
|
||||
**redis_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_redis_connection_pool(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
pool_kwargs = {
|
||||
"timeout": REDIS_CONNECTION_POOL_TIMEOUT,
|
||||
"url": redis_kwargs["url"],
|
||||
}
|
||||
if "max_connections" in redis_kwargs:
|
||||
try:
|
||||
pool_kwargs["max_connections"] = int(redis_kwargs["max_connections"])
|
||||
except (TypeError, ValueError):
|
||||
verbose_logger.warning(
|
||||
"REDIS: invalid max_connections value %r, ignoring",
|
||||
redis_kwargs["max_connections"],
|
||||
)
|
||||
return async_redis.BlockingConnectionPool.from_url(**pool_kwargs)
|
||||
connection_class = async_redis.Connection
|
||||
if "ssl" in redis_kwargs:
|
||||
connection_class = async_redis.SSLConnection
|
||||
redis_kwargs.pop("ssl", None)
|
||||
redis_kwargs["connection_class"] = connection_class
|
||||
redis_kwargs.pop("startup_nodes", None)
|
||||
return async_redis.BlockingConnectionPool(
|
||||
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
|
||||
)
|
||||
|
||||
|
||||
def _pretty_print_redis_config(redis_kwargs: dict) -> None:
|
||||
"""Pretty print the Redis configuration using rich with sensitive data masking"""
|
||||
try:
|
||||
import logging
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
if not verbose_logger.isEnabledFor(logging.DEBUG):
|
||||
return
|
||||
|
||||
console = Console()
|
||||
|
||||
# Initialize the sensitive data masker
|
||||
masker = SensitiveDataMasker()
|
||||
|
||||
# Mask sensitive data in redis_kwargs
|
||||
masked_redis_kwargs = masker.mask_dict(redis_kwargs)
|
||||
|
||||
# Create main panel title
|
||||
title = Text("Redis Configuration", style="bold blue")
|
||||
|
||||
# Create configuration table
|
||||
config_table = Table(
|
||||
title="🔧 Redis Connection Parameters",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
title_justify="left",
|
||||
)
|
||||
config_table.add_column("Parameter", style="cyan", no_wrap=True)
|
||||
config_table.add_column("Value", style="yellow")
|
||||
|
||||
# Add rows for each configuration parameter
|
||||
for key, value in masked_redis_kwargs.items():
|
||||
if value is not None:
|
||||
# Special handling for complex objects
|
||||
if isinstance(value, list):
|
||||
if key == "startup_nodes" and value:
|
||||
# Special handling for cluster nodes
|
||||
value_str = f"[{len(value)} cluster nodes]"
|
||||
elif key == "sentinel_nodes" and value:
|
||||
# Special handling for sentinel nodes
|
||||
value_str = f"[{len(value)} sentinel nodes]"
|
||||
else:
|
||||
value_str = str(value)
|
||||
else:
|
||||
value_str = str(value)
|
||||
|
||||
config_table.add_row(key, value_str)
|
||||
|
||||
# Determine connection type
|
||||
connection_type = "Standard Redis"
|
||||
if masked_redis_kwargs.get("startup_nodes"):
|
||||
connection_type = "Redis Cluster"
|
||||
elif masked_redis_kwargs.get("sentinel_nodes"):
|
||||
connection_type = "Redis Sentinel"
|
||||
elif masked_redis_kwargs.get("url"):
|
||||
connection_type = "Redis (URL-based)"
|
||||
|
||||
# Create connection type info
|
||||
info_table = Table(
|
||||
title="📊 Connection Info",
|
||||
show_header=True,
|
||||
header_style="bold green",
|
||||
title_justify="left",
|
||||
)
|
||||
info_table.add_column("Property", style="cyan", no_wrap=True)
|
||||
info_table.add_column("Value", style="yellow")
|
||||
info_table.add_row("Connection Type", connection_type)
|
||||
|
||||
# Print everything in a nice panel
|
||||
console.print("\n")
|
||||
console.print(Panel(title, border_style="blue"))
|
||||
console.print(info_table)
|
||||
console.print(config_table)
|
||||
console.print("\n")
|
||||
|
||||
except ImportError:
|
||||
# Fallback to simple logging if rich is not available
|
||||
masker = SensitiveDataMasker()
|
||||
masked_redis_kwargs = masker.mask_dict(redis_kwargs)
|
||||
verbose_logger.info(f"Redis configuration: {masked_redis_kwargs}")
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error pretty printing Redis configuration: {e}")
|
||||
@@ -0,0 +1,323 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
from .integrations.custom_logger import CustomLogger
|
||||
from .integrations.datadog.datadog import DataDogLogger
|
||||
from .integrations.opentelemetry import OpenTelemetry
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
from .types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
OTELClass = OpenTelemetry
|
||||
else:
|
||||
Span = Any
|
||||
OTELClass = Any
|
||||
UserAPIKeyAuth = Any
|
||||
|
||||
|
||||
class ServiceLogging(CustomLogger):
|
||||
"""
|
||||
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
|
||||
"""
|
||||
|
||||
def __init__(self, mock_testing: bool = False) -> None:
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_sync_success_hook = 0
|
||||
self.mock_testing_async_success_hook = 0
|
||||
self.mock_testing_sync_failure_hook = 0
|
||||
self.mock_testing_async_failure_hook = 0
|
||||
if "prometheus_system" in litellm.service_callback:
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
|
||||
def service_success_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
duration: float,
|
||||
call_type: str,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
):
|
||||
"""
|
||||
Handles both sync and async monitoring by checking for existing event loop.
|
||||
"""
|
||||
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_success_hook += 1
|
||||
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
# Check if the loop is running
|
||||
if loop.is_running():
|
||||
# If we're in a running loop, create a task
|
||||
loop.create_task(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Loop exists but not running, we can use run_until_complete
|
||||
loop.run_until_complete(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one and run
|
||||
asyncio.run(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
|
||||
def service_failure_hook(
|
||||
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_failure_hook += 1
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
call_type: str,
|
||||
duration: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is successful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_success_hook += 1
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=False,
|
||||
error=None,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
await self.prometheusServicesLogger.async_service_success_hook(
|
||||
payload=payload
|
||||
)
|
||||
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
||||
await self.init_datadog_logger_if_none()
|
||||
await self.dd_logger.async_service_success_hook(
|
||||
payload=payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use: Optional[OpenTelemetry] = None
|
||||
if isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use = callback
|
||||
else:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None and isinstance(
|
||||
open_telemetry_logger, OpenTelemetry
|
||||
):
|
||||
_otel_logger_to_use = open_telemetry_logger
|
||||
|
||||
if _otel_logger_to_use is not None and parent_otel_span is not None:
|
||||
await _otel_logger_to_use.async_service_success_hook(
|
||||
payload=payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
async def init_prometheus_services_logger_if_none(self):
|
||||
"""
|
||||
initializes prometheusServicesLogger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
if not hasattr(self, "prometheusServicesLogger"):
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
elif self.prometheusServicesLogger is None:
|
||||
self.prometheusServicesLogger = self.prometheusServicesLogger()
|
||||
return
|
||||
|
||||
async def init_datadog_logger_if_none(self):
|
||||
"""
|
||||
initializes dd_logger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
from litellm.integrations.datadog.datadog import DataDogLogger
|
||||
|
||||
if not hasattr(self, "dd_logger"):
|
||||
self.dd_logger: DataDogLogger = DataDogLogger()
|
||||
|
||||
return
|
||||
|
||||
async def init_otel_logger_if_none(self):
|
||||
"""
|
||||
initializes otel_logger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if not hasattr(self, "otel_logger"):
|
||||
if open_telemetry_logger is not None and isinstance(
|
||||
open_telemetry_logger, OpenTelemetry
|
||||
):
|
||||
self.otel_logger: OpenTelemetry = open_telemetry_logger
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
"ServiceLogger: open_telemetry_logger is None or not an instance of OpenTelemetry"
|
||||
)
|
||||
return
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
duration: float,
|
||||
error: Union[str, Exception],
|
||||
call_type: str,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is unsuccessful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_failure_hook += 1
|
||||
|
||||
error_message = ""
|
||||
if isinstance(error, Exception):
|
||||
error_message = str(error)
|
||||
elif isinstance(error, str):
|
||||
error_message = error
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=True,
|
||||
error=error_message,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
await self.prometheusServicesLogger.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error,
|
||||
)
|
||||
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
||||
await self.init_datadog_logger_if_none()
|
||||
await self.dd_logger.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error_message,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use: Optional[OpenTelemetry] = None
|
||||
if isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use = callback
|
||||
else:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None and isinstance(
|
||||
open_telemetry_logger, OpenTelemetry
|
||||
):
|
||||
_otel_logger_to_use = open_telemetry_logger
|
||||
|
||||
if not isinstance(error, str):
|
||||
error = str(error)
|
||||
|
||||
if _otel_logger_to_use is not None and parent_otel_span is not None:
|
||||
await _otel_logger_to_use.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Hook to track failed litellm-service calls
|
||||
"""
|
||||
return await super().async_post_call_failure_hook(
|
||||
request_data,
|
||||
original_exception,
|
||||
user_api_key_dict,
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Hook to track latency for litellm proxy llm api calls
|
||||
"""
|
||||
try:
|
||||
_duration = end_time - start_time
|
||||
if isinstance(_duration, timedelta):
|
||||
_duration = _duration.total_seconds()
|
||||
elif isinstance(_duration, float):
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
"Duration={} is not a float or timedelta object. type={}".format(
|
||||
_duration, type(_duration)
|
||||
)
|
||||
) # invalid _duration value
|
||||
# Batch polling callbacks (check_batch_cost) don't include call_type in kwargs.
|
||||
# Use .get() to avoid KeyError.
|
||||
await self.async_service_success_hook(
|
||||
service=ServiceTypes.LITELLM,
|
||||
duration=_duration,
|
||||
call_type=kwargs.get("call_type", "unknown"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
16
llm-gateway-competitors/litellm-wheel-src/litellm/_uuid.py
Normal file
16
llm-gateway-competitors/litellm-wheel-src/litellm/_uuid.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Internal unified UUID helper.
|
||||
|
||||
Always uses fastuuid for performance.
|
||||
"""
|
||||
|
||||
import fastuuid as _uuid # type: ignore
|
||||
|
||||
|
||||
# Expose a module-like alias so callers can use: uuid.uuid4()
|
||||
uuid = _uuid
|
||||
|
||||
|
||||
def uuid4():
|
||||
"""Return a UUID4 using the selected backend."""
|
||||
return uuid.uuid4()
|
||||
@@ -0,0 +1,6 @@
|
||||
import importlib_metadata
|
||||
|
||||
try:
|
||||
version = importlib_metadata.version("litellm")
|
||||
except Exception:
|
||||
version = "unknown"
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
LiteLLM A2A - Wrapper for invoking A2A protocol agents.
|
||||
|
||||
This module provides a thin wrapper around the official `a2a` SDK that:
|
||||
- Handles httpx client creation and agent card resolution
|
||||
- Adds LiteLLM logging via @client decorator
|
||||
- Matches the A2A SDK interface (SendMessageRequest, SendMessageResponse, etc.)
|
||||
|
||||
Example usage (standalone functions with @client decorator):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
base_url="http://localhost:10001",
|
||||
request=request,
|
||||
)
|
||||
print(response.model_dump(mode='json', exclude_none=True))
|
||||
```
|
||||
|
||||
Example usage (class-based):
|
||||
```python
|
||||
from litellm.a2a_protocol import A2AClient
|
||||
|
||||
client = A2AClient(base_url="http://localhost:10001")
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.client import A2AClient
|
||||
from litellm.a2a_protocol.exceptions import (
|
||||
A2AAgentCardError,
|
||||
A2AConnectionError,
|
||||
A2AError,
|
||||
A2ALocalhostURLError,
|
||||
)
|
||||
from litellm.a2a_protocol.main import (
|
||||
aget_agent_card,
|
||||
asend_message,
|
||||
asend_message_streaming,
|
||||
create_a2a_client,
|
||||
send_message,
|
||||
)
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"A2AClient",
|
||||
# Functions
|
||||
"asend_message",
|
||||
"send_message",
|
||||
"asend_message_streaming",
|
||||
"aget_agent_card",
|
||||
"create_a2a_client",
|
||||
# Response types
|
||||
"LiteLLMSendMessageResponse",
|
||||
# Exceptions
|
||||
"A2AError",
|
||||
"A2AConnectionError",
|
||||
"A2AAgentCardError",
|
||||
"A2ALocalhostURLError",
|
||||
]
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Custom A2A Card Resolver for LiteLLM.
|
||||
|
||||
Extends the A2A SDK's card resolver to support multiple well-known paths.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import LOCALHOST_URL_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
|
||||
# Runtime imports with availability check
|
||||
_A2ACardResolver: Any = None
|
||||
AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent-card.json"
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent.json"
|
||||
|
||||
try:
|
||||
from a2a.client import A2ACardResolver as _A2ACardResolver # type: ignore[no-redef]
|
||||
from a2a.utils.constants import ( # type: ignore[no-redef]
|
||||
AGENT_CARD_WELL_KNOWN_PATH,
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def is_localhost_or_internal_url(url: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if a URL is a localhost or internal URL.
|
||||
|
||||
This detects common development URLs that are accidentally left in
|
||||
agent cards when deploying to production.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the URL is localhost/internal
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
url_lower = url.lower()
|
||||
|
||||
return any(pattern in url_lower for pattern in LOCALHOST_URL_PATTERNS)
|
||||
|
||||
|
||||
def fix_agent_card_url(agent_card: "AgentCard", base_url: str) -> "AgentCard":
|
||||
"""
|
||||
Fix the agent card URL if it contains a localhost/internal address.
|
||||
|
||||
Many A2A agents are deployed with agent cards that contain internal URLs
|
||||
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This function
|
||||
replaces such URLs with the provided base_url.
|
||||
|
||||
Args:
|
||||
agent_card: The agent card to fix
|
||||
base_url: The base URL to use as replacement
|
||||
|
||||
Returns:
|
||||
The agent card with the URL fixed if necessary
|
||||
"""
|
||||
card_url = getattr(agent_card, "url", None)
|
||||
|
||||
if card_url and is_localhost_or_internal_url(card_url):
|
||||
# Normalize base_url to ensure it ends with /
|
||||
fixed_url = base_url.rstrip("/") + "/"
|
||||
agent_card.url = fixed_url
|
||||
|
||||
return agent_card
|
||||
|
||||
|
||||
class LiteLLMA2ACardResolver(_A2ACardResolver): # type: ignore[misc]
|
||||
"""
|
||||
Custom A2A card resolver that supports multiple well-known paths.
|
||||
|
||||
Extends the base A2ACardResolver to try both:
|
||||
- /.well-known/agent-card.json (standard)
|
||||
- /.well-known/agent.json (previous/alternative)
|
||||
"""
|
||||
|
||||
async def get_agent_card(
|
||||
self,
|
||||
relative_card_path: Optional[str] = None,
|
||||
http_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> "AgentCard":
|
||||
"""
|
||||
Fetch the agent card, trying multiple well-known paths.
|
||||
|
||||
First tries the standard path, then falls back to the previous path.
|
||||
|
||||
Args:
|
||||
relative_card_path: Optional path to the agent card endpoint.
|
||||
If None, tries both well-known paths.
|
||||
http_kwargs: Optional dictionary of keyword arguments to pass to httpx.get
|
||||
|
||||
Returns:
|
||||
AgentCard from the A2A agent
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError or A2AClientJSONError if both paths fail
|
||||
"""
|
||||
# If a specific path is provided, use the parent implementation
|
||||
if relative_card_path is not None:
|
||||
return await super().get_agent_card(
|
||||
relative_card_path=relative_card_path,
|
||||
http_kwargs=http_kwargs,
|
||||
)
|
||||
|
||||
# Try both well-known paths
|
||||
paths = [
|
||||
AGENT_CARD_WELL_KNOWN_PATH,
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH,
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for path in paths:
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"Attempting to fetch agent card from {self.base_url}{path}"
|
||||
)
|
||||
return await super().get_agent_card(
|
||||
relative_card_path=path,
|
||||
http_kwargs=http_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to fetch agent card from {self.base_url}{path}: {e}"
|
||||
)
|
||||
last_error = e
|
||||
continue
|
||||
|
||||
# If we get here, all paths failed - re-raise the last error
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
|
||||
# This shouldn't happen, but just in case
|
||||
raise Exception(
|
||||
f"Failed to fetch agent card from {self.base_url}. "
|
||||
f"Tried paths: {', '.join(paths)}"
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
LiteLLM A2A Client class.
|
||||
|
||||
Provides a class-based interface for A2A agent invocation.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional
|
||||
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
SendMessageRequest,
|
||||
SendStreamingMessageRequest,
|
||||
SendStreamingMessageResponse,
|
||||
)
|
||||
|
||||
|
||||
class A2AClient:
|
||||
"""
|
||||
LiteLLM wrapper for A2A agent invocation.
|
||||
|
||||
Creates the underlying A2A client once on first use and reuses it.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.a2a_protocol import A2AClient
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
client = A2AClient(base_url="http://localhost:10001")
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
timeout: float = 60.0,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the A2A client wrapper.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.extra_headers = extra_headers
|
||||
self._a2a_client: Optional["A2AClientType"] = None
|
||||
|
||||
async def _get_client(self) -> "A2AClientType":
|
||||
"""Get or create the underlying A2A client."""
|
||||
if self._a2a_client is None:
|
||||
from litellm.a2a_protocol.main import create_a2a_client
|
||||
|
||||
self._a2a_client = await create_a2a_client(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
return self._a2a_client
|
||||
|
||||
async def get_agent_card(self) -> "AgentCard":
|
||||
"""Fetch the agent card from the server."""
|
||||
from litellm.a2a_protocol.main import aget_agent_card
|
||||
|
||||
return await aget_agent_card(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, request: "SendMessageRequest"
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""Send a message to the A2A agent."""
|
||||
from litellm.a2a_protocol.main import asend_message
|
||||
|
||||
a2a_client = await self._get_client()
|
||||
return await asend_message(a2a_client=a2a_client, request=request)
|
||||
|
||||
async def send_message_streaming(
|
||||
self, request: "SendStreamingMessageRequest"
|
||||
) -> AsyncIterator["SendStreamingMessageResponse"]:
|
||||
"""Send a streaming message to the A2A agent."""
|
||||
from litellm.a2a_protocol.main import asend_message_streaming
|
||||
|
||||
a2a_client = await self._get_client()
|
||||
async for chunk in asend_message_streaming(
|
||||
a2a_client=a2a_client, request=request
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Cost calculator for A2A (Agent-to-Agent) calls.
|
||||
|
||||
Supports dynamic cost parameters that allow platform owners
|
||||
to define custom costs per agent query or per token.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LitellmLoggingObject,
|
||||
)
|
||||
else:
|
||||
LitellmLoggingObject = Any
|
||||
|
||||
|
||||
class A2ACostCalculator:
|
||||
@staticmethod
|
||||
def calculate_a2a_cost(
|
||||
litellm_logging_obj: Optional[LitellmLoggingObject],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of an A2A send_message call.
|
||||
|
||||
Supports multiple cost parameters for platform owners:
|
||||
- cost_per_query: Fixed cost per query
|
||||
- input_cost_per_token + output_cost_per_token: Token-based pricing
|
||||
|
||||
Priority order:
|
||||
1. response_cost - if set directly (backward compatibility)
|
||||
2. cost_per_query - fixed cost per query
|
||||
3. input_cost_per_token + output_cost_per_token - token-based cost
|
||||
4. Default to 0.0
|
||||
|
||||
Args:
|
||||
litellm_logging_obj: The LiteLLM logging object containing call details
|
||||
|
||||
Returns:
|
||||
float: The cost of the A2A call
|
||||
"""
|
||||
if litellm_logging_obj is None:
|
||||
return 0.0
|
||||
|
||||
model_call_details = litellm_logging_obj.model_call_details
|
||||
|
||||
# Check if user set a custom response cost (backward compatibility)
|
||||
response_cost = model_call_details.get("response_cost", None)
|
||||
if response_cost is not None:
|
||||
return float(response_cost)
|
||||
|
||||
# Get litellm_params for cost parameters
|
||||
litellm_params = model_call_details.get("litellm_params", {}) or {}
|
||||
|
||||
# Check for cost_per_query (fixed cost per query)
|
||||
if litellm_params.get("cost_per_query") is not None:
|
||||
return float(litellm_params["cost_per_query"])
|
||||
|
||||
# Check for token-based pricing
|
||||
input_cost_per_token = litellm_params.get("input_cost_per_token")
|
||||
output_cost_per_token = litellm_params.get("output_cost_per_token")
|
||||
|
||||
if input_cost_per_token is not None or output_cost_per_token is not None:
|
||||
return A2ACostCalculator._calculate_token_based_cost(
|
||||
model_call_details=model_call_details,
|
||||
input_cost_per_token=input_cost_per_token,
|
||||
output_cost_per_token=output_cost_per_token,
|
||||
)
|
||||
|
||||
# Default to 0.0 for A2A calls
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _calculate_token_based_cost(
|
||||
model_call_details: dict,
|
||||
input_cost_per_token: Optional[float],
|
||||
output_cost_per_token: Optional[float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost based on token usage and per-token pricing.
|
||||
|
||||
Args:
|
||||
model_call_details: The model call details containing usage
|
||||
input_cost_per_token: Cost per input token (can be None, defaults to 0)
|
||||
output_cost_per_token: Cost per output token (can be None, defaults to 0)
|
||||
|
||||
Returns:
|
||||
float: The calculated cost
|
||||
"""
|
||||
# Get usage from model_call_details
|
||||
usage = model_call_details.get("usage")
|
||||
if usage is None:
|
||||
return 0.0
|
||||
|
||||
# Get token counts
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
# Calculate costs
|
||||
input_cost = prompt_tokens * (
|
||||
float(input_cost_per_token) if input_cost_per_token else 0.0
|
||||
)
|
||||
output_cost = completion_tokens * (
|
||||
float(output_cost_per_token) if output_cost_per_token else 0.0
|
||||
)
|
||||
|
||||
return input_cost + output_cost
|
||||
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
A2A Protocol Exception Mapping Utils.
|
||||
|
||||
Maps A2A SDK exceptions to LiteLLM A2A exception types.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.card_resolver import (
|
||||
fix_agent_card_url,
|
||||
is_localhost_or_internal_url,
|
||||
)
|
||||
from litellm.a2a_protocol.exceptions import (
|
||||
A2AAgentCardError,
|
||||
A2AConnectionError,
|
||||
A2AError,
|
||||
A2ALocalhostURLError,
|
||||
)
|
||||
from litellm.constants import CONNECTION_ERROR_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
|
||||
|
||||
# Runtime import
|
||||
A2A_SDK_AVAILABLE = False
|
||||
try:
|
||||
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
|
||||
|
||||
A2A_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
_A2AClient = None # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
class A2AExceptionCheckers:
|
||||
"""
|
||||
Helper class for checking various A2A error conditions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_connection_error(error_str: str) -> bool:
|
||||
"""
|
||||
Check if an error string indicates a connection error.
|
||||
|
||||
Args:
|
||||
error_str: The error string to check
|
||||
|
||||
Returns:
|
||||
True if the error indicates a connection issue
|
||||
"""
|
||||
if not isinstance(error_str, str):
|
||||
return False
|
||||
|
||||
error_str_lower = error_str.lower()
|
||||
return any(pattern in error_str_lower for pattern in CONNECTION_ERROR_PATTERNS)
|
||||
|
||||
@staticmethod
|
||||
def is_localhost_url(url: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if a URL is a localhost/internal URL.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the URL is localhost/internal
|
||||
"""
|
||||
return is_localhost_or_internal_url(url)
|
||||
|
||||
@staticmethod
|
||||
def is_agent_card_error(error_str: str) -> bool:
|
||||
"""
|
||||
Check if an error string indicates an agent card error.
|
||||
|
||||
Args:
|
||||
error_str: The error string to check
|
||||
|
||||
Returns:
|
||||
True if the error is related to agent card fetching/parsing
|
||||
"""
|
||||
if not isinstance(error_str, str):
|
||||
return False
|
||||
|
||||
error_str_lower = error_str.lower()
|
||||
agent_card_patterns = [
|
||||
"agent card",
|
||||
"agent-card",
|
||||
".well-known",
|
||||
"card not found",
|
||||
"invalid agent",
|
||||
]
|
||||
return any(pattern in error_str_lower for pattern in agent_card_patterns)
|
||||
|
||||
|
||||
def map_a2a_exception(
|
||||
original_exception: Exception,
|
||||
card_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> Exception:
|
||||
"""
|
||||
Map an A2A SDK exception to a LiteLLM A2A exception type.
|
||||
|
||||
Args:
|
||||
original_exception: The original exception from the A2A SDK
|
||||
card_url: The URL from the agent card (if available)
|
||||
api_base: The original API base URL
|
||||
model: The model/agent name
|
||||
|
||||
Returns:
|
||||
A mapped LiteLLM A2A exception
|
||||
|
||||
Raises:
|
||||
A2ALocalhostURLError: If the error is a connection error to a localhost URL
|
||||
A2AConnectionError: If the error is a general connection error
|
||||
A2AAgentCardError: If the error is related to agent card issues
|
||||
A2AError: For other A2A-related errors
|
||||
"""
|
||||
error_str = str(original_exception)
|
||||
|
||||
# Check for localhost URL connection error (special case - retryable)
|
||||
if (
|
||||
card_url
|
||||
and api_base
|
||||
and A2AExceptionCheckers.is_localhost_url(card_url)
|
||||
and A2AExceptionCheckers.is_connection_error(error_str)
|
||||
):
|
||||
raise A2ALocalhostURLError(
|
||||
localhost_url=card_url,
|
||||
base_url=api_base,
|
||||
original_error=original_exception,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Check for agent card errors
|
||||
if A2AExceptionCheckers.is_agent_card_error(error_str):
|
||||
raise A2AAgentCardError(
|
||||
message=error_str,
|
||||
url=api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Check for general connection errors
|
||||
if A2AExceptionCheckers.is_connection_error(error_str):
|
||||
raise A2AConnectionError(
|
||||
message=error_str,
|
||||
url=card_url or api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Default: wrap in generic A2AError
|
||||
raise A2AError(
|
||||
message=error_str,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def handle_a2a_localhost_retry(
|
||||
error: A2ALocalhostURLError,
|
||||
agent_card: Any,
|
||||
a2a_client: "A2AClientType",
|
||||
is_streaming: bool = False,
|
||||
) -> "A2AClientType":
|
||||
"""
|
||||
Handle A2ALocalhostURLError by fixing the URL and creating a new client.
|
||||
|
||||
This is called when we catch an A2ALocalhostURLError and want to retry
|
||||
with the corrected URL.
|
||||
|
||||
Args:
|
||||
error: The localhost URL error
|
||||
agent_card: The agent card object to fix
|
||||
a2a_client: The current A2A client
|
||||
is_streaming: Whether this is a streaming request (for logging)
|
||||
|
||||
Returns:
|
||||
A new A2A client with the fixed URL
|
||||
|
||||
Raises:
|
||||
ImportError: If the A2A SDK is not installed
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE or _A2AClient is None:
|
||||
raise ImportError(
|
||||
"A2A SDK is required for localhost retry handling. "
|
||||
"Install it with: pip install a2a"
|
||||
)
|
||||
|
||||
request_type = "streaming " if is_streaming else ""
|
||||
verbose_logger.warning(
|
||||
f"A2A {request_type}request to '{error.localhost_url}' failed: {error.original_error}. "
|
||||
f"Agent card contains localhost/internal URL. "
|
||||
f"Retrying with base_url '{error.base_url}'."
|
||||
)
|
||||
|
||||
# Fix the agent card URL
|
||||
fix_agent_card_url(agent_card, error.base_url)
|
||||
|
||||
# Create a new client with the fixed agent card (transport caches URL)
|
||||
return _A2AClient(
|
||||
httpx_client=a2a_client._transport.httpx_client, # type: ignore[union-attr]
|
||||
agent_card=agent_card,
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
A2A Protocol Exceptions.
|
||||
|
||||
Custom exception types for A2A protocol operations, following LiteLLM's exception pattern.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class A2AError(Exception):
|
||||
"""
|
||||
Base exception for A2A protocol errors.
|
||||
|
||||
Follows the same pattern as LiteLLM's main exceptions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
status_code: int = 500,
|
||||
llm_provider: str = "a2a_agent",
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = f"litellm.A2AError: {message}"
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(method="POST", url="https://litellm.ai"),
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class A2AConnectionError(A2AError):
|
||||
"""
|
||||
Raised when connection to an A2A agent fails.
|
||||
|
||||
This typically occurs when:
|
||||
- The agent is unreachable
|
||||
- The agent card contains a localhost/internal URL
|
||||
- Network issues prevent connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.url = url
|
||||
super().__init__(
|
||||
message=message,
|
||||
status_code=503,
|
||||
llm_provider="a2a_agent",
|
||||
model=model,
|
||||
response=response,
|
||||
litellm_debug_info=litellm_debug_info,
|
||||
max_retries=max_retries,
|
||||
num_retries=num_retries,
|
||||
)
|
||||
|
||||
|
||||
class A2AAgentCardError(A2AError):
|
||||
"""
|
||||
Raised when there's an issue with the agent card.
|
||||
|
||||
This includes:
|
||||
- Failed to fetch agent card
|
||||
- Invalid agent card format
|
||||
- Missing required fields
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.url = url
|
||||
super().__init__(
|
||||
message=message,
|
||||
status_code=404,
|
||||
llm_provider="a2a_agent",
|
||||
model=model,
|
||||
response=response,
|
||||
litellm_debug_info=litellm_debug_info,
|
||||
)
|
||||
|
||||
|
||||
class A2ALocalhostURLError(A2AConnectionError):
|
||||
"""
|
||||
Raised when an agent card contains a localhost/internal URL.
|
||||
|
||||
Many A2A agents are deployed with agent cards that contain internal URLs
|
||||
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This error
|
||||
indicates that the URL needs to be corrected and the request should be retried.
|
||||
|
||||
Attributes:
|
||||
localhost_url: The localhost/internal URL found in the agent card
|
||||
base_url: The public base URL that should be used instead
|
||||
original_error: The original connection error that was raised
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
localhost_url: str,
|
||||
base_url: str,
|
||||
original_error: Optional[Exception] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.localhost_url = localhost_url
|
||||
self.base_url = base_url
|
||||
self.original_error = original_error
|
||||
|
||||
message = (
|
||||
f"Agent card contains localhost/internal URL '{localhost_url}'. "
|
||||
f"Retrying with base URL '{base_url}'."
|
||||
)
|
||||
super().__init__(
|
||||
message=message,
|
||||
url=localhost_url,
|
||||
model=model,
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
# A2A to LiteLLM Completion Bridge
|
||||
|
||||
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
|
||||
|
||||
## Flow
|
||||
|
||||
```
|
||||
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
|
||||
```
|
||||
|
||||
## SDK Usage
|
||||
|
||||
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, asend_message_streaming
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
# Non-streaming
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
|
||||
# Streaming
|
||||
stream_request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=stream_request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
Configure an agent with `custom_llm_provider` in `litellm_params`:
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: my-langgraph-agent
|
||||
agent_card_params:
|
||||
name: "LangGraph Agent"
|
||||
url: "http://localhost:2024" # Used as api_base
|
||||
litellm_params:
|
||||
custom_llm_provider: langgraph
|
||||
model: agent
|
||||
```
|
||||
|
||||
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
|
||||
|
||||
1. Detects `custom_llm_provider` in agent's `litellm_params`
|
||||
2. Transforms A2A message → OpenAI messages
|
||||
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
|
||||
4. Transforms response → A2A format
|
||||
|
||||
## Classes
|
||||
|
||||
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
|
||||
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
A2A to LiteLLM Completion Bridge.
|
||||
|
||||
This module provides transformation between A2A protocol messages and
|
||||
LiteLLM completion API, enabling any LiteLLM-supported provider to be
|
||||
invoked via the A2A protocol.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
handle_a2a_completion,
|
||||
handle_a2a_completion_streaming,
|
||||
)
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"A2ACompletionBridgeTransformation",
|
||||
"A2ACompletionBridgeHandler",
|
||||
"handle_a2a_completion",
|
||||
"handle_a2a_completion_streaming",
|
||||
]
|
||||
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Handler for A2A to LiteLLM completion bridge.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
|
||||
A2A Streaming Events (in order):
|
||||
1. Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status change to "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
4. Status update (kind: "status-update") - Final status "completed" with final=true
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
A2AStreamingContext,
|
||||
)
|
||||
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
|
||||
|
||||
|
||||
class A2ACompletionBridgeHandler:
|
||||
"""
|
||||
Static methods for handling A2A requests via LiteLLM completion.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Get provider config for custom_llm_provider
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
# If provider config exists, use it
|
||||
if a2a_provider_config is not None:
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for {custom_llm_provider}")
|
||||
|
||||
verbose_logger.info(f"A2A: Using provider config for {custom_llm_provider}")
|
||||
|
||||
response_data = await a2a_provider_config.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": False,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# Call litellm.acompletion
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# Transform response to A2A format
|
||||
a2a_response = (
|
||||
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
|
||||
response=response,
|
||||
request_id=request_id,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request via litellm.acompletion with stream=True.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content delivery
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Get provider config for custom_llm_provider
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
# If provider config exists, use it
|
||||
if a2a_provider_config is not None:
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for {custom_llm_provider}")
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A: Using provider config for {custom_llm_provider} (streaming)"
|
||||
)
|
||||
|
||||
async for chunk in a2a_provider_config.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Create streaming context
|
||||
ctx = A2AStreamingContext(
|
||||
request_id=request_id,
|
||||
input_message=message,
|
||||
)
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": True,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="working",
|
||||
final=False,
|
||||
message_text="Processing request...",
|
||||
)
|
||||
yield working_event
|
||||
|
||||
# Call litellm.acompletion with streaming
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# 3. Accumulate content and emit artifact update
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response: # type: ignore[union-attr]
|
||||
chunk_count += 1
|
||||
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if content:
|
||||
accumulated_text += content
|
||||
|
||||
# Emit artifact update with accumulated content
|
||||
if accumulated_text:
|
||||
artifact_event = (
|
||||
A2ACompletionBridgeTransformation.create_artifact_update_event(
|
||||
ctx=ctx,
|
||||
text=accumulated_text,
|
||||
)
|
||||
)
|
||||
yield artifact_event
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="completed",
|
||||
final=True,
|
||||
)
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions that delegate to the class methods
|
||||
async def handle_a2a_completion(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for non-streaming A2A completion."""
|
||||
return await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
||||
async def handle_a2a_completion_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Convenience function for streaming A2A completion."""
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Transformation utilities for A2A <-> OpenAI message format conversion.
|
||||
|
||||
A2A Message Format:
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": "abc123"
|
||||
}
|
||||
|
||||
OpenAI Message Format:
|
||||
{"role": "user", "content": "Hello!"}
|
||||
|
||||
A2A Streaming Events:
|
||||
- Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
- Status update (kind: "status-update") - Status changes (working, completed)
|
||||
- Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class A2AStreamingContext:
|
||||
"""
|
||||
Context holder for A2A streaming state.
|
||||
Tracks task_id, context_id, and message accumulation.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, input_message: Dict[str, Any]):
|
||||
self.request_id = request_id
|
||||
self.task_id = str(uuid4())
|
||||
self.context_id = str(uuid4())
|
||||
self.input_message = input_message
|
||||
self.accumulated_text = ""
|
||||
self.has_emitted_task = False
|
||||
self.has_emitted_working = False
|
||||
|
||||
|
||||
class A2ACompletionBridgeTransformation:
|
||||
"""
|
||||
Static methods for transforming between A2A and OpenAI message formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def a2a_message_to_openai_messages(
|
||||
a2a_message: Dict[str, Any],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Transform an A2A message to OpenAI message format.
|
||||
|
||||
Args:
|
||||
a2a_message: A2A message with role, parts, and messageId
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format messages
|
||||
"""
|
||||
role = a2a_message.get("role", "user")
|
||||
parts = a2a_message.get("parts", [])
|
||||
|
||||
# Map A2A roles to OpenAI roles
|
||||
openai_role = role
|
||||
if role == "user":
|
||||
openai_role = "user"
|
||||
elif role == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role == "system":
|
||||
openai_role = "system"
|
||||
|
||||
# Extract text content from parts
|
||||
content_parts = []
|
||||
for part in parts:
|
||||
kind = part.get("kind", "")
|
||||
if kind == "text":
|
||||
text = part.get("text", "")
|
||||
content_parts.append(text)
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
verbose_logger.debug(
|
||||
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
|
||||
)
|
||||
|
||||
return [{"role": openai_role, "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_a2a_response(
|
||||
response: Any,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
request_id: Original A2A request ID
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
content = choice.message.content or ""
|
||||
|
||||
# Build A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
|
||||
# Build A2A response
|
||||
a2a_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""Get current timestamp in ISO format with timezone."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def create_task_event(
|
||||
ctx: A2AStreamingContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the initial task event with status 'submitted'.
|
||||
|
||||
This is the first event emitted in an A2A streaming response.
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": ctx.input_message.get("messageId", uuid4().hex),
|
||||
"parts": ctx.input_message.get("parts", []),
|
||||
"role": ctx.input_message.get("role", "user"),
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
],
|
||||
"id": ctx.task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_status_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
state: str,
|
||||
final: bool = False,
|
||||
message_text: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a status update event.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
state: Status state ('working', 'completed')
|
||||
final: Whether this is the final event
|
||||
message_text: Optional message text for 'working' status
|
||||
"""
|
||||
status: Dict[str, Any] = {
|
||||
"state": state,
|
||||
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
|
||||
}
|
||||
|
||||
# Add message for 'working' status
|
||||
if state == "working" and message_text:
|
||||
status["message"] = {
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": str(uuid4()),
|
||||
"parts": [{"kind": "text", "text": message_text}],
|
||||
"role": "agent",
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"final": final,
|
||||
"kind": "status-update",
|
||||
"status": status,
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_artifact_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an artifact update event with content.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
text: The text content for the artifact
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"artifact": {
|
||||
"artifactId": str(uuid4()),
|
||||
"name": "response",
|
||||
"parts": [{"kind": "text", "text": text}],
|
||||
},
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def openai_chunk_to_a2a_chunk(
|
||||
chunk: Any,
|
||||
request_id: Optional[str] = None,
|
||||
is_final: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform a LiteLLM streaming chunk to A2A streaming format.
|
||||
|
||||
NOTE: This method is deprecated for streaming. Use the event-based
|
||||
methods (create_task_event, create_status_update_event,
|
||||
create_artifact_update_event) instead for proper A2A streaming.
|
||||
|
||||
Args:
|
||||
chunk: LiteLLM ModelResponse chunk
|
||||
request_id: Original A2A request ID
|
||||
is_final: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
A2A streaming chunk dict or None if no content
|
||||
"""
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if not content and not is_final:
|
||||
return None
|
||||
|
||||
# Build A2A streaming chunk (legacy format)
|
||||
a2a_chunk = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
},
|
||||
"final": is_final,
|
||||
},
|
||||
}
|
||||
|
||||
return a2a_chunk
|
||||
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
LiteLLM A2A SDK functions.
|
||||
|
||||
Provides standalone functions with @client decorator for LiteLLM logging integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
|
||||
from litellm.a2a_protocol.utils import A2ARequestUtils
|
||||
from litellm.constants import DEFAULT_A2A_AGENT_TIMEOUT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
from litellm.utils import client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
# Runtime imports with availability check
|
||||
A2A_SDK_AVAILABLE = False
|
||||
A2ACardResolver: Any = None
|
||||
_A2AClient: Any = None
|
||||
|
||||
try:
|
||||
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
|
||||
|
||||
A2A_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Import our custom card resolver that supports multiple well-known paths
|
||||
from litellm.a2a_protocol.card_resolver import LiteLLMA2ACardResolver
|
||||
from litellm.a2a_protocol.exception_mapping_utils import (
|
||||
handle_a2a_localhost_retry,
|
||||
map_a2a_exception,
|
||||
)
|
||||
from litellm.a2a_protocol.exceptions import A2ALocalhostURLError
|
||||
|
||||
# Use our custom resolver instead of the default A2A SDK resolver
|
||||
A2ACardResolver = LiteLLMA2ACardResolver
|
||||
|
||||
|
||||
def _set_usage_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Set usage on litellm_logging_obj for standard logging payload.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
prompt_tokens: Number of input tokens
|
||||
completion_tokens: Number of output tokens
|
||||
"""
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
litellm_logging_obj.model_call_details["usage"] = usage
|
||||
|
||||
|
||||
def _set_agent_id_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
agent_id: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Set agent_id on litellm_logging_obj for SpendLogs tracking.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
agent_id: The A2A agent ID
|
||||
"""
|
||||
if agent_id is None:
|
||||
return
|
||||
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
|
||||
litellm_logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
|
||||
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract agent info and set model/custom_llm_provider for cost tracking.
|
||||
|
||||
Sets model info on the litellm_logging_obj if available.
|
||||
Returns the agent name for logging.
|
||||
"""
|
||||
agent_name = "unknown"
|
||||
|
||||
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
|
||||
if agent_card is None:
|
||||
agent_card = getattr(a2a_client, "agent_card", None)
|
||||
|
||||
if agent_card is not None:
|
||||
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
|
||||
|
||||
# Build model string
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
custom_llm_provider = "a2a_agent"
|
||||
|
||||
# Set on litellm_logging_obj if available (for standard logging payload)
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
litellm_logging_obj.model = model
|
||||
litellm_logging_obj.custom_llm_provider = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["model"] = model
|
||||
litellm_logging_obj.model_call_details[
|
||||
"custom_llm_provider"
|
||||
] = custom_llm_provider
|
||||
|
||||
return agent_name
|
||||
|
||||
|
||||
async def _send_message_via_completion_bridge(
|
||||
request: "SendMessageRequest",
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
|
||||
|
||||
Requires request; api_base is optional for providers that derive endpoint from model.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return LiteLLMSendMessageResponse.from_dict(response_dict)
|
||||
|
||||
|
||||
async def _execute_a2a_send_with_retry(
|
||||
a2a_client: Any,
|
||||
request: Any,
|
||||
agent_card: Any,
|
||||
card_url: Optional[str],
|
||||
api_base: Optional[str],
|
||||
agent_name: Optional[str],
|
||||
) -> Any:
|
||||
"""Send an A2A message with retry logic for localhost URL errors."""
|
||||
a2a_response = None
|
||||
for _ in range(2): # max 2 attempts: original + 1 retry
|
||||
try:
|
||||
a2a_response = await a2a_client.send_message(request)
|
||||
break # success, exit retry loop
|
||||
except A2ALocalhostURLError as e:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
except Exception as e:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
if a2a_response is None:
|
||||
raise RuntimeError(
|
||||
"A2A send_message failed: no response received after retry attempts."
|
||||
)
|
||||
return a2a_response
|
||||
|
||||
|
||||
@client
|
||||
async def asend_message(
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Async: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
|
||||
api_base: API base URL (required for completion bridge, optional for standard A2A)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
|
||||
Example (standard A2A):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, create_a2a_client
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
a2a_client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(a2a_client=a2a_client, request=request)
|
||||
```
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
return await _send_message_via_completion_bridge(
|
||||
request=request,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
trace_id = trace_id or str(uuid.uuid4())
|
||||
extra_headers: Dict[str, str] = {"X-LiteLLM-Trace-Id": trace_id}
|
||||
if agent_id:
|
||||
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
# Overlay agent-level headers (agent headers take precedence over LiteLLM internal ones)
|
||||
if agent_extra_headers:
|
||||
extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
agent_name = _get_a2a_model_info(a2a_client, kwargs)
|
||||
|
||||
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
|
||||
|
||||
# Get agent card URL for localhost retry logic
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
|
||||
context_id = trace_id or str(uuid.uuid4())
|
||||
message = request.params.message
|
||||
if isinstance(message, dict):
|
||||
if message.get("context_id") is None:
|
||||
message["context_id"] = context_id
|
||||
else:
|
||||
if getattr(message, "context_id", None) is None:
|
||||
message.context_id = context_id
|
||||
|
||||
a2a_response = await _execute_a2a_send_with_retry(
|
||||
a2a_client=a2a_client,
|
||||
request=request,
|
||||
agent_card=agent_card,
|
||||
card_url=card_url,
|
||||
api_base=api_base,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A send_message completed, request_id={request.id}")
|
||||
|
||||
# Wrap in LiteLLM response type for _hidden_params support
|
||||
response = LiteLLMSendMessageResponse.from_a2a_response(a2a_response)
|
||||
|
||||
# Calculate token usage from request and response
|
||||
response_dict = a2a_response.model_dump(mode="json", exclude_none=True)
|
||||
(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
_,
|
||||
) = A2ARequestUtils.calculate_usage_from_request_response(
|
||||
request=request,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
# Set usage on logging obj for standard logging payload
|
||||
_set_usage_on_logging_obj(
|
||||
kwargs=kwargs,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
# Set agent_id on logging obj for SpendLogs tracking
|
||||
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@client
|
||||
def send_message(
|
||||
a2a_client: "A2AClientType",
|
||||
request: "SendMessageRequest",
|
||||
**kwargs: Any,
|
||||
) -> Union[LiteLLMSendMessageResponse, Coroutine[Any, Any, LiteLLMSendMessageResponse]]:
|
||||
"""
|
||||
Sync: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance
|
||||
request: SendMessageRequest from a2a.types
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None:
|
||||
return asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
else:
|
||||
return asyncio.run(
|
||||
asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
def _build_streaming_logging_obj(
|
||||
request: "SendStreamingMessageRequest",
|
||||
agent_name: str,
|
||||
agent_id: Optional[str],
|
||||
litellm_params: Optional[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
proxy_server_request: Optional[Dict[str, Any]],
|
||||
) -> Logging:
|
||||
"""Build logging object for streaming A2A requests."""
|
||||
start_time = datetime.datetime.now()
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "streaming-request"}],
|
||||
stream=False,
|
||||
call_type="asend_message_streaming",
|
||||
start_time=start_time,
|
||||
litellm_call_id=str(request.id),
|
||||
function_id=str(request.id),
|
||||
)
|
||||
logging_obj.model = model
|
||||
logging_obj.custom_llm_provider = "a2a_agent"
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
|
||||
if agent_id:
|
||||
logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
_litellm_params = litellm_params.copy() if litellm_params else {}
|
||||
if metadata:
|
||||
_litellm_params["metadata"] = metadata
|
||||
if proxy_server_request:
|
||||
_litellm_params["proxy_server_request"] = proxy_server_request
|
||||
|
||||
logging_obj.litellm_params = _litellm_params
|
||||
logging_obj.optional_params = _litellm_params
|
||||
logging_obj.model_call_details["litellm_params"] = _litellm_params
|
||||
logging_obj.model_call_details["metadata"] = metadata or {}
|
||||
|
||||
return logging_obj
|
||||
|
||||
|
||||
async def asend_message_streaming( # noqa: PLR0915
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendStreamingMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
proxy_server_request: Optional[Dict[str, Any]] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncIterator[Any]:
|
||||
"""
|
||||
Async: Send a streaming message to an A2A agent.
|
||||
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendStreamingMessageRequest from a2a.types
|
||||
api_base: API base URL (required for completion bridge)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
|
||||
proxy_server_request: Optional proxy server request data
|
||||
|
||||
Yields:
|
||||
SendStreamingMessageResponse chunks from the agent
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message_streaming
|
||||
from a2a.types import SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A streaming using completion bridge: provider={custom_llm_provider}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
# Extract params from request
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
# Mirror the non-streaming path: always include trace and agent-id headers
|
||||
streaming_extra_headers: Dict[str, str] = {
|
||||
"X-LiteLLM-Trace-Id": str(request.id),
|
||||
}
|
||||
if agent_id:
|
||||
streaming_extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
if agent_extra_headers:
|
||||
streaming_extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=streaming_extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")
|
||||
|
||||
# Build logging object for streaming completion callbacks
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
|
||||
|
||||
logging_obj = _build_streaming_logging_obj(
|
||||
request=request,
|
||||
agent_name=agent_name,
|
||||
agent_id=agent_id,
|
||||
litellm_params=litellm_params,
|
||||
metadata=metadata,
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
|
||||
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
|
||||
# Connection errors in streaming typically occur on first chunk iteration
|
||||
first_chunk = True
|
||||
for attempt in range(2): # max 2 attempts: original + 1 retry
|
||||
stream = a2a_client.send_message_streaming(request)
|
||||
iterator = A2AStreamingIterator(
|
||||
stream=stream,
|
||||
request=request,
|
||||
logging_obj=logging_obj,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
try:
|
||||
first_chunk = True
|
||||
async for chunk in iterator:
|
||||
if first_chunk:
|
||||
first_chunk = False # connection succeeded
|
||||
yield chunk
|
||||
return # stream completed successfully
|
||||
except A2ALocalhostURLError as e:
|
||||
# Only retry on first chunk, not mid-stream
|
||||
if first_chunk and attempt == 0:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Only map exception on first chunk
|
||||
if first_chunk and attempt == 0:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
# Localhost URL error - fix and retry
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
# Re-raise the mapped exception
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
async def create_a2a_client(
|
||||
base_url: str,
|
||||
timeout: float = 60.0,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "A2AClientType":
|
||||
"""
|
||||
Create an A2A client for the given agent URL.
|
||||
|
||||
This resolves the agent card and returns a ready-to-use A2A client.
|
||||
The client can be reused for multiple requests.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
An initialized a2a.client.A2AClient instance
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.a2a_protocol import create_a2a_client, asend_message
|
||||
|
||||
# Create client once
|
||||
client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
|
||||
# Reuse for multiple requests
|
||||
response1 = await asend_message(a2a_client=client, request=request1)
|
||||
response2 = await asend_message(a2a_client=client, request=request2)
|
||||
```
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Creating A2A client for {base_url}")
|
||||
|
||||
# Use get_async_httpx_client with per-agent params so that different agents
|
||||
# (with different extra_headers) get separate cached clients. The params
|
||||
# dict is hashed into the cache key, keeping agent auth isolated while
|
||||
# still reusing connections within the same agent.
|
||||
#
|
||||
# Only pass params that AsyncHTTPHandler.__init__ accepts (e.g. timeout).
|
||||
# Use "disable_aiohttp_transport" key for cache-key-only data (it's
|
||||
# filtered out before reaching the constructor).
|
||||
_client_params: dict = {"timeout": timeout}
|
||||
if extra_headers:
|
||||
# Encode headers into a cache-key-only param so each unique header
|
||||
# set produces a distinct cache key.
|
||||
_client_params["disable_aiohttp_transport"] = str(sorted(extra_headers.items()))
|
||||
_async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2AProvider,
|
||||
params=_client_params,
|
||||
)
|
||||
httpx_client = _async_handler.client
|
||||
if extra_headers:
|
||||
httpx_client.headers.update(extra_headers)
|
||||
verbose_proxy_logger.debug(
|
||||
f"A2A client created with extra_headers={list(extra_headers.keys())}"
|
||||
)
|
||||
|
||||
# Resolve agent card
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
|
||||
# Create A2A client
|
||||
a2a_client = _A2AClient(
|
||||
httpx_client=httpx_client,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
# Store agent_card on client for later retrieval (SDK doesn't expose it)
|
||||
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
|
||||
|
||||
verbose_logger.info(f"A2A client created for {base_url}")
|
||||
|
||||
return a2a_client
|
||||
|
||||
|
||||
async def aget_agent_card(
|
||||
base_url: str,
|
||||
timeout: float = DEFAULT_A2A_AGENT_TIMEOUT,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "AgentCard":
|
||||
"""
|
||||
Fetch the agent card from an A2A agent.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
AgentCard from the A2A agent
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Fetching agent card from {base_url}")
|
||||
|
||||
# Use LiteLLM's cached httpx client
|
||||
http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2A,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
httpx_client = http_handler.client
|
||||
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.info(
|
||||
f"Fetched agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
return agent_card
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
A2A Protocol Providers.
|
||||
|
||||
This module contains provider-specific implementations for the A2A protocol.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
|
||||
|
||||
__all__ = ["BaseA2AProviderConfig", "A2AProviderConfigManager"]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Base configuration for A2A protocol providers.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
|
||||
class BaseA2AProviderConfig(ABC):
|
||||
"""
|
||||
Base configuration class for A2A protocol providers.
|
||||
|
||||
Each provider should implement this interface to define how to handle
|
||||
A2A requests for their specific agent type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# This is an abstract method - subclasses must implement
|
||||
# The yield is here to make this a generator function
|
||||
if False: # pragma: no cover
|
||||
yield {}
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
A2A Provider Config Manager.
|
||||
|
||||
Manages provider-specific configurations for A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
|
||||
|
||||
class A2AProviderConfigManager:
|
||||
"""
|
||||
Manager for A2A provider configurations.
|
||||
|
||||
Similar to ProviderConfigManager in litellm.utils but specifically for A2A providers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_provider_config(
|
||||
custom_llm_provider: Optional[str],
|
||||
) -> Optional[BaseA2AProviderConfig]:
|
||||
"""
|
||||
Get the provider configuration for a given custom_llm_provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider identifier (e.g., "pydantic_ai_agents")
|
||||
|
||||
Returns:
|
||||
Provider configuration instance or None if not found
|
||||
"""
|
||||
if custom_llm_provider is None:
|
||||
return None
|
||||
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
|
||||
return PydanticAIProviderConfig()
|
||||
|
||||
# Add more providers here as needed
|
||||
# elif custom_llm_provider == "another_provider":
|
||||
# from litellm.a2a_protocol.providers.another_provider.config import AnotherProviderConfig
|
||||
# return AnotherProviderConfig()
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,74 @@
|
||||
# A2A to LiteLLM Completion Bridge
|
||||
|
||||
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
|
||||
|
||||
## Flow
|
||||
|
||||
```
|
||||
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
|
||||
```
|
||||
|
||||
## SDK Usage
|
||||
|
||||
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, asend_message_streaming
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
# Non-streaming
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
|
||||
# Streaming
|
||||
stream_request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=stream_request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
Configure an agent with `custom_llm_provider` in `litellm_params`:
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: my-langgraph-agent
|
||||
agent_card_params:
|
||||
name: "LangGraph Agent"
|
||||
url: "http://localhost:2024" # Used as api_base
|
||||
litellm_params:
|
||||
custom_llm_provider: langgraph
|
||||
model: agent
|
||||
```
|
||||
|
||||
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
|
||||
|
||||
1. Detects `custom_llm_provider` in agent's `litellm_params`
|
||||
2. Transforms A2A message → OpenAI messages
|
||||
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
|
||||
4. Transforms response → A2A format
|
||||
|
||||
## Classes
|
||||
|
||||
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
|
||||
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
LiteLLM Completion bridge provider for A2A protocol.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
"""
|
||||
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Handler for A2A to LiteLLM completion bridge.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
|
||||
A2A Streaming Events (in order):
|
||||
1. Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status change to "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
4. Status update (kind: "status-update") - Final status "completed" with final=true
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
A2AStreamingContext,
|
||||
)
|
||||
|
||||
|
||||
class A2ACompletionBridgeHandler:
|
||||
"""
|
||||
Static methods for handling A2A requests via LiteLLM completion.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Routing to Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": False,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# Call litellm.acompletion
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# Transform response to A2A format
|
||||
a2a_response = (
|
||||
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
|
||||
response=response,
|
||||
request_id=request_id,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request via litellm.acompletion with stream=True.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content delivery
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get non-streaming response first
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Convert to fake streaming
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=response_data,
|
||||
request_id=request_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Create streaming context
|
||||
ctx = A2AStreamingContext(
|
||||
request_id=request_id,
|
||||
input_message=message,
|
||||
)
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": True,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="working",
|
||||
final=False,
|
||||
message_text="Processing request...",
|
||||
)
|
||||
yield working_event
|
||||
|
||||
# Call litellm.acompletion with streaming
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# 3. Accumulate content and emit artifact update
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response: # type: ignore[union-attr]
|
||||
chunk_count += 1
|
||||
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if content:
|
||||
accumulated_text += content
|
||||
|
||||
# Emit artifact update with accumulated content
|
||||
if accumulated_text:
|
||||
artifact_event = (
|
||||
A2ACompletionBridgeTransformation.create_artifact_update_event(
|
||||
ctx=ctx,
|
||||
text=accumulated_text,
|
||||
)
|
||||
)
|
||||
yield artifact_event
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="completed",
|
||||
final=True,
|
||||
)
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions that delegate to the class methods
|
||||
async def handle_a2a_completion(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for non-streaming A2A completion."""
|
||||
return await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
||||
async def handle_a2a_completion_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Convenience function for streaming A2A completion."""
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Transformation utilities for A2A <-> OpenAI message format conversion.
|
||||
|
||||
A2A Message Format:
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": "abc123"
|
||||
}
|
||||
|
||||
OpenAI Message Format:
|
||||
{"role": "user", "content": "Hello!"}
|
||||
|
||||
A2A Streaming Events:
|
||||
- Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
- Status update (kind: "status-update") - Status changes (working, completed)
|
||||
- Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class A2AStreamingContext:
|
||||
"""
|
||||
Context holder for A2A streaming state.
|
||||
Tracks task_id, context_id, and message accumulation.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, input_message: Dict[str, Any]):
|
||||
self.request_id = request_id
|
||||
self.task_id = str(uuid4())
|
||||
self.context_id = str(uuid4())
|
||||
self.input_message = input_message
|
||||
self.accumulated_text = ""
|
||||
self.has_emitted_task = False
|
||||
self.has_emitted_working = False
|
||||
|
||||
|
||||
class A2ACompletionBridgeTransformation:
|
||||
"""
|
||||
Static methods for transforming between A2A and OpenAI message formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def a2a_message_to_openai_messages(
|
||||
a2a_message: Dict[str, Any],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Transform an A2A message to OpenAI message format.
|
||||
|
||||
Args:
|
||||
a2a_message: A2A message with role, parts, and messageId
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format messages
|
||||
"""
|
||||
role = a2a_message.get("role", "user")
|
||||
parts = a2a_message.get("parts", [])
|
||||
|
||||
# Map A2A roles to OpenAI roles
|
||||
openai_role = role
|
||||
if role == "user":
|
||||
openai_role = "user"
|
||||
elif role == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role == "system":
|
||||
openai_role = "system"
|
||||
|
||||
# Extract text content from parts
|
||||
content_parts = []
|
||||
for part in parts:
|
||||
kind = part.get("kind", "")
|
||||
if kind == "text":
|
||||
text = part.get("text", "")
|
||||
content_parts.append(text)
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
verbose_logger.debug(
|
||||
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
|
||||
)
|
||||
|
||||
return [{"role": openai_role, "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_a2a_response(
|
||||
response: Any,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
request_id: Original A2A request ID
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
content = choice.message.content or ""
|
||||
|
||||
# Build A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
|
||||
# Build A2A response
|
||||
a2a_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""Get current timestamp in ISO format with timezone."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def create_task_event(
|
||||
ctx: A2AStreamingContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the initial task event with status 'submitted'.
|
||||
|
||||
This is the first event emitted in an A2A streaming response.
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": ctx.input_message.get("messageId", uuid4().hex),
|
||||
"parts": ctx.input_message.get("parts", []),
|
||||
"role": ctx.input_message.get("role", "user"),
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
],
|
||||
"id": ctx.task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_status_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
state: str,
|
||||
final: bool = False,
|
||||
message_text: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a status update event.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
state: Status state ('working', 'completed')
|
||||
final: Whether this is the final event
|
||||
message_text: Optional message text for 'working' status
|
||||
"""
|
||||
status: Dict[str, Any] = {
|
||||
"state": state,
|
||||
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
|
||||
}
|
||||
|
||||
# Add message for 'working' status
|
||||
if state == "working" and message_text:
|
||||
status["message"] = {
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": str(uuid4()),
|
||||
"parts": [{"kind": "text", "text": message_text}],
|
||||
"role": "agent",
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"final": final,
|
||||
"kind": "status-update",
|
||||
"status": status,
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_artifact_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an artifact update event with content.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
text: The text content for the artifact
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"artifact": {
|
||||
"artifactId": str(uuid4()),
|
||||
"name": "response",
|
||||
"parts": [{"kind": "text", "text": text}],
|
||||
},
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def openai_chunk_to_a2a_chunk(
|
||||
chunk: Any,
|
||||
request_id: Optional[str] = None,
|
||||
is_final: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform a LiteLLM streaming chunk to A2A streaming format.
|
||||
|
||||
NOTE: This method is deprecated for streaming. Use the event-based
|
||||
methods (create_task_event, create_status_update_event,
|
||||
create_artifact_update_event) instead for proper A2A streaming.
|
||||
|
||||
Args:
|
||||
chunk: LiteLLM ModelResponse chunk
|
||||
request_id: Original A2A request ID
|
||||
is_final: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
A2A streaming chunk dict or None if no content
|
||||
"""
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if not content and not is_final:
|
||||
return None
|
||||
|
||||
# Build A2A streaming chunk (legacy format)
|
||||
a2a_chunk = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
},
|
||||
"final": is_final,
|
||||
},
|
||||
}
|
||||
|
||||
return a2a_chunk
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Pydantic AI agent provider for A2A protocol.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This provider handles fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
__all__ = ["PydanticAIHandler", "PydanticAITransformation", "PydanticAIProviderConfig"]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Pydantic AI provider configuration.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
|
||||
|
||||
class PydanticAIProviderConfig(BaseA2AProviderConfig):
|
||||
"""
|
||||
Provider configuration for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This config provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle non-streaming request to Pydantic AI agent."""
|
||||
return await PydanticAIHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
)
|
||||
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Handle streaming request with fake streaming."""
|
||||
async for chunk in PydanticAIHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
chunk_size=kwargs.get("chunk_size", 50),
|
||||
delay_ms=kwargs.get("delay_ms", 10),
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Handler for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This handler provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAIHandler:
|
||||
"""
|
||||
Handler for Pydantic AI agent requests.
|
||||
|
||||
Provides:
|
||||
- Direct non-streaming requests to Pydantic AI agents
|
||||
- Fake streaming by converting non-streaming responses into streaming chunks
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming request to Pydantic AI agent.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
verbose_logger.info(f"Pydantic AI: Routing to Pydantic AI agent at {api_base}")
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming request to Pydantic AI agent with fake streaming.
|
||||
|
||||
Since Pydantic AI agents don't support streaming natively, this method:
|
||||
1. Makes a non-streaming request
|
||||
2. Converts the response into streaming chunks
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
chunk_size: Number of characters per chunk
|
||||
delay_ms: Delay between chunks in milliseconds
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get raw task response first (not the transformed A2A format)
|
||||
raw_response = await PydanticAITransformation.send_and_get_raw_response(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Convert raw task response to fake streaming chunks
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
chunk_size=chunk_size,
|
||||
delay_ms=delay_ms,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming.
|
||||
This module provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAITransformation:
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Handles:
|
||||
- Direct A2A requests to Pydantic AI endpoints
|
||||
- Polling for task completion (since Pydantic AI doesn't support streaming)
|
||||
- Fake streaming by chunking non-streaming responses
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _remove_none_values(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively remove None values from a dict/list structure.
|
||||
|
||||
FastA2A/Pydantic AI servers don't accept None values for optional fields -
|
||||
they expect those fields to be omitted entirely.
|
||||
|
||||
Args:
|
||||
obj: Dict, list, or other value to clean
|
||||
|
||||
Returns:
|
||||
Cleaned object with None values removed
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: PydanticAITransformation._remove_none_values(v)
|
||||
for k, v in obj.items()
|
||||
if v is not None
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [
|
||||
PydanticAITransformation._remove_none_values(item)
|
||||
for item in obj
|
||||
if item is not None
|
||||
]
|
||||
else:
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _params_to_dict(params: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert params to a dict, handling Pydantic models.
|
||||
|
||||
Args:
|
||||
params: Dict or Pydantic model
|
||||
|
||||
Returns:
|
||||
Dict representation of params
|
||||
"""
|
||||
if hasattr(params, "model_dump"):
|
||||
# Pydantic v2 model
|
||||
return params.model_dump(mode="python", exclude_none=True)
|
||||
elif hasattr(params, "dict"):
|
||||
# Pydantic v1 model
|
||||
return params.dict(exclude_none=True)
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
# Try to convert to dict
|
||||
return dict(params)
|
||||
|
||||
@staticmethod
|
||||
async def _poll_for_completion(
|
||||
client: AsyncHTTPHandler,
|
||||
endpoint: str,
|
||||
task_id: str,
|
||||
request_id: str,
|
||||
max_attempts: int = 30,
|
||||
poll_interval: float = 0.5,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll for task completion using tasks/get method.
|
||||
|
||||
Args:
|
||||
client: HTTPX async client
|
||||
endpoint: API endpoint URL
|
||||
task_id: Task ID to poll for
|
||||
request_id: JSON-RPC request ID
|
||||
max_attempts: Maximum polling attempts
|
||||
poll_interval: Seconds between poll attempts
|
||||
|
||||
Returns:
|
||||
Completed task response
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
poll_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"{request_id}-poll-{attempt}",
|
||||
"method": "tasks/get",
|
||||
"params": {"id": task_id},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=poll_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
poll_data = response.json()
|
||||
|
||||
result = poll_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Pydantic AI: Poll attempt {attempt + 1}/{max_attempts}, state={state}"
|
||||
)
|
||||
|
||||
if state == "completed":
|
||||
return poll_data
|
||||
elif state in ("failed", "canceled"):
|
||||
raise Exception(f"Task {task_id} ended with state: {state}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {max_attempts * poll_interval} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_and_poll_raw(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
This is an internal method used by both non-streaming and streaming handlers.
|
||||
Returns the raw Pydantic AI task format with history/artifacts.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
# Convert params to dict if it's a Pydantic model
|
||||
params_dict = PydanticAITransformation._params_to_dict(params)
|
||||
|
||||
# Remove None values - FastA2A doesn't accept null for optional fields
|
||||
params_dict = PydanticAITransformation._remove_none_values(params_dict)
|
||||
|
||||
# Ensure the message has 'kind': 'message' as required by FastA2A/Pydantic AI
|
||||
if "message" in params_dict:
|
||||
params_dict["message"]["kind"] = "message"
|
||||
|
||||
# Build A2A JSON-RPC request using message/send method for FastA2A compatibility
|
||||
a2a_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": "message/send",
|
||||
"params": params_dict,
|
||||
}
|
||||
|
||||
# FastA2A uses root endpoint (/) not /messages
|
||||
endpoint = api_base.rstrip("/")
|
||||
|
||||
verbose_logger.info(f"Pydantic AI: Sending non-streaming request to {endpoint}")
|
||||
|
||||
# Send request to Pydantic AI agent using shared async HTTP client
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=cast(Any, "pydantic_ai_agent"),
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=a2a_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
# Check if task is already completed
|
||||
result = response_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
if state != "completed":
|
||||
# Need to poll for completion
|
||||
task_id = result.get("id")
|
||||
if task_id:
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Task {task_id} submitted, polling for completion..."
|
||||
)
|
||||
response_data = await PydanticAITransformation._poll_for_completion(
|
||||
client=client,
|
||||
endpoint=endpoint,
|
||||
task_id=task_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Received completed response for request_id={request_id}"
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def send_non_streaming_request(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a non-streaming A2A request to Pydantic AI agent and wait for completion.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent (e.g., "http://localhost:9999")
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message (dict or Pydantic model)
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format with message
|
||||
"""
|
||||
# Get raw task response
|
||||
raw_response = await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Transform to standard A2A non-streaming format
|
||||
return PydanticAITransformation._transform_to_a2a_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_and_get_raw_response(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
Used by streaming handler to get raw response for fake streaming.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
return await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _transform_to_a2a_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform Pydantic AI task response to standard A2A non-streaming format.
|
||||
|
||||
Pydantic AI returns a task with history/artifacts, but the standard A2A
|
||||
non-streaming format expects:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "...",
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": "..."}],
|
||||
"messageId": "..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
response_data: Pydantic AI task response
|
||||
request_id: Original request ID
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format
|
||||
"""
|
||||
# Extract the agent response text
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Build standard A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": parts if parts else [{"kind": "text", "text": full_text}],
|
||||
"messageId": message_id,
|
||||
}
|
||||
|
||||
# Return standard A2A non-streaming format
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response_data: Dict[str, Any]) -> tuple[str, str, list]:
|
||||
"""
|
||||
Extract response text from completed task response.
|
||||
|
||||
Pydantic AI returns completed tasks with:
|
||||
- history: list of messages (user and agent)
|
||||
- artifacts: list of result artifacts
|
||||
|
||||
Args:
|
||||
response_data: Completed task response
|
||||
|
||||
Returns:
|
||||
Tuple of (full_text, message_id, parts)
|
||||
"""
|
||||
result = response_data.get("result", {})
|
||||
|
||||
# Try to extract from artifacts first (preferred for results)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts:
|
||||
for artifact in artifacts:
|
||||
parts = artifact.get("parts", [])
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
return text, str(uuid4()), parts
|
||||
|
||||
# Fall back to history - get the last agent message
|
||||
history = result.get("history", [])
|
||||
for msg in reversed(history):
|
||||
if msg.get("role") == "agent":
|
||||
parts = msg.get("parts", [])
|
||||
message_id = msg.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
if full_text:
|
||||
return full_text, message_id, parts
|
||||
|
||||
# Fall back to message field (original format)
|
||||
message = result.get("message", {})
|
||||
if message:
|
||||
parts = message.get("parts", [])
|
||||
message_id = message.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
return full_text, message_id, parts
|
||||
|
||||
return "", str(uuid4()), []
|
||||
|
||||
@staticmethod
|
||||
async def fake_streaming_from_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a non-streaming A2A response into fake streaming chunks.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update chunks (kind: "artifact-update") - Content delivery in chunks
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
response_data: Non-streaming A2A response dict (completed task)
|
||||
request_id: A2A JSON-RPC request ID
|
||||
chunk_size: Number of characters per chunk (default: 50)
|
||||
delay_ms: Delay between chunks in milliseconds (default: 10)
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Extract the response text from completed task
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Extract input message from raw response for history
|
||||
result = response_data.get("result", {})
|
||||
history = result.get("history", [])
|
||||
input_message = {}
|
||||
for msg in history:
|
||||
if msg.get("role") == "user":
|
||||
input_message = msg
|
||||
break
|
||||
|
||||
# Generate IDs for streaming events
|
||||
task_id = str(uuid4())
|
||||
context_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
input_message_id = input_message.get("messageId", str(uuid4()))
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_task_event
|
||||
task_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": context_id,
|
||||
"kind": "message",
|
||||
"messageId": input_message_id,
|
||||
"parts": input_message.get(
|
||||
"parts", [{"kind": "text", "text": ""}]
|
||||
),
|
||||
"role": "user",
|
||||
"taskId": task_id,
|
||||
}
|
||||
],
|
||||
"id": task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_status_update_event
|
||||
working_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": False,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "working",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield working_event
|
||||
|
||||
# Small delay to simulate processing
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 3. Emit artifact update chunks (kind: "artifact-update")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_artifact_update_event
|
||||
if full_text:
|
||||
# Split text into chunks
|
||||
for i in range(0, len(full_text), chunk_size):
|
||||
chunk_text = full_text[i : i + chunk_size]
|
||||
is_last_chunk = (i + chunk_size) >= len(full_text)
|
||||
|
||||
artifact_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": task_id,
|
||||
"artifact": {
|
||||
"artifactId": artifact_id,
|
||||
"parts": [
|
||||
{
|
||||
"kind": "text",
|
||||
"text": chunk_text,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
yield artifact_event
|
||||
|
||||
# Add delay between chunks (except for last chunk)
|
||||
if not is_last_chunk:
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": True,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "completed",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Fake streaming completed for request_id={request_id}"
|
||||
)
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
A2A Streaming Iterator with token tracking and logging support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.cost_calculator import A2ACostCalculator
|
||||
from litellm.a2a_protocol.utils import A2ARequestUtils
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import SendStreamingMessageRequest, SendStreamingMessageResponse
|
||||
|
||||
|
||||
class A2AStreamingIterator:
|
||||
"""
|
||||
Async iterator for A2A streaming responses with token tracking.
|
||||
|
||||
Collects chunks, extracts text, and logs usage on completion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: AsyncIterator["SendStreamingMessageResponse"],
|
||||
request: "SendStreamingMessageRequest",
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
agent_name: str = "unknown",
|
||||
):
|
||||
self.stream = stream
|
||||
self.request = request
|
||||
self.logging_obj = logging_obj
|
||||
self.agent_name = agent_name
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# Collect chunks for token counting
|
||||
self.chunks: List[Any] = []
|
||||
self.collected_text_parts: List[str] = []
|
||||
self.final_chunk: Optional[Any] = None
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> "SendStreamingMessageResponse":
|
||||
try:
|
||||
chunk = await self.stream.__anext__()
|
||||
|
||||
# Store chunk
|
||||
self.chunks.append(chunk)
|
||||
|
||||
# Extract text from chunk for token counting
|
||||
self._collect_text_from_chunk(chunk)
|
||||
|
||||
# Check if this is the final chunk (completed status)
|
||||
if self._is_completed_chunk(chunk):
|
||||
self.final_chunk = chunk
|
||||
|
||||
return chunk
|
||||
|
||||
except StopAsyncIteration:
|
||||
# Stream ended - handle logging
|
||||
if self.final_chunk is None and self.chunks:
|
||||
self.final_chunk = self.chunks[-1]
|
||||
await self._handle_stream_complete()
|
||||
raise
|
||||
|
||||
def _collect_text_from_chunk(self, chunk: Any) -> None:
|
||||
"""Extract text from a streaming chunk and add to collected parts."""
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
if hasattr(chunk, "model_dump")
|
||||
else {}
|
||||
)
|
||||
text = A2ARequestUtils.extract_text_from_response(chunk_dict)
|
||||
if text:
|
||||
self.collected_text_parts.append(text)
|
||||
except Exception:
|
||||
verbose_logger.debug("Failed to extract text from A2A streaming chunk")
|
||||
|
||||
def _is_completed_chunk(self, chunk: Any) -> bool:
|
||||
"""Check if chunk indicates stream completion."""
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
if hasattr(chunk, "model_dump")
|
||||
else {}
|
||||
)
|
||||
result = chunk_dict.get("result", {})
|
||||
if isinstance(result, dict):
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
return status.get("state") == "completed"
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def _handle_stream_complete(self) -> None:
|
||||
"""Handle logging and token counting when stream completes."""
|
||||
try:
|
||||
end_time = datetime.now()
|
||||
|
||||
# Calculate tokens from collected text
|
||||
input_message = A2ARequestUtils.get_input_message_from_request(self.request)
|
||||
input_text = A2ARequestUtils.extract_text_from_message(input_message)
|
||||
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
|
||||
|
||||
# Use the last (most complete) text from chunks
|
||||
output_text = (
|
||||
self.collected_text_parts[-1] if self.collected_text_parts else ""
|
||||
)
|
||||
completion_tokens = A2ARequestUtils.count_tokens(output_text)
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
# Create usage object
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
# Set usage on logging obj
|
||||
self.logging_obj.model_call_details["usage"] = usage
|
||||
# Mark stream flag for downstream callbacks
|
||||
self.logging_obj.model_call_details["stream"] = False
|
||||
|
||||
# Calculate cost using A2ACostCalculator
|
||||
response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj)
|
||||
self.logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
# Build result for logging
|
||||
result = self._build_logging_result(usage)
|
||||
|
||||
# Call success handlers - they will build standard_logging_object
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
result=result,
|
||||
start_time=self.start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
result=result,
|
||||
cache_hit=None,
|
||||
start_time=self.start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A streaming completed: prompt_tokens={prompt_tokens}, "
|
||||
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, "
|
||||
f"response_cost={response_cost}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error in A2A streaming completion handler: {e}")
|
||||
|
||||
def _build_logging_result(self, usage: litellm.Usage) -> Dict[str, Any]:
|
||||
"""Build a result dict for logging."""
|
||||
result: Dict[str, Any] = {
|
||||
"id": getattr(self.request, "id", "unknown"),
|
||||
"jsonrpc": "2.0",
|
||||
"usage": usage.model_dump()
|
||||
if hasattr(usage, "model_dump")
|
||||
else dict(usage),
|
||||
}
|
||||
|
||||
# Add final chunk result if available
|
||||
if self.final_chunk:
|
||||
try:
|
||||
chunk_dict = self.final_chunk.model_dump(mode="json", exclude_none=True)
|
||||
result["result"] = chunk_dict.get("result", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Utility functions for A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
|
||||
class A2ARequestUtils:
|
||||
"""Utility class for A2A request/response processing."""
|
||||
|
||||
@staticmethod
|
||||
def extract_text_from_message(message: Any) -> str:
|
||||
"""
|
||||
Extract text content from A2A message parts.
|
||||
|
||||
Args:
|
||||
message: A2A message dict or object with 'parts' containing text parts
|
||||
|
||||
Returns:
|
||||
Concatenated text from all text parts
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
# Handle both dict and object access
|
||||
if isinstance(message, dict):
|
||||
parts = message.get("parts", [])
|
||||
else:
|
||||
parts = getattr(message, "parts", []) or []
|
||||
|
||||
text_parts: List[str] = []
|
||||
for part in parts:
|
||||
if isinstance(part, dict):
|
||||
if part.get("kind") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
else:
|
||||
if getattr(part, "kind", None) == "text":
|
||||
text_parts.append(getattr(part, "text", ""))
|
||||
|
||||
return " ".join(text_parts)
|
||||
|
||||
@staticmethod
|
||||
def extract_text_from_response(response_dict: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract text content from A2A response result.
|
||||
|
||||
Args:
|
||||
response_dict: A2A response dict with 'result' containing message
|
||||
|
||||
Returns:
|
||||
Text from response message parts
|
||||
"""
|
||||
result = response_dict.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
return ""
|
||||
|
||||
message = result.get("message", {})
|
||||
return A2ARequestUtils.extract_text_from_message(message)
|
||||
|
||||
@staticmethod
|
||||
def get_input_message_from_request(
|
||||
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
|
||||
) -> Any:
|
||||
"""
|
||||
Extract the input message from an A2A request.
|
||||
|
||||
Args:
|
||||
request: The A2A SendMessageRequest or SendStreamingMessageRequest
|
||||
|
||||
Returns:
|
||||
The message object/dict or None
|
||||
"""
|
||||
params = getattr(request, "params", None)
|
||||
if params is None:
|
||||
return None
|
||||
return getattr(params, "message", None)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> int:
|
||||
"""
|
||||
Count tokens in text using litellm.token_counter.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Token count, or 0 if counting fails
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
return litellm.token_counter(text=text)
|
||||
except Exception:
|
||||
verbose_logger.debug("Failed to count tokens")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def calculate_usage_from_request_response(
|
||||
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
|
||||
response_dict: Dict[str, Any],
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate token usage from A2A request and response.
|
||||
|
||||
Args:
|
||||
request: The A2A SendMessageRequest or SendStreamingMessageRequest
|
||||
response_dict: The A2A response as a dict
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, completion_tokens, total_tokens)
|
||||
"""
|
||||
# Count input tokens
|
||||
input_message = A2ARequestUtils.get_input_message_from_request(request)
|
||||
input_text = A2ARequestUtils.extract_text_from_message(input_message)
|
||||
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
|
||||
|
||||
# Count output tokens
|
||||
output_text = A2ARequestUtils.extract_text_from_response(response_dict)
|
||||
completion_tokens = A2ARequestUtils.count_tokens(output_text)
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return prompt_tokens, completion_tokens, total_tokens
|
||||
|
||||
|
||||
# Backwards compatibility aliases
|
||||
def extract_text_from_a2a_message(message: Any) -> str:
|
||||
return A2ARequestUtils.extract_text_from_message(message)
|
||||
|
||||
|
||||
def extract_text_from_a2a_response(response_dict: Dict[str, Any]) -> str:
|
||||
return A2ARequestUtils.extract_text_from_response(response_dict)
|
||||
@@ -0,0 +1,182 @@
|
||||
{
|
||||
"description": "Mapping of Anthropic beta headers for each provider. Keys are input header names, values are provider-specific header names (or null if unsupported). Only headers present in mapping keys with non-null values can be forwarded.",
|
||||
"anthropic": {
|
||||
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": "code-execution-2025-08-25",
|
||||
"compact-2026-01-12": "compact-2026-01-12",
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": "effort-2025-11-24",
|
||||
"fast-mode-2026-02-01": "fast-mode-2026-02-01",
|
||||
"files-api-2025-04-14": "files-api-2025-04-14",
|
||||
"structured-output-2024-03-01": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": "fine-grained-tool-streaming-2025-05-14",
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
|
||||
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"oauth-2025-04-20": "oauth-2025-04-20",
|
||||
"output-128k-2025-02-19": "output-128k-2025-02-19",
|
||||
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
|
||||
"skills-2025-10-02": "skills-2025-10-02",
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": "token-efficient-tools-2025-02-19",
|
||||
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
},
|
||||
"azure_ai": {
|
||||
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": "code-execution-2025-08-25",
|
||||
"compact-2026-01-12": null,
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": "effort-2025-11-24",
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": "files-api-2025-04-14",
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
|
||||
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
|
||||
"skills-2025-10-02": "skills-2025-10-02",
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
},
|
||||
"bedrock_converse": {
|
||||
"advanced-tool-use-2025-11-20": null,
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": null,
|
||||
"compact-2026-01-12": null,
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": null,
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": null,
|
||||
"mcp-client-2025-04-04": null,
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": null,
|
||||
"skills-2025-10-02": null,
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"tool-search-tool-2025-10-19": null,
|
||||
"web-fetch-2025-09-10": null,
|
||||
"web-search-2025-03-05": null
|
||||
},
|
||||
"bedrock": {
|
||||
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": null,
|
||||
"compact-2026-01-12": "compact-2026-01-12",
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": null,
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": null,
|
||||
"mcp-client-2025-04-04": null,
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": null,
|
||||
"skills-2025-10-02": null,
|
||||
"structured-outputs-2025-11-13": null,
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"tool-search-tool-2025-10-19": "tool-search-tool-2025-10-19",
|
||||
"web-fetch-2025-09-10": null,
|
||||
"web-search-2025-03-05": null
|
||||
},
|
||||
"vertex_ai": {
|
||||
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": null,
|
||||
"compact-2026-01-12": null,
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": null,
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": null,
|
||||
"mcp-client-2025-04-04": null,
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": null,
|
||||
"skills-2025-10-02": null,
|
||||
"structured-outputs-2025-11-13": null,
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"tool-search-tool-2025-10-19": "tool-search-tool-2025-10-19",
|
||||
"web-fetch-2025-09-10": null,
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
},
|
||||
"databricks": {
|
||||
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": "code-execution-2025-08-25",
|
||||
"compact-2026-01-12": "compact-2026-01-12",
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": "effort-2025-11-24",
|
||||
"fast-mode-2026-02-01": "fast-mode-2026-02-01",
|
||||
"files-api-2025-04-14": "files-api-2025-04-14",
|
||||
"structured-output-2024-03-01": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": "fine-grained-tool-streaming-2025-05-14",
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
|
||||
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"oauth-2025-04-20": "oauth-2025-04-20",
|
||||
"output-128k-2025-02-19": "output-128k-2025-02-19",
|
||||
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
|
||||
"skills-2025-10-02": "skills-2025-10-02",
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": "token-efficient-tools-2025-02-19",
|
||||
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
Centralized manager for Anthropic beta headers across different providers.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Load beta header configuration from JSON (mapping of supported headers per provider)
|
||||
2. Filter and map beta headers based on provider support
|
||||
3. Handle provider-specific header name mappings (e.g., advanced-tool-use -> tool-search-tool)
|
||||
4. Support remote fetching and caching similar to model cost map
|
||||
|
||||
Design:
|
||||
- JSON config contains mapping of beta headers for each provider
|
||||
- Keys are input header names, values are provider-specific header names (or null if unsupported)
|
||||
- Only headers present in mapping keys with non-null values can be forwarded
|
||||
- This enforces stricter validation than the previous unsupported list approach
|
||||
|
||||
Configuration can be loaded from:
|
||||
- Remote URL (default): Fetches from GitHub repository
|
||||
- Local file: Set LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS=True to use bundled config only
|
||||
|
||||
Environment Variables:
|
||||
- LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS: Set to "True" to disable remote fetching
|
||||
- LITELLM_ANTHROPIC_BETA_HEADERS_URL: Custom URL for remote config (optional)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib.resources import files
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import verbose_logger
|
||||
|
||||
# Cache for the loaded configuration
|
||||
_BETA_HEADERS_CONFIG: Optional[Dict] = None
|
||||
|
||||
|
||||
class GetAnthropicBetaHeadersConfig:
|
||||
"""
|
||||
Handles fetching, validating, and loading the Anthropic beta headers configuration.
|
||||
|
||||
Similar to GetModelCostMap, this class manages the lifecycle of the beta headers
|
||||
configuration with support for remote fetching and local fallback.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_local_beta_headers_config() -> Dict:
|
||||
"""Load the local backup beta headers config bundled with the package."""
|
||||
try:
|
||||
content = json.loads(
|
||||
files("litellm")
|
||||
.joinpath("anthropic_beta_headers_config.json")
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
return content
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to load local beta headers config: {e}")
|
||||
# Return empty config as fallback
|
||||
return {
|
||||
"anthropic": {},
|
||||
"azure_ai": {},
|
||||
"bedrock": {},
|
||||
"bedrock_converse": {},
|
||||
"vertex_ai": {},
|
||||
"provider_aliases": {},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _check_is_valid_dict(fetched_config: dict) -> bool:
|
||||
"""Check if fetched config is a non-empty dict with expected structure."""
|
||||
if not isinstance(fetched_config, dict):
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config is not a dict (type=%s). "
|
||||
"Falling back to local backup.",
|
||||
type(fetched_config).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
if len(fetched_config) == 0:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config is empty. "
|
||||
"Falling back to local backup.",
|
||||
)
|
||||
return False
|
||||
|
||||
# Check for at least one provider key
|
||||
provider_keys = [
|
||||
"anthropic",
|
||||
"azure_ai",
|
||||
"bedrock",
|
||||
"bedrock_converse",
|
||||
"vertex_ai",
|
||||
]
|
||||
has_provider = any(key in fetched_config for key in provider_keys)
|
||||
|
||||
if not has_provider:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config missing provider keys. "
|
||||
"Falling back to local backup.",
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_beta_headers_config(cls, fetched_config: dict) -> bool:
|
||||
"""
|
||||
Validate the integrity of a fetched beta headers config.
|
||||
|
||||
Returns True if all checks pass, False otherwise.
|
||||
"""
|
||||
return cls._check_is_valid_dict(fetched_config)
|
||||
|
||||
@staticmethod
|
||||
def fetch_remote_beta_headers_config(url: str, timeout: int = 5) -> dict:
|
||||
"""
|
||||
Fetch the beta headers config from a remote URL.
|
||||
|
||||
Returns the parsed JSON dict. Raises on network/parse errors
|
||||
(caller is expected to handle).
|
||||
"""
|
||||
response = httpx.get(url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_beta_headers_config(url: str) -> dict:
|
||||
"""
|
||||
Public entry point — returns the beta headers config dict.
|
||||
|
||||
1. If ``LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS`` is set, uses the local backup only.
|
||||
2. Otherwise fetches from ``url``, validates integrity, and falls back
|
||||
to the local backup on any failure.
|
||||
|
||||
Args:
|
||||
url: URL to fetch the remote beta headers configuration from
|
||||
|
||||
Returns:
|
||||
Dict containing the beta headers configuration
|
||||
"""
|
||||
# Check if local-only mode is enabled
|
||||
if os.getenv("LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS", "").lower() == "true":
|
||||
# verbose_logger.debug("Using local Anthropic beta headers config (LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS=True)")
|
||||
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
|
||||
|
||||
try:
|
||||
content = GetAnthropicBetaHeadersConfig.fetch_remote_beta_headers_config(url)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Failed to fetch remote beta headers config from %s: %s. "
|
||||
"Falling back to local backup.",
|
||||
url,
|
||||
str(e),
|
||||
)
|
||||
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
|
||||
|
||||
# Validate the fetched config
|
||||
if not GetAnthropicBetaHeadersConfig.validate_beta_headers_config(
|
||||
fetched_config=content
|
||||
):
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config failed integrity check. "
|
||||
"Using local backup instead. url=%s",
|
||||
url,
|
||||
)
|
||||
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _load_beta_headers_config() -> Dict:
|
||||
"""
|
||||
Load the beta headers configuration.
|
||||
Uses caching to avoid repeated fetches/file reads.
|
||||
|
||||
This function is called by all public API functions and manages the global cache.
|
||||
|
||||
Returns:
|
||||
Dict containing the beta headers configuration
|
||||
"""
|
||||
global _BETA_HEADERS_CONFIG
|
||||
|
||||
if _BETA_HEADERS_CONFIG is not None:
|
||||
return _BETA_HEADERS_CONFIG
|
||||
|
||||
# Get the URL from environment or use default
|
||||
from litellm import anthropic_beta_headers_url
|
||||
|
||||
_BETA_HEADERS_CONFIG = get_beta_headers_config(url=anthropic_beta_headers_url)
|
||||
verbose_logger.debug("Loaded and cached beta headers config")
|
||||
|
||||
return _BETA_HEADERS_CONFIG
|
||||
|
||||
|
||||
def reload_beta_headers_config() -> Dict:
|
||||
"""
|
||||
Force reload the beta headers configuration from source (remote or local).
|
||||
Clears the cache and fetches fresh configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing the newly loaded beta headers configuration
|
||||
"""
|
||||
global _BETA_HEADERS_CONFIG
|
||||
_BETA_HEADERS_CONFIG = None
|
||||
verbose_logger.info("Reloading beta headers config (cache cleared)")
|
||||
return _load_beta_headers_config()
|
||||
|
||||
|
||||
def get_provider_name(provider: str) -> str:
|
||||
"""
|
||||
Resolve provider aliases to canonical provider names.
|
||||
|
||||
Args:
|
||||
provider: Provider name (may be an alias)
|
||||
|
||||
Returns:
|
||||
Canonical provider name
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
aliases = config.get("provider_aliases", {})
|
||||
return aliases.get(provider, provider)
|
||||
|
||||
|
||||
def filter_and_transform_beta_headers(
|
||||
beta_headers: List[str],
|
||||
provider: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Filter and transform beta headers based on provider's mapping configuration.
|
||||
|
||||
This function:
|
||||
1. Only allows headers that are present in the provider's mapping keys
|
||||
2. Filters out headers with null values (unsupported)
|
||||
3. Maps headers to provider-specific names (e.g., advanced-tool-use -> tool-search-tool)
|
||||
|
||||
Args:
|
||||
beta_headers: List of Anthropic beta header values
|
||||
provider: Provider name (e.g., "anthropic", "bedrock", "vertex_ai")
|
||||
|
||||
Returns:
|
||||
List of filtered and transformed beta headers for the provider
|
||||
"""
|
||||
if not beta_headers:
|
||||
return []
|
||||
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
|
||||
# Get the header mapping for this provider
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
filtered_headers: Set[str] = set()
|
||||
|
||||
for header in beta_headers:
|
||||
header = header.strip()
|
||||
|
||||
# Check if header is in the mapping
|
||||
if header not in provider_mapping:
|
||||
verbose_logger.debug(
|
||||
f"Dropping unknown beta header '{header}' for provider '{provider}' (not in mapping)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the mapped header value
|
||||
mapped_header = provider_mapping[header]
|
||||
|
||||
# Skip if header is unsupported (null value)
|
||||
if mapped_header is None:
|
||||
verbose_logger.debug(
|
||||
f"Dropping unsupported beta header '{header}' for provider '{provider}'"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add the mapped header
|
||||
filtered_headers.add(mapped_header)
|
||||
|
||||
return sorted(list(filtered_headers))
|
||||
|
||||
|
||||
def is_beta_header_supported(
|
||||
beta_header: str,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a specific beta header is supported by a provider.
|
||||
|
||||
Args:
|
||||
beta_header: The Anthropic beta header value
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
True if the header is in the mapping with a non-null value, False otherwise
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
# Header is supported if it's in the mapping and has a non-null value
|
||||
return beta_header in provider_mapping and provider_mapping[beta_header] is not None
|
||||
|
||||
|
||||
def get_provider_beta_header(
|
||||
anthropic_beta_header: str,
|
||||
provider: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the provider-specific beta header name for a given Anthropic beta header.
|
||||
|
||||
This function handles header transformations/mappings (e.g., advanced-tool-use -> tool-search-tool).
|
||||
|
||||
Args:
|
||||
anthropic_beta_header: The Anthropic beta header value
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
The provider-specific header name if supported, or None if unsupported/unknown
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
|
||||
# Get the header mapping for this provider
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
# Check if header is in the mapping
|
||||
if anthropic_beta_header not in provider_mapping:
|
||||
return None
|
||||
|
||||
# Return the mapped value (could be None if unsupported)
|
||||
return provider_mapping[anthropic_beta_header]
|
||||
|
||||
|
||||
def update_headers_with_filtered_beta(
|
||||
headers: dict,
|
||||
provider: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Update headers dict by filtering and transforming anthropic-beta header values.
|
||||
Modifies the headers dict in place and returns it.
|
||||
|
||||
Args:
|
||||
headers: Request headers dict (will be modified in place)
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
Updated headers dict
|
||||
"""
|
||||
existing_beta = headers.get("anthropic-beta")
|
||||
if not existing_beta:
|
||||
return headers
|
||||
|
||||
# Parse existing beta headers
|
||||
beta_values = [b.strip() for b in existing_beta.split(",") if b.strip()]
|
||||
|
||||
# Filter and transform based on provider
|
||||
filtered_beta_values = filter_and_transform_beta_headers(
|
||||
beta_headers=beta_values,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Update or remove the header
|
||||
if filtered_beta_values:
|
||||
headers["anthropic-beta"] = ",".join(filtered_beta_values)
|
||||
else:
|
||||
# Remove the header if no values remain
|
||||
headers.pop("anthropic-beta", None)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def get_unsupported_headers(provider: str) -> List[str]:
|
||||
"""
|
||||
Get all beta headers that are unsupported by a provider (have null values in mapping).
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
List of unsupported Anthropic beta header names
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
# Return headers with null values
|
||||
return [header for header, value in provider_mapping.items() if value is None]
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Anthropic module for LiteLLM
|
||||
"""
|
||||
from .messages import acreate, create
|
||||
|
||||
__all__ = ["acreate", "create"]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Anthropic error format utilities."""
|
||||
|
||||
from .exception_mapping_utils import (
|
||||
ANTHROPIC_ERROR_TYPE_MAP,
|
||||
AnthropicExceptionMapping,
|
||||
)
|
||||
from .exceptions import (
|
||||
AnthropicErrorDetail,
|
||||
AnthropicErrorResponse,
|
||||
AnthropicErrorType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicErrorType",
|
||||
"AnthropicErrorDetail",
|
||||
"AnthropicErrorResponse",
|
||||
"ANTHROPIC_ERROR_TYPE_MAP",
|
||||
"AnthropicExceptionMapping",
|
||||
]
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Utilities for mapping exceptions to Anthropic error format.
|
||||
|
||||
Similar to litellm/litellm_core_utils/exception_mapping_utils.py but for Anthropic response format.
|
||||
"""
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
|
||||
|
||||
|
||||
# HTTP status code -> Anthropic error type
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {
|
||||
400: "invalid_request_error",
|
||||
401: "authentication_error",
|
||||
403: "permission_error",
|
||||
404: "not_found_error",
|
||||
413: "request_too_large",
|
||||
429: "rate_limit_error",
|
||||
500: "api_error",
|
||||
529: "overloaded_error",
|
||||
}
|
||||
|
||||
|
||||
class AnthropicExceptionMapping:
|
||||
"""
|
||||
Helper class for mapping exceptions to Anthropic error format.
|
||||
|
||||
Similar pattern to ExceptionCheckers in litellm_core_utils/exception_mapping_utils.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_error_type(status_code: int) -> AnthropicErrorType:
|
||||
"""Map HTTP status code to Anthropic error type."""
|
||||
return ANTHROPIC_ERROR_TYPE_MAP.get(status_code, "api_error")
|
||||
|
||||
@staticmethod
|
||||
def create_error_response(
|
||||
status_code: int,
|
||||
message: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> AnthropicErrorResponse:
|
||||
"""
|
||||
Create an Anthropic-formatted error response dict.
|
||||
|
||||
Anthropic error format:
|
||||
{
|
||||
"type": "error",
|
||||
"error": {"type": "...", "message": "..."},
|
||||
"request_id": "req_..."
|
||||
}
|
||||
"""
|
||||
error_type = AnthropicExceptionMapping.get_error_type(status_code)
|
||||
|
||||
response: AnthropicErrorResponse = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
|
||||
if request_id:
|
||||
response["request_id"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def extract_error_message(raw_message: str) -> str:
|
||||
"""
|
||||
Extract error message from various provider response formats.
|
||||
|
||||
Handles:
|
||||
- Bedrock: {"detail": {"message": "..."}}
|
||||
- AWS: {"Message": "..."}
|
||||
- Generic: {"message": "..."}
|
||||
- Plain strings
|
||||
"""
|
||||
parsed = safe_json_loads(raw_message)
|
||||
if isinstance(parsed, dict):
|
||||
# Bedrock format
|
||||
if "detail" in parsed and isinstance(parsed["detail"], dict):
|
||||
return parsed["detail"].get("message", raw_message)
|
||||
# AWS/generic format
|
||||
return parsed.get("Message") or parsed.get("message") or raw_message
|
||||
return raw_message
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_error_dict(parsed: dict) -> bool:
|
||||
"""
|
||||
Check if a parsed dict is in Anthropic error format.
|
||||
|
||||
Anthropic error format:
|
||||
{
|
||||
"type": "error",
|
||||
"error": {"type": "...", "message": "..."}
|
||||
}
|
||||
"""
|
||||
return (
|
||||
parsed.get("type") == "error"
|
||||
and isinstance(parsed.get("error"), dict)
|
||||
and "type" in parsed["error"]
|
||||
and "message" in parsed["error"]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_from_dict(parsed: dict, raw_message: str) -> str:
|
||||
"""
|
||||
Extract error message from a parsed provider-specific dict.
|
||||
|
||||
Handles:
|
||||
- Bedrock: {"detail": {"message": "..."}}
|
||||
- AWS: {"Message": "..."}
|
||||
- Generic: {"message": "..."}
|
||||
"""
|
||||
# Bedrock format
|
||||
if "detail" in parsed and isinstance(parsed["detail"], dict):
|
||||
return parsed["detail"].get("message", raw_message)
|
||||
# AWS/generic format
|
||||
return parsed.get("Message") or parsed.get("message") or raw_message
|
||||
|
||||
@staticmethod
|
||||
def transform_to_anthropic_error(
|
||||
status_code: int,
|
||||
raw_message: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> AnthropicErrorResponse:
|
||||
"""
|
||||
Transform an error message to Anthropic format.
|
||||
|
||||
- If already in Anthropic format: passthrough unchanged
|
||||
- Otherwise: extract message and create Anthropic error
|
||||
|
||||
Parses JSON only once for efficiency.
|
||||
|
||||
Args:
|
||||
status_code: HTTP status code
|
||||
raw_message: Raw error message (may be JSON string or plain text)
|
||||
request_id: Optional request ID to include
|
||||
|
||||
Returns:
|
||||
AnthropicErrorResponse dict
|
||||
"""
|
||||
# Try to parse as JSON once
|
||||
parsed: Optional[dict] = safe_json_loads(raw_message)
|
||||
if not isinstance(parsed, dict):
|
||||
parsed = None
|
||||
|
||||
# If parsed and already in Anthropic format - passthrough
|
||||
if parsed is not None and AnthropicExceptionMapping._is_anthropic_error_dict(
|
||||
parsed
|
||||
):
|
||||
# Optionally add request_id if provided and not present
|
||||
if request_id and "request_id" not in parsed:
|
||||
parsed["request_id"] = request_id
|
||||
return parsed # type: ignore
|
||||
|
||||
# Extract message - use parsed dict if available, otherwise raw string
|
||||
if parsed is not None:
|
||||
message = AnthropicExceptionMapping._extract_message_from_dict(
|
||||
parsed, raw_message
|
||||
)
|
||||
else:
|
||||
message = raw_message
|
||||
|
||||
return AnthropicExceptionMapping.create_error_response(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request_id=request_id,
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Anthropic error format type definitions."""
|
||||
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
|
||||
# Known Anthropic error types
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
AnthropicErrorType = Literal[
|
||||
"invalid_request_error",
|
||||
"authentication_error",
|
||||
"permission_error",
|
||||
"not_found_error",
|
||||
"request_too_large",
|
||||
"rate_limit_error",
|
||||
"api_error",
|
||||
"overloaded_error",
|
||||
]
|
||||
|
||||
|
||||
class AnthropicErrorDetail(TypedDict):
|
||||
"""Inner error detail in Anthropic format."""
|
||||
|
||||
type: AnthropicErrorType
|
||||
message: str
|
||||
|
||||
|
||||
class AnthropicErrorResponse(TypedDict, total=False):
|
||||
"""
|
||||
Anthropic-formatted error response.
|
||||
|
||||
Format:
|
||||
{
|
||||
"type": "error",
|
||||
"error": {"type": "...", "message": "..."},
|
||||
"request_id": "req_..." # optional
|
||||
}
|
||||
"""
|
||||
|
||||
type: Required[Literal["error"]]
|
||||
error: Required[AnthropicErrorDetail]
|
||||
request_id: str
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Interface for Anthropic's messages API
|
||||
|
||||
Use this to call LLMs in Anthropic /messages Request/Response format
|
||||
|
||||
This is an __init__.py file to allow the following interface
|
||||
|
||||
- litellm.messages.acreate
|
||||
- litellm.messages.create
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union
|
||||
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||
anthropic_messages as _async_anthropic_messages,
|
||||
)
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||
anthropic_messages_handler as _sync_anthropic_messages,
|
||||
)
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
|
||||
|
||||
async def acreate(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
container: Optional[Dict] = None,
|
||||
**kwargs
|
||||
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
||||
"""
|
||||
Async wrapper for Anthropic's messages API
|
||||
|
||||
Args:
|
||||
max_tokens (int): Maximum tokens to generate (required)
|
||||
messages (List[Dict]): List of message objects with role and content (required)
|
||||
model (str): Model name to use (required)
|
||||
metadata (Dict, optional): Request metadata
|
||||
stop_sequences (List[str], optional): Custom stop sequences
|
||||
stream (bool, optional): Whether to stream the response
|
||||
system (str, optional): System prompt
|
||||
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
||||
thinking (Dict, optional): Extended thinking configuration
|
||||
tool_choice (Dict, optional): Tool choice configuration
|
||||
tools (List[Dict], optional): List of tool definitions
|
||||
top_k (int, optional): Top K sampling parameter
|
||||
top_p (float, optional): Nucleus sampling parameter
|
||||
container (Dict, optional): Container config with skills for code execution
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dict: Response from the API
|
||||
"""
|
||||
return await _async_anthropic_messages(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
container=container,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
container: Optional[Dict] = None,
|
||||
**kwargs
|
||||
) -> Union[
|
||||
AnthropicMessagesResponse,
|
||||
AsyncIterator[Any],
|
||||
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
|
||||
]:
|
||||
"""
|
||||
Async wrapper for Anthropic's messages API
|
||||
|
||||
Args:
|
||||
max_tokens (int): Maximum tokens to generate (required)
|
||||
messages (List[Dict]): List of message objects with role and content (required)
|
||||
model (str): Model name to use (required)
|
||||
metadata (Dict, optional): Request metadata
|
||||
stop_sequences (List[str], optional): Custom stop sequences
|
||||
stream (bool, optional): Whether to stream the response
|
||||
system (str, optional): System prompt
|
||||
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
||||
thinking (Dict, optional): Extended thinking configuration
|
||||
tool_choice (Dict, optional): Tool choice configuration
|
||||
tools (List[Dict], optional): List of tool definitions
|
||||
top_k (int, optional): Top K sampling parameter
|
||||
top_p (float, optional): Nucleus sampling parameter
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dict: Response from the API
|
||||
"""
|
||||
return _sync_anthropic_messages(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
container=container,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
## Use LLM API endpoints in Anthropic Interface
|
||||
|
||||
Note: This is called `anthropic_interface` because `anthropic` is a known python package and was failing mypy type checking.
|
||||
|
||||
|
||||
## Usage
|
||||
---
|
||||
|
||||
### LiteLLM Python SDK
|
||||
|
||||
#### Non-streaming example
|
||||
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
||||
import litellm
|
||||
response = await litellm.anthropic.messages.acreate(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
api_key=api_key,
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
)
|
||||
```
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hi! this is a very short joke",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 2095,
|
||||
"output_tokens": 503,
|
||||
"cache_creation_input_tokens": 2095,
|
||||
"cache_read_input_tokens": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Streaming example
|
||||
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
||||
import litellm
|
||||
response = await litellm.anthropic.messages.acreate(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
api_key=api_key,
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
### LiteLLM Proxy Server
|
||||
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: claude-3-7-sonnet-latest
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="Anthropic Python SDK" value="python">
|
||||
|
||||
```python showLineNumbers title="Example using LiteLLM Proxy Server"
|
||||
import anthropic
|
||||
|
||||
# point anthropic sdk to litellm proxy
|
||||
client = anthropic.Anthropic(
|
||||
base_url="http://0.0.0.0:4000",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem label="curl" value="curl">
|
||||
|
||||
```bash showLineNumbers title="Example using LiteLLM Proxy Server"
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
|
||||
-H 'content-type: application/json' \
|
||||
-H 'x-api-key: $LITELLM_API_KEY' \
|
||||
-H 'anthropic-version: 2023-06-01' \
|
||||
-d '{
|
||||
"model": "anthropic-claude",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, can you tell me a short joke?"
|
||||
}
|
||||
],
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
1484
llm-gateway-competitors/litellm-wheel-src/litellm/assistants/main.py
Normal file
1484
llm-gateway-competitors/litellm-wheel-src/litellm/assistants/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,161 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
from ..exceptions import UnsupportedParamsError
|
||||
from ..types.llms.openai import *
|
||||
|
||||
|
||||
def get_optional_params_add_message(
|
||||
role: Optional[str],
|
||||
content: Optional[
|
||||
Union[
|
||||
str,
|
||||
List[
|
||||
Union[
|
||||
MessageContentTextObject,
|
||||
MessageContentImageFileObject,
|
||||
MessageContentImageURLObject,
|
||||
]
|
||||
],
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]],
|
||||
metadata: Optional[dict],
|
||||
custom_llm_provider: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Azure doesn't support 'attachments' for creating a message
|
||||
|
||||
Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
||||
"""
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"role": None,
|
||||
"content": None,
|
||||
"attachments": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
non_default_params = {
|
||||
k: v
|
||||
for k, v in passed_params.items()
|
||||
if (k in default_params and v != default_params[k])
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
def _check_valid_arg(supported_params):
|
||||
if len(non_default_params.keys()) > 0:
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
litellm.drop_params is True and k not in supported_params
|
||||
): # drop the unsupported non-default values
|
||||
non_default_params.pop(k, None)
|
||||
elif k not in supported_params:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
|
||||
k, custom_llm_provider, supported_params
|
||||
),
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
if custom_llm_provider == "openai":
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "azure":
|
||||
supported_params = (
|
||||
litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
for k in passed_params.keys():
|
||||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
return optional_params
|
||||
|
||||
|
||||
def get_optional_params_image_gen(
|
||||
n: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
size: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"n": None,
|
||||
"quality": None,
|
||||
"response_format": None,
|
||||
"size": None,
|
||||
"style": None,
|
||||
"user": None,
|
||||
}
|
||||
|
||||
non_default_params = {
|
||||
k: v
|
||||
for k, v in passed_params.items()
|
||||
if (k in default_params and v != default_params[k])
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
def _check_valid_arg(supported_params):
|
||||
if len(non_default_params.keys()) > 0:
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
litellm.drop_params is True and k not in supported_params
|
||||
): # drop the unsupported non-default values
|
||||
non_default_params.pop(k, None)
|
||||
elif k not in supported_params:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
):
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "bedrock":
|
||||
supported_params = ["size"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if size is not None:
|
||||
width, height = size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
supported_params = ["n"]
|
||||
"""
|
||||
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
"""
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if n is not None:
|
||||
optional_params["sampleCount"] = int(n)
|
||||
|
||||
for k in passed_params.keys():
|
||||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
return optional_params
|
||||
@@ -0,0 +1,11 @@
|
||||
# Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
|
||||
|
||||
Doc: https://docs.litellm.ai/docs/completion/batching
|
||||
|
||||
|
||||
LiteLLM Python SDK allows you to:
|
||||
1. `litellm.batch_completion` Batch litellm.completion function for a given model.
|
||||
2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
|
||||
as soon as one of the models responds.
|
||||
3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
|
||||
from all models that respond.
|
||||
@@ -0,0 +1,273 @@
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.utils import get_optional_params
|
||||
|
||||
from ..llms.vllm.completion import handler as vllm_handler
|
||||
|
||||
|
||||
def batch_completion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
functions: Optional[List] = None,
|
||||
function_call: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
deployment_id=None,
|
||||
request_timeout: Optional[int] = None,
|
||||
timeout: Optional[int] = 600,
|
||||
max_workers: Optional[int] = 100,
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Batch litellm.completion function for a given model.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for generating completions.
|
||||
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
|
||||
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
|
||||
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
|
||||
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
|
||||
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
|
||||
n (int, optional): The number of completions to generate. Defaults to None.
|
||||
stream (bool, optional): Whether to stream completions or not. Defaults to None.
|
||||
stop (optional): The stop parameter for generating completions. Defaults to None.
|
||||
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
|
||||
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
|
||||
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
|
||||
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
|
||||
user (str, optional): The user string for generating completions. Defaults to "".
|
||||
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
|
||||
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
|
||||
max_workers (int,optional): The maximum number of threads to use for parallel processing.
|
||||
|
||||
Returns:
|
||||
list: A list of completion results.
|
||||
"""
|
||||
args = locals()
|
||||
|
||||
batch_messages = messages
|
||||
completions = []
|
||||
model = model
|
||||
custom_llm_provider = None
|
||||
if model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if custom_llm_provider == "vllm":
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream or False,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
results = vllm_handler.batch_completions(
|
||||
model=model,
|
||||
messages=batch_messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
# all non VLLM models for batch completion models
|
||||
else:
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for sub_batch in chunks(batch_messages, 100):
|
||||
for message_list in sub_batch:
|
||||
kwargs_modified = args.copy()
|
||||
kwargs_modified.pop("max_workers")
|
||||
kwargs_modified["messages"] = message_list
|
||||
original_kwargs = {}
|
||||
if "kwargs" in kwargs_modified:
|
||||
original_kwargs = kwargs_modified.pop("kwargs")
|
||||
future = executor.submit(
|
||||
litellm.completion, **kwargs_modified, **original_kwargs
|
||||
)
|
||||
completions.append(future)
|
||||
|
||||
# Retrieve the results from the futures
|
||||
# results = [future.result() for future in completions]
|
||||
# return exceptions if any
|
||||
results = []
|
||||
for future in completions:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as exc:
|
||||
results.append(exc)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# send one request to multiple models
|
||||
# return as soon as one of the llms responds
|
||||
def batch_completion_models(*args, **kwargs):
|
||||
"""
|
||||
Send a request to multiple language models concurrently and return the response
|
||||
as soon as one of the models responds.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments passed to the completion function.
|
||||
**kwargs: Additional keyword arguments:
|
||||
- models (str or list of str): The language models to send requests to.
|
||||
- Other keyword arguments to be passed to the completion function.
|
||||
|
||||
Returns:
|
||||
str or None: The response from one of the language models, or None if no response is received.
|
||||
|
||||
Note:
|
||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||
It sends requests concurrently and returns the response from the first model that responds.
|
||||
"""
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
if "models" in kwargs:
|
||||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
futures = {}
|
||||
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
for model in models:
|
||||
futures[model] = executor.submit(
|
||||
litellm.completion, *args, model=model, **kwargs
|
||||
)
|
||||
|
||||
for model, future in sorted(
|
||||
futures.items(), key=lambda x: models.index(x[0])
|
||||
):
|
||||
if future.result() is not None:
|
||||
return future.result()
|
||||
elif "deployments" in kwargs:
|
||||
deployments = kwargs["deployments"]
|
||||
kwargs.pop("deployments")
|
||||
kwargs.pop("model_list")
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
futures = {}
|
||||
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
||||
for deployment in deployments:
|
||||
for key in kwargs.keys():
|
||||
if (
|
||||
key not in deployment
|
||||
): # don't override deployment values e.g. model name, api base, etc.
|
||||
deployment[key] = kwargs[key]
|
||||
kwargs = {**deployment, **nested_kwargs}
|
||||
futures[deployment["model"]] = executor.submit(
|
||||
litellm.completion, **kwargs
|
||||
)
|
||||
|
||||
while futures:
|
||||
# wait for the first returned future
|
||||
print_verbose("\n\n waiting for next result\n\n")
|
||||
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
||||
print_verbose(f"done list\n{done}")
|
||||
for future in done:
|
||||
try:
|
||||
result = future.result()
|
||||
return result
|
||||
except Exception:
|
||||
# if model 1 fails, continue with response from model 2, model3
|
||||
print_verbose(
|
||||
"\n\ngot an exception, ignoring, removing from futures"
|
||||
)
|
||||
print_verbose(futures)
|
||||
new_futures = {}
|
||||
for key, value in futures.items():
|
||||
if future == value:
|
||||
print_verbose(f"removing key{key}")
|
||||
continue
|
||||
else:
|
||||
new_futures[key] = value
|
||||
futures = new_futures
|
||||
print_verbose(f"new futures{futures}")
|
||||
continue
|
||||
|
||||
print_verbose("\n\ndone looping through futures\n\n")
|
||||
print_verbose(futures)
|
||||
|
||||
return None # If no response is received from any model
|
||||
|
||||
|
||||
def batch_completion_models_all_responses(*args, **kwargs):
|
||||
"""
|
||||
Send a request to multiple language models concurrently and return a list of responses
|
||||
from all models that respond.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments passed to the completion function.
|
||||
**kwargs: Additional keyword arguments:
|
||||
- models (str or list of str): The language models to send requests to.
|
||||
- Other keyword arguments to be passed to the completion function.
|
||||
|
||||
Returns:
|
||||
list: A list of responses from the language models that responded.
|
||||
|
||||
Note:
|
||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||
It sends requests concurrently and collects responses from all models that respond.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
# ANSI escape codes for colored output
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
if "models" in kwargs:
|
||||
models = kwargs.pop("models")
|
||||
else:
|
||||
raise Exception("'models' param not in kwargs")
|
||||
|
||||
if isinstance(models, str):
|
||||
models = [models]
|
||||
elif isinstance(models, (list, tuple)):
|
||||
models = list(models)
|
||||
else:
|
||||
raise TypeError("'models' must be a string or list of strings")
|
||||
|
||||
if len(models) == 0:
|
||||
return []
|
||||
|
||||
responses = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
futures = [
|
||||
executor.submit(litellm.completion, *args, model=model, **kwargs)
|
||||
for model in models
|
||||
]
|
||||
|
||||
for future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
responses.append(result)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"batch_completion_models_all_responses: model request failed: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
return responses
|
||||
@@ -0,0 +1,442 @@
|
||||
import json
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import Batch
|
||||
from litellm.types.utils import CallTypes, ModelInfo, Usage
|
||||
from litellm.utils import token_counter
|
||||
|
||||
|
||||
async def calculate_batch_cost_and_usage(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
],
|
||||
model_name: Optional[str] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> Tuple[float, Usage, List[str]]:
|
||||
"""
|
||||
Calculate the cost and usage of a batch.
|
||||
|
||||
Args:
|
||||
model_info: Optional deployment-level model info with custom batch
|
||||
pricing. Threaded through to batch_cost_calculator so that
|
||||
deployment-specific pricing (e.g. input_cost_per_token_batches)
|
||||
is used instead of the global cost map.
|
||||
"""
|
||||
batch_cost = _batch_cost_calculator(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
model_name=model_name,
|
||||
model_info=model_info,
|
||||
)
|
||||
batch_usage = _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
batch_models = _get_batch_models_from_file_content(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
|
||||
return batch_cost, batch_usage, batch_models
|
||||
|
||||
|
||||
async def _handle_completed_batch(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
],
|
||||
model_name: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Tuple[float, Usage, List[str]]:
|
||||
"""Helper function to process a completed batch and handle logging
|
||||
|
||||
Args:
|
||||
batch: The batch object
|
||||
custom_llm_provider: The LLM provider
|
||||
model_name: Optional model name
|
||||
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
|
||||
"""
|
||||
# Get batch results
|
||||
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
||||
batch, custom_llm_provider, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
# Calculate costs and usage
|
||||
batch_cost = _batch_cost_calculator(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
model_name=model_name,
|
||||
)
|
||||
batch_usage = _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
batch_models = _get_batch_models_from_file_content(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
|
||||
return batch_cost, batch_usage, batch_models
|
||||
|
||||
|
||||
def _get_batch_models_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
model_name: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the models from the file content
|
||||
"""
|
||||
if model_name:
|
||||
return [model_name]
|
||||
batch_models = []
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
_model = _response_body.get("model")
|
||||
if _model:
|
||||
batch_models.append(_model)
|
||||
return batch_models
|
||||
|
||||
|
||||
def _batch_cost_calculator(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
model_name: Optional[str] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of a batch based on the output file id
|
||||
"""
|
||||
# Handle Vertex AI with specialized method
|
||||
if custom_llm_provider == "vertex_ai" and model_name:
|
||||
batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
|
||||
return batch_cost
|
||||
|
||||
# For other providers, use the existing logic
|
||||
total_cost = _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_info=model_info,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
|
||||
|
||||
def calculate_vertex_ai_batch_cost_and_usage(
|
||||
vertex_ai_batch_responses: List[dict],
|
||||
model_name: Optional[str] = None,
|
||||
) -> Tuple[float, Usage]:
|
||||
"""
|
||||
Calculate both cost and usage from Vertex AI batch responses.
|
||||
|
||||
Vertex AI batch output lines have format:
|
||||
{"request": ..., "status": "", "response": {"candidates": [...], "usageMetadata": {...}}}
|
||||
|
||||
usageMetadata contains promptTokenCount, candidatesTokenCount, totalTokenCount.
|
||||
"""
|
||||
from litellm.cost_calculator import batch_cost_calculator
|
||||
|
||||
total_cost = 0.0
|
||||
total_tokens = 0
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
actual_model_name = model_name or "gemini-2.0-flash-001"
|
||||
|
||||
for response in vertex_ai_batch_responses:
|
||||
response_body = response.get("response")
|
||||
if response_body is None:
|
||||
continue
|
||||
|
||||
usage_metadata = response_body.get("usageMetadata", {})
|
||||
_prompt = usage_metadata.get("promptTokenCount", 0) or 0
|
||||
_completion = usage_metadata.get("candidatesTokenCount", 0) or 0
|
||||
_total = usage_metadata.get("totalTokenCount", 0) or (_prompt + _completion)
|
||||
|
||||
line_usage = Usage(
|
||||
prompt_tokens=_prompt,
|
||||
completion_tokens=_completion,
|
||||
total_tokens=_total,
|
||||
)
|
||||
|
||||
try:
|
||||
p_cost, c_cost = batch_cost_calculator(
|
||||
usage=line_usage,
|
||||
model=actual_model_name,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
total_cost += p_cost + c_cost
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
"vertex_ai batch cost calculation error for line: %s", str(e)
|
||||
)
|
||||
|
||||
prompt_tokens += _prompt
|
||||
completion_tokens += _completion
|
||||
total_tokens += _total
|
||||
|
||||
verbose_logger.info(
|
||||
"vertex_ai batch cost: cost=%s, prompt=%d, completion=%d, total=%d",
|
||||
total_cost,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
)
|
||||
|
||||
return total_cost, Usage(
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
async def _get_batch_output_file_content_as_dictionary(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get the batch output file content as a list of dictionaries
|
||||
|
||||
Args:
|
||||
batch: The batch object
|
||||
custom_llm_provider: The LLM provider
|
||||
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
|
||||
Required for Azure and other providers that need authentication
|
||||
"""
|
||||
from litellm.files.main import afile_content
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
raise ValueError("Vertex AI does not support file content retrieval")
|
||||
|
||||
if batch.output_file_id is None:
|
||||
raise ValueError("Output file id is None cannot retrieve file content")
|
||||
|
||||
file_id = batch.output_file_id
|
||||
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
|
||||
if is_base64_unified_file_id:
|
||||
try:
|
||||
file_id = is_base64_unified_file_id.split("llm_output_file_id,")[1].split(
|
||||
";"
|
||||
)[0]
|
||||
verbose_logger.debug(
|
||||
f"Extracted LLM output file ID from unified file ID: {file_id}"
|
||||
)
|
||||
except (IndexError, AttributeError) as e:
|
||||
verbose_logger.error(
|
||||
f"Failed to extract LLM output file ID from unified file ID: {batch.output_file_id}, error: {e}"
|
||||
)
|
||||
|
||||
# Build kwargs for afile_content with credentials from litellm_params
|
||||
file_content_kwargs = {
|
||||
"file_id": file_id,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
}
|
||||
|
||||
# Extract and add credentials for file access
|
||||
credentials = _extract_file_access_credentials(litellm_params)
|
||||
file_content_kwargs.update(credentials)
|
||||
|
||||
_file_content = await afile_content(**file_content_kwargs) # type: ignore[reportArgumentType]
|
||||
return _get_file_content_as_dictionary(_file_content.content)
|
||||
|
||||
|
||||
def _extract_file_access_credentials(litellm_params: Optional[dict]) -> dict:
|
||||
"""
|
||||
Extract credentials from litellm_params for file access operations.
|
||||
|
||||
This method extracts relevant authentication and configuration parameters
|
||||
needed for accessing files across different providers (Azure, Vertex AI, etc.).
|
||||
|
||||
Args:
|
||||
litellm_params: Dictionary containing litellm parameters with credentials
|
||||
|
||||
Returns:
|
||||
Dictionary containing only the credentials needed for file access
|
||||
"""
|
||||
credentials = {}
|
||||
|
||||
if litellm_params:
|
||||
# List of credential keys that should be passed to file operations
|
||||
credential_keys = [
|
||||
"api_key",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"organization",
|
||||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
"vertex_project",
|
||||
"vertex_location",
|
||||
"vertex_credentials",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
]
|
||||
for key in credential_keys:
|
||||
if key in litellm_params:
|
||||
credentials[key] = litellm_params[key]
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
|
||||
"""
|
||||
Get the file content as a list of dictionaries from JSON Lines format
|
||||
"""
|
||||
try:
|
||||
_file_content_str = file_content.decode("utf-8")
|
||||
# Split by newlines and parse each line as a separate JSON object
|
||||
json_objects = []
|
||||
for line in _file_content_str.strip().split("\n"):
|
||||
if line: # Skip empty lines
|
||||
json_objects.append(json.loads(line))
|
||||
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
|
||||
return json_objects
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost of a batch job from the file content
|
||||
"""
|
||||
from litellm.cost_calculator import batch_cost_calculator
|
||||
|
||||
try:
|
||||
total_cost: float = 0.0
|
||||
# parse the file content as json
|
||||
verbose_logger.debug(
|
||||
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
|
||||
)
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
if model_info is not None:
|
||||
usage = _get_batch_job_usage_from_response_body(_response_body)
|
||||
model = _response_body.get("model", "")
|
||||
prompt_cost, completion_cost = batch_cost_calculator(
|
||||
usage=usage,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_info=model_info,
|
||||
)
|
||||
total_cost += prompt_cost + completion_cost
|
||||
else:
|
||||
total_cost += litellm.completion_cost(
|
||||
completion_response=_response_body,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
call_type=CallTypes.aretrieve_batch.value,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
except Exception as e:
|
||||
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
model_name: Optional[str] = None,
|
||||
) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the file content
|
||||
"""
|
||||
# Handle Vertex AI with specialized method
|
||||
if custom_llm_provider == "vertex_ai" and model_name:
|
||||
_, batch_usage = calculate_vertex_ai_batch_cost_and_usage(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
return batch_usage
|
||||
|
||||
# For other providers, use the existing logic
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
|
||||
total_tokens += usage.total_tokens
|
||||
prompt_tokens += usage.prompt_tokens
|
||||
completion_tokens += usage.completion_tokens
|
||||
return Usage(
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_job_input_file_usage(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
model_name: Optional[str] = None,
|
||||
) -> Usage:
|
||||
"""
|
||||
Count the number of tokens in the input file
|
||||
|
||||
Used for batch rate limiting to count the number of tokens in the input file
|
||||
"""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
for _item in file_content_dictionary:
|
||||
body = _item.get("body", {})
|
||||
model = body.get("model", model_name or "")
|
||||
messages = body.get("messages", [])
|
||||
|
||||
if messages:
|
||||
item_tokens = token_counter(model=model, messages=messages)
|
||||
prompt_tokens += item_tokens
|
||||
|
||||
return Usage(
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the response body
|
||||
"""
|
||||
_usage_dict = response_body.get("usage", None) or {}
|
||||
usage: Usage = Usage(**_usage_dict)
|
||||
return usage
|
||||
|
||||
|
||||
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
|
||||
"""
|
||||
Get the response from the batch job output file
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
_response_body = _response.get("body", None) or {}
|
||||
return _response_body
|
||||
|
||||
|
||||
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
||||
"""
|
||||
Check if the batch job response status == 200
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
return _response.get("status_code", None) == 200
|
||||
1181
llm-gateway-competitors/litellm-wheel-src/litellm/batches/main.py
Normal file
1181
llm-gateway-competitors/litellm-wheel-src/litellm/batches/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"posts": [
|
||||
{
|
||||
"title": "Incident Report: SERVER_ROOT_PATH regression broke UI routing",
|
||||
"description": "How a single line removal caused UI 404s for all deployments using SERVER_ROOT_PATH, and the tests we added to prevent it from happening again.",
|
||||
"date": "2026-02-21",
|
||||
"url": "https://docs.litellm.ai/blog/server-root-path-incident"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | NOT PROXY BUDGET MANAGER |
|
||||
# | proxy budget manager is in proxy_server.py |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.constants import (
|
||||
DAYS_IN_A_MONTH,
|
||||
DAYS_IN_A_WEEK,
|
||||
DAYS_IN_A_YEAR,
|
||||
HOURS_IN_A_DAY,
|
||||
)
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
|
||||
class BudgetManager:
|
||||
def __init__(
|
||||
self,
|
||||
project_name: str,
|
||||
client_type: str = "local",
|
||||
api_base: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
self.client_type = client_type
|
||||
self.project_name = project_name
|
||||
self.api_base = api_base or "https://api.litellm.ai"
|
||||
self.headers = headers or {"Content-Type": "application/json"}
|
||||
## load the data or init the initial dictionaries
|
||||
self.load_data()
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
import logging
|
||||
|
||||
logging.info(print_statement)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
if self.client_type == "local":
|
||||
# Check if user dict file exists
|
||||
if os.path.isfile("user_cost.json"):
|
||||
# Load the user dict
|
||||
with open("user_cost.json", "r") as json_file:
|
||||
self.user_dict = json.load(json_file)
|
||||
else:
|
||||
self.print_verbose("User Dictionary not found!")
|
||||
self.user_dict = {}
|
||||
self.print_verbose(f"user dict from local: {self.user_dict}")
|
||||
elif self.client_type == "hosted":
|
||||
# Load the user_dict from hosted db
|
||||
url = self.api_base + "/get_budget"
|
||||
data = {"project_name": self.project_name}
|
||||
response = litellm.module_level_client.post(
|
||||
url, headers=self.headers, json=data
|
||||
)
|
||||
response = response.json()
|
||||
if response["status"] == "error":
|
||||
self.user_dict = (
|
||||
{}
|
||||
) # assume this means the user dict hasn't been stored yet
|
||||
else:
|
||||
self.user_dict = response["data"]
|
||||
|
||||
def create_budget(
|
||||
self,
|
||||
total_budget: float,
|
||||
user: str,
|
||||
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
|
||||
created_at: float = time.time(),
|
||||
):
|
||||
self.user_dict[user] = {"total_budget": total_budget}
|
||||
if duration is None:
|
||||
return self.user_dict[user]
|
||||
|
||||
if duration == "daily":
|
||||
duration_in_days = 1
|
||||
elif duration == "weekly":
|
||||
duration_in_days = DAYS_IN_A_WEEK
|
||||
elif duration == "monthly":
|
||||
duration_in_days = DAYS_IN_A_MONTH
|
||||
elif duration == "yearly":
|
||||
duration_in_days = DAYS_IN_A_YEAR
|
||||
else:
|
||||
raise ValueError(
|
||||
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
||||
)
|
||||
self.user_dict[user] = {
|
||||
"total_budget": total_budget,
|
||||
"duration": duration_in_days,
|
||||
"created_at": created_at,
|
||||
"last_updated_at": created_at,
|
||||
}
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return self.user_dict[user]
|
||||
|
||||
def projected_cost(self, model: str, messages: list, user: str):
|
||||
text = "".join(message["content"] for message in messages)
|
||||
prompt_tokens = litellm.token_counter(model=model, text=text)
|
||||
prompt_cost, _ = litellm.cost_per_token(
|
||||
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
|
||||
)
|
||||
current_cost = self.user_dict[user].get("current_cost", 0)
|
||||
projected_cost = prompt_cost + current_cost
|
||||
return projected_cost
|
||||
|
||||
def get_total_budget(self, user: str):
|
||||
return self.user_dict[user]["total_budget"]
|
||||
|
||||
def update_cost(
|
||||
self,
|
||||
user: str,
|
||||
completion_obj: Optional[ModelResponse] = None,
|
||||
model: Optional[str] = None,
|
||||
input_text: Optional[str] = None,
|
||||
output_text: Optional[str] = None,
|
||||
):
|
||||
if model and input_text and output_text:
|
||||
prompt_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": input_text}]
|
||||
)
|
||||
completion_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": output_text}]
|
||||
)
|
||||
(
|
||||
prompt_tokens_cost_usd_dollar,
|
||||
completion_tokens_cost_usd_dollar,
|
||||
) = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
elif completion_obj:
|
||||
cost = litellm.completion_cost(completion_response=completion_obj)
|
||||
model = completion_obj[
|
||||
"model"
|
||||
] # if this throws an error try, model = completion_obj['model']
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
|
||||
)
|
||||
|
||||
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
|
||||
"current_cost", 0
|
||||
)
|
||||
if "model_cost" in self.user_dict[user]:
|
||||
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
|
||||
"model_cost"
|
||||
].get(model, 0)
|
||||
else:
|
||||
self.user_dict[user]["model_cost"] = {model: cost}
|
||||
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
def get_current_cost(self, user):
|
||||
return self.user_dict[user].get("current_cost", 0)
|
||||
|
||||
def get_model_cost(self, user):
|
||||
return self.user_dict[user].get("model_cost", 0)
|
||||
|
||||
def is_valid_user(self, user: str) -> bool:
|
||||
return user in self.user_dict
|
||||
|
||||
def get_users(self):
|
||||
return list(self.user_dict.keys())
|
||||
|
||||
def reset_cost(self, user):
|
||||
self.user_dict[user]["current_cost"] = 0
|
||||
self.user_dict[user]["model_cost"] = {}
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
def reset_on_duration(self, user: str):
|
||||
# Get current and creation time
|
||||
last_updated_at = self.user_dict[user]["last_updated_at"]
|
||||
current_time = time.time()
|
||||
|
||||
# Convert duration from days to seconds
|
||||
duration_in_seconds = (
|
||||
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
|
||||
)
|
||||
|
||||
# Check if duration has elapsed
|
||||
if current_time - last_updated_at >= duration_in_seconds:
|
||||
# Reset cost if duration has elapsed and update the creation time
|
||||
self.reset_cost(user)
|
||||
self.user_dict[user]["last_updated_at"] = current_time
|
||||
self._save_data_thread() # Save the data
|
||||
|
||||
def update_budget_all_users(self):
|
||||
for user in self.get_users():
|
||||
if "duration" in self.user_dict[user]:
|
||||
self.reset_on_duration(user)
|
||||
|
||||
def _save_data_thread(self):
|
||||
thread = threading.Thread(
|
||||
target=self.save_data
|
||||
) # [Non-Blocking]: saves data without blocking execution
|
||||
thread.start()
|
||||
|
||||
def save_data(self):
|
||||
if self.client_type == "local":
|
||||
import json
|
||||
|
||||
# save the user dict
|
||||
with open("user_cost.json", "w") as json_file:
|
||||
json.dump(
|
||||
self.user_dict, json_file, indent=4
|
||||
) # Indent for pretty formatting
|
||||
return {"status": "success"}
|
||||
elif self.client_type == "hosted":
|
||||
url = self.api_base + "/set_budget"
|
||||
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
||||
response = litellm.module_level_client.post(
|
||||
url, headers=self.headers, json=data
|
||||
)
|
||||
response = response.json()
|
||||
return response
|
||||
@@ -0,0 +1,41 @@
|
||||
# Caching on LiteLLM
|
||||
|
||||
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
|
||||
|
||||
The following caching mechanisms are supported:
|
||||
|
||||
1. **RedisCache**
|
||||
2. **RedisSemanticCache**
|
||||
3. **QdrantSemanticCache**
|
||||
4. **InMemoryCache**
|
||||
5. **DiskCache**
|
||||
6. **S3Cache**
|
||||
7. **AzureBlobCache**
|
||||
8. **DualCache** (updates both Redis and an in-memory cache simultaneously)
|
||||
|
||||
## Folder Structure
|
||||
|
||||
```
|
||||
litellm/caching/
|
||||
├── base_cache.py
|
||||
├── caching.py
|
||||
├── caching_handler.py
|
||||
├── disk_cache.py
|
||||
├── dual_cache.py
|
||||
├── in_memory_cache.py
|
||||
├── qdrant_semantic_cache.py
|
||||
├── redis_cache.py
|
||||
├── redis_semantic_cache.py
|
||||
├── s3_cache.py
|
||||
```
|
||||
|
||||
## Documentation
|
||||
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
|
||||
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from .azure_blob_cache import AzureBlobCache
|
||||
from .caching import Cache, LiteLLMCacheType
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
from .gcs_cache import GCSCache
|
||||
@@ -0,0 +1,30 @@
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def lru_cache_wrapper(
|
||||
maxsize: Optional[int] = None,
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""
|
||||
Wrapper for lru_cache that caches success and exceptions
|
||||
"""
|
||||
|
||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||
@lru_cache(maxsize=maxsize)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return ("success", f(*args, **kwargs))
|
||||
except Exception as e:
|
||||
return ("error", e)
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
result = wrapper(*args, **kwargs)
|
||||
if result[0] == "error":
|
||||
raise result[1]
|
||||
return result[1]
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Azure Blob Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import suppress
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class AzureBlobCache(BaseCache):
|
||||
def __init__(self, account_url, container) -> None:
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.core.exceptions import ResourceExistsError
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.identity.aio import (
|
||||
DefaultAzureCredential as AsyncDefaultAzureCredential,
|
||||
)
|
||||
from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient
|
||||
|
||||
self.container_client = BlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
).get_container_client(container)
|
||||
self.async_container_client = AsyncBlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=AsyncDefaultAzureCredential(),
|
||||
).get_container_client(container)
|
||||
|
||||
with suppress(ResourceExistsError):
|
||||
self.container_client.create_container()
|
||||
|
||||
def set_cache(self, key, value, **kwargs) -> None:
|
||||
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
|
||||
serialized_value = json.dumps(value)
|
||||
try:
|
||||
self.container_client.upload_blob(key, serialized_value)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Azure Blob is throwing an exception
|
||||
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs) -> None:
|
||||
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
|
||||
serialized_value = json.dumps(value)
|
||||
try:
|
||||
await self.async_container_client.upload_blob(
|
||||
key, serialized_value, overwrite=True
|
||||
)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Azure Blob is throwing an exception
|
||||
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
try:
|
||||
print_verbose(f"Get Azure Blob Cache: key: {key}")
|
||||
as_bytes = self.container_client.download_blob(key).readall()
|
||||
as_str = as_bytes.decode("utf-8")
|
||||
cached_response = json.loads(as_str)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
try:
|
||||
print_verbose(f"Get Azure Blob Cache: key: {key}")
|
||||
blob = await self.async_container_client.download_blob(key)
|
||||
as_bytes = await blob.readall()
|
||||
as_str = as_bytes.decode("utf-8")
|
||||
cached_response = json.loads(as_str)
|
||||
verbose_logger.debug(
|
||||
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
return cached_response
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
def flush_cache(self) -> None:
|
||||
for blob in self.container_client.walk_blobs():
|
||||
self.container_client.delete_blob(blob.name)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.container_client.close()
|
||||
await self.async_container_client.close()
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs) -> None:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Base Cache implementation. All cache implementations should inherit from this class.
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
def __init__(self, default_ttl: int = 60):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def get_ttl(self, **kwargs) -> Optional[int]:
|
||||
kwargs_ttl: Optional[int] = kwargs.get("ttl")
|
||||
if kwargs_ttl is not None:
|
||||
try:
|
||||
return int(kwargs_ttl)
|
||||
except ValueError:
|
||||
return self.default_ttl
|
||||
return self.default_ttl
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
pass
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def batch_cache_write(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def test_connection(self) -> dict:
|
||||
"""
|
||||
Test the cache connection.
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,926 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
||||
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.utils import EmbeddingResponse, all_litellm_params
|
||||
|
||||
from .azure_blob_cache import AzureBlobCache
|
||||
from .base_cache import BaseCache
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache # noqa
|
||||
from .gcs_cache import GCSCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class CacheMode(str, Enum):
|
||||
default_on = "default_on"
|
||||
default_off = "default_off"
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
mode: Optional[
|
||||
CacheMode
|
||||
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
# s3 Bucket, boto3 configuration
|
||||
azure_account_url: Optional[str] = None,
|
||||
azure_blob_container: Optional[str] = None,
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify: Optional[Union[bool, str]] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
s3_aws_access_key_id: Optional[str] = None,
|
||||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
s3_path: Optional[str] = None,
|
||||
gcs_bucket_name: Optional[str] = None,
|
||||
gcs_path_service_account: Optional[str] = None,
|
||||
gcs_path: Optional[str] = None,
|
||||
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
redis_semantic_cache_index_name: Optional[str] = None,
|
||||
redis_flush_size: Optional[int] = None,
|
||||
redis_startup_nodes: Optional[List] = None,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
qdrant_semantic_cache_vector_size: Optional[int] = None,
|
||||
# GCP IAM authentication parameters
|
||||
gcp_service_account: Optional[str] = None,
|
||||
gcp_ssl_ca_certs: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||
|
||||
# Redis Cache Args
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
||||
ttl (float, optional): The ttl for the Redis cache
|
||||
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
||||
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
||||
|
||||
# Qdrant Cache Args
|
||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||
|
||||
# Disk Cache Args
|
||||
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
||||
|
||||
# S3 Cache Args
|
||||
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
||||
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
||||
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
||||
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
||||
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
||||
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
||||
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
||||
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
||||
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
||||
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
||||
|
||||
# GCS Cache Args
|
||||
gcs_bucket_name (str, optional): The bucket name for the gcs cache. Defaults to None.
|
||||
gcs_path_service_account (str, optional): Path to the service account json.
|
||||
gcs_path (str, optional): Folder path inside the bucket to store cache files.
|
||||
|
||||
# Common Cache Args
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid cache type is provided.
|
||||
|
||||
Returns:
|
||||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == LiteLLMCacheType.REDIS:
|
||||
# Check REDIS_CLUSTER_NODES env var if no explicit startup nodes
|
||||
if not redis_startup_nodes:
|
||||
_env_cluster_nodes = litellm.get_secret("REDIS_CLUSTER_NODES")
|
||||
if _env_cluster_nodes is not None and isinstance(
|
||||
_env_cluster_nodes, str
|
||||
):
|
||||
redis_startup_nodes = json.loads(_env_cluster_nodes)
|
||||
|
||||
if redis_startup_nodes:
|
||||
# Only pass GCP parameters if they are provided
|
||||
cluster_kwargs = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"password": password,
|
||||
"redis_flush_size": redis_flush_size,
|
||||
"startup_nodes": redis_startup_nodes,
|
||||
**kwargs,
|
||||
}
|
||||
if gcp_service_account is not None:
|
||||
cluster_kwargs["gcp_service_account"] = gcp_service_account
|
||||
if gcp_ssl_ca_certs is not None:
|
||||
cluster_kwargs["gcp_ssl_ca_certs"] = gcp_ssl_ca_certs
|
||||
|
||||
self.cache: BaseCache = RedisClusterCache(**cluster_kwargs)
|
||||
else:
|
||||
self.cache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||
self.cache = RedisSemanticCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
index_name=redis_semantic_cache_index_name,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_api_base=qdrant_api_base,
|
||||
qdrant_api_key=qdrant_api_key,
|
||||
collection_name=qdrant_collection_name,
|
||||
similarity_threshold=similarity_threshold,
|
||||
quantization_config=qdrant_quantization_config,
|
||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||
vector_size=qdrant_semantic_cache_vector_size,
|
||||
)
|
||||
elif type == LiteLLMCacheType.LOCAL:
|
||||
self.cache = InMemoryCache()
|
||||
elif type == LiteLLMCacheType.S3:
|
||||
self.cache = S3Cache(
|
||||
s3_bucket_name=s3_bucket_name,
|
||||
s3_region_name=s3_region_name,
|
||||
s3_api_version=s3_api_version,
|
||||
s3_use_ssl=s3_use_ssl,
|
||||
s3_verify=s3_verify,
|
||||
s3_endpoint_url=s3_endpoint_url,
|
||||
s3_aws_access_key_id=s3_aws_access_key_id,
|
||||
s3_aws_secret_access_key=s3_aws_secret_access_key,
|
||||
s3_aws_session_token=s3_aws_session_token,
|
||||
s3_config=s3_config,
|
||||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.GCS:
|
||||
self.cache = GCSCache(
|
||||
bucket_name=gcs_bucket_name,
|
||||
path_service_account=gcs_path_service_account,
|
||||
gcs_path=gcs_path,
|
||||
)
|
||||
elif type == LiteLLMCacheType.AZURE_BLOB:
|
||||
self.cache = AzureBlobCache(
|
||||
account_url=azure_account_url,
|
||||
container=azure_blob_container,
|
||||
)
|
||||
elif type == LiteLLMCacheType.DISK:
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.type = type
|
||||
self.namespace = namespace
|
||||
self.redis_flush_size = redis_flush_size
|
||||
self.ttl = ttl
|
||||
self.mode: CacheMode = mode or CacheMode.default_on
|
||||
|
||||
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
||||
self.ttl = default_in_memory_ttl
|
||||
|
||||
if (
|
||||
self.type == LiteLLMCacheType.REDIS
|
||||
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
||||
) and default_in_redis_ttl is not None:
|
||||
self.ttl = default_in_redis_ttl
|
||||
|
||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||
return preset_cache_key
|
||||
|
||||
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||
if param_value is not None:
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
if (
|
||||
litellm.enable_caching_on_provider_specific_optional_params is True
|
||||
): # feature flagged for now
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
|
||||
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
|
||||
self._set_preset_cache_key_in_kwargs(
|
||||
preset_cache_key=hashed_cache_key, **kwargs
|
||||
)
|
||||
return hashed_cache_key
|
||||
|
||||
def _get_param_value(
|
||||
self,
|
||||
param: str,
|
||||
kwargs: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value for the given param from kwargs
|
||||
"""
|
||||
if param == "model":
|
||||
return self._get_model_param_value(kwargs)
|
||||
elif param == "file":
|
||||
return self._get_file_param_value(kwargs)
|
||||
return kwargs[param]
|
||||
|
||||
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'model' param from kwargs
|
||||
|
||||
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||
3. Else use the `model` passed in kwargs
|
||||
"""
|
||||
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||
model_group: Optional[str] = metadata.get(
|
||||
"model_group"
|
||||
) or metadata_in_litellm_params.get("model_group")
|
||||
caching_group = self._get_caching_group(metadata, model_group)
|
||||
return caching_group or model_group or kwargs["model"]
|
||||
|
||||
def _get_caching_group(
|
||||
self, metadata: dict, model_group: Optional[str]
|
||||
) -> Optional[str]:
|
||||
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
return str(group)
|
||||
return None
|
||||
|
||||
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||
"""
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
return (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
|
||||
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get the preset cache key from kwargs["litellm_params"]
|
||||
|
||||
We use _get_preset_cache_keys for two reasons
|
||||
|
||||
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
2. avoid doing duplicate / repeated work
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||
return None
|
||||
|
||||
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||
"""
|
||||
Set the calculated cache key in kwargs
|
||||
|
||||
This is used to avoid doing duplicate / repeated work
|
||||
|
||||
Placed in kwargs["litellm_params"]
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||
|
||||
@staticmethod
|
||||
def _get_hashed_cache_key(cache_key: str) -> str:
|
||||
"""
|
||||
Get the hashed cache key for the given cache key.
|
||||
|
||||
Use hashlib to create a sha256 hash of the cache key
|
||||
|
||||
Args:
|
||||
cache_key (str): The cache key to hash.
|
||||
|
||||
Returns:
|
||||
str: The hashed cache key.
|
||||
"""
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||
"""
|
||||
If a redis namespace is provided, add it to the cache key
|
||||
|
||||
Args:
|
||||
hash_hex (str): The hashed cache key.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The final hashed cache key with the redis namespace.
|
||||
"""
|
||||
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
|
||||
namespace = (
|
||||
dynamic_cache_control.get("namespace")
|
||||
or kwargs.get("metadata", {}).get("redis_namespace")
|
||||
or self.namespace
|
||||
)
|
||||
if namespace:
|
||||
hash_hex = f"{namespace}:{hash_hex}"
|
||||
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": content[i : i + chunk_size],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
|
||||
|
||||
def _get_cache_logic(
|
||||
self,
|
||||
cached_result: Optional[Any],
|
||||
max_age: Optional[float],
|
||||
):
|
||||
"""
|
||||
Common get cache logic across sync + async implementations
|
||||
"""
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
|
||||
max_age = (
|
||||
cache_control_args.get("s-maxage")
|
||||
or cache_control_args.get("s-max-age")
|
||||
or float("inf")
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
cached_result = dynamic_cache_object.get_cache(
|
||||
cache_key, messages=messages
|
||||
)
|
||||
else:
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def async_get_cache(
|
||||
self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async get cache implementation.
|
||||
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
cached_result = await dynamic_cache_object.async_get_cache(
|
||||
cache_key, **kwargs
|
||||
)
|
||||
else:
|
||||
cached_result = await self.cache.async_get_cache(
|
||||
cache_key, **kwargs
|
||||
)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _add_cache_logic(self, result, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
|
||||
## DEFAULT TTL ##
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
## Get Cache-Controls ##
|
||||
_cache_kwargs = kwargs.get("cache", None)
|
||||
if isinstance(_cache_kwargs, dict):
|
||||
for k, v in _cache_kwargs.items():
|
||||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
|
||||
cached_data = {"timestamp": time.time(), "response": result}
|
||||
return cache_key, cached_data, kwargs
|
||||
else:
|
||||
raise Exception("cache key is None")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache(
|
||||
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
if self.type == "redis" and self.redis_flush_size is not None:
|
||||
# high traffic - fill in results in memory and then flush
|
||||
await self.batch_cache_write(result, **kwargs)
|
||||
else:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
await dynamic_cache_object.async_set_cache(
|
||||
cache_key, cached_data, **kwargs
|
||||
)
|
||||
else:
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def _convert_to_cached_embedding(
|
||||
self, embedding_response: Any, model: Optional[str]
|
||||
) -> CachedEmbedding:
|
||||
"""
|
||||
Convert any embedding response into the standardized CachedEmbedding TypedDict format.
|
||||
"""
|
||||
try:
|
||||
if isinstance(embedding_response, dict):
|
||||
return {
|
||||
"embedding": embedding_response.get("embedding"),
|
||||
"index": embedding_response.get("index"),
|
||||
"object": embedding_response.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
elif hasattr(embedding_response, "model_dump"):
|
||||
data = embedding_response.model_dump()
|
||||
return {
|
||||
"embedding": data.get("embedding"),
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
else:
|
||||
data = vars(embedding_response)
|
||||
return {
|
||||
"embedding": data.get("embedding"),
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing expected key in embedding response: {e}")
|
||||
|
||||
def add_embedding_response_to_cache(
|
||||
self,
|
||||
result: EmbeddingResponse,
|
||||
input: str,
|
||||
kwargs: dict,
|
||||
idx_in_result_data: int = 0,
|
||||
) -> Tuple[str, dict, dict]:
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx_in_result_data]
|
||||
|
||||
# Always convert to properly typed CachedEmbedding
|
||||
model_name = result.model
|
||||
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
|
||||
embedding_response, model_name
|
||||
)
|
||||
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_dict,
|
||||
**kwargs,
|
||||
)
|
||||
return cache_key, cached_data, kwargs
|
||||
|
||||
async def async_add_cache_pipeline(
|
||||
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
|
||||
cache_list = []
|
||||
if isinstance(kwargs["input"], list):
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
(
|
||||
cache_key,
|
||||
cached_data,
|
||||
kwargs,
|
||||
) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
elif isinstance(kwargs["input"], str):
|
||||
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
|
||||
result, kwargs["input"], kwargs
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
|
||||
if dynamic_cache_object is not None:
|
||||
await dynamic_cache_object.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
else:
|
||||
await self.cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def should_use_cache(self, **kwargs):
|
||||
"""
|
||||
Returns true if we should use the cache for LLM API calls
|
||||
|
||||
If cache is default_on then this is True
|
||||
If cache is default_off then this is only true when user has opted in to use cache
|
||||
"""
|
||||
if self.mode == CacheMode.default_on:
|
||||
return True
|
||||
|
||||
# when mode == default_off -> Cache is opt in only
|
||||
_cache = kwargs.get("cache", None)
|
||||
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
|
||||
if _cache and isinstance(_cache, dict):
|
||||
if _cache.get("use-cache", False) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def batch_cache_write(self, result, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
||||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
cache_ping = getattr(self.cache, "ping")
|
||||
if cache_ping:
|
||||
return await cache_ping()
|
||||
return None
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
||||
if cache_delete_cache_keys:
|
||||
return await cache_delete_cache_keys(keys)
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
def _supports_async(self) -> bool:
|
||||
"""
|
||||
Internal method to check if the cache type supports async get/set operations
|
||||
|
||||
All cache types now support async operations
|
||||
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Enable cache with the specified configuration.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
||||
host (Optional[str]): The host address of the cache server. Defaults to None.
|
||||
port (Optional[str]): The port number of the cache server. Defaults to None.
|
||||
password (Optional[str]): The password for the cache server. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
print_verbose("LiteLLM: Enabling Cache")
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
|
||||
if litellm.cache is None:
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Update the cache for LiteLLM.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
||||
host (Optional[str]): The host of the cache. Defaults to None.
|
||||
port (Optional[str]): The port of the cache. Defaults to None.
|
||||
password (Optional[str]): The password for the cache. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments for the cache.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
print_verbose("LiteLLM: Updating Cache")
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def disable_cache():
|
||||
"""
|
||||
Disable the cache used by LiteLLM.
|
||||
|
||||
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from contextlib import suppress
|
||||
|
||||
print_verbose("LiteLLM: Disabling Cache")
|
||||
with suppress(ValueError):
|
||||
litellm.input_callback.remove("cache")
|
||||
litellm.success_callback.remove("cache")
|
||||
litellm._async_success_callback.remove("cache")
|
||||
|
||||
litellm.cache = None
|
||||
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class DiskCache(BaseCache):
|
||||
def __init__(self, disk_cache_dir: Optional[str] = None):
|
||||
try:
|
||||
import diskcache as dc
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install litellm with `litellm[caching]` to use disk caching."
|
||||
) from e
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
if disk_cache_dir is None:
|
||||
self.disk_cache = dc.Cache(".litellm_cache")
|
||||
else:
|
||||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
self.disk_cache.set(key, value)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if "ttl" in kwargs:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
original_cached_response = self.disk_cache.get(key)
|
||||
if original_cached_response:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response) # type: ignore
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.disk_cache.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.disk_cache.pop(key)
|
||||
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
|
||||
|
||||
Has 4 primary methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE
|
||||
|
||||
from .base_cache import BaseCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LimitedSizeOrderedDict(OrderedDict):
|
||||
def __init__(self, *args, max_size=100, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_size = max_size
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# If inserting a new key exceeds max size, remove the oldest item
|
||||
if len(self) >= self.max_size:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_memory_cache: Optional[InMemoryCache] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_redis_ttl: Optional[float] = None,
|
||||
default_redis_batch_cache_expiry: Optional[float] = None,
|
||||
default_max_redis_batch_cache_size: int = DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
|
||||
max_size=default_max_redis_batch_cache_size
|
||||
)
|
||||
self._last_redis_batch_access_time_lock = Lock()
|
||||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
or 10
|
||||
)
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
)
|
||||
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
|
||||
|
||||
def update_cache_ttl(
|
||||
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
|
||||
):
|
||||
if default_in_memory_ttl is not None:
|
||||
self.default_in_memory_ttl = default_in_memory_ttl
|
||||
|
||||
if default_redis_ttl is not None:
|
||||
self.default_redis_ttl = default_redis_ttl
|
||||
|
||||
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
received_args = locals()
|
||||
received_args.pop("self")
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(
|
||||
self.async_batch_get_cache(**received_args)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
||||
async def async_get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(
|
||||
f"async get cache: cache key: {key}; local_only: {local_only}"
|
||||
)
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_get_cache(
|
||||
key, **kwargs
|
||||
)
|
||||
|
||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result, **kwargs
|
||||
)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def _reserve_redis_batch_keys(
|
||||
self,
|
||||
current_time: float,
|
||||
keys: List[str],
|
||||
result: List[Any],
|
||||
) -> Tuple[List[str], Dict[str, Optional[float]]]:
|
||||
"""
|
||||
Atomically choose keys to fetch from Redis and reserve their access time.
|
||||
This prevents check-then-act races under concurrent async callers.
|
||||
"""
|
||||
sublist_keys: List[str] = []
|
||||
previous_access_times: Dict[str, Optional[float]] = {}
|
||||
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, value in zip(keys, result):
|
||||
if value is not None:
|
||||
continue
|
||||
|
||||
if (
|
||||
key not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
sublist_keys.append(key)
|
||||
previous_access_times[key] = self.last_redis_batch_access_time.get(
|
||||
key
|
||||
)
|
||||
self.last_redis_batch_access_time[key] = current_time
|
||||
|
||||
return sublist_keys, previous_access_times
|
||||
|
||||
def _rollback_redis_batch_key_reservations(
|
||||
self, previous_access_times: Dict[str, Optional[float]]
|
||||
) -> None:
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, previous_time in previous_access_times.items():
|
||||
if previous_time is None:
|
||||
self.last_redis_batch_access_time.pop(key, None)
|
||||
else:
|
||||
self.last_redis_batch_access_time[key] = previous_time
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
result = [None] * len(keys)
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||
keys, **kwargs
|
||||
)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only is False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
current_time = time.time()
|
||||
sublist_keys, previous_access_times = self._reserve_redis_batch_keys(
|
||||
current_time, keys, result
|
||||
)
|
||||
|
||||
# Only hit Redis if enough time has passed since last access.
|
||||
if len(sublist_keys) > 0:
|
||||
try:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
except Exception:
|
||||
# Do not throttle subsequent callers if the Redis read fails.
|
||||
self._rollback_redis_batch_key_reservations(
|
||||
previous_access_times
|
||||
)
|
||||
raise
|
||||
|
||||
# Short-circuit if redis_result is None or contains only None values
|
||||
if redis_result is None or all(
|
||||
v is None for v in redis_result.values()
|
||||
):
|
||||
return result
|
||||
|
||||
# Pre-compute key-to-index mapping for O(1) lookup
|
||||
key_to_index = {key: i for i, key in enumerate(keys)}
|
||||
|
||||
# Update both result and in-memory cache in a single loop
|
||||
for key, value in redis_result.items():
|
||||
result[key_to_index[key]] = value
|
||||
|
||||
if value is not None and self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
print_verbose(
|
||||
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
# async_batch_set_cache
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Batch write values to the cache
|
||||
"""
|
||||
print_verbose(
|
||||
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_increment_cache(
|
||||
self,
|
||||
key,
|
||||
value: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
) -> float:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - float - the value you want to increment by
|
||||
|
||||
Returns - float - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: float = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment(
|
||||
key,
|
||||
value,
|
||||
parent_otel_span=parent_otel_span,
|
||||
ttl=kwargs.get("ttl", None),
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e # don't log if exception is raised
|
||||
|
||||
async def async_increment_cache_pipeline(
|
||||
self,
|
||||
increment_list: List["RedisPipelineIncrementOperation"],
|
||||
local_only: bool = False,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
**kwargs,
|
||||
) -> Optional[List[float]]:
|
||||
try:
|
||||
result: Optional[List[float]] = None
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment_pipeline(
|
||||
increment_list=increment_list,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment_pipeline(
|
||||
increment_list=increment_list,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e # don't log if exception is raised
|
||||
|
||||
async def async_set_cache_sadd(
|
||||
self, key, value: List, local_only: bool = False, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Add value to a set
|
||||
|
||||
Key - the key in cache
|
||||
|
||||
Value - str - the value you want to add to the set
|
||||
|
||||
Returns - None
|
||||
"""
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
_ = await self.in_memory_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
_ = await self.redis_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise e # don't log, if exception is raised
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.flush_cache()
|
||||
|
||||
def delete_cache(self, key):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.delete_cache(key)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
await self.redis_cache.async_delete_cache(key)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache or redis
|
||||
"""
|
||||
ttl = await self.in_memory_cache.async_get_ttl(key)
|
||||
if ttl is None and self.redis_cache is not None:
|
||||
ttl = await self.redis_cache.async_get_ttl(key)
|
||||
return ttl
|
||||
@@ -0,0 +1,113 @@
|
||||
"""GCS Cache implementation
|
||||
Supports syncing responses to Google Cloud Storage Buckets using HTTP requests.
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class GCSCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
bucket_name: Optional[str] = None,
|
||||
path_service_account: Optional[str] = None,
|
||||
gcs_path: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.bucket_name = bucket_name or GCSBucketBase(bucket_name=None).BUCKET_NAME
|
||||
self.path_service_account = (
|
||||
path_service_account
|
||||
or GCSBucketBase(bucket_name=None).path_service_account_json
|
||||
)
|
||||
self.key_prefix = gcs_path.rstrip("/") + "/" if gcs_path else ""
|
||||
# create httpx clients
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.sync_client = _get_httpx_client()
|
||||
|
||||
def _construct_headers(self) -> dict:
|
||||
base = GCSBucketBase(bucket_name=self.bucket_name)
|
||||
base.path_service_account_json = self.path_service_account
|
||||
base.BUCKET_NAME = self.bucket_name
|
||||
return base.sync_construct_request_headers()
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
print_verbose(f"LiteLLM SET Cache - GCS. Key={key}. Value={value}")
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
data = json.dumps(value)
|
||||
self.sync_client.post(url=url, data=data, headers=headers)
|
||||
except Exception as e:
|
||||
print_verbose(f"GCS Caching: set_cache() - Got exception from GCS: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
data = json.dumps(value)
|
||||
await self.async_client.post(url=url, data=data, headers=headers)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"GCS Caching: async_set_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
response = self.sync_client.get(url=url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
cached_response = json.loads(response.text)
|
||||
verbose_logger.debug(
|
||||
f"Got GCS Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
return cached_response
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"GCS Caching: get_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
response = await self.async_client.get(url=url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return json.loads(response.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"GCS Caching: async_get_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
In-Memory Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import heapq
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
max_size_in_memory: Optional[int] = 200,
|
||||
default_ttl: Optional[
|
||||
int
|
||||
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
|
||||
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
|
||||
):
|
||||
"""
|
||||
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
|
||||
"""
|
||||
self.max_size_in_memory = (
|
||||
max_size_in_memory if max_size_in_memory is not None else 200
|
||||
) # set an upper bound of 200 items in-memory
|
||||
self.default_ttl = default_ttl or 600
|
||||
self.max_size_per_item = (
|
||||
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
) # 1MB = 1024KB
|
||||
|
||||
# in-memory cache
|
||||
self.cache_dict: dict = {}
|
||||
self.ttl_dict: dict = {}
|
||||
self.expiration_heap: list[tuple[float, str]] = []
|
||||
|
||||
def check_value_size(self, value: Any):
|
||||
"""
|
||||
Check if value size exceeds max_size_per_item (1MB)
|
||||
Returns True if value size is acceptable, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Fast path for common primitive types that are typically small
|
||||
if (
|
||||
isinstance(value, (bool, int, float, str))
|
||||
and len(str(value))
|
||||
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
): # Conservative estimate
|
||||
return True
|
||||
|
||||
# Direct size check for bytes objects
|
||||
if isinstance(value, bytes):
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
# Handle special types without full conversion when possible
|
||||
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
|
||||
size = value.__sizeof__() / 1024
|
||||
return size <= self.max_size_per_item
|
||||
|
||||
# Fallback for complex types
|
||||
if isinstance(value, BaseModel) and hasattr(
|
||||
value, "model_dump"
|
||||
): # Pydantic v2
|
||||
value = value.model_dump()
|
||||
elif hasattr(value, "isoformat"): # datetime objects
|
||||
return True # datetime strings are always small
|
||||
|
||||
# Only convert to JSON if absolutely necessary
|
||||
if not isinstance(value, (str, bytes)):
|
||||
value = json.dumps(value, default=str)
|
||||
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_key_expired(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a specific key is expired
|
||||
"""
|
||||
return key in self.ttl_dict and time.time() > self.ttl_dict[key]
|
||||
|
||||
def _remove_key(self, key: str) -> None:
|
||||
"""
|
||||
Remove a key from both cache_dict and ttl_dict
|
||||
"""
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
|
||||
def evict_cache(self):
|
||||
"""
|
||||
Eviction policy:
|
||||
1. First, remove expired items from ttl_dict and cache_dict
|
||||
2. If cache is still at or above max_size_in_memory, evict items with earliest expiration times
|
||||
|
||||
|
||||
This guarantees the following:
|
||||
- 1. When item ttl not set: At minimum each item will remain in memory for the default ttl
|
||||
- 2. When ttl is set: the item will remain in memory for at least that amount of time, unless cache size requires eviction
|
||||
- 3. the size of in-memory cache is bounded
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Step 1: Remove expired or outdated items
|
||||
while self.expiration_heap:
|
||||
expiration_time, key = self.expiration_heap[0]
|
||||
|
||||
# Case 1: Heap entry is outdated
|
||||
if expiration_time != self.ttl_dict.get(key):
|
||||
heapq.heappop(self.expiration_heap)
|
||||
# Case 2: Entry is valid but expired
|
||||
elif expiration_time <= current_time:
|
||||
heapq.heappop(self.expiration_heap)
|
||||
self._remove_key(key)
|
||||
else:
|
||||
# Case 3: Entry is valid and not expired
|
||||
break
|
||||
|
||||
# Step 2: Evict if cache is still full
|
||||
while len(self.cache_dict) >= self.max_size_in_memory:
|
||||
expiration_time, key = heapq.heappop(self.expiration_heap)
|
||||
# Skip if key was removed or updated
|
||||
if self.ttl_dict.get(key) == expiration_time:
|
||||
self._remove_key(key)
|
||||
|
||||
# de-reference the removed item
|
||||
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
||||
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
||||
# This can occur when an object is referenced by another object, but the reference is never removed.
|
||||
|
||||
def allow_ttl_override(self, key: str) -> bool:
|
||||
"""
|
||||
Check if ttl is set for a key
|
||||
"""
|
||||
ttl_time = self.ttl_dict.get(key)
|
||||
if ttl_time is None: # if ttl is not set, allow override
|
||||
return True
|
||||
elif float(ttl_time) < time.time(): # if ttl is expired, allow override
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
# Handle the edge case where max_size_in_memory is 0
|
||||
if self.max_size_in_memory == 0:
|
||||
return # Don't cache anything if max size is 0
|
||||
|
||||
if len(self.cache_dict) >= self.max_size_in_memory:
|
||||
# only evict when cache is full
|
||||
self.evict_cache()
|
||||
if not self.check_value_size(value):
|
||||
return
|
||||
|
||||
self.cache_dict[key] = value
|
||||
if self.allow_ttl_override(key): # if ttl is not set, set it to default ttl
|
||||
if "ttl" in kwargs and kwargs["ttl"] is not None:
|
||||
self.ttl_dict[key] = time.time() + float(kwargs["ttl"])
|
||||
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
|
||||
else:
|
||||
self.ttl_dict[key] = time.time() + self.default_ttl
|
||||
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
|
||||
"""
|
||||
Add value to set
|
||||
"""
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or set()
|
||||
for val in value:
|
||||
init_value.add(val)
|
||||
self.set_cache(key, init_value, ttl=ttl)
|
||||
return value
|
||||
|
||||
def evict_element_if_expired(self, key: str) -> bool:
|
||||
"""
|
||||
Returns True if the element is expired and removed from the cache
|
||||
|
||||
Returns False if the element is not expired
|
||||
"""
|
||||
if self._is_key_expired(key):
|
||||
self._remove_key(key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if self.evict_element_if_expired(key):
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_increment_pipeline(
|
||||
self, increment_list: List["RedisPipelineIncrementOperation"], **kwargs
|
||||
) -> Optional[List[float]]:
|
||||
results = []
|
||||
for increment in increment_list:
|
||||
result = await self.async_increment(
|
||||
increment["key"], increment["increment_value"], **kwargs
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
self.expiration_heap.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self._remove_key(key)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache
|
||||
"""
|
||||
return self.ttl_dict.get(key, None)
|
||||
|
||||
async def async_get_oldest_n_keys(self, n: int) -> List[str]:
|
||||
"""
|
||||
Get the oldest n keys in the cache
|
||||
"""
|
||||
# sorted ttl dict by ttl
|
||||
sorted_ttl_dict = sorted(self.ttl_dict.items(), key=lambda x: x[1])
|
||||
return [key for key, _ in sorted_ttl_dict[:n]]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from .in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class LLMClientCache(InMemoryCache):
|
||||
"""Cache for LLM HTTP clients (OpenAI, Azure, httpx, etc.).
|
||||
|
||||
IMPORTANT: This cache intentionally does NOT close clients on eviction.
|
||||
Evicted clients may still be in use by in-flight requests. Closing them
|
||||
eagerly causes ``RuntimeError: Cannot send a request, as the client has
|
||||
been closed.`` errors in production after the TTL (1 hour) expires.
|
||||
|
||||
Clients that are no longer referenced will be garbage-collected normally.
|
||||
For explicit shutdown cleanup, use ``close_litellm_async_clients()``.
|
||||
"""
|
||||
|
||||
def update_cache_key_with_event_loop(self, key):
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
If none, use the key as is.
|
||||
"""
|
||||
try:
|
||||
event_loop = asyncio.get_running_loop()
|
||||
stringified_event_loop = str(id(event_loop))
|
||||
return f"{key}-{stringified_event_loop}"
|
||||
except RuntimeError: # handle no current running event loop
|
||||
return key
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return super().set_cache(key, value, **kwargs)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return await super().async_set_cache(key, value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return super().get_cache(key, **kwargs)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return await super().async_get_cache(key, **kwargs)
|
||||
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Qdrant Semantic Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class QdrantSemanticCache(BaseCache):
|
||||
def __init__( # noqa: PLR0915
|
||||
self,
|
||||
qdrant_api_base=None,
|
||||
qdrant_api_key=None,
|
||||
collection_name=None,
|
||||
similarity_threshold=None,
|
||||
quantization_config=None,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
host_type=None,
|
||||
vector_size=None,
|
||||
):
|
||||
import os
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
if collection_name is None:
|
||||
raise Exception("collection_name must be provided, passed None")
|
||||
|
||||
self.collection_name = collection_name
|
||||
print_verbose(
|
||||
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
||||
)
|
||||
|
||||
if similarity_threshold is None:
|
||||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
self.vector_size = (
|
||||
vector_size if vector_size is not None else QDRANT_VECTOR_SIZE
|
||||
)
|
||||
headers = {}
|
||||
|
||||
# check if defined as os.environ/ variable
|
||||
if qdrant_api_base:
|
||||
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_base = get_secret_str(qdrant_api_base)
|
||||
if qdrant_api_key:
|
||||
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_key = get_secret_str(qdrant_api_key)
|
||||
|
||||
qdrant_api_base = (
|
||||
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
||||
)
|
||||
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if qdrant_api_key:
|
||||
headers["api-key"] = qdrant_api_key
|
||||
|
||||
if qdrant_api_base is None:
|
||||
raise ValueError("Qdrant url must be provided")
|
||||
|
||||
self.qdrant_api_base = qdrant_api_base
|
||||
self.qdrant_api_key = qdrant_api_key
|
||||
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
|
||||
|
||||
self.headers = headers
|
||||
|
||||
self.sync_client = _get_httpx_client()
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Caching
|
||||
)
|
||||
|
||||
if quantization_config is None:
|
||||
print_verbose(
|
||||
"Quantization config is not provided. Default binary quantization will be used."
|
||||
)
|
||||
collection_exists = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
|
||||
headers=self.headers,
|
||||
)
|
||||
if collection_exists.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error from qdrant checking if /collections exist {collection_exists.text}"
|
||||
)
|
||||
|
||||
if collection_exists.json()["result"]["exists"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"Collection already exists.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
if quantization_config is None or quantization_config == "binary":
|
||||
quantization_params = {
|
||||
"binary": {
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "scalar":
|
||||
quantization_params = {
|
||||
"scalar": {
|
||||
"type": "int8",
|
||||
"quantile": QDRANT_SCALAR_QUANTILE,
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "product":
|
||||
quantization_params = {
|
||||
"product": {"compression": "x16", "always_ram": False}
|
||||
}
|
||||
else:
|
||||
raise Exception(
|
||||
"Quantization config must be one of 'scalar', 'binary' or 'product'"
|
||||
)
|
||||
|
||||
new_collection_status = self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
json={
|
||||
"vectors": {"size": self.vector_size, "distance": "Cosine"},
|
||||
"quantization_config": quantization_params,
|
||||
},
|
||||
headers=self.headers,
|
||||
)
|
||||
if new_collection_status.json()["result"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"New collection created.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
raise Exception("Error while creating new collection")
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
from litellm._uuid import uuid
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = self.sync_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
from litellm._uuid import uuid
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
await self.async_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = await self.async_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Redis Cluster Cache implementation
|
||||
|
||||
Key differences:
|
||||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.asyncio.client import Pipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
async_redis_client = Redis
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
pipeline = Any
|
||||
async_redis_client = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class RedisClusterCache(RedisCache):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
|
||||
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
|
||||
|
||||
def init_async_client(self):
|
||||
from redis.asyncio import RedisCluster
|
||||
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
if self.redis_async_redis_cluster_client:
|
||||
return self.redis_async_redis_cluster_client
|
||||
|
||||
_redis_client = get_redis_async_client(
|
||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||
)
|
||||
if isinstance(_redis_client, RedisCluster):
|
||||
self.redis_async_redis_cluster_client = _redis_client
|
||||
|
||||
return _redis_client
|
||||
|
||||
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
|
||||
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_async_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
async_redis_cluster_client = self.init_async_client()
|
||||
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
|
||||
async def test_connection(self) -> dict:
|
||||
"""
|
||||
Test the Redis Cluster connection.
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
|
||||
"""
|
||||
try:
|
||||
import redis.asyncio as redis_async
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
# Create ClusterNode objects from startup_nodes
|
||||
cluster_kwargs = self.redis_kwargs.copy()
|
||||
startup_nodes = cluster_kwargs.pop("startup_nodes", [])
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
for item in startup_nodes:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
# Create a fresh Redis Cluster client with current settings
|
||||
redis_client = redis_async.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
||||
)
|
||||
|
||||
# Test the connection
|
||||
ping_result = await redis_client.ping() # type: ignore[attr-defined, misc]
|
||||
|
||||
# Close the connection
|
||||
await redis_client.aclose() # type: ignore[attr-defined]
|
||||
|
||||
if ping_result:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Redis Cluster connection test successful",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": "Redis Cluster ping returned False",
|
||||
}
|
||||
except Exception as e:
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.error(f"Redis Cluster connection test failed: {str(e)}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": f"Redis Cluster connection failed: {str(e)}",
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Redis Semantic Cache implementation for LiteLLM
|
||||
|
||||
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
|
||||
This cache stores responses based on the semantic similarity of prompts rather than
|
||||
exact matching, allowing for more flexible caching of LLM responses.
|
||||
|
||||
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
|
||||
and their cached responses.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_str_from_messages,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class RedisSemanticCache(BaseCache):
|
||||
"""
|
||||
Redis-backed semantic cache for LLM responses.
|
||||
|
||||
This cache uses vector similarity to find semantically similar prompts that have been
|
||||
previously sent to the LLM, allowing for cache hits even when prompts are not identical
|
||||
but carry similar meaning.
|
||||
"""
|
||||
|
||||
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
redis_url: Optional[str] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
index_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Redis Semantic Cache.
|
||||
|
||||
Args:
|
||||
host: Redis host address
|
||||
port: Redis port
|
||||
password: Redis password
|
||||
redis_url: Full Redis URL (alternative to separate host/port/password)
|
||||
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
|
||||
where 1.0 requires exact matches and 0.0 accepts any match
|
||||
embedding_model: Model to use for generating embeddings
|
||||
index_name: Name for the Redis index
|
||||
ttl: Default time-to-live for cache entries in seconds
|
||||
**kwargs: Additional arguments passed to the Redis client
|
||||
|
||||
Raises:
|
||||
Exception: If similarity_threshold is not provided or required Redis
|
||||
connection information is missing
|
||||
"""
|
||||
from redisvl.extensions.llmcache import SemanticCache
|
||||
from redisvl.utils.vectorize import CustomTextVectorizer
|
||||
|
||||
if index_name is None:
|
||||
index_name = self.DEFAULT_REDIS_INDEX_NAME
|
||||
|
||||
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
|
||||
|
||||
# Validate similarity threshold
|
||||
if similarity_threshold is None:
|
||||
raise ValueError("similarity_threshold must be provided, passed None")
|
||||
|
||||
# Store configuration
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
# Convert similarity threshold [0,1] to distance threshold [0,2]
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
self.distance_threshold = 1 - similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
# Set up Redis connection
|
||||
if redis_url is None:
|
||||
try:
|
||||
# Attempt to use provided parameters or fallback to environment variables
|
||||
host = host or os.environ["REDIS_HOST"]
|
||||
port = port or os.environ["REDIS_PORT"]
|
||||
password = password or os.environ["REDIS_PASSWORD"]
|
||||
except KeyError as e:
|
||||
# Raise a more informative exception if any of the required keys are missing
|
||||
missing_var = e.args[0]
|
||||
raise ValueError(
|
||||
f"Missing required Redis configuration: {missing_var}. "
|
||||
f"Provide {missing_var} or redis_url."
|
||||
) from e
|
||||
|
||||
redis_url = f"redis://:{password}@{host}:{port}"
|
||||
|
||||
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
|
||||
|
||||
# Initialize the Redis vectorizer and cache
|
||||
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
|
||||
|
||||
self.llmcache = SemanticCache(
|
||||
name=index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
def _get_ttl(self, **kwargs) -> Optional[int]:
|
||||
"""
|
||||
Get the TTL (time-to-live) value for cache entries.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments that may contain a custom TTL
|
||||
|
||||
Returns:
|
||||
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
|
||||
"""
|
||||
ttl = kwargs.get("ttl")
|
||||
if ttl is not None:
|
||||
ttl = int(ttl)
|
||||
return ttl
|
||||
|
||||
def _get_embedding(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given prompt using the configured embedding model.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
# Create an embedding from prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any) -> Any:
|
||||
"""
|
||||
Process the cached response to prepare it for use.
|
||||
|
||||
Args:
|
||||
cached_response: The raw cached response
|
||||
|
||||
Returns:
|
||||
The processed cache response, or None if input was None
|
||||
"""
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
|
||||
# Convert bytes to string if needed
|
||||
if isinstance(cached_response, bytes):
|
||||
cached_response = cached_response.decode("utf-8")
|
||||
|
||||
# Convert string representation to Python object
|
||||
try:
|
||||
cached_response = json.loads(cached_response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
print_verbose(f"Error parsing cached response: {str(e)}")
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
value_str: Optional[str] = None
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
self.llmcache.store(prompt, value_str, ttl=int(ttl))
|
||||
else:
|
||||
self.llmcache.store(prompt, value_str)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
||||
)
|
||||
|
||||
def get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
# Check the cache for semantically similar prompts
|
||||
results = self.llmcache.check(prompt=prompt)
|
||||
|
||||
# Return None if no similar prompts found
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# Process the best matching result
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity score
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
|
||||
|
||||
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
|
||||
"""
|
||||
Asynchronously generate an embedding for the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
**kwargs: Additional arguments that may contain metadata
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# Route the embedding request through the proxy if appropriate
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
|
||||
try:
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
# Use the router for embedding generation
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Generate embedding directly
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# Extract and return the embedding vector
|
||||
return embedding_response["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
print_verbose(f"Error generating async embedding: {str(e)}")
|
||||
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
|
||||
|
||||
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Asynchronously store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Generate embedding for the value (response) to cache
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
ttl=ttl,
|
||||
)
|
||||
else:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache: {str(e)}")
|
||||
|
||||
async def async_get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Asynchronously retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
# Generate embedding for the prompt
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Check the cache for semantically similar prompts
|
||||
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
||||
|
||||
# handle results / cache hit
|
||||
if not results:
|
||||
kwargs.setdefault("metadata", {})[
|
||||
"semantic-similarity"
|
||||
] = 0.0 # TODO why here but not above??
|
||||
return None
|
||||
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_get_cache: {str(e)}")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
|
||||
async def _index_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the Redis index.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Information about the Redis index
|
||||
"""
|
||||
aindex = await self.llmcache._get_async_index()
|
||||
return await aindex.info()
|
||||
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: List[Tuple[str, Any]], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Asynchronously store multiple values in the semantic cache.
|
||||
|
||||
Args:
|
||||
cache_list: List of (key, value) tuples to cache
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
try:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
S3 Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache (uses run_in_executor)
|
||||
- async_get_cache (uses run_in_executor)
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
s3_bucket_name,
|
||||
s3_region_name=None,
|
||||
s3_api_version=None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify=None,
|
||||
s3_endpoint_url=None,
|
||||
s3_aws_access_key_id=None,
|
||||
s3_aws_secret_access_key=None,
|
||||
s3_aws_session_token=None,
|
||||
s3_config=None,
|
||||
s3_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
import boto3
|
||||
|
||||
self.bucket_name = s3_bucket_name
|
||||
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
|
||||
# Create an S3 client with custom endpoint URL
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=s3_region_name,
|
||||
endpoint_url=s3_endpoint_url,
|
||||
api_version=s3_api_version,
|
||||
use_ssl=s3_use_ssl,
|
||||
verify=s3_verify,
|
||||
aws_access_key_id=s3_aws_access_key_id,
|
||||
aws_secret_access_key=s3_aws_secret_access_key,
|
||||
aws_session_token=s3_aws_session_token,
|
||||
config=s3_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _to_s3_key(self, key: str) -> str:
|
||||
"""Convert cache key to S3 key"""
|
||||
return self.key_prefix + key.replace(":", "/")
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
|
||||
ttl = kwargs.get("ttl", None)
|
||||
# Convert value to JSON before storing in S3
|
||||
serialized_value = json.dumps(value)
|
||||
key = self._to_s3_key(key)
|
||||
|
||||
if ttl is not None:
|
||||
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
|
||||
|
||||
# Calculate expiration time
|
||||
expiration_time = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
# Upload the data to S3 with the calculated expiration time
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
Expires=expiration_time,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
else:
|
||||
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
|
||||
# Upload the data to S3 without specifying Expires
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
"""
|
||||
Asynchronously set cache using run_in_executor to avoid blocking the event loop.
|
||||
Compatible with Python 3.8+.
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Set ASYNC S3 Cache: Key={key}. Value={value}")
|
||||
loop = asyncio.get_event_loop()
|
||||
func = partial(self.set_cache, key, value, **kwargs)
|
||||
await loop.run_in_executor(None, func)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: async_set_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
import botocore
|
||||
|
||||
try:
|
||||
key = self._to_s3_key(key)
|
||||
|
||||
print_verbose(f"Get S3 Cache: key: {key}")
|
||||
# Download the data from S3
|
||||
cached_response = self.s3_client.get_object(
|
||||
Bucket=self.bucket_name, Key=key
|
||||
)
|
||||
|
||||
if cached_response is not None:
|
||||
if "Expires" in cached_response:
|
||||
expires_time = cached_response["Expires"]
|
||||
current_time = datetime.now(expires_time.tzinfo)
|
||||
|
||||
if current_time > expires_time:
|
||||
return None
|
||||
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = (
|
||||
cached_response["Body"].read().decode("utf-8")
|
||||
) # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if not isinstance(cached_response, dict):
|
||||
cached_response = dict(cached_response)
|
||||
verbose_logger.debug(
|
||||
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except botocore.exceptions.ClientError as e: # type: ignore
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
verbose_logger.debug(
|
||||
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
"""
|
||||
Asynchronously get cache using run_in_executor to avoid blocking the event loop.
|
||||
Compatible with Python 3.8+.
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Get ASYNC S3 Cache: key: {key}")
|
||||
loop = asyncio.get_event_loop()
|
||||
func = partial(self.get_cache, key, **kwargs)
|
||||
result = await loop.run_in_executor(None, func)
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: async_get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,4 @@
|
||||
Logic specific for `litellm.completion`.
|
||||
|
||||
Includes:
|
||||
- Bridge for transforming completion requests to responses api requests
|
||||
@@ -0,0 +1,3 @@
|
||||
from .litellm_responses_transformation import responses_api_bridge
|
||||
|
||||
__all__ = ["responses_api_bridge"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from .handler import responses_api_bridge
|
||||
|
||||
__all__ = ["responses_api_bridge"]
|
||||
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Handler for transforming /chat/completions api requests to litellm.responses requests
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm.types.llms.openai import ResponsesAPIResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper, LiteLLMLoggingObj, ModelResponse
|
||||
|
||||
|
||||
class ResponsesToCompletionBridgeHandlerInputKwargs(TypedDict):
|
||||
model: str
|
||||
messages: list
|
||||
optional_params: dict
|
||||
litellm_params: dict
|
||||
headers: dict
|
||||
model_response: "ModelResponse"
|
||||
logging_obj: "LiteLLMLoggingObj"
|
||||
custom_llm_provider: str
|
||||
|
||||
|
||||
class ResponsesToCompletionBridgeHandler:
|
||||
def __init__(self):
|
||||
from .transformation import LiteLLMResponsesTransformationHandler
|
||||
|
||||
super().__init__()
|
||||
self.transformation_handler = LiteLLMResponsesTransformationHandler()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_stream_flag(optional_params: dict, litellm_params: dict) -> bool:
|
||||
stream = optional_params.get("stream")
|
||||
if stream is None:
|
||||
stream = litellm_params.get("stream", False)
|
||||
return bool(stream)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_response_object(
|
||||
response_obj: Any,
|
||||
hidden_params: Optional[dict],
|
||||
) -> "ResponsesAPIResponse":
|
||||
if isinstance(response_obj, ResponsesAPIResponse):
|
||||
response = response_obj
|
||||
elif isinstance(response_obj, dict):
|
||||
try:
|
||||
response = ResponsesAPIResponse(**response_obj)
|
||||
except Exception:
|
||||
response = ResponsesAPIResponse.model_construct(**response_obj)
|
||||
else:
|
||||
raise ValueError("Unexpected responses stream payload")
|
||||
|
||||
if hidden_params:
|
||||
existing = getattr(response, "_hidden_params", None)
|
||||
if not isinstance(existing, dict) or not existing:
|
||||
setattr(response, "_hidden_params", dict(hidden_params))
|
||||
else:
|
||||
for key, value in hidden_params.items():
|
||||
existing.setdefault(key, value)
|
||||
return response
|
||||
|
||||
def _collect_response_from_stream(self, stream_iter: Any) -> "ResponsesAPIResponse":
|
||||
for _ in stream_iter:
|
||||
pass
|
||||
|
||||
completed = getattr(stream_iter, "completed_response", None)
|
||||
response_obj = getattr(completed, "response", None) if completed else None
|
||||
if response_obj is None:
|
||||
raise ValueError("Stream ended without a completed response")
|
||||
|
||||
hidden_params = getattr(stream_iter, "_hidden_params", None)
|
||||
response = self._coerce_response_object(response_obj, hidden_params)
|
||||
if not isinstance(response, ResponsesAPIResponse):
|
||||
raise ValueError("Stream completed response is invalid")
|
||||
return response
|
||||
|
||||
async def _collect_response_from_stream_async(
|
||||
self, stream_iter: Any
|
||||
) -> "ResponsesAPIResponse":
|
||||
async for _ in stream_iter:
|
||||
pass
|
||||
|
||||
completed = getattr(stream_iter, "completed_response", None)
|
||||
response_obj = getattr(completed, "response", None) if completed else None
|
||||
if response_obj is None:
|
||||
raise ValueError("Stream ended without a completed response")
|
||||
|
||||
hidden_params = getattr(stream_iter, "_hidden_params", None)
|
||||
response = self._coerce_response_object(response_obj, hidden_params)
|
||||
if not isinstance(response, ResponsesAPIResponse):
|
||||
raise ValueError("Stream completed response is invalid")
|
||||
return response
|
||||
|
||||
def validate_input_kwargs(
|
||||
self, kwargs: dict
|
||||
) -> ResponsesToCompletionBridgeHandlerInputKwargs:
|
||||
from litellm import LiteLLMLoggingObj
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None or not isinstance(model, str):
|
||||
raise ValueError("model is required")
|
||||
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider")
|
||||
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
|
||||
raise ValueError("custom_llm_provider is required")
|
||||
|
||||
messages = kwargs.get("messages")
|
||||
if messages is None or not isinstance(messages, list):
|
||||
raise ValueError("messages is required")
|
||||
|
||||
optional_params = kwargs.get("optional_params")
|
||||
if optional_params is None or not isinstance(optional_params, dict):
|
||||
raise ValueError("optional_params is required")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
if litellm_params is None or not isinstance(litellm_params, dict):
|
||||
raise ValueError("litellm_params is required")
|
||||
|
||||
headers = kwargs.get("headers")
|
||||
if headers is None or not isinstance(headers, dict):
|
||||
raise ValueError("headers is required")
|
||||
|
||||
model_response = kwargs.get("model_response")
|
||||
if model_response is None or not isinstance(model_response, ModelResponse):
|
||||
raise ValueError("model_response is required")
|
||||
|
||||
logging_obj = kwargs.get("logging_obj")
|
||||
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
|
||||
raise ValueError("logging_obj is required")
|
||||
|
||||
return ResponsesToCompletionBridgeHandlerInputKwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self, *args, **kwargs
|
||||
) -> Union[
|
||||
Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]],
|
||||
"ModelResponse",
|
||||
"CustomStreamWrapper",
|
||||
]:
|
||||
if kwargs.get("acompletion") is True:
|
||||
return self.acompletion(**kwargs)
|
||||
|
||||
from litellm import responses
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(kwargs)
|
||||
model = validated_kwargs["model"]
|
||||
messages = validated_kwargs["messages"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
model_response = validated_kwargs["model_response"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
|
||||
result = responses(
|
||||
**request_data,
|
||||
)
|
||||
|
||||
stream = self._resolve_stream_flag(optional_params, litellm_params)
|
||||
if isinstance(result, ResponsesAPIResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=result,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
elif not stream:
|
||||
responses_api_response = self._collect_response_from_stream(result)
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=responses_api_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
else:
|
||||
completion_stream = self.transformation_handler.get_model_response_iterator(
|
||||
streaming_response=result, # type: ignore
|
||||
sync_stream=True,
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return self._apply_post_stream_processing(
|
||||
streamwrapper, model, custom_llm_provider
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self, *args, **kwargs
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
from litellm import aresponses
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(kwargs)
|
||||
model = validated_kwargs["model"]
|
||||
messages = validated_kwargs["messages"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
model_response = validated_kwargs["model_response"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
|
||||
try:
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
result = await aresponses(
|
||||
**request_data,
|
||||
aresponses=True,
|
||||
)
|
||||
|
||||
stream = self._resolve_stream_flag(optional_params, litellm_params)
|
||||
if isinstance(result, ResponsesAPIResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=result,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
elif not stream:
|
||||
responses_api_response = await self._collect_response_from_stream_async(
|
||||
result
|
||||
)
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=responses_api_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
else:
|
||||
completion_stream = self.transformation_handler.get_model_response_iterator(
|
||||
streaming_response=result, # type: ignore
|
||||
sync_stream=False,
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return self._apply_post_stream_processing(
|
||||
streamwrapper, model, custom_llm_provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _apply_post_stream_processing(
|
||||
stream: "CustomStreamWrapper",
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
) -> Any:
|
||||
"""Apply provider-specific post-stream processing if available."""
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
try:
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
return stream
|
||||
|
||||
if provider_config is not None:
|
||||
return provider_config.post_stream_processing(stream)
|
||||
return stream
|
||||
|
||||
|
||||
responses_api_bridge = ResponsesToCompletionBridgeHandler()
|
||||
File diff suppressed because it is too large
Load Diff
1530
llm-gateway-competitors/litellm-wheel-src/litellm/constants.py
Normal file
1530
llm-gateway-competitors/litellm-wheel-src/litellm/constants.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,241 @@
|
||||
# Container Files API
|
||||
|
||||
This module provides a unified interface for container file operations across multiple LLM providers (OpenAI, Azure OpenAI, etc.).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
endpoints.json # Declarative endpoint definitions
|
||||
↓
|
||||
endpoint_factory.py # Auto-generates SDK functions
|
||||
↓
|
||||
container_handler.py # Generic HTTP handler
|
||||
↓
|
||||
BaseContainerConfig # Provider-specific transformations
|
||||
├── OpenAIContainerConfig
|
||||
└── AzureContainerConfig (example)
|
||||
```
|
||||
|
||||
## Files Overview
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `endpoints.json` | **Single source of truth** - Defines all container file endpoints |
|
||||
| `endpoint_factory.py` | Auto-generates SDK functions (`list_container_files`, etc.) |
|
||||
| `main.py` | Core container operations (create, list, retrieve, delete containers) |
|
||||
| `utils.py` | Request parameter utilities |
|
||||
|
||||
## Adding a New Endpoint
|
||||
|
||||
To add a new container file endpoint (e.g., `get_container_file_content`):
|
||||
|
||||
### Step 1: Add to `endpoints.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "get_container_file_content",
|
||||
"async_name": "aget_container_file_content",
|
||||
"path": "/containers/{container_id}/files/{file_id}/content",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "ContainerFileContentResponse"
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Add Response Type (if new)
|
||||
|
||||
In `litellm/types/containers/main.py`:
|
||||
|
||||
```python
|
||||
class ContainerFileContentResponse(BaseModel):
|
||||
"""Response for file content download."""
|
||||
content: bytes
|
||||
# ... other fields
|
||||
```
|
||||
|
||||
### Step 3: Register Response Type
|
||||
|
||||
In `litellm/llms/custom_httpx/container_handler.py`, add to `RESPONSE_TYPES`:
|
||||
|
||||
```python
|
||||
RESPONSE_TYPES = {
|
||||
# ... existing types
|
||||
"ContainerFileContentResponse": ContainerFileContentResponse,
|
||||
}
|
||||
```
|
||||
|
||||
### Step 4: Update Router (one-time setup)
|
||||
|
||||
In `litellm/router.py`, add the call_type to the factory_function Literal and `_init_containers_api_endpoints` condition.
|
||||
|
||||
In `litellm/proxy/route_llm_request.py`, add to the route mappings and skip-model-routing lists.
|
||||
|
||||
### Step 5: Update Proxy Handler Factory (if new path params)
|
||||
|
||||
If your endpoint has a new combination of path parameters, add a handler in `litellm/proxy/container_endpoints/handler_factory.py`:
|
||||
|
||||
```python
|
||||
elif path_params == ["container_id", "file_id", "new_param"]:
|
||||
async def handler(...):
|
||||
# handler implementation
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding a New Provider (e.g., Azure OpenAI)
|
||||
|
||||
### Step 1: Create Provider Config
|
||||
|
||||
Create `litellm/llms/azure/containers/transformation.py`:
|
||||
|
||||
```python
|
||||
from typing import Dict, Optional, Tuple, Any
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
from litellm.types.containers.main import (
|
||||
ContainerFileListResponse,
|
||||
ContainerFileObject,
|
||||
DeleteContainerFileResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class AzureContainerConfig(BaseContainerConfig):
|
||||
"""Configuration class for Azure OpenAI container API."""
|
||||
|
||||
def get_supported_openai_params(self) -> list:
|
||||
return ["name", "expires_after", "file_ids", "extra_headers"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
container_create_optional_params,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
return dict(container_create_optional_params)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Azure uses api-key header instead of Bearer token."""
|
||||
import litellm
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
headers["api-key"] = api_key
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Azure format:
|
||||
https://{resource}.openai.azure.com/openai/containers?api-version=2024-xx
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Azure")
|
||||
|
||||
api_version = litellm_params.get("api_version", "2024-02-15-preview")
|
||||
return f"{api_base.rstrip('/')}/openai/containers?api-version={api_version}"
|
||||
|
||||
# Implement remaining abstract methods from BaseContainerConfig:
|
||||
# - transform_container_create_request
|
||||
# - transform_container_create_response
|
||||
# - transform_container_list_request
|
||||
# - transform_container_list_response
|
||||
# - transform_container_retrieve_request
|
||||
# - transform_container_retrieve_response
|
||||
# - transform_container_delete_request
|
||||
# - transform_container_delete_response
|
||||
# - transform_container_file_list_request
|
||||
# - transform_container_file_list_response
|
||||
```
|
||||
|
||||
### Step 2: Register Provider Config
|
||||
|
||||
In `litellm/utils.py`, find `ProviderConfigManager.get_provider_container_config()` and add:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def get_provider_container_config(
|
||||
provider: LlmProviders,
|
||||
) -> Optional[BaseContainerConfig]:
|
||||
if provider == LlmProviders.OPENAI:
|
||||
from litellm.llms.openai.containers.transformation import OpenAIContainerConfig
|
||||
return OpenAIContainerConfig()
|
||||
elif provider == LlmProviders.AZURE:
|
||||
from litellm.llms.azure.containers.transformation import AzureContainerConfig
|
||||
return AzureContainerConfig()
|
||||
return None
|
||||
```
|
||||
|
||||
### Step 3: Test the New Provider
|
||||
|
||||
```bash
|
||||
# Create container via Azure
|
||||
curl -X POST "http://localhost:4000/v1/containers" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "custom-llm-provider: azure" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"name": "My Azure Container"}'
|
||||
|
||||
# List container files via Azure
|
||||
curl -X GET "http://localhost:4000/v1/containers/cntr_123/files" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "custom-llm-provider: azure"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## How Provider Selection Works
|
||||
|
||||
1. **Proxy receives request** with `custom-llm-provider` header/query/body
|
||||
2. **Router calls** `ProviderConfigManager.get_provider_container_config(provider)`
|
||||
3. **Generic handler** uses the provider config for:
|
||||
- URL construction (`get_complete_url`)
|
||||
- Authentication (`validate_environment`)
|
||||
- Request/response transformation
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
Run the container API tests:
|
||||
|
||||
```bash
|
||||
cd /Users/ishaanjaffer/github/litellm
|
||||
python -m pytest tests/test_litellm/containers/ -v
|
||||
```
|
||||
|
||||
Test via proxy:
|
||||
|
||||
```bash
|
||||
# Start proxy
|
||||
cd litellm/proxy && python proxy_cli.py --config proxy_config.yaml --port 4000
|
||||
|
||||
# Test endpoints
|
||||
curl -X GET "http://localhost:4000/v1/containers/cntr_123/files" \
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Endpoint Reference
|
||||
|
||||
| Endpoint | Method | Path |
|
||||
|----------|--------|------|
|
||||
| List container files | GET | `/v1/containers/{container_id}/files` |
|
||||
| Retrieve container file | GET | `/v1/containers/{container_id}/files/{file_id}` |
|
||||
| Delete container file | DELETE | `/v1/containers/{container_id}/files/{file_id}` |
|
||||
|
||||
See `endpoints.json` for the complete list.
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Container management functions for LiteLLM."""
|
||||
|
||||
# Auto-generated container file functions from endpoints.json
|
||||
from .endpoint_factory import (
|
||||
adelete_container_file,
|
||||
alist_container_files,
|
||||
aretrieve_container_file,
|
||||
aretrieve_container_file_content,
|
||||
delete_container_file,
|
||||
list_container_files,
|
||||
retrieve_container_file,
|
||||
retrieve_container_file_content,
|
||||
)
|
||||
from .main import (
|
||||
acreate_container,
|
||||
adelete_container,
|
||||
alist_containers,
|
||||
aretrieve_container,
|
||||
create_container,
|
||||
delete_container,
|
||||
list_containers,
|
||||
retrieve_container,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core container operations
|
||||
"acreate_container",
|
||||
"adelete_container",
|
||||
"alist_containers",
|
||||
"aretrieve_container",
|
||||
"create_container",
|
||||
"delete_container",
|
||||
"list_containers",
|
||||
"retrieve_container",
|
||||
# Container file operations (auto-generated from endpoints.json)
|
||||
"adelete_container_file",
|
||||
"alist_container_files",
|
||||
"aretrieve_container_file",
|
||||
"aretrieve_container_file_content",
|
||||
"delete_container_file",
|
||||
"list_container_files",
|
||||
"retrieve_container_file",
|
||||
"retrieve_container_file_content",
|
||||
]
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Factory for generating container SDK functions from JSON config.
|
||||
|
||||
This module reads endpoints.json and dynamically generates SDK functions
|
||||
that use the generic container handler.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Type
|
||||
|
||||
import litellm
|
||||
from litellm.constants import request_timeout as DEFAULT_REQUEST_TIMEOUT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
from litellm.llms.custom_httpx.container_handler import generic_container_handler
|
||||
from litellm.types.containers.main import (
|
||||
ContainerFileListResponse,
|
||||
ContainerFileObject,
|
||||
DeleteContainerFileResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import ProviderConfigManager, client
|
||||
|
||||
# Response type mapping
|
||||
RESPONSE_TYPES: Dict[str, Type] = {
|
||||
"ContainerFileListResponse": ContainerFileListResponse,
|
||||
"ContainerFileObject": ContainerFileObject,
|
||||
"DeleteContainerFileResponse": DeleteContainerFileResponse,
|
||||
}
|
||||
|
||||
|
||||
def _load_endpoints_config() -> Dict:
|
||||
"""Load the endpoints configuration from JSON file."""
|
||||
config_path = Path(__file__).parent / "endpoints.json"
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def create_sync_endpoint_function(endpoint_config: Dict) -> Callable:
|
||||
"""
|
||||
Create a sync SDK function from endpoint config.
|
||||
|
||||
Uses the generic container handler instead of individual handler methods.
|
||||
"""
|
||||
endpoint_name = endpoint_config["name"]
|
||||
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
|
||||
path_params = endpoint_config.get("path_params", [])
|
||||
|
||||
@client
|
||||
def endpoint_func(
|
||||
timeout: int = 600,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
local_vars = locals()
|
||||
try:
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.pop("litellm_logging_obj")
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id")
|
||||
_is_async = kwargs.pop("async_call", False) is True
|
||||
|
||||
# Check for mock response
|
||||
mock_response = kwargs.get("mock_response")
|
||||
if mock_response is not None:
|
||||
if isinstance(mock_response, str):
|
||||
mock_response = json.loads(mock_response)
|
||||
if response_type:
|
||||
return response_type(**mock_response)
|
||||
return mock_response
|
||||
|
||||
# Get provider config
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
container_provider_config: Optional[
|
||||
BaseContainerConfig
|
||||
] = ProviderConfigManager.get_provider_container_config(
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if container_provider_config is None:
|
||||
raise ValueError(
|
||||
f"Container provider config not found for: {custom_llm_provider}"
|
||||
)
|
||||
|
||||
# Build optional params for logging
|
||||
optional_params = {k: kwargs.get(k) for k in path_params if k in kwargs}
|
||||
|
||||
# Pre-call logging
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model="",
|
||||
optional_params=optional_params,
|
||||
litellm_params={"litellm_call_id": litellm_call_id},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Use generic handler
|
||||
return generic_container_handler.handle(
|
||||
endpoint_name=endpoint_name,
|
||||
container_provider_config=container_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
|
||||
_is_async=_is_async,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model="",
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
|
||||
|
||||
def create_async_endpoint_function(
|
||||
sync_func: Callable,
|
||||
endpoint_config: Dict,
|
||||
) -> Callable:
|
||||
"""Create an async SDK function that wraps the sync function."""
|
||||
|
||||
@client
|
||||
async def async_endpoint_func(
|
||||
timeout: int = 600,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["async_call"] = True
|
||||
|
||||
func = partial(
|
||||
sync_func,
|
||||
timeout=timeout,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model="",
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return async_endpoint_func
|
||||
|
||||
|
||||
def generate_container_endpoints() -> Dict[str, Callable]:
|
||||
"""
|
||||
Generate all container endpoint functions from the JSON config.
|
||||
|
||||
Returns a dict mapping function names to their implementations.
|
||||
"""
|
||||
config = _load_endpoints_config()
|
||||
endpoints = {}
|
||||
|
||||
for endpoint_config in config["endpoints"]:
|
||||
# Create sync function
|
||||
sync_func = create_sync_endpoint_function(endpoint_config)
|
||||
endpoints[endpoint_config["name"]] = sync_func
|
||||
|
||||
# Create async function
|
||||
async_func = create_async_endpoint_function(sync_func, endpoint_config)
|
||||
endpoints[endpoint_config["async_name"]] = async_func
|
||||
|
||||
return endpoints
|
||||
|
||||
|
||||
def get_all_endpoint_names() -> List[str]:
|
||||
"""Get all endpoint names (sync and async) from config."""
|
||||
config = _load_endpoints_config()
|
||||
names = []
|
||||
for endpoint in config["endpoints"]:
|
||||
names.append(endpoint["name"])
|
||||
names.append(endpoint["async_name"])
|
||||
return names
|
||||
|
||||
|
||||
def get_async_endpoint_names() -> List[str]:
|
||||
"""Get all async endpoint names for router registration."""
|
||||
config = _load_endpoints_config()
|
||||
return [endpoint["async_name"] for endpoint in config["endpoints"]]
|
||||
|
||||
|
||||
# Generate endpoints on module load
|
||||
_generated_endpoints = generate_container_endpoints()
|
||||
|
||||
# Export generated functions dynamically
|
||||
list_container_files = _generated_endpoints.get("list_container_files")
|
||||
alist_container_files = _generated_endpoints.get("alist_container_files")
|
||||
upload_container_file = _generated_endpoints.get("upload_container_file")
|
||||
aupload_container_file = _generated_endpoints.get("aupload_container_file")
|
||||
retrieve_container_file = _generated_endpoints.get("retrieve_container_file")
|
||||
aretrieve_container_file = _generated_endpoints.get("aretrieve_container_file")
|
||||
delete_container_file = _generated_endpoints.get("delete_container_file")
|
||||
adelete_container_file = _generated_endpoints.get("adelete_container_file")
|
||||
retrieve_container_file_content = _generated_endpoints.get(
|
||||
"retrieve_container_file_content"
|
||||
)
|
||||
aretrieve_container_file_content = _generated_endpoints.get(
|
||||
"aretrieve_container_file_content"
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"endpoints": [
|
||||
{
|
||||
"name": "list_container_files",
|
||||
"async_name": "alist_container_files",
|
||||
"path": "/containers/{container_id}/files",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id"],
|
||||
"query_params": ["after", "limit", "order"],
|
||||
"response_type": "ContainerFileListResponse"
|
||||
},
|
||||
{
|
||||
"name": "upload_container_file",
|
||||
"async_name": "aupload_container_file",
|
||||
"path": "/containers/{container_id}/files",
|
||||
"method": "POST",
|
||||
"path_params": ["container_id"],
|
||||
"query_params": [],
|
||||
"response_type": "ContainerFileObject",
|
||||
"is_multipart": true
|
||||
},
|
||||
{
|
||||
"name": "retrieve_container_file",
|
||||
"async_name": "aretrieve_container_file",
|
||||
"path": "/containers/{container_id}/files/{file_id}",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "ContainerFileObject"
|
||||
},
|
||||
{
|
||||
"name": "delete_container_file",
|
||||
"async_name": "adelete_container_file",
|
||||
"path": "/containers/{container_id}/files/{file_id}",
|
||||
"method": "DELETE",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "DeleteContainerFileResponse"
|
||||
},
|
||||
{
|
||||
"name": "retrieve_container_file_content",
|
||||
"async_name": "aretrieve_container_file_content",
|
||||
"path": "/containers/{container_id}/files/{file_id}/content",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "raw",
|
||||
"returns_binary": true
|
||||
}
|
||||
]
|
||||
}
|
||||
1290
llm-gateway-competitors/litellm-wheel-src/litellm/containers/main.py
Normal file
1290
llm-gateway-competitors/litellm-wheel-src/litellm/containers/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,70 @@
|
||||
from typing import Dict
|
||||
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
from litellm.types.containers.main import (
|
||||
ContainerCreateOptionalRequestParams,
|
||||
ContainerListOptionalRequestParams,
|
||||
)
|
||||
|
||||
|
||||
class ContainerRequestUtils:
|
||||
@staticmethod
|
||||
def get_requested_container_create_optional_param(
|
||||
passed_params: dict,
|
||||
) -> ContainerCreateOptionalRequestParams:
|
||||
"""Extract only valid container creation parameters from the passed parameters."""
|
||||
container_create_optional_params = ContainerCreateOptionalRequestParams()
|
||||
|
||||
valid_params = [
|
||||
"expires_after",
|
||||
"file_ids",
|
||||
"extra_headers",
|
||||
"extra_body",
|
||||
]
|
||||
|
||||
for param in valid_params:
|
||||
if param in passed_params and passed_params[param] is not None:
|
||||
container_create_optional_params[param] = passed_params[param] # type: ignore
|
||||
|
||||
return container_create_optional_params
|
||||
|
||||
@staticmethod
|
||||
def get_optional_params_container_create(
|
||||
container_provider_config: BaseContainerConfig,
|
||||
container_create_optional_params: ContainerCreateOptionalRequestParams,
|
||||
) -> Dict:
|
||||
"""Get the optional parameters for container creation."""
|
||||
supported_params = container_provider_config.get_supported_openai_params()
|
||||
|
||||
# Filter out unsupported parameters
|
||||
filtered_params = {
|
||||
k: v
|
||||
for k, v in container_create_optional_params.items()
|
||||
if k in supported_params
|
||||
}
|
||||
|
||||
return container_provider_config.map_openai_params(
|
||||
container_create_optional_params=filtered_params, # type: ignore
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_requested_container_list_optional_param(
|
||||
passed_params: dict,
|
||||
) -> ContainerListOptionalRequestParams:
|
||||
"""Extract only valid container list parameters from the passed parameters."""
|
||||
container_list_optional_params = ContainerListOptionalRequestParams()
|
||||
|
||||
valid_params = [
|
||||
"after",
|
||||
"limit",
|
||||
"order",
|
||||
"extra_headers",
|
||||
"extra_query",
|
||||
]
|
||||
|
||||
for param in valid_params:
|
||||
if param in passed_params and passed_params[param] is not None:
|
||||
container_list_optional_params[param] = passed_params[param] # type: ignore
|
||||
|
||||
return container_list_optional_params
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gpt-3.5-turbo-0613": 0.00015000000000000001,
|
||||
"claude-2": 0.00016454,
|
||||
"gpt-4-0613": 0.015408
|
||||
}
|
||||
2268
llm-gateway-competitors/litellm-wheel-src/litellm/cost_calculator.py
Normal file
2268
llm-gateway-competitors/litellm-wheel-src/litellm/cost_calculator.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Handler for transforming /chat/completions api requests to litellm.responses requests
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
|
||||
class SpeechToCompletionBridgeHandlerInputKwargs(TypedDict):
|
||||
model: str
|
||||
input: str
|
||||
voice: Optional[Union[str, dict]]
|
||||
optional_params: dict
|
||||
litellm_params: dict
|
||||
logging_obj: "LiteLLMLoggingObj"
|
||||
headers: dict
|
||||
custom_llm_provider: str
|
||||
|
||||
|
||||
class SpeechToCompletionBridgeHandler:
|
||||
def __init__(self):
|
||||
from .transformation import SpeechToCompletionBridgeTransformationHandler
|
||||
|
||||
super().__init__()
|
||||
self.transformation_handler = SpeechToCompletionBridgeTransformationHandler()
|
||||
|
||||
def validate_input_kwargs(
|
||||
self, kwargs: dict
|
||||
) -> SpeechToCompletionBridgeHandlerInputKwargs:
|
||||
from litellm import LiteLLMLoggingObj
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None or not isinstance(model, str):
|
||||
raise ValueError("model is required")
|
||||
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider")
|
||||
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
|
||||
raise ValueError("custom_llm_provider is required")
|
||||
|
||||
input = kwargs.get("input")
|
||||
if input is None or not isinstance(input, str):
|
||||
raise ValueError("input is required")
|
||||
|
||||
optional_params = kwargs.get("optional_params")
|
||||
if optional_params is None or not isinstance(optional_params, dict):
|
||||
raise ValueError("optional_params is required")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
if litellm_params is None or not isinstance(litellm_params, dict):
|
||||
raise ValueError("litellm_params is required")
|
||||
|
||||
headers = kwargs.get("headers")
|
||||
if headers is None or not isinstance(headers, dict):
|
||||
raise ValueError("headers is required")
|
||||
|
||||
headers = kwargs.get("headers")
|
||||
if headers is None or not isinstance(headers, dict):
|
||||
raise ValueError("headers is required")
|
||||
|
||||
logging_obj = kwargs.get("logging_obj")
|
||||
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
|
||||
raise ValueError("logging_obj is required")
|
||||
|
||||
return SpeechToCompletionBridgeHandlerInputKwargs(
|
||||
model=model,
|
||||
input=input,
|
||||
voice=kwargs.get("voice"),
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def speech(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, dict]],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
custom_llm_provider: str,
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
received_args = locals()
|
||||
from litellm import completion
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(received_args)
|
||||
model = validated_kwargs["model"]
|
||||
input = validated_kwargs["input"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
voice = validated_kwargs["voice"]
|
||||
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
result = completion(
|
||||
**request_data,
|
||||
)
|
||||
|
||||
if isinstance(result, ModelResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model_response=result,
|
||||
)
|
||||
else:
|
||||
raise Exception("Unmapped response type. Got type: {}".format(type(result)))
|
||||
|
||||
|
||||
speech_to_completion_bridge_handler = SpeechToCompletionBridgeHandler()
|
||||
@@ -0,0 +1,134 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
from litellm.constants import OPENAI_CHAT_COMPLETION_PARAMS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class SpeechToCompletionBridgeTransformationHandler:
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, dict]],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
custom_llm_provider: str,
|
||||
) -> dict:
|
||||
passed_optional_params = {}
|
||||
for op in optional_params:
|
||||
if op in OPENAI_CHAT_COMPLETION_PARAMS:
|
||||
passed_optional_params[op] = optional_params[op]
|
||||
|
||||
if voice is not None:
|
||||
if isinstance(voice, str):
|
||||
passed_optional_params["audio"] = {"voice": voice}
|
||||
if "response_format" in optional_params:
|
||||
passed_optional_params["audio"]["format"] = optional_params[
|
||||
"response_format"
|
||||
]
|
||||
|
||||
return_kwargs = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": input,
|
||||
}
|
||||
],
|
||||
"modalities": ["audio"],
|
||||
**passed_optional_params,
|
||||
**litellm_params,
|
||||
"headers": headers,
|
||||
"litellm_logging_obj": litellm_logging_obj,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
}
|
||||
|
||||
# filter out None values
|
||||
return_kwargs = {k: v for k, v in return_kwargs.items() if v is not None}
|
||||
return return_kwargs
|
||||
|
||||
def _convert_pcm16_to_wav(
|
||||
self, pcm_data: bytes, sample_rate: int = 24000, channels: int = 1
|
||||
) -> bytes:
|
||||
"""
|
||||
Convert raw PCM16 data to WAV format.
|
||||
|
||||
Args:
|
||||
pcm_data: Raw PCM16 audio data
|
||||
sample_rate: Sample rate in Hz (Gemini TTS typically uses 24000)
|
||||
channels: Number of audio channels (1 for mono)
|
||||
|
||||
Returns:
|
||||
bytes: WAV formatted audio data
|
||||
"""
|
||||
import struct
|
||||
|
||||
# WAV header parameters
|
||||
byte_rate = sample_rate * channels * 2 # 2 bytes per sample (16-bit)
|
||||
block_align = channels * 2
|
||||
data_size = len(pcm_data)
|
||||
file_size = 36 + data_size
|
||||
|
||||
# Create WAV header
|
||||
wav_header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # Chunk ID
|
||||
file_size, # Chunk Size
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1 ID
|
||||
16, # Subchunk1 Size (PCM)
|
||||
1, # Audio Format (PCM)
|
||||
channels, # Number of Channels
|
||||
sample_rate, # Sample Rate
|
||||
byte_rate, # Byte Rate
|
||||
block_align, # Block Align
|
||||
16, # Bits per Sample
|
||||
b"data", # Subchunk2 ID
|
||||
data_size, # Subchunk2 Size
|
||||
)
|
||||
|
||||
return wav_header + pcm_data
|
||||
|
||||
def _is_gemini_tts_model(self, model: str) -> bool:
|
||||
"""Check if the model is a Gemini TTS model that returns PCM16 data."""
|
||||
return "gemini" in model.lower() and (
|
||||
"tts" in model.lower() or "preview-tts" in model.lower()
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self, model_response: "ModelResponse"
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.utils import Choices
|
||||
|
||||
audio_part = cast(Choices, model_response.choices[0]).message.audio
|
||||
if audio_part is None:
|
||||
raise ValueError("No audio part found in the response")
|
||||
audio_content = audio_part.data
|
||||
|
||||
# Decode base64 to get binary content
|
||||
binary_data = base64.b64decode(audio_content)
|
||||
|
||||
# Check if this is a Gemini TTS model that returns raw PCM16 data
|
||||
model = getattr(model_response, "model", "")
|
||||
headers = {}
|
||||
if self._is_gemini_tts_model(model):
|
||||
# Convert PCM16 to WAV format for proper audio file playback
|
||||
binary_data = self._convert_pcm16_to_wav(binary_data)
|
||||
headers["Content-Type"] = "audio/wav"
|
||||
else:
|
||||
headers["Content-Type"] = "audio/mpeg"
|
||||
|
||||
# Create an httpx.Response object
|
||||
response = httpx.Response(status_code=200, content=binary_data, headers=headers)
|
||||
return HttpxBinaryResponseContent(response)
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Evals API operations
|
||||
"""
|
||||
|
||||
from .main import (
|
||||
acancel_eval,
|
||||
acreate_eval,
|
||||
adelete_eval,
|
||||
aget_eval,
|
||||
alist_evals,
|
||||
aupdate_eval,
|
||||
cancel_eval,
|
||||
create_eval,
|
||||
delete_eval,
|
||||
get_eval,
|
||||
list_evals,
|
||||
update_eval,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"acreate_eval",
|
||||
"alist_evals",
|
||||
"aget_eval",
|
||||
"aupdate_eval",
|
||||
"adelete_eval",
|
||||
"acancel_eval",
|
||||
"create_eval",
|
||||
"list_evals",
|
||||
"get_eval",
|
||||
"update_eval",
|
||||
"delete_eval",
|
||||
"cancel_eval",
|
||||
]
|
||||
1975
llm-gateway-competitors/litellm-wheel-src/litellm/evals/main.py
Normal file
1975
llm-gateway-competitors/litellm-wheel-src/litellm/evals/main.py
Normal file
File diff suppressed because it is too large
Load Diff
1030
llm-gateway-competitors/litellm-wheel-src/litellm/exceptions.py
Normal file
1030
llm-gateway-competitors/litellm-wheel-src/litellm/exceptions.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
||||
# LiteLLM MCP Client
|
||||
|
||||
LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import call_openai_tool, load_mcp_tools
|
||||
|
||||
__all__ = ["load_mcp_tools", "call_openai_tool"]
|
||||
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
streamable_http_client: Optional[Any] = None
|
||||
try:
|
||||
import mcp.client.streamable_http as streamable_http_module # type: ignore
|
||||
|
||||
streamable_http_client = getattr(
|
||||
streamable_http_module, "streamable_http_client", None
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import (
|
||||
GetPromptRequestParams,
|
||||
GetPromptResult,
|
||||
Prompt,
|
||||
ResourceTemplate,
|
||||
TextContent,
|
||||
)
|
||||
from mcp.types import Tool as MCPTool
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import MCP_CLIENT_TIMEOUT
|
||||
from litellm.llms.custom_httpx.http_handler import get_ssl_configuration
|
||||
from litellm.types.llms.custom_http import VerifyTypes
|
||||
from litellm.types.mcp import (
|
||||
MCPAuth,
|
||||
MCPAuthType,
|
||||
MCPStdioConfig,
|
||||
MCPTransport,
|
||||
MCPTransportType,
|
||||
)
|
||||
|
||||
|
||||
def to_basic_auth(auth_value: str) -> str:
|
||||
"""Convert auth value to Basic Auth format."""
|
||||
return base64.b64encode(auth_value.encode("utf-8")).decode()
|
||||
|
||||
|
||||
TSessionResult = TypeVar("TSessionResult")
|
||||
|
||||
|
||||
class MCPSigV4Auth(httpx.Auth):
|
||||
"""
|
||||
httpx Auth class that signs each request with AWS SigV4.
|
||||
|
||||
This is used for MCP servers that require AWS SigV4 authentication,
|
||||
such as AWS Bedrock AgentCore MCP servers. httpx calls auth_flow()
|
||||
for every outgoing request, enabling per-request signature computation.
|
||||
"""
|
||||
|
||||
requires_request_body = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_service_name: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing botocore to use AWS SigV4 authentication. "
|
||||
"Run 'pip install boto3'."
|
||||
)
|
||||
|
||||
self.service_name = aws_service_name or "bedrock-agentcore"
|
||||
self.region_name = aws_region_name or "us-east-1"
|
||||
|
||||
# Note: os.environ/ prefixed values are already resolved by
|
||||
# ProxyConfig._check_for_os_environ_vars() at config load time.
|
||||
# Values arrive here as plain strings.
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
self.credentials = Credentials(
|
||||
access_key=aws_access_key_id,
|
||||
secret_key=aws_secret_access_key,
|
||||
token=aws_session_token,
|
||||
)
|
||||
else:
|
||||
# Fall back to default boto3 credential chain
|
||||
import botocore.session
|
||||
|
||||
session = botocore.session.get_session()
|
||||
self.credentials = session.get_credentials()
|
||||
if self.credentials is None:
|
||||
raise ValueError(
|
||||
"No AWS credentials found. Provide aws_access_key_id and "
|
||||
"aws_secret_access_key, or configure default credentials "
|
||||
"(env vars, ~/.aws/credentials, instance profile)."
|
||||
)
|
||||
|
||||
def auth_flow(
|
||||
self, request: httpx.Request
|
||||
) -> Generator[httpx.Request, httpx.Response, None]:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
# Build AWSRequest from the httpx Request.
|
||||
# Pass all request headers so the canonical SigV4 signature covers them.
|
||||
aws_request = AWSRequest(
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
data=request.content,
|
||||
headers=dict(request.headers),
|
||||
)
|
||||
|
||||
# Sign the request — SigV4Auth.add_auth() adds Authorization,
|
||||
# X-Amz-Date, and X-Amz-Security-Token (if session token present).
|
||||
# Host header is derived automatically from the URL.
|
||||
sigv4 = SigV4Auth(self.credentials, self.service_name, self.region_name)
|
||||
sigv4.add_auth(aws_request)
|
||||
|
||||
# Copy SigV4 headers back to the httpx request
|
||||
for header_name, header_value in aws_request.headers.items():
|
||||
request.headers[header_name] = header_value
|
||||
|
||||
yield request
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
MCP Client supporting:
|
||||
SSE and HTTP transports
|
||||
Authentication via Bearer token, Basic Auth, or API Key
|
||||
Tool calling with error handling and result parsing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str = "",
|
||||
transport_type: MCPTransportType = MCPTransport.http,
|
||||
auth_type: MCPAuthType = None,
|
||||
auth_value: Optional[Union[str, Dict[str, str]]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
stdio_config: Optional[MCPStdioConfig] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
ssl_verify: Optional[VerifyTypes] = None,
|
||||
aws_auth: Optional[httpx.Auth] = None,
|
||||
):
|
||||
self.server_url: str = server_url
|
||||
self.transport_type: MCPTransport = transport_type
|
||||
self.auth_type: MCPAuthType = auth_type
|
||||
self.timeout: float = timeout if timeout is not None else MCP_CLIENT_TIMEOUT
|
||||
self._mcp_auth_value: Optional[Union[str, Dict[str, str]]] = None
|
||||
self.stdio_config: Optional[MCPStdioConfig] = stdio_config
|
||||
self.extra_headers: Optional[Dict[str, str]] = extra_headers
|
||||
self.ssl_verify: Optional[VerifyTypes] = ssl_verify
|
||||
self._aws_auth: Optional[httpx.Auth] = aws_auth
|
||||
# handle the basic auth value if provided
|
||||
if auth_value:
|
||||
self.update_auth_value(auth_value)
|
||||
|
||||
def _create_transport_context(
|
||||
self,
|
||||
) -> Tuple[Any, Optional[httpx.AsyncClient]]:
|
||||
"""
|
||||
Create the appropriate transport context based on transport type.
|
||||
|
||||
Returns:
|
||||
Tuple of (transport_context, http_client).
|
||||
http_client is only set for HTTP transport and needs cleanup.
|
||||
"""
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
if self.transport_type == MCPTransport.stdio:
|
||||
if not self.stdio_config:
|
||||
raise ValueError("stdio_config is required for stdio transport")
|
||||
server_params = StdioServerParameters(
|
||||
command=self.stdio_config.get("command", ""),
|
||||
args=self.stdio_config.get("args", []),
|
||||
env=self.stdio_config.get("env", {}),
|
||||
)
|
||||
return stdio_client(server_params), None
|
||||
|
||||
if self.transport_type == MCPTransport.sse:
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
return (
|
||||
sse_client(
|
||||
url=self.server_url,
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
httpx_client_factory=httpx_client_factory,
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# HTTP transport (default)
|
||||
if streamable_http_client is None:
|
||||
raise ImportError(
|
||||
"streamable_http_client is not available. "
|
||||
"Please install mcp with HTTP support."
|
||||
)
|
||||
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
verbose_logger.debug("litellm headers for streamable_http_client: %s", headers)
|
||||
http_client = httpx_client_factory(
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout),
|
||||
)
|
||||
transport_ctx = streamable_http_client(
|
||||
url=self.server_url,
|
||||
http_client=http_client,
|
||||
)
|
||||
return transport_ctx, http_client
|
||||
|
||||
async def _execute_session_operation(
|
||||
self,
|
||||
transport_ctx: Any,
|
||||
operation: Callable[[ClientSession], Awaitable[TSessionResult]],
|
||||
) -> TSessionResult:
|
||||
"""
|
||||
Execute an operation within a transport and session context.
|
||||
|
||||
Handles entering/exiting contexts and running the operation.
|
||||
"""
|
||||
transport = await transport_ctx.__aenter__()
|
||||
try:
|
||||
read_stream, write_stream = transport[0], transport[1]
|
||||
session_ctx = ClientSession(read_stream, write_stream)
|
||||
session = await session_ctx.__aenter__()
|
||||
try:
|
||||
await session.initialize()
|
||||
return await operation(session)
|
||||
finally:
|
||||
try:
|
||||
await session_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during session context exit: {e}")
|
||||
finally:
|
||||
try:
|
||||
await transport_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during transport context exit: {e}")
|
||||
|
||||
async def run_with_session(
|
||||
self, operation: Callable[[ClientSession], Awaitable[TSessionResult]]
|
||||
) -> TSessionResult:
|
||||
"""Open a session, run the provided coroutine, and clean up."""
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
try:
|
||||
transport_ctx, http_client = self._create_transport_context()
|
||||
return await self._execute_session_operation(transport_ctx, operation)
|
||||
except Exception:
|
||||
verbose_logger.warning(
|
||||
"MCP client run_with_session failed for %s", self.server_url or "stdio"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if http_client is not None:
|
||||
try:
|
||||
await http_client.aclose()
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during http_client cleanup: {e}")
|
||||
|
||||
def update_auth_value(self, mcp_auth_value: Union[str, Dict[str, str]]):
|
||||
"""
|
||||
Set the authentication header for the MCP client.
|
||||
"""
|
||||
if isinstance(mcp_auth_value, dict):
|
||||
self._mcp_auth_value = mcp_auth_value
|
||||
else:
|
||||
if self.auth_type == MCPAuth.basic:
|
||||
# Assuming mcp_auth_value is in format "username:password", convert it when updating
|
||||
mcp_auth_value = to_basic_auth(mcp_auth_value)
|
||||
self._mcp_auth_value = mcp_auth_value
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
"""Generate authentication headers based on auth type."""
|
||||
headers = {}
|
||||
|
||||
if self._mcp_auth_value:
|
||||
if isinstance(self._mcp_auth_value, str):
|
||||
if self.auth_type == MCPAuth.bearer_token:
|
||||
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.basic:
|
||||
headers["Authorization"] = f"Basic {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.api_key:
|
||||
headers["X-API-Key"] = self._mcp_auth_value
|
||||
elif self.auth_type == MCPAuth.authorization:
|
||||
headers["Authorization"] = self._mcp_auth_value
|
||||
elif self.auth_type == MCPAuth.oauth2:
|
||||
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.token:
|
||||
headers["Authorization"] = f"token {self._mcp_auth_value}"
|
||||
elif isinstance(self._mcp_auth_value, dict):
|
||||
headers.update(self._mcp_auth_value)
|
||||
# Note: aws_sigv4 auth is not handled here — SigV4 requires per-request
|
||||
# signing (including the body hash), so it uses httpx.Auth flow instead
|
||||
# of static headers. See MCPSigV4Auth and _create_httpx_client_factory().
|
||||
|
||||
# update the headers with the extra headers
|
||||
if self.extra_headers:
|
||||
headers.update(self.extra_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]:
|
||||
"""
|
||||
Create a custom httpx client factory that uses LiteLLM's SSL configuration.
|
||||
|
||||
This factory follows the same CA bundle path logic as http_handler.py:
|
||||
1. Check ssl_verify parameter (can be SSLContext, bool, or path to CA bundle)
|
||||
2. Check SSL_VERIFY environment variable
|
||||
3. Check SSL_CERT_FILE environment variable
|
||||
4. Fall back to certifi CA bundle
|
||||
"""
|
||||
|
||||
def factory(
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
auth: Optional[httpx.Auth] = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an httpx.AsyncClient with LiteLLM's SSL configuration."""
|
||||
# Get unified SSL configuration using the same logic as http_handler.py
|
||||
ssl_config = get_ssl_configuration(self.ssl_verify)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"MCP client using SSL configuration: {type(ssl_config).__name__}"
|
||||
)
|
||||
|
||||
# Use SigV4 auth if configured and no explicit auth provided.
|
||||
# The MCP SDK's sse_client and streamable_http_client call this
|
||||
# factory without passing auth=, so self._aws_auth is used.
|
||||
# For non-SigV4 clients, self._aws_auth is None — no behavior change.
|
||||
effective_auth = auth if auth is not None else self._aws_auth
|
||||
|
||||
return httpx.AsyncClient(
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
auth=effective_auth,
|
||||
verify=ssl_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""List available tools from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing tools from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_tools_operation(session: ClientSession):
|
||||
return await session.list_tools()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_tools_operation)
|
||||
tool_count = len(result.tools)
|
||||
tool_names = [tool.name for tool in result.tools]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {tool_count} tools from {self.server_url or 'stdio'}: {tool_names}"
|
||||
)
|
||||
return result.tools
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_tools was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.exception(
|
||||
f"MCP client list_tools failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_tools - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
host_progress_callback: Optional[Callable] = None,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an MCP Tool.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"MCP client calling tool '{call_tool_request_params.name}' with arguments: {call_tool_request_params.arguments}"
|
||||
)
|
||||
|
||||
async def on_progress(
|
||||
progress: float, total: float | None, message: str | None
|
||||
):
|
||||
percentage = (progress / total * 100) if total else 0
|
||||
verbose_logger.info(
|
||||
f"MCP Tool '{call_tool_request_params.name}' progress: "
|
||||
f"{progress}/{total} ({percentage:.0f}%) - {message or ''}"
|
||||
)
|
||||
|
||||
# Forward to Host if callback provided
|
||||
if host_progress_callback:
|
||||
try:
|
||||
await host_progress_callback(progress, total)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to forward to Host: {e}")
|
||||
|
||||
async def _call_tool_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending tool call to session")
|
||||
return await session.call_tool(
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
progress_callback=on_progress,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result = await self.run_with_session(_call_tool_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client tool call '{call_tool_request_params.name}' completed successfully"
|
||||
)
|
||||
return tool_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client tool call was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client tool call traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client call_tool failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Tool: {call_tool_request_params.name}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
# Return a default error result instead of raising
|
||||
return MCPCallToolResult(
|
||||
content=[
|
||||
TextContent(type="text", text=f"{error_type}: {str(e)}")
|
||||
], # Empty content for error case
|
||||
isError=True,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> List[Prompt]:
|
||||
"""List available prompts from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing tools from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_prompts_operation(session: ClientSession):
|
||||
return await session.list_prompts()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_prompts_operation)
|
||||
prompt_count = len(result.prompts)
|
||||
prompt_names = [prompt.name for prompt in result.prompts]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {prompt_count} tools from {self.server_url or 'stdio'}: {prompt_names}"
|
||||
)
|
||||
return result.prompts
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_prompts was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_prompts failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_tools - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def get_prompt(
|
||||
self, get_prompt_request_params: GetPromptRequestParams
|
||||
) -> GetPromptResult:
|
||||
"""Fetch a prompt definition from the MCP server."""
|
||||
verbose_logger.info(
|
||||
f"MCP client fetching prompt '{get_prompt_request_params.name}' with arguments: {get_prompt_request_params.arguments}"
|
||||
)
|
||||
|
||||
async def _get_prompt_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending get_prompt request to session")
|
||||
return await session.get_prompt(
|
||||
name=get_prompt_request_params.name,
|
||||
arguments=get_prompt_request_params.arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
get_prompt_result = await self.run_with_session(_get_prompt_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client get_prompt '{get_prompt_request_params.name}' completed successfully"
|
||||
)
|
||||
return get_prompt_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client get_prompt was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client get_prompt traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client get_prompt failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Prompt: {get_prompt_request_params.name}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during get_prompt - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""List available resources from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing resources from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_resources_operation(session: ClientSession):
|
||||
return await session.list_resources()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_resources_operation)
|
||||
resource_count = len(result.resources)
|
||||
resource_names = [resource.name for resource in result.resources]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {resource_count} resources from {self.server_url or 'stdio'}: {resource_names}"
|
||||
)
|
||||
return result.resources
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_resources was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_resources failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_resources - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def list_resource_templates(self) -> list[ResourceTemplate]:
|
||||
"""List available resource templates from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing resource templates from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_resource_templates_operation(session: ClientSession):
|
||||
return await session.list_resource_templates()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_resource_templates_operation)
|
||||
resource_template_count = len(result.resourceTemplates)
|
||||
resource_template_names = [
|
||||
resourceTemplate.name for resourceTemplate in result.resourceTemplates
|
||||
]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {resource_template_count} resource templates from {self.server_url or 'stdio'}: {resource_template_names}"
|
||||
)
|
||||
return result.resourceTemplates
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_resource_templates was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_resource_templates failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_resource_templates - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def read_resource(self, url: AnyUrl) -> ReadResourceResult:
|
||||
"""Fetch resource contents from the MCP server."""
|
||||
verbose_logger.info(f"MCP client fetching resource '{url}'")
|
||||
|
||||
async def _read_resource_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending read_resource request to session")
|
||||
return await session.read_resource(url)
|
||||
|
||||
try:
|
||||
read_resource_result = await self.run_with_session(_read_resource_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client read_resource '{url}' completed successfully"
|
||||
)
|
||||
return read_resource_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client read_resource was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client read_resource traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client read_resource failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Url: {url}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during read_resource - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
raise
|
||||
@@ -0,0 +1,159 @@
|
||||
import json
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import Tool as MCPTool
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from openai.types.responses.function_tool_param import FunctionToolParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall
|
||||
|
||||
|
||||
########################################################
|
||||
# List MCP Tool functions
|
||||
########################################################
|
||||
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI tool."""
|
||||
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
|
||||
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
parameters=normalized_parameters,
|
||||
strict=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_mcp_input_schema(input_schema: dict) -> dict:
|
||||
"""
|
||||
Normalize MCP input schema to ensure it's valid for OpenAI function calling.
|
||||
|
||||
OpenAI requires that function parameters have:
|
||||
- type: 'object'
|
||||
- properties: dict (can be empty)
|
||||
- additionalProperties: false (recommended)
|
||||
"""
|
||||
if not input_schema:
|
||||
return {"type": "object", "properties": {}, "additionalProperties": False}
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
normalized_schema = dict(input_schema)
|
||||
|
||||
# Ensure type is 'object'
|
||||
if "type" not in normalized_schema:
|
||||
normalized_schema["type"] = "object"
|
||||
|
||||
# Ensure properties exists (can be empty)
|
||||
if "properties" not in normalized_schema:
|
||||
normalized_schema["properties"] = {}
|
||||
|
||||
# Add additionalProperties if not present (recommended by OpenAI)
|
||||
if "additionalProperties" not in normalized_schema:
|
||||
normalized_schema["additionalProperties"] = False
|
||||
|
||||
return normalized_schema
|
||||
|
||||
|
||||
def transform_mcp_tool_to_openai_responses_api_tool(
|
||||
mcp_tool: MCPTool,
|
||||
) -> FunctionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI Responses API tool."""
|
||||
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
|
||||
|
||||
return FunctionToolParam(
|
||||
name=mcp_tool.name,
|
||||
parameters=normalized_parameters,
|
||||
strict=False,
|
||||
type="function",
|
||||
description=mcp_tool.description or "",
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
||||
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
||||
"""
|
||||
Load all available MCP tools
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
format: The format to convert the tools to
|
||||
By default, the tools are returned in MCP format.
|
||||
|
||||
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
|
||||
"""
|
||||
tools = await session.list_tools()
|
||||
if format == "openai":
|
||||
return [
|
||||
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
|
||||
]
|
||||
return tools.tools
|
||||
|
||||
|
||||
########################################################
|
||||
# Call MCP Tool functions
|
||||
########################################################
|
||||
|
||||
|
||||
async def call_mcp_tool(
|
||||
session: ClientSession,
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
) -> MCPCallToolResult:
|
||||
"""Call an MCP tool."""
|
||||
tool_result = await session.call_tool(
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
)
|
||||
return tool_result
|
||||
|
||||
|
||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||
"""Helper to safely get and parse function arguments."""
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
|
||||
) -> MCPCallToolRequestParams:
|
||||
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||
function = openai_tool["function"]
|
||||
return MCPCallToolRequestParams(
|
||||
name=function["name"],
|
||||
arguments=_get_function_arguments(function),
|
||||
)
|
||||
|
||||
|
||||
async def call_openai_tool(
|
||||
session: ClientSession,
|
||||
openai_tool: ChatCompletionMessageToolCall,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an OpenAI tool using MCP client.
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
|
||||
Returns:
|
||||
The result of the MCP tool call.
|
||||
"""
|
||||
mcp_tool_call_request_params = (
|
||||
transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
)
|
||||
return await call_mcp_tool(
|
||||
session=session,
|
||||
call_tool_request_params=mcp_tool_call_request_params,
|
||||
)
|
||||
984
llm-gateway-competitors/litellm-wheel-src/litellm/files/main.py
Normal file
984
llm-gateway-competitors/litellm-wheel-src/litellm/files/main.py
Normal file
@@ -0,0 +1,984 @@
|
||||
"""
|
||||
Main File for Files API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/files
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import time
|
||||
import uuid as uuid_module
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
# Type aliases for provider parameters
|
||||
FileCreateProvider = Literal[
|
||||
"openai",
|
||||
"azure",
|
||||
"gemini",
|
||||
"vertex_ai",
|
||||
"bedrock",
|
||||
"hosted_vllm",
|
||||
"manus",
|
||||
"anthropic",
|
||||
]
|
||||
FileRetrieveProvider = Literal[
|
||||
"openai", "azure", "gemini", "vertex_ai", "hosted_vllm", "manus", "anthropic"
|
||||
]
|
||||
FileDeleteProvider = Literal["openai", "azure", "gemini", "manus", "anthropic"]
|
||||
FileListProvider = Literal["openai", "azure", "manus", "anthropic"]
|
||||
FileContentProvider = Literal[
|
||||
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
|
||||
]
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret_str
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure.common_utils import get_azure_credentials
|
||||
from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
|
||||
from litellm.llms.bedrock.files.handler import BedrockFilesHandler
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.openai.common_utils import get_openai_credentials
|
||||
from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||
from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
|
||||
from litellm.types.llms.openai import (
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
FileExpiresAfter,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.router import *
|
||||
from litellm.types.utils import (
|
||||
OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS,
|
||||
LlmProviders,
|
||||
)
|
||||
from litellm.utils import (
|
||||
ProviderConfigManager,
|
||||
client,
|
||||
get_litellm_params,
|
||||
supports_httpx_timeout,
|
||||
)
|
||||
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
azure_files_instance = AzureOpenAIFilesAPI()
|
||||
vertex_ai_files_instance = VertexAIFilesHandler()
|
||||
bedrock_files_instance = BedrockFilesHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
@client
|
||||
async def acreate_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune", "messages"],
|
||||
expires_after: Optional[FileExpiresAfter] = None,
|
||||
custom_llm_provider: FileCreateProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_file"] = True
|
||||
|
||||
call_args = {
|
||||
"file": file,
|
||||
"purpose": purpose,
|
||||
"expires_after": expires_after,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"extra_headers": extra_headers,
|
||||
"extra_body": extra_body,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(create_file, **call_args)
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def create_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune", "messages"],
|
||||
expires_after: Optional[FileExpiresAfter] = None,
|
||||
custom_llm_provider: Optional[FileCreateProvider] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
|
||||
"""
|
||||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
|
||||
Specify either provider_list or custom_llm_provider.
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = dict(**kwargs)
|
||||
logging_obj = cast(
|
||||
Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
|
||||
)
|
||||
if logging_obj is None:
|
||||
raise ValueError("logging_obj is required")
|
||||
client = kwargs.get("client")
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
if expires_after is not None:
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
expires_after=expires_after,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
else:
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
response = base_llm_http_handler.create_file(
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
create_file_data=_create_file_request,
|
||||
headers=extra_headers or {},
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
elif custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai', 'manus', 'anthropic'] are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def afile_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileRetrieveProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Async: Get file contents
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_retrieve,
|
||||
file_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return OpenAIFileObject(**response.model_dump())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileRetrieveProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
|
||||
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.retrieve_file(
|
||||
file_id=file_id,
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.retrieve_file(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_id=file_id,
|
||||
)
|
||||
else:
|
||||
# Try using provider config pattern (for Manus, Bedrock, etc.)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_retrieve" if _is_async else "file_retrieve",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id") or ""),
|
||||
)
|
||||
|
||||
client = kwargs.get("client")
|
||||
response = base_llm_http_handler.retrieve_file(
|
||||
file_id=file_id,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai', 'azure', 'manus', and 'anthropic' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
return cast(FileObject, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Delete file
|
||||
@client
|
||||
async def afile_delete(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileDeleteProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Coroutine[Any, Any, FileObject]:
|
||||
"""
|
||||
Async: Delete file
|
||||
|
||||
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
model = kwargs.pop("model", None)
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_delete,
|
||||
file_id,
|
||||
model,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return cast(FileDeleted, response) # type: ignore
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_delete(
|
||||
file_id: str,
|
||||
model: Optional[str] = None,
|
||||
custom_llm_provider: Union[FileDeleteProvider, str] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FileDeleted:
|
||||
"""
|
||||
Delete file
|
||||
|
||||
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
if model is not None:
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
client = kwargs.get("client")
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.delete_file(
|
||||
file_id=file_id,
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.delete_file(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_id=file_id,
|
||||
client=client,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
# Try using provider config pattern (for Manus, Bedrock, etc.)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_delete" if _is_async else "file_delete",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id") or ""),
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.delete_file(
|
||||
file_id=file_id,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_delete'. Only 'openai', 'azure', 'gemini', 'manus', and 'anthropic' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return cast(FileDeleted, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# List files
|
||||
@client
|
||||
async def afile_list(
|
||||
custom_llm_provider: FileListProvider = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: List files
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_list,
|
||||
custom_llm_provider,
|
||||
purpose,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_list(
|
||||
custom_llm_provider: FileListProvider = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
List files
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
|
||||
# Check if provider has a custom files config (e.g., Manus, Bedrock, Vertex AI)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_list" if _is_async else "file_list",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id", "")),
|
||||
)
|
||||
|
||||
client = kwargs.get("client")
|
||||
response = base_llm_http_handler.list_files(
|
||||
purpose=purpose,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
elif custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.list_files(
|
||||
purpose=purpose,
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.list_files(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
purpose=purpose,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai', 'azure', 'manus', and 'anthropic' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="file_list", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def afile_content(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileContentProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Async: Get file contents
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["afile_content"] = True
|
||||
model = kwargs.pop("model", None)
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_content,
|
||||
file_id,
|
||||
model,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_content(
|
||||
file_id: str,
|
||||
model: Optional[str] = None,
|
||||
custom_llm_provider: Optional[Union[FileContentProvider, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
client = kwargs.get("client")
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
try:
|
||||
if model is not None:
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_file_content_request = FileContentRequest(
|
||||
file_id=file_id,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
_is_async = kwargs.pop("afile_content", False) is True
|
||||
|
||||
# Check if provider has a custom files config (e.g., Anthropic, Manus)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_content" if _is_async else "file_content",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id") or ""),
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.retrieve_file_content(
|
||||
file_content_request=_file_content_request,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
file_content_request=_file_content_request,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_content_request=_file_content_request,
|
||||
client=client,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
file_content_request=_file_content_request,
|
||||
api_base=api_base,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
response = bedrock_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
file_content_request=_file_content_request,
|
||||
api_base=optional_params.api_base,
|
||||
optional_params=litellm_params_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_content'. Supported providers are 'openai', 'azure', 'vertex_ai', 'bedrock', 'manus', 'anthropic'.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.types.llms.openai import CreateFileRequest
|
||||
from litellm.types.utils import ExtractedFileData
|
||||
|
||||
|
||||
class FilesAPIUtils:
|
||||
"""
|
||||
Utils for files API interface on litellm
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_batch_jsonl_file(
|
||||
create_file_data: CreateFileRequest, extracted_file_data: ExtractedFileData
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the file is a batch jsonl file
|
||||
"""
|
||||
return (
|
||||
create_file_data.get("purpose") == "batch"
|
||||
and FilesAPIUtils.valid_content_type(
|
||||
extracted_file_data.get("content_type")
|
||||
)
|
||||
and extracted_file_data.get("content") is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def valid_content_type(content_type: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the content type is valid
|
||||
"""
|
||||
return content_type in set(["application/jsonl", "application/octet-stream"])
|
||||
@@ -0,0 +1,826 @@
|
||||
"""
|
||||
Main File for Fine Tuning API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/fine-tuning
|
||||
|
||||
- fine_tuning.jobs.create()
|
||||
- fine_tuning.jobs.list()
|
||||
- client.fine_tuning.jobs.list_events()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure.fine_tuning.handler import AzureOpenAIFineTuningAPI
|
||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||
from litellm.llms.vertex_ai.fine_tuning.handler import VertexFineTuningAPI
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import FineTuningJobCreate, Hyperparameters
|
||||
from litellm.types.router import *
|
||||
from litellm.types.utils import LiteLLMFineTuningJob
|
||||
from litellm.utils import client, supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
|
||||
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
|
||||
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
|
||||
#################################################
|
||||
|
||||
|
||||
def _prepare_azure_extra_body(
|
||||
extra_body: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
azure_specific_hyperparams: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare extra_body for Azure fine-tuning API by combining Azure-specific parameters.
|
||||
|
||||
Azure fine-tuning API accepts additional parameters beyond the standard OpenAI spec:
|
||||
- trainingType: Type of training (e.g., 1 for supervised fine-tuning)
|
||||
- prompt_loss_weight: Weight for prompt loss in training
|
||||
|
||||
These parameters must be passed in the extra_body field when calling the Azure OpenAI SDK.
|
||||
|
||||
Args:
|
||||
extra_body: Optional existing extra_body dict
|
||||
kwargs: Request kwargs that may contain Azure-specific parameters
|
||||
azure_specific_hyperparams: Dict of Azure-specific hyperparameters already extracted
|
||||
|
||||
Returns:
|
||||
Dict containing all Azure-specific parameters to be passed in extra_body
|
||||
"""
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
|
||||
# Azure-specific root-level parameters
|
||||
azure_specific_params = ["trainingType"]
|
||||
for param in azure_specific_params:
|
||||
if param in kwargs:
|
||||
extra_body[param] = kwargs[param]
|
||||
|
||||
# Add Azure-specific hyperparameters
|
||||
if azure_specific_hyperparams:
|
||||
extra_body.update(azure_specific_hyperparams)
|
||||
|
||||
return extra_body
|
||||
|
||||
|
||||
@client
|
||||
async def acreate_fine_tuning_job(
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: Optional[dict] = {},
|
||||
suffix: Optional[str] = None,
|
||||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> LiteLLMFineTuningJob:
|
||||
"""
|
||||
Async: Creates and executes a batch from an uploaded file of request
|
||||
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
create_fine_tuning_job,
|
||||
model,
|
||||
training_file,
|
||||
hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _build_fine_tuning_job_data(
|
||||
model, training_file, hyperparameters, suffix, validation_file, integrations, seed
|
||||
):
|
||||
return FineTuningJobCreate(
|
||||
model=model,
|
||||
training_file=training_file,
|
||||
hyperparameters=hyperparameters,
|
||||
suffix=suffix,
|
||||
validation_file=validation_file,
|
||||
integrations=integrations,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_fine_tuning_timeout(
|
||||
timeout: Any,
|
||||
custom_llm_provider: str,
|
||||
) -> Union[float, httpx.Timeout]:
|
||||
"""Normalise a raw timeout value to a float (seconds) or httpx.Timeout for fine-tuning calls."""
|
||||
timeout = timeout or 600.0
|
||||
if isinstance(timeout, httpx.Timeout):
|
||||
if not supports_httpx_timeout(custom_llm_provider):
|
||||
return float(timeout.read or 600)
|
||||
return timeout
|
||||
return float(timeout)
|
||||
|
||||
|
||||
@client
|
||||
def create_fine_tuning_job(
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: Optional[dict] = {},
|
||||
suffix: Optional[str] = None,
|
||||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
"""
|
||||
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
|
||||
|
||||
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
||||
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
# handle hyperparameters
|
||||
hyperparameters = hyperparameters or {} # original hyperparameters
|
||||
|
||||
# For Azure, extract Azure-specific hyperparameters before creating OpenAI-spec hyperparameters
|
||||
azure_specific_hyperparams = {}
|
||||
if custom_llm_provider == "azure":
|
||||
azure_hyperparameter_keys = ["prompt_loss_weight"]
|
||||
for key in azure_hyperparameter_keys:
|
||||
if key in hyperparameters:
|
||||
azure_specific_hyperparams[key] = hyperparameters.pop(key)
|
||||
|
||||
_oai_hyperparameters: Hyperparameters = Hyperparameters(
|
||||
**hyperparameters
|
||||
) # Typed Hyperparameters for OpenAI Spec
|
||||
timeout = _resolve_fine_tuning_timeout(
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600),
|
||||
custom_llm_provider,
|
||||
)
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data_dict = _build_fine_tuning_job_data(
|
||||
model,
|
||||
training_file,
|
||||
_oai_hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get(
|
||||
"client", None
|
||||
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
# Prepare Azure-specific parameters for extra_body
|
||||
extra_body = _prepare_azure_extra_body(
|
||||
extra_body, kwargs, azure_specific_hyperparams
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data_dict = _build_fine_tuning_job_data(
|
||||
model,
|
||||
training_file,
|
||||
_oai_hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
# Add extra_body if it has Azure-specific parameters
|
||||
if extra_body:
|
||||
create_fine_tuning_job_data_dict["extra_body"] = extra_body
|
||||
|
||||
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
_is_async=_is_async,
|
||||
create_fine_tuning_job_data=_build_fine_tuning_job_data(
|
||||
model,
|
||||
training_file,
|
||||
_oai_hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
),
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
timeout=timeout,
|
||||
api_base=api_base,
|
||||
kwargs=kwargs,
|
||||
original_hyperparameters=hyperparameters,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def acancel_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> LiteLLMFineTuningJob:
|
||||
"""
|
||||
Async: Immediately cancel a fine-tune job.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acancel_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
cancel_fine_tuning_job,
|
||||
fine_tuning_job_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def cancel_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
"""
|
||||
Immediately cancel a fine-tune job.
|
||||
|
||||
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
||||
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def alist_fine_tuning_jobs(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: List your organization's fine-tuning jobs
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["alist_fine_tuning_jobs"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
list_fine_tuning_jobs,
|
||||
after,
|
||||
limit,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def list_fine_tuning_jobs(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
List your organization's fine-tuning jobs
|
||||
|
||||
Params:
|
||||
|
||||
- after: Optional[str] = None, Identifier for the last job from the previous pagination request.
|
||||
- limit: Optional[int] = None, Number of fine-tuning jobs to retrieve. Defaults to 20
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
after=after,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
after=after,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def aretrieve_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> LiteLLMFineTuningJob:
|
||||
"""
|
||||
Async: Get info about a fine-tuning job.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aretrieve_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
retrieve_fine_tuning_job,
|
||||
fine_tuning_job_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def retrieve_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
"""
|
||||
Get info about a fine-tuning job.
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None
|
||||
)
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,123 @@
|
||||
# LiteLLM Google GenAI Interface
|
||||
|
||||
Interface to interact with Google GenAI Functions in the native Google interface format.
|
||||
|
||||
## Overview
|
||||
|
||||
This module provides a native interface to Google's Generative AI API, allowing you to use Google's content generation capabilities with both streaming and non-streaming modes, in both synchronous and asynchronous contexts.
|
||||
|
||||
## Available Functions
|
||||
|
||||
### Non-Streaming Functions
|
||||
|
||||
- `generate_content()` - Synchronous content generation
|
||||
- `agenerate_content()` - Asynchronous content generation
|
||||
|
||||
### Streaming Functions
|
||||
|
||||
- `generate_content_stream()` - Synchronous streaming content generation
|
||||
- `agenerate_content_stream()` - Asynchronous streaming content generation
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Non-Streaming Usage
|
||||
|
||||
```python
|
||||
from litellm.google_genai import generate_content, agenerate_content
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
# Synchronous usage
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Hello, can you tell me a short joke?")
|
||||
],
|
||||
)
|
||||
|
||||
response = generate_content(
|
||||
contents=contents,
|
||||
model="gemini-pro", # or your preferred model
|
||||
# Add other model-specific parameters as needed
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Async Non-Streaming Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from litellm.google_genai import agenerate_content
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
async def main():
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Hello, can you tell me a short joke?")
|
||||
],
|
||||
)
|
||||
|
||||
response = await agenerate_content(
|
||||
contents=contents,
|
||||
model="gemini-pro",
|
||||
# Add other model-specific parameters as needed
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
# Run the async function
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Streaming Usage
|
||||
|
||||
```python
|
||||
from litellm.google_genai import generate_content_stream
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
# Synchronous streaming
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Tell me a story about space exploration")
|
||||
],
|
||||
)
|
||||
|
||||
for chunk in generate_content_stream(
|
||||
contents=contents,
|
||||
model="gemini-pro",
|
||||
):
|
||||
print(f"Chunk: {chunk}")
|
||||
```
|
||||
|
||||
### Async Streaming Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from litellm.google_genai import agenerate_content_stream
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
async def main():
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Tell me a story about space exploration")
|
||||
],
|
||||
)
|
||||
|
||||
async for chunk in agenerate_content_stream(
|
||||
contents=contents,
|
||||
model="gemini-pro",
|
||||
):
|
||||
print(f"Async chunk: {chunk}")
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
|
||||
## Testing
|
||||
|
||||
This module includes comprehensive tests covering:
|
||||
- Sync and async non-streaming requests
|
||||
- Sync and async streaming requests
|
||||
- Response validation
|
||||
- Error handling scenarios
|
||||
|
||||
See `tests/unified_google_tests/base_google_test.py` for test implementation examples.
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
This allows using Google GenAI model in their native interface.
|
||||
|
||||
This module provides generate_content functionality for Google GenAI models.
|
||||
"""
|
||||
|
||||
from .main import (
|
||||
agenerate_content,
|
||||
agenerate_content_stream,
|
||||
generate_content,
|
||||
generate_content_stream,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"generate_content",
|
||||
"agenerate_content",
|
||||
"generate_content_stream",
|
||||
"agenerate_content_stream",
|
||||
]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Google GenAI Adapters for LiteLLM
|
||||
|
||||
This module provides adapters for transforming Google GenAI generate_content requests
|
||||
to/from LiteLLM completion format with full support for:
|
||||
- Text content transformation
|
||||
- Tool calling (function declarations, function calls, function responses)
|
||||
- Streaming (both regular and tool calling)
|
||||
- Mixed content (text + tool calls)
|
||||
"""
|
||||
|
||||
from .handler import GenerateContentToCompletionHandler
|
||||
from .transformation import GoogleGenAIAdapter, GoogleGenAIStreamWrapper
|
||||
|
||||
__all__ = [
|
||||
"GoogleGenAIAdapter",
|
||||
"GoogleGenAIStreamWrapper",
|
||||
"GenerateContentToCompletionHandler",
|
||||
]
|
||||
@@ -0,0 +1,183 @@
|
||||
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from .transformation import GoogleGenAIAdapter
|
||||
|
||||
# Initialize adapter
|
||||
GOOGLE_GENAI_ADAPTER = GoogleGenAIAdapter()
|
||||
|
||||
|
||||
class GenerateContentToCompletionHandler:
|
||||
"""Handler for transforming generate_content calls to completion format when provider config is None"""
|
||||
|
||||
@staticmethod
|
||||
def _prepare_completion_kwargs(
|
||||
model: str,
|
||||
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
litellm_params: Optional[GenericLiteLLMParams] = None,
|
||||
extra_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare kwargs for litellm.completion/acompletion"""
|
||||
|
||||
# Transform generate_content request to completion format
|
||||
completion_request = (
|
||||
GOOGLE_GENAI_ADAPTER.translate_generate_content_to_completion(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
litellm_params=litellm_params,
|
||||
**(extra_kwargs or {}),
|
||||
)
|
||||
)
|
||||
|
||||
completion_kwargs: Dict[str, Any] = dict(completion_request)
|
||||
|
||||
# Forward extra_kwargs that should be passed to completion call
|
||||
if extra_kwargs is not None:
|
||||
# Forward metadata for custom callback
|
||||
if "metadata" in extra_kwargs:
|
||||
completion_kwargs["metadata"] = extra_kwargs["metadata"]
|
||||
# Forward extra_headers for providers that require custom headers (e.g., github_copilot)
|
||||
if "extra_headers" in extra_kwargs:
|
||||
completion_kwargs["extra_headers"] = extra_kwargs["extra_headers"]
|
||||
|
||||
if stream:
|
||||
completion_kwargs["stream"] = stream
|
||||
|
||||
return completion_kwargs
|
||||
|
||||
@staticmethod
|
||||
async def async_generate_content_handler(
|
||||
model: str,
|
||||
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[Dict[str, Any], AsyncIterator[bytes]]:
|
||||
"""Handle generate_content call asynchronously using completion adapter"""
|
||||
|
||||
completion_kwargs = (
|
||||
GenerateContentToCompletionHandler._prepare_completion_kwargs(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
completion_response = await litellm.acompletion(**completion_kwargs)
|
||||
|
||||
if stream:
|
||||
# Check if completion_response is actually a stream or a ModelResponse
|
||||
# This can happen in error cases or when stream is not properly supported
|
||||
if not hasattr(completion_response, "__aiter__"):
|
||||
# If it's not a stream, treat it as a regular response
|
||||
generate_content_response = (
|
||||
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
|
||||
cast(ModelResponse, completion_response)
|
||||
)
|
||||
)
|
||||
return generate_content_response
|
||||
else:
|
||||
# Transform streaming completion response to generate_content format
|
||||
transformed_stream = GOOGLE_GENAI_ADAPTER.translate_completion_output_params_streaming(
|
||||
completion_response
|
||||
)
|
||||
if transformed_stream is not None:
|
||||
return transformed_stream
|
||||
raise ValueError("Failed to transform streaming response")
|
||||
else:
|
||||
# Transform completion response back to generate_content format
|
||||
generate_content_response = (
|
||||
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
|
||||
cast(ModelResponse, completion_response)
|
||||
)
|
||||
)
|
||||
return generate_content_response
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error calling litellm.acompletion for generate_content: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_content_handler(
|
||||
model: str,
|
||||
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
_is_async: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
Dict[str, Any],
|
||||
AsyncIterator[bytes],
|
||||
Coroutine[Any, Any, Union[Dict[str, Any], AsyncIterator[bytes]]],
|
||||
]:
|
||||
"""Handle generate_content call using completion adapter"""
|
||||
|
||||
if _is_async:
|
||||
return GenerateContentToCompletionHandler.async_generate_content_handler(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
completion_kwargs = (
|
||||
GenerateContentToCompletionHandler._prepare_completion_kwargs(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
completion_response = litellm.completion(**completion_kwargs)
|
||||
|
||||
if stream:
|
||||
# Check if completion_response is actually a stream or a ModelResponse
|
||||
# This can happen in error cases or when stream is not properly supported
|
||||
if not hasattr(completion_response, "__iter__"):
|
||||
# If it's not a stream, treat it as a regular response
|
||||
generate_content_response = (
|
||||
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
|
||||
cast(ModelResponse, completion_response)
|
||||
)
|
||||
)
|
||||
return generate_content_response
|
||||
else:
|
||||
# Transform streaming completion response to generate_content format
|
||||
transformed_stream = GOOGLE_GENAI_ADAPTER.translate_completion_output_params_streaming(
|
||||
completion_response
|
||||
)
|
||||
if transformed_stream is not None:
|
||||
return transformed_stream
|
||||
raise ValueError("Failed to transform streaming response")
|
||||
else:
|
||||
# Transform completion response back to generate_content format
|
||||
generate_content_response = (
|
||||
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
|
||||
cast(ModelResponse, completion_response)
|
||||
)
|
||||
)
|
||||
return generate_content_response
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error calling litellm.completion for generate_content: {str(e)}"
|
||||
)
|
||||
@@ -0,0 +1,783 @@
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.json_validation_rule import normalize_tool_schema
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAssistantToolCall,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionSystemMessage,
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolChoiceValues,
|
||||
ChatCompletionToolMessage,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessage,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import (
|
||||
AdapterCompletionStreamWrapper,
|
||||
Choices,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenAIStreamWrapper(AdapterCompletionStreamWrapper):
|
||||
"""
|
||||
Wrapper for streaming Google GenAI generate_content responses.
|
||||
Transforms OpenAI streaming chunks to Google GenAI format.
|
||||
"""
|
||||
|
||||
sent_first_chunk: bool = False
|
||||
# State tracking for accumulating partial tool calls
|
||||
accumulated_tool_calls: Dict[str, Dict[str, Any]]
|
||||
|
||||
def __init__(self, completion_stream: Any):
|
||||
self.sent_first_chunk = False
|
||||
self.accumulated_tool_calls = {}
|
||||
self._returned_response = False
|
||||
super().__init__(completion_stream)
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
if not hasattr(self.completion_stream, "__iter__"):
|
||||
if self._returned_response:
|
||||
raise StopIteration
|
||||
self._returned_response = True
|
||||
return GoogleGenAIAdapter().translate_completion_to_generate_content(
|
||||
self.completion_stream
|
||||
)
|
||||
|
||||
for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
continue
|
||||
|
||||
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
|
||||
chunk, self
|
||||
)
|
||||
if transformed_chunk:
|
||||
return transformed_chunk
|
||||
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
raise
|
||||
except Exception:
|
||||
raise StopIteration
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
if not hasattr(self.completion_stream, "__aiter__"):
|
||||
if self._returned_response:
|
||||
raise StopAsyncIteration
|
||||
self._returned_response = True
|
||||
return GoogleGenAIAdapter().translate_completion_to_generate_content(
|
||||
self.completion_stream
|
||||
)
|
||||
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
continue
|
||||
|
||||
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
|
||||
chunk, self
|
||||
)
|
||||
if transformed_chunk:
|
||||
return transformed_chunk
|
||||
|
||||
# After the stream is exhausted, check for any remaining accumulated tool calls
|
||||
if self.accumulated_tool_calls:
|
||||
try:
|
||||
parts = []
|
||||
for (
|
||||
tool_call_index,
|
||||
tool_call_data,
|
||||
) in self.accumulated_tool_calls.items():
|
||||
try:
|
||||
# For tool calls with no arguments, accumulated_args will be "", which is not valid JSON.
|
||||
# We default to an empty JSON object in this case.
|
||||
parsed_args = json.loads(
|
||||
tool_call_data["arguments"] or "{}"
|
||||
)
|
||||
function_call_part = {
|
||||
"functionCall": {
|
||||
"name": tool_call_data["name"]
|
||||
or "undefined_tool_name",
|
||||
"args": parsed_args,
|
||||
}
|
||||
}
|
||||
parts.append(function_call_part)
|
||||
except json.JSONDecodeError:
|
||||
# This can happen if the stream is abruptly cut off mid-argument string.
|
||||
verbose_logger.warning(
|
||||
f"Could not parse tool call arguments at end of stream for index {tool_call_index}. "
|
||||
f"Name: {tool_call_data['name']}. "
|
||||
f"Partial args: {tool_call_data['arguments']}"
|
||||
)
|
||||
pass
|
||||
if parts:
|
||||
final_chunk = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {"parts": parts, "role": "model"},
|
||||
"finishReason": "STOP",
|
||||
"index": 0,
|
||||
"safetyRatings": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
return final_chunk
|
||||
finally:
|
||||
# Ensure the accumulator is always cleared to prevent memory leaks
|
||||
self.accumulated_tool_calls.clear()
|
||||
raise StopAsyncIteration
|
||||
except StopAsyncIteration:
|
||||
raise
|
||||
except Exception:
|
||||
raise StopAsyncIteration
|
||||
|
||||
def google_genai_sse_wrapper(self) -> Iterator[bytes]:
|
||||
"""
|
||||
Convert Google GenAI streaming chunks to Server-Sent Events format.
|
||||
"""
|
||||
for chunk in self.completion_stream:
|
||||
if isinstance(chunk, dict):
|
||||
payload = f"data: {json.dumps(chunk)}\n\n"
|
||||
yield payload.encode()
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
async def async_google_genai_sse_wrapper(self) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Async version of google_genai_sse_wrapper.
|
||||
"""
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
async for chunk in self.completion_stream:
|
||||
if isinstance(chunk, dict):
|
||||
payload = f"data: {json.dumps(chunk)}\n\n"
|
||||
yield payload.encode()
|
||||
elif isinstance(chunk, ModelResponseStream):
|
||||
# Transform OpenAI streaming chunk to Google GenAI format
|
||||
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
|
||||
chunk, self
|
||||
)
|
||||
|
||||
if isinstance(transformed_chunk, dict): # Only return non-empty chunks
|
||||
payload = f"data: {json.dumps(transformed_chunk)}\n\n"
|
||||
yield payload.encode()
|
||||
else:
|
||||
# For empty chunks, continue to next iteration
|
||||
continue
|
||||
else:
|
||||
# For other chunk types, yield them directly
|
||||
if hasattr(chunk, "encode"):
|
||||
yield chunk.encode()
|
||||
else:
|
||||
yield str(chunk).encode()
|
||||
|
||||
|
||||
class GoogleGenAIAdapter:
|
||||
"""Adapter for transforming Google GenAI generate_content requests to/from litellm.completion format"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def translate_generate_content_to_completion(
|
||||
self,
|
||||
model: str,
|
||||
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
litellm_params: Optional[GenericLiteLLMParams] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform generate_content request to litellm completion format
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
contents: Generate content contents (can be list or single dict)
|
||||
config: Optional config parameters
|
||||
**kwargs: Additional parameters from the original request
|
||||
|
||||
Returns:
|
||||
Dict in OpenAI format
|
||||
"""
|
||||
|
||||
# Extract top-level fields from kwargs
|
||||
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
|
||||
"system_instruction"
|
||||
)
|
||||
tools = kwargs.get("tools")
|
||||
tool_config = kwargs.get("toolConfig") or kwargs.get("tool_config")
|
||||
|
||||
# Normalize contents to list format
|
||||
if isinstance(contents, dict):
|
||||
contents_list = [contents]
|
||||
else:
|
||||
contents_list = contents
|
||||
|
||||
# Transform contents to OpenAI messages format
|
||||
messages = self._transform_contents_to_messages(
|
||||
contents_list, system_instruction=system_instruction
|
||||
)
|
||||
|
||||
# Create base request as dict (which is compatible with ChatCompletionRequest)
|
||||
completion_request: ChatCompletionRequest = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
#########################################################
|
||||
# Supported OpenAI chat completion params
|
||||
# - temperature
|
||||
# - max_tokens
|
||||
# - top_p
|
||||
# - frequency_penalty
|
||||
# - presence_penalty
|
||||
# - stop
|
||||
# - tools
|
||||
# - tool_choice
|
||||
#########################################################
|
||||
|
||||
# Add config parameters if provided
|
||||
if config:
|
||||
# Map common Google GenAI config parameters to OpenAI equivalents
|
||||
if "temperature" in config:
|
||||
completion_request["temperature"] = config["temperature"]
|
||||
if "maxOutputTokens" in config:
|
||||
completion_request["max_tokens"] = config["maxOutputTokens"]
|
||||
if "topP" in config:
|
||||
completion_request["top_p"] = config["topP"]
|
||||
if "topK" in config:
|
||||
# OpenAI doesn't have direct topK, but we can pass it as extra
|
||||
pass
|
||||
if "stopSequences" in config:
|
||||
completion_request["stop"] = config["stopSequences"]
|
||||
|
||||
# Handle tools transformation
|
||||
if tools:
|
||||
# Check if tools are already in OpenAI format or Google GenAI format
|
||||
if isinstance(tools, list) and len(tools) > 0:
|
||||
# Tools are in Google GenAI format, transform them
|
||||
openai_tools = self._transform_google_genai_tools_to_openai(tools)
|
||||
|
||||
if openai_tools:
|
||||
completion_request["tools"] = openai_tools
|
||||
|
||||
# Handle tool_config (tool choice)
|
||||
if tool_config:
|
||||
tool_choice = self._transform_google_genai_tool_config_to_openai(
|
||||
tool_config
|
||||
)
|
||||
if tool_choice:
|
||||
completion_request["tool_choice"] = tool_choice
|
||||
|
||||
#########################################################
|
||||
# forward any litellm specific params
|
||||
#########################################################
|
||||
completion_request_dict = dict(completion_request)
|
||||
if litellm_params:
|
||||
completion_request_dict = self._add_generic_litellm_params_to_request(
|
||||
completion_request_dict=completion_request_dict,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
return completion_request_dict
|
||||
|
||||
def _add_generic_litellm_params_to_request(
|
||||
self,
|
||||
completion_request_dict: Dict[str, Any],
|
||||
litellm_params: Optional[GenericLiteLLMParams] = None,
|
||||
) -> dict:
|
||||
"""Add generic litellm params to request. e.g add api_base, api_key, api_version, etc.
|
||||
|
||||
Args:
|
||||
completion_request_dict: Dict[str, Any]
|
||||
litellm_params: GenericLiteLLMParams
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]
|
||||
"""
|
||||
allowed_fields = GenericLiteLLMParams.model_fields.keys()
|
||||
if litellm_params:
|
||||
litellm_dict = litellm_params.model_dump(exclude_none=True)
|
||||
for key, value in litellm_dict.items():
|
||||
if key in allowed_fields:
|
||||
completion_request_dict[key] = value
|
||||
return completion_request_dict
|
||||
|
||||
def translate_completion_output_params_streaming(
|
||||
self,
|
||||
completion_stream: Any,
|
||||
) -> Union[AsyncIterator[bytes], None]:
|
||||
"""Transform streaming completion output to Google GenAI format"""
|
||||
google_genai_wrapper = GoogleGenAIStreamWrapper(
|
||||
completion_stream=completion_stream
|
||||
)
|
||||
# Return the SSE-wrapped version for proper event formatting
|
||||
return google_genai_wrapper.async_google_genai_sse_wrapper()
|
||||
|
||||
def _transform_google_genai_tools_to_openai(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
"""Transform Google GenAI tools to OpenAI tools format"""
|
||||
openai_tools: List[Dict[str, Any]] = []
|
||||
|
||||
for tool in tools:
|
||||
if "functionDeclarations" in tool:
|
||||
for func_decl in tool["functionDeclarations"]:
|
||||
function_chunk: Dict[str, Any] = {
|
||||
"name": func_decl.get("name", ""),
|
||||
}
|
||||
|
||||
if "description" in func_decl:
|
||||
function_chunk["description"] = func_decl["description"]
|
||||
if "parametersJsonSchema" in func_decl:
|
||||
function_chunk["parameters"] = func_decl["parametersJsonSchema"]
|
||||
|
||||
openai_tool = {"type": "function", "function": function_chunk}
|
||||
openai_tools.append(openai_tool)
|
||||
|
||||
# normalize the tool schemas
|
||||
normalized_tools = [normalize_tool_schema(tool) for tool in openai_tools]
|
||||
|
||||
return cast(List[ChatCompletionToolParam], normalized_tools)
|
||||
|
||||
def _transform_google_genai_tool_config_to_openai(
|
||||
self,
|
||||
tool_config: Dict[str, Any],
|
||||
) -> Optional[ChatCompletionToolChoiceValues]:
|
||||
"""Transform Google GenAI tool_config to OpenAI tool_choice"""
|
||||
function_calling_config = tool_config.get("functionCallingConfig", {})
|
||||
mode = function_calling_config.get("mode", "AUTO")
|
||||
|
||||
mode_mapping = {"AUTO": "auto", "ANY": "required", "NONE": "none"}
|
||||
|
||||
tool_choice = mode_mapping.get(mode, "auto")
|
||||
return cast(ChatCompletionToolChoiceValues, tool_choice)
|
||||
|
||||
def _transform_contents_to_messages(
|
||||
self,
|
||||
contents: List[Dict[str, Any]],
|
||||
system_instruction: Optional[Dict[str, Any]] = None,
|
||||
) -> List[AllMessageValues]:
|
||||
"""Transform Google GenAI contents to OpenAI messages format"""
|
||||
messages: List[AllMessageValues] = []
|
||||
|
||||
# Handle system instruction
|
||||
if system_instruction:
|
||||
system_parts = system_instruction.get("parts", [])
|
||||
if system_parts and "text" in system_parts[0]:
|
||||
messages.append(
|
||||
ChatCompletionSystemMessage(
|
||||
role="system", content=system_parts[0]["text"]
|
||||
)
|
||||
)
|
||||
|
||||
for content in contents:
|
||||
role = content.get("role", "user")
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if role == "user":
|
||||
# Handle user messages with potential function responses
|
||||
content_parts: List[
|
||||
Union[ChatCompletionTextObject, ChatCompletionImageObject]
|
||||
] = []
|
||||
tool_messages: List[ChatCompletionToolMessage] = []
|
||||
|
||||
for part in parts:
|
||||
if isinstance(part, dict):
|
||||
if "text" in part:
|
||||
content_parts.append(
|
||||
cast(
|
||||
ChatCompletionTextObject,
|
||||
{"type": "text", "text": part["text"]},
|
||||
)
|
||||
)
|
||||
elif "inline_data" in part:
|
||||
# Handle Base64 image data
|
||||
inline_data = part["inline_data"]
|
||||
mime_type = inline_data.get("mime_type", "image/jpeg")
|
||||
data = inline_data.get("data", "")
|
||||
content_parts.append(
|
||||
cast(
|
||||
ChatCompletionImageObject,
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{data}"
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
elif "functionResponse" in part:
|
||||
# Transform function response to tool message
|
||||
func_response = part["functionResponse"]
|
||||
tool_message = ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=f"call_{func_response.get('name', 'unknown')}",
|
||||
content=json.dumps(func_response.get("response", {})),
|
||||
)
|
||||
tool_messages.append(tool_message)
|
||||
elif isinstance(part, str):
|
||||
content_parts.append(
|
||||
cast(
|
||||
ChatCompletionTextObject, {"type": "text", "text": part}
|
||||
)
|
||||
)
|
||||
|
||||
# Add user message if there's content
|
||||
if content_parts:
|
||||
# If only one text part, use simple string format for backward compatibility
|
||||
if (
|
||||
len(content_parts) == 1
|
||||
and isinstance(content_parts[0], dict)
|
||||
and content_parts[0].get("type") == "text"
|
||||
):
|
||||
text_part = cast(ChatCompletionTextObject, content_parts[0])
|
||||
messages.append(
|
||||
ChatCompletionUserMessage(
|
||||
role="user", content=text_part["text"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use multimodal format (array of content parts)
|
||||
messages.append(
|
||||
ChatCompletionUserMessage(
|
||||
role="user", content=content_parts
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool messages
|
||||
messages.extend(tool_messages)
|
||||
|
||||
elif role == "model":
|
||||
# Handle assistant messages with potential function calls
|
||||
combined_text = ""
|
||||
tool_calls: List[ChatCompletionAssistantToolCall] = []
|
||||
|
||||
for part in parts:
|
||||
if isinstance(part, dict):
|
||||
if "text" in part:
|
||||
combined_text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
# Transform function call to tool call
|
||||
func_call = part["functionCall"]
|
||||
tool_call = ChatCompletionAssistantToolCall(
|
||||
id=f"call_{func_call.get('name', 'unknown')}",
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=func_call.get("name", ""),
|
||||
arguments=json.dumps(func_call.get("args", {})),
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(part, str):
|
||||
combined_text += part
|
||||
|
||||
# Create assistant message
|
||||
if tool_calls:
|
||||
assistant_message = ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=combined_text if combined_text else None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
else:
|
||||
assistant_message = ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=combined_text if combined_text else None,
|
||||
)
|
||||
|
||||
messages.append(assistant_message)
|
||||
|
||||
return messages
|
||||
|
||||
def translate_completion_to_generate_content(
|
||||
self,
|
||||
response: ModelResponse,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform litellm completion response to Google GenAI generate_content format
|
||||
|
||||
Args:
|
||||
response: ModelResponse from litellm.completion
|
||||
|
||||
Returns:
|
||||
Dict in Google GenAI generate_content response format
|
||||
"""
|
||||
|
||||
# Extract the main response content
|
||||
choice = response.choices[0] if response.choices else None
|
||||
if not choice:
|
||||
raise ValueError("Invalid completion response: no choices found")
|
||||
|
||||
# Handle different choice types (Choices vs StreamingChoices)
|
||||
if isinstance(choice, Choices):
|
||||
if not choice.message:
|
||||
raise ValueError(
|
||||
"Invalid completion response: no message found in choice"
|
||||
)
|
||||
parts = self._transform_openai_message_to_google_genai_parts(choice.message)
|
||||
else:
|
||||
# Fallback for generic choice objects
|
||||
message_content = getattr(choice, "message", {}).get(
|
||||
"content", ""
|
||||
) or getattr(choice, "delta", {}).get("content", "")
|
||||
parts = [{"text": message_content}] if message_content else []
|
||||
|
||||
# Create Google GenAI format response
|
||||
generate_content_response: Dict[str, Any] = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {"parts": parts, "role": "model"},
|
||||
"finishReason": self._map_finish_reason(
|
||||
getattr(choice, "finish_reason", None)
|
||||
),
|
||||
"index": 0,
|
||||
"safetyRatings": [],
|
||||
}
|
||||
],
|
||||
"usageMetadata": (
|
||||
self._map_usage(getattr(response, "usage", None))
|
||||
if hasattr(response, "usage") and getattr(response, "usage", None)
|
||||
else {
|
||||
"promptTokenCount": 0,
|
||||
"candidatesTokenCount": 0,
|
||||
"totalTokenCount": 0,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
# Add text field for convenience (common in Google GenAI responses)
|
||||
text_content = ""
|
||||
for part in parts:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
text_content += part["text"]
|
||||
if text_content:
|
||||
generate_content_response["text"] = text_content
|
||||
|
||||
return generate_content_response
|
||||
|
||||
def translate_streaming_completion_to_generate_content(
|
||||
self,
|
||||
response: Union[ModelResponse, ModelResponseStream],
|
||||
wrapper: GoogleGenAIStreamWrapper,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform streaming litellm completion chunk to Google GenAI generate_content format
|
||||
|
||||
Args:
|
||||
response: Streaming ModelResponse chunk from litellm.completion
|
||||
wrapper: GoogleGenAIStreamWrapper instance
|
||||
|
||||
Returns:
|
||||
Dict in Google GenAI streaming generate_content response format
|
||||
"""
|
||||
|
||||
# Extract the main response content from streaming chunk
|
||||
choice = response.choices[0] if response.choices else None
|
||||
if not choice:
|
||||
# Return empty chunk if no choices
|
||||
return None
|
||||
|
||||
# Handle streaming choice
|
||||
if isinstance(choice, StreamingChoices):
|
||||
if choice.delta:
|
||||
parts = self._transform_openai_delta_to_google_genai_parts_with_accumulation(
|
||||
choice.delta, wrapper
|
||||
)
|
||||
else:
|
||||
parts = []
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
else:
|
||||
# Fallback for generic choice objects
|
||||
message_content = getattr(choice, "delta", {}).get("content", "")
|
||||
parts = [{"text": message_content}] if message_content else []
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
# Only create response chunk if we have parts or it's the final chunk
|
||||
if not parts and not finish_reason:
|
||||
return None
|
||||
|
||||
# Create Google GenAI streaming format response
|
||||
streaming_chunk: Dict[str, Any] = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {"parts": parts, "role": "model"},
|
||||
"finishReason": (
|
||||
self._map_finish_reason(finish_reason)
|
||||
if finish_reason
|
||||
else None
|
||||
),
|
||||
"index": 0,
|
||||
"safetyRatings": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Add usage metadata only in the final chunk (when finish_reason is present)
|
||||
if finish_reason:
|
||||
usage_metadata = (
|
||||
self._map_usage(getattr(response, "usage", None))
|
||||
if hasattr(response, "usage") and getattr(response, "usage", None)
|
||||
else {
|
||||
"promptTokenCount": 0,
|
||||
"candidatesTokenCount": 0,
|
||||
"totalTokenCount": 0,
|
||||
}
|
||||
)
|
||||
streaming_chunk["usageMetadata"] = usage_metadata
|
||||
|
||||
# Add text field for convenience (common in Google GenAI responses)
|
||||
text_content = ""
|
||||
for part in parts:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
text_content += part["text"]
|
||||
if text_content:
|
||||
streaming_chunk["text"] = text_content
|
||||
|
||||
return streaming_chunk
|
||||
|
||||
def _transform_openai_message_to_google_genai_parts(
|
||||
self,
|
||||
message: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Transform OpenAI message to Google GenAI parts format"""
|
||||
parts: List[Dict[str, Any]] = []
|
||||
|
||||
# Add text content if present
|
||||
if hasattr(message, "content") and message.content:
|
||||
parts.append({"text": message.content})
|
||||
|
||||
# Add tool calls if present
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if hasattr(tool_call, "function") and tool_call.function:
|
||||
try:
|
||||
args = (
|
||||
json.loads(tool_call.function.arguments)
|
||||
if tool_call.function.arguments
|
||||
else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
function_call_part = {
|
||||
"functionCall": {
|
||||
"name": tool_call.function.name or "undefined_tool_name",
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
parts.append(function_call_part)
|
||||
|
||||
return parts if parts else [{"text": ""}]
|
||||
|
||||
def _transform_openai_delta_to_google_genai_parts_with_accumulation(
|
||||
self, delta: Any, wrapper: GoogleGenAIStreamWrapper
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Transforms OpenAI delta to Google GenAI parts, accumulating streaming tool calls."""
|
||||
|
||||
# 1. Initialize wrapper state if it doesn't exist
|
||||
if not hasattr(wrapper, "accumulated_tool_calls"):
|
||||
wrapper.accumulated_tool_calls = {}
|
||||
|
||||
parts: List[Dict[str, Any]] = []
|
||||
|
||||
if hasattr(delta, "content") and delta.content:
|
||||
parts.append({"text": delta.content})
|
||||
|
||||
# 2. Ensure tool_calls is iterable
|
||||
tool_calls = delta.tool_calls or []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if not hasattr(tool_call, "function"):
|
||||
continue
|
||||
|
||||
# 3. Use `index` as the primary key for accumulation
|
||||
tool_call_index = getattr(tool_call, "index", None)
|
||||
if tool_call_index is None:
|
||||
continue # Index is essential for tracking streaming tool calls
|
||||
|
||||
# Initialize accumulator for this index if it's new
|
||||
if tool_call_index not in wrapper.accumulated_tool_calls:
|
||||
wrapper.accumulated_tool_calls[tool_call_index] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
# Accumulate name and arguments
|
||||
function_name = getattr(tool_call.function, "name", None)
|
||||
args_chunk = getattr(tool_call.function, "arguments", None)
|
||||
|
||||
# Optimization: Skip chunks that have no new data
|
||||
if not function_name and not args_chunk:
|
||||
verbose_logger.debug(
|
||||
f"Skipping empty tool call chunk for index: {tool_call_index}"
|
||||
)
|
||||
continue
|
||||
|
||||
if function_name:
|
||||
wrapper.accumulated_tool_calls[tool_call_index]["name"] = function_name
|
||||
|
||||
if args_chunk:
|
||||
wrapper.accumulated_tool_calls[tool_call_index][
|
||||
"arguments"
|
||||
] += args_chunk
|
||||
|
||||
# Attempt to parse and emit a complete tool call
|
||||
accumulated_data = wrapper.accumulated_tool_calls[tool_call_index]
|
||||
accumulated_name = accumulated_data["name"]
|
||||
accumulated_args = accumulated_data["arguments"]
|
||||
|
||||
# 5. Attempt to parse arguments even if name hasn't arrived.
|
||||
try:
|
||||
# Attempt to parse the accumulated arguments string
|
||||
parsed_args = json.loads(accumulated_args)
|
||||
|
||||
# If parsing succeeds, but we don't have a name yet, wait.
|
||||
# The part will be created by a later chunk that brings the name.
|
||||
if accumulated_name:
|
||||
# If successful, create the part and clean up
|
||||
function_call_part = {
|
||||
"functionCall": {"name": accumulated_name, "args": parsed_args}
|
||||
}
|
||||
parts.append(function_call_part)
|
||||
|
||||
# Remove the completed tool call from the accumulator
|
||||
del wrapper.accumulated_tool_calls[tool_call_index]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# The JSON for arguments is still incomplete.
|
||||
# We will continue to accumulate and wait for more chunks.
|
||||
pass
|
||||
|
||||
return parts
|
||||
|
||||
def _map_finish_reason(self, finish_reason: Optional[str]) -> str:
|
||||
"""Map OpenAI finish reasons to Google GenAI finish reasons"""
|
||||
if not finish_reason:
|
||||
return "STOP"
|
||||
|
||||
mapping = {
|
||||
"stop": "STOP",
|
||||
"length": "MAX_TOKENS",
|
||||
"content_filter": "SAFETY",
|
||||
"tool_calls": "STOP",
|
||||
"function_call": "STOP",
|
||||
}
|
||||
|
||||
return mapping.get(finish_reason, "STOP")
|
||||
|
||||
def _map_usage(self, usage: Any) -> Dict[str, int]:
|
||||
"""Map OpenAI usage to Google GenAI usage format"""
|
||||
return {
|
||||
"promptTokenCount": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"candidatesTokenCount": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"totalTokenCount": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
@@ -0,0 +1,548 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
import litellm
|
||||
from litellm.constants import request_timeout
|
||||
|
||||
# Import the adapter for fallback to completion format
|
||||
from litellm.google_genai.adapters.handler import GenerateContentToCompletionHandler
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.google_genai.transformation import (
|
||||
BaseGoogleGenAIGenerateContentConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import ProviderConfigManager, client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.google_genai.main import (
|
||||
GenerateContentConfigDict,
|
||||
GenerateContentContentListUnionDict,
|
||||
GenerateContentResponse,
|
||||
ToolConfigDict,
|
||||
)
|
||||
else:
|
||||
GenerateContentConfigDict = Any
|
||||
GenerateContentContentListUnionDict = Any
|
||||
GenerateContentResponse = Any
|
||||
ToolConfigDict = Any
|
||||
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
class GenerateContentSetupResult(BaseModel):
|
||||
"""Internal Type - Result of setting up a generate content call"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
model: str
|
||||
request_body: Dict[str, Any]
|
||||
custom_llm_provider: str
|
||||
generate_content_provider_config: Optional[BaseGoogleGenAIGenerateContentConfig]
|
||||
generate_content_config_dict: Dict[str, Any]
|
||||
litellm_params: GenericLiteLLMParams
|
||||
litellm_logging_obj: LiteLLMLoggingObj
|
||||
litellm_call_id: Optional[str]
|
||||
|
||||
|
||||
class GenerateContentHelper:
|
||||
"""Helper class for Google GenAI generate content operations"""
|
||||
|
||||
@staticmethod
|
||||
def mock_generate_content_response(
|
||||
mock_response: str = "This is a mock response from Google GenAI generate_content.",
|
||||
) -> Dict[str, Any]:
|
||||
"""Mock response for generate_content for testing purposes"""
|
||||
return {
|
||||
"text": mock_response,
|
||||
"candidates": [
|
||||
{
|
||||
"content": {"parts": [{"text": mock_response}], "role": "model"},
|
||||
"finishReason": "STOP",
|
||||
"index": 0,
|
||||
"safetyRatings": [],
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 10,
|
||||
"candidatesTokenCount": 20,
|
||||
"totalTokenCount": 30,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def setup_generate_content_call(
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
config: Optional[GenerateContentConfigDict] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
tools: Optional[ToolConfigDict] = None,
|
||||
**kwargs,
|
||||
) -> GenerateContentSetupResult:
|
||||
"""
|
||||
Common setup logic for generate_content calls
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
contents: The content to generate from
|
||||
config: Optional configuration
|
||||
custom_llm_provider: Optional custom LLM provider
|
||||
tools: Optional tools
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
GenerateContentSetupResult containing all setup information
|
||||
"""
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
|
||||
"litellm_logging_obj"
|
||||
)
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||
|
||||
# get llm provider logic
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
## MOCK RESPONSE LOGIC (only for non-streaming)
|
||||
if (
|
||||
not kwargs.get("stream", False)
|
||||
and litellm_params.mock_response
|
||||
and isinstance(litellm_params.mock_response, str)
|
||||
):
|
||||
raise ValueError("Mock response should be handled by caller")
|
||||
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
dynamic_api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
)
|
||||
|
||||
if litellm_params.custom_llm_provider is None:
|
||||
litellm_params.custom_llm_provider = custom_llm_provider
|
||||
|
||||
# get provider config
|
||||
generate_content_provider_config: Optional[
|
||||
BaseGoogleGenAIGenerateContentConfig
|
||||
] = ProviderConfigManager.get_provider_google_genai_generate_content_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if generate_content_provider_config is None:
|
||||
# Use adapter to transform to completion format when provider config is None
|
||||
# Signal that we should use the adapter by returning special result
|
||||
if litellm_logging_obj is None:
|
||||
raise ValueError("litellm_logging_obj is required, but got None")
|
||||
return GenerateContentSetupResult(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
request_body={}, # Will be handled by adapter
|
||||
generate_content_provider_config=None, # type: ignore
|
||||
generate_content_config_dict=dict(config or {}),
|
||||
litellm_params=litellm_params,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
litellm_call_id=litellm_call_id,
|
||||
)
|
||||
|
||||
#########################################################################################
|
||||
# Construct request body
|
||||
#########################################################################################
|
||||
# Create Google Optional Params Config
|
||||
generate_content_config_dict = (
|
||||
generate_content_provider_config.map_generate_content_optional_params(
|
||||
generate_content_config_dict=config or {},
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
# Extract systemInstruction from kwargs to pass to transform
|
||||
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
|
||||
"system_instruction"
|
||||
)
|
||||
request_body = (
|
||||
generate_content_provider_config.transform_generate_content_request(
|
||||
model=model,
|
||||
contents=contents,
|
||||
tools=tools,
|
||||
generate_content_config_dict=generate_content_config_dict,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
)
|
||||
|
||||
# Pre Call logging
|
||||
if litellm_logging_obj is None:
|
||||
raise ValueError("litellm_logging_obj is required, but got None")
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
optional_params=dict(generate_content_config_dict),
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
return GenerateContentSetupResult(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
request_body=request_body,
|
||||
generate_content_provider_config=generate_content_provider_config,
|
||||
generate_content_config_dict=generate_content_config_dict,
|
||||
litellm_params=litellm_params,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
litellm_call_id=litellm_call_id,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
async def agenerate_content(
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
config: Optional[GenerateContentConfigDict] = None,
|
||||
tools: Optional[ToolConfigDict] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Async: Generate content using Google GenAI
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["agenerate_content"] = True
|
||||
|
||||
# Handle generationConfig parameter from kwargs for backward compatibility
|
||||
if "generationConfig" in kwargs and config is None:
|
||||
config = kwargs.pop("generationConfig")
|
||||
# get custom llm provider so we can use this for mapping exceptions
|
||||
if custom_llm_provider is None:
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
func = partial(
|
||||
generate_content,
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def generate_content(
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
config: Optional[GenerateContentConfigDict] = None,
|
||||
tools: Optional[ToolConfigDict] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Generate content using Google GenAI
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
_is_async = kwargs.pop("agenerate_content", False)
|
||||
|
||||
# Handle generationConfig parameter from kwargs for backward compatibility
|
||||
if "generationConfig" in kwargs and config is None:
|
||||
config = kwargs.pop("generationConfig")
|
||||
# Check for mock response first
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
if litellm_params.mock_response and isinstance(
|
||||
litellm_params.mock_response, str
|
||||
):
|
||||
return GenerateContentHelper.mock_generate_content_response(
|
||||
mock_response=litellm_params.mock_response
|
||||
)
|
||||
|
||||
# Setup the call
|
||||
setup_result = GenerateContentHelper.setup_generate_content_call(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract systemInstruction from kwargs to pass to handler
|
||||
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
|
||||
"system_instruction"
|
||||
)
|
||||
|
||||
# Check if we should use the adapter (when provider config is None)
|
||||
if setup_result.generate_content_provider_config is None:
|
||||
# Use the adapter to convert to completion format
|
||||
return GenerateContentToCompletionHandler.generate_content_handler(
|
||||
model=model,
|
||||
contents=contents, # type: ignore
|
||||
config=setup_result.generate_content_config_dict,
|
||||
tools=tools,
|
||||
_is_async=_is_async,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
extra_headers=extra_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Call the standard handler
|
||||
response = base_llm_http_handler.generate_content_handler(
|
||||
model=setup_result.model,
|
||||
contents=contents,
|
||||
tools=tools,
|
||||
generate_content_provider_config=setup_result.generate_content_provider_config,
|
||||
generate_content_config_dict=setup_result.generate_content_config_dict,
|
||||
custom_llm_provider=setup_result.custom_llm_provider,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
logging_obj=setup_result.litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or request_timeout,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client"),
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
async def agenerate_content_stream(
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
config: Optional[GenerateContentConfigDict] = None,
|
||||
tools: Optional[ToolConfigDict] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Async: Generate content using Google GenAI with streaming response
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
kwargs["agenerate_content_stream"] = True
|
||||
|
||||
# Handle generationConfig parameter from kwargs for backward compatibility
|
||||
if "generationConfig" in kwargs and config is None:
|
||||
config = kwargs.pop("generationConfig")
|
||||
# get custom llm provider so we can use this for mapping exceptions
|
||||
if custom_llm_provider is None:
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=model, api_base=local_vars.get("base_url", None)
|
||||
)
|
||||
|
||||
# Setup the call
|
||||
setup_result = GenerateContentHelper.setup_generate_content_call(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract systemInstruction from kwargs to pass to handler
|
||||
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
|
||||
"system_instruction"
|
||||
)
|
||||
|
||||
# Check if we should use the adapter (when provider config is None)
|
||||
if setup_result.generate_content_provider_config is None:
|
||||
if "stream" in kwargs:
|
||||
kwargs.pop("stream", None)
|
||||
|
||||
# Use the adapter to convert to completion format
|
||||
return (
|
||||
await GenerateContentToCompletionHandler.async_generate_content_handler(
|
||||
model=model,
|
||||
contents=contents, # type: ignore
|
||||
config=setup_result.generate_content_config_dict,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
extra_headers=extra_headers,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
# Call the handler with async enabled and streaming
|
||||
# Return the coroutine directly for the router to handle
|
||||
return await base_llm_http_handler.generate_content_handler(
|
||||
model=setup_result.model,
|
||||
contents=contents,
|
||||
generate_content_provider_config=setup_result.generate_content_provider_config,
|
||||
generate_content_config_dict=setup_result.generate_content_config_dict,
|
||||
tools=tools,
|
||||
custom_llm_provider=setup_result.custom_llm_provider,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
logging_obj=setup_result.litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or request_timeout,
|
||||
_is_async=True,
|
||||
client=kwargs.get("client"),
|
||||
stream=True,
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def generate_content_stream(
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
config: Optional[GenerateContentConfigDict] = None,
|
||||
tools: Optional[ToolConfigDict] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Iterator[Any]:
|
||||
"""
|
||||
Generate content using Google GenAI with streaming response
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
# Remove any async-related flags since this is the sync function
|
||||
_is_async = kwargs.pop("agenerate_content_stream", False)
|
||||
|
||||
# Handle generationConfig parameter from kwargs for backward compatibility
|
||||
if "generationConfig" in kwargs and config is None:
|
||||
config = kwargs.pop("generationConfig")
|
||||
# Setup the call
|
||||
setup_result = GenerateContentHelper.setup_generate_content_call(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check if we should use the adapter (when provider config is None)
|
||||
if setup_result.generate_content_provider_config is None:
|
||||
if "stream" in kwargs:
|
||||
kwargs.pop("stream", None)
|
||||
|
||||
# Use the adapter to convert to completion format
|
||||
return GenerateContentToCompletionHandler.generate_content_handler(
|
||||
model=model,
|
||||
contents=contents, # type: ignore
|
||||
config=setup_result.generate_content_config_dict,
|
||||
_is_async=_is_async,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
stream=True,
|
||||
extra_headers=extra_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Call the handler with streaming enabled (sync version)
|
||||
return base_llm_http_handler.generate_content_handler(
|
||||
model=setup_result.model,
|
||||
contents=contents,
|
||||
generate_content_provider_config=setup_result.generate_content_provider_config,
|
||||
generate_content_config_dict=setup_result.generate_content_config_dict,
|
||||
tools=tools,
|
||||
custom_llm_provider=setup_result.custom_llm_provider,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
logging_obj=setup_result.litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or request_timeout,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client"),
|
||||
stream=True,
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
@@ -0,0 +1,159 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||
PassThroughEndpointLogging,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.base_llm.google_genai.transformation import (
|
||||
BaseGoogleGenAIGenerateContentConfig,
|
||||
)
|
||||
else:
|
||||
BaseGoogleGenAIGenerateContentConfig = Any
|
||||
|
||||
GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ = PassThroughEndpointLogging()
|
||||
|
||||
|
||||
class BaseGoogleGenAIGenerateContentStreamingIterator:
|
||||
"""
|
||||
Base class for Google GenAI Generate Content streaming iterators that provides common logic
|
||||
for streaming response handling and logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
request_body: dict,
|
||||
model: str,
|
||||
):
|
||||
self.litellm_logging_obj = litellm_logging_obj
|
||||
self.request_body = request_body
|
||||
self.start_time = datetime.now()
|
||||
self.collected_chunks: List[bytes] = []
|
||||
self.model = model
|
||||
|
||||
async def _handle_async_streaming_logging(
|
||||
self,
|
||||
):
|
||||
"""Handle the logging after all chunks have been collected."""
|
||||
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||
PassThroughStreamingHandler,
|
||||
)
|
||||
|
||||
end_time = datetime.now()
|
||||
asyncio.create_task(
|
||||
PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
||||
litellm_logging_obj=self.litellm_logging_obj,
|
||||
passthrough_success_handler_obj=GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ,
|
||||
url_route="/v1/generateContent",
|
||||
request_body=self.request_body or {},
|
||||
endpoint_type=EndpointType.VERTEX_AI,
|
||||
start_time=self.start_time,
|
||||
raw_bytes=self.collected_chunks,
|
||||
end_time=end_time,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenAIGenerateContentStreamingIterator(
|
||||
BaseGoogleGenAIGenerateContentStreamingIterator
|
||||
):
|
||||
"""
|
||||
Streaming iterator specifically for Google GenAI generate content API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response,
|
||||
model: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig,
|
||||
litellm_metadata: dict,
|
||||
custom_llm_provider: str,
|
||||
request_body: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
litellm_logging_obj=logging_obj,
|
||||
request_body=request_body or {},
|
||||
model=model,
|
||||
)
|
||||
self.response = response
|
||||
self.model = model
|
||||
self.generate_content_provider_config = generate_content_provider_config
|
||||
self.litellm_metadata = litellm_metadata
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
# Store the iterator once to avoid multiple stream consumption
|
||||
self.stream_iterator = response.iter_bytes()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
# Get the next chunk from the stored iterator
|
||||
chunk = next(self.stream_iterator)
|
||||
self.collected_chunks.append(chunk)
|
||||
# Just yield raw bytes
|
||||
return chunk
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
# This should not be used for sync responses
|
||||
# If you need async iteration, use AsyncGoogleGenAIGenerateContentStreamingIterator
|
||||
raise NotImplementedError(
|
||||
"Use AsyncGoogleGenAIGenerateContentStreamingIterator for async iteration"
|
||||
)
|
||||
|
||||
|
||||
class AsyncGoogleGenAIGenerateContentStreamingIterator(
|
||||
BaseGoogleGenAIGenerateContentStreamingIterator
|
||||
):
|
||||
"""
|
||||
Async streaming iterator specifically for Google GenAI generate content API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response,
|
||||
model: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig,
|
||||
litellm_metadata: dict,
|
||||
custom_llm_provider: str,
|
||||
request_body: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
litellm_logging_obj=logging_obj,
|
||||
request_body=request_body or {},
|
||||
model=model,
|
||||
)
|
||||
self.response = response
|
||||
self.model = model
|
||||
self.generate_content_provider_config = generate_content_provider_config
|
||||
self.litellm_metadata = litellm_metadata
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
# Store the async iterator once to avoid multiple stream consumption
|
||||
self.stream_iterator = response.aiter_bytes()
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
# Get the next chunk from the stored async iterator
|
||||
chunk = await self.stream_iterator.__anext__()
|
||||
self.collected_chunks.append(chunk)
|
||||
# Just yield raw bytes
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
await self._handle_async_streaming_logging()
|
||||
raise StopAsyncIteration
|
||||
1080
llm-gateway-competitors/litellm-wheel-src/litellm/images/main.py
Normal file
1080
llm-gateway-competitors/litellm-wheel-src/litellm/images/main.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user