Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/openai_like/dynamic_config.py

230 lines
8.3 KiB
Python
Raw Normal View History

"""
Dynamic configuration class generator for JSON-based providers.
"""
from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, overload
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
)
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from .json_loader import SimpleProviderConfig
def create_config_class(provider: SimpleProviderConfig):
"""Generate config class dynamically from JSON configuration"""
# Choose base class
base_class: type = (
OpenAIGPTConfig if provider.base_class == "openai_gpt" else OpenAILikeChatConfig
)
class JSONProviderConfig(base_class): # type: ignore[valid-type,misc]
@overload
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
) -> Coroutine[Any, Any, List[AllMessageValues]]:
...
@overload
def _transform_messages(
self,
messages: List[AllMessageValues],
model: str,
is_async: Literal[False] = False,
) -> List[AllMessageValues]:
...
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: bool = False
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
"""Transform messages based on special_handling config"""
# Handle content list to string conversion if configured
if provider.special_handling.get("convert_content_list_to_string"):
messages = handle_messages_with_content_list_to_str_conversion(messages)
if is_async:
return super()._transform_messages(
messages=messages, model=model, is_async=True
)
else:
return super()._transform_messages(
messages=messages, model=model, is_async=False
)
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
"""Get API base and key from JSON config"""
# Resolve base URL
resolved_base = api_base
if not resolved_base and provider.api_base_env:
resolved_base = get_secret_str(provider.api_base_env)
if not resolved_base:
resolved_base = provider.base_url
# Resolve API key
resolved_key = api_key or get_secret_str(provider.api_key_env)
return resolved_base, resolved_key
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""Build complete URL for the API endpoint"""
if not api_base:
api_base = provider.base_url
if api_base is None:
raise ValueError(f"api_base is required for provider {provider.slug}")
if not api_base.endswith("/chat/completions"):
api_base = f"{api_base}/chat/completions"
return api_base
def get_supported_openai_params(self, model: str) -> list:
"""Get supported OpenAI params, excluding tool-related params for models
that don't support function calling."""
from litellm.utils import supports_function_calling
supported_params = super().get_supported_openai_params(model=model)
_supports_fc = supports_function_calling(
model=model, custom_llm_provider=provider.slug
)
if not _supports_fc:
tool_params = [
"tools",
"tool_choice",
"function_call",
"functions",
"parallel_tool_calls",
]
for param in tool_params:
if param in supported_params:
supported_params.remove(param)
verbose_logger.debug(
f"Model {model} on provider {provider.slug} does not support "
f"function calling — removed tool-related params from supported params."
)
return supported_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""Apply parameter mappings and constraints"""
supported_params = self.get_supported_openai_params(model)
# Apply supported params
for param, value in non_default_params.items():
# Check parameter mappings first
if param in provider.param_mappings:
optional_params[provider.param_mappings[param]] = value
elif param in supported_params:
optional_params[param] = value
# Apply temperature constraints if present
if "temperature" in optional_params:
temp = optional_params["temperature"]
constraints = provider.constraints
# Clamp to max
if "temperature_max" in constraints:
temp = min(temp, constraints["temperature_max"])
# Clamp to min
if "temperature_min" in constraints:
temp = max(temp, constraints["temperature_min"])
# Special case: temperature_min_with_n_gt_1
if "temperature_min_with_n_gt_1" in constraints:
n = optional_params.get("n", 1)
if n > 1 and temp < constraints["temperature_min_with_n_gt_1"]:
temp = constraints["temperature_min_with_n_gt_1"]
optional_params["temperature"] = temp
return optional_params
@property
def custom_llm_provider(self) -> Optional[str]:
return provider.slug
return JSONProviderConfig
_responses_config_cache: dict = {}
def create_responses_config_class(provider: SimpleProviderConfig):
"""Generate a Responses API config class dynamically from JSON configuration.
Parallel to create_config_class() but for /v1/responses endpoints.
Classes are cached per provider slug to avoid regeneration on every request.
"""
if provider.slug in _responses_config_cache:
return _responses_config_cache[provider.slug]
from litellm.llms.openai_like.responses.transformation import (
OpenAILikeResponsesConfig,
)
from litellm.types.router import GenericLiteLLMParams
class JSONProviderResponsesConfig(OpenAILikeResponsesConfig):
@property
def custom_llm_provider(self): # type: ignore[override]
return provider.slug
def validate_environment(
self,
headers: dict,
model: str,
litellm_params: Optional[GenericLiteLLMParams],
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = litellm_params.api_key or get_secret_str(provider.api_key_env)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
if not api_base:
if provider.api_base_env:
api_base = get_secret_str(provider.api_base_env)
if not api_base:
api_base = provider.base_url
if api_base is None:
raise ValueError(f"api_base is required for provider {provider.slug}")
api_base = api_base.rstrip("/")
return f"{api_base}/responses"
_responses_config_cache[provider.slug] = JSONProviderResponsesConfig
return JSONProviderResponsesConfig