chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,270 @@
|
||||
"""Support for OpenAI gpt-5 model family."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.utils import _supports_factory
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
def _normalize_reasoning_effort_for_chat_completion(
|
||||
value: Union[str, dict, None],
|
||||
) -> Optional[str]:
|
||||
"""Convert reasoning_effort to the string format expected by OpenAI chat completion API.
|
||||
|
||||
The chat completion API expects a simple string: 'none', 'low', 'medium', 'high', or 'xhigh'.
|
||||
Config/deployments may pass the Responses API format: {'effort': 'high', 'summary': 'detailed'}.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict) and "effort" in value:
|
||||
return value["effort"]
|
||||
return None
|
||||
|
||||
|
||||
def _get_effort_level(value: Union[str, dict, None]) -> Optional[str]:
|
||||
"""Extract the effective effort level from reasoning_effort (string or dict).
|
||||
|
||||
Use this for guards that compare effort level (e.g. xhigh validation, "none" checks).
|
||||
Ensures dict inputs like {"effort": "none", "summary": "detailed"} are correctly
|
||||
treated as effort="none" for validation purposes.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict) and "effort" in value:
|
||||
return value["effort"]
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
"""Configuration for gpt-5 models including GPT-5-Codex variants.
|
||||
|
||||
Handles OpenAI API quirks for the gpt-5 series like:
|
||||
|
||||
- Mapping ``max_tokens`` -> ``max_completion_tokens``.
|
||||
- Dropping unsupported ``temperature`` values when requested.
|
||||
- Support for GPT-5-Codex models optimized for code generation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_model(cls, model: str) -> bool:
|
||||
# gpt-5-chat* behaves like a regular chat model (supports temperature, etc.)
|
||||
# Don't route it through GPT-5 reasoning-specific parameter restrictions.
|
||||
return "gpt-5" in model and "gpt-5-chat" not in model
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_search_model(cls, model: str) -> bool:
|
||||
"""Check if the model is a GPT-5 search variant (e.g. gpt-5-search-api).
|
||||
|
||||
Search-only models have a severely restricted parameter set compared to
|
||||
regular GPT-5 models. They are identified by name convention (contain
|
||||
both ``gpt-5`` and ``search``). Note: ``supports_web_search`` in model
|
||||
info is a *different* concept — it indicates a model can *use* web
|
||||
search as a tool, which many non-search-only models also support.
|
||||
"""
|
||||
return "gpt-5" in model and "search" in model
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_codex_model(cls, model: str) -> bool:
|
||||
"""Check if the model is specifically a GPT-5 Codex variant."""
|
||||
return "gpt-5-codex" in model
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_2_model(cls, model: str) -> bool:
|
||||
"""Check if the model is a gpt-5.2 variant (including pro)."""
|
||||
model_name = model.split("/")[-1]
|
||||
return model_name.startswith("gpt-5.2") or model_name.startswith("gpt-5.4")
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_4_model(cls, model: str) -> bool:
|
||||
"""Check if the model is a gpt-5.4 variant (including pro)."""
|
||||
model_name = model.split("/")[-1]
|
||||
return model_name.startswith("gpt-5.4")
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_4_plus_model(cls, model: str) -> bool:
|
||||
"""Check if the model is gpt-5.4 or newer (5.4, 5.5, 5.6, etc., including pro)."""
|
||||
model_name = model.split("/")[-1]
|
||||
if not model_name.startswith("gpt-5."):
|
||||
return False
|
||||
try:
|
||||
version_str = model_name.replace("gpt-5.", "").split("-")[0]
|
||||
major = version_str.split(".")[0]
|
||||
return int(major) >= 4
|
||||
except (ValueError, IndexError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _supports_reasoning_effort_level(cls, model: str, level: str) -> bool:
|
||||
"""Check if the model supports a specific reasoning_effort level.
|
||||
|
||||
Looks up ``supports_{level}_reasoning_effort`` in the model map via
|
||||
the shared ``_supports_factory`` helper.
|
||||
Returns False for unknown models (safe fallback).
|
||||
"""
|
||||
return _supports_factory(
|
||||
model=model,
|
||||
custom_llm_provider=None,
|
||||
key=f"supports_{level}_reasoning_effort",
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
if self.is_model_gpt_5_search_model(model):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"web_search_options",
|
||||
"service_tier",
|
||||
"safety_identifier",
|
||||
"response_format",
|
||||
"user",
|
||||
"store",
|
||||
"verbosity",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
from litellm.utils import supports_tool_choice
|
||||
|
||||
base_gpt_series_params = super().get_supported_openai_params(model=model)
|
||||
gpt_5_only_params = ["reasoning_effort", "verbosity"]
|
||||
base_gpt_series_params.extend(gpt_5_only_params)
|
||||
if not supports_tool_choice(model=model):
|
||||
base_gpt_series_params.remove("tool_choice")
|
||||
|
||||
non_supported_params = [
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"stop",
|
||||
"logit_bias",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"audio",
|
||||
"web_search_options",
|
||||
]
|
||||
|
||||
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs when reasoning_effort="none"
|
||||
if not self._supports_reasoning_effort_level(model, "none"):
|
||||
non_supported_params.extend(["logprobs", "top_p", "top_logprobs"])
|
||||
|
||||
return [
|
||||
param
|
||||
for param in base_gpt_series_params
|
||||
if param not in non_supported_params
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
if self.is_model_gpt_5_search_model(model):
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
return super()._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
# Get raw reasoning_effort and effective effort level for all guards.
|
||||
# Use effective_effort (extracted string) for xhigh validation, "none" checks, and
|
||||
# tool/sampling guards — dict inputs like {"effort": "none", "summary": "detailed"}
|
||||
# must be treated as effort="none" to avoid incorrect tool-drop or sampling errors.
|
||||
raw_reasoning_effort = non_default_params.get(
|
||||
"reasoning_effort"
|
||||
) or optional_params.get("reasoning_effort")
|
||||
effective_effort = _get_effort_level(raw_reasoning_effort)
|
||||
|
||||
# Normalize dict reasoning_effort to string for Chat Completions API.
|
||||
# Example: {"effort": "high", "summary": "detailed"} -> "high"
|
||||
if isinstance(raw_reasoning_effort, dict) and "effort" in raw_reasoning_effort:
|
||||
normalized = _normalize_reasoning_effort_for_chat_completion(
|
||||
raw_reasoning_effort
|
||||
)
|
||||
if normalized is not None:
|
||||
if "reasoning_effort" in non_default_params:
|
||||
non_default_params["reasoning_effort"] = normalized
|
||||
if "reasoning_effort" in optional_params:
|
||||
optional_params["reasoning_effort"] = normalized
|
||||
|
||||
if effective_effort is not None and effective_effort == "xhigh":
|
||||
if not self._supports_reasoning_effort_level(model, "xhigh"):
|
||||
if litellm.drop_params or drop_params:
|
||||
non_default_params.pop("reasoning_effort", None)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message=(
|
||||
"reasoning_effort='xhigh' is only supported for gpt-5.1-codex-max, gpt-5.2, and gpt-5.4+ models."
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
################################################################
|
||||
# max_tokens is not supported for gpt-5 models on OpenAI API
|
||||
# Relevant issue: https://github.com/BerriAI/litellm/issues/13381
|
||||
################################################################
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
|
||||
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs only when reasoning_effort="none"
|
||||
supports_none = self._supports_reasoning_effort_level(model, "none")
|
||||
if supports_none:
|
||||
sampling_params = ["logprobs", "top_logprobs", "top_p"]
|
||||
has_sampling = any(p in non_default_params for p in sampling_params)
|
||||
if has_sampling and effective_effort not in (None, "none"):
|
||||
if litellm.drop_params or drop_params:
|
||||
for p in sampling_params:
|
||||
non_default_params.pop(p, None)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message=(
|
||||
"gpt-5.1/5.2/5.4 only support logprobs, top_p, top_logprobs when "
|
||||
"reasoning_effort='none'. Current reasoning_effort='{}'. "
|
||||
"To drop unsupported params set `litellm.drop_params = True`"
|
||||
).format(effective_effort),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if "temperature" in non_default_params:
|
||||
temperature_value: Optional[float] = non_default_params.pop("temperature")
|
||||
if temperature_value is not None:
|
||||
# models supporting reasoning_effort="none" also support flexible temperature
|
||||
if supports_none and (
|
||||
effective_effort == "none" or effective_effort is None
|
||||
):
|
||||
optional_params["temperature"] = temperature_value
|
||||
elif temperature_value == 1:
|
||||
optional_params["temperature"] = temperature_value
|
||||
elif litellm.drop_params or drop_params:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message=(
|
||||
"gpt-5 models (including gpt-5-codex) don't support temperature={}. "
|
||||
"Only temperature=1 is supported. "
|
||||
"For gpt-5.1, temperature is supported when reasoning_effort='none' (or not specified, as it defaults to 'none'). "
|
||||
"To drop unsupported params set `litellm.drop_params = True`"
|
||||
).format(temperature_value),
|
||||
status_code=400,
|
||||
)
|
||||
return super()._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Support for GPT-4o audio Family
|
||||
|
||||
OpenAI Doc: https://platform.openai.com/docs/guides/audio/quickstart?audio-generation-quickstart-example=audio-in&lang=python
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIGPTAudioConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/audio
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the `gpt-audio` models
|
||||
|
||||
"""
|
||||
|
||||
all_openai_params = super().get_supported_openai_params(model=model)
|
||||
audio_specific_params = ["audio"]
|
||||
return all_openai_params + audio_specific_params
|
||||
|
||||
def is_model_gpt_audio_model(self, model: str) -> bool:
|
||||
if model in litellm.open_ai_chat_completion_models and "audio" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return super()._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,819 @@
|
||||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_extract_reasoning_content,
|
||||
_handle_invalid_parallel_tool_calls,
|
||||
_should_convert_tool_call_to_json_mode,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import get_tool_call_names
|
||||
from litellm.litellm_core_utils.prompt_templates.image_handling import (
|
||||
async_convert_url_to_base64,
|
||||
convert_url_to_base64,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionFileObject,
|
||||
ChatCompletionFileObjectFile,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionImageUrlObject,
|
||||
OpenAIChatCompletionChoices,
|
||||
OpenAIMessageContentListBlock,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
)
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
# Add a class variable to track if this is the base class
|
||||
_is_base_class = True
|
||||
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
self.__class__._is_base_class = False
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"audio",
|
||||
"web_search_options",
|
||||
"service_tier",
|
||||
"safety_identifier",
|
||||
"prompt_cache_key",
|
||||
"prompt_cache_retention",
|
||||
"store",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
if (
|
||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||
): # gpt-4 does not support 'response_format'
|
||||
model_specific_params.append("response_format")
|
||||
|
||||
# Normalize model name for responses API (e.g., "responses/gpt-4.1" -> "gpt-4.1")
|
||||
model_for_check = (
|
||||
model.split("responses/", 1)[1] if "responses/" in model else model
|
||||
)
|
||||
if (
|
||||
model_for_check in litellm.open_ai_chat_completion_models
|
||||
) or model_for_check in litellm.open_ai_text_completion_models:
|
||||
model_specific_params.append(
|
||||
"user"
|
||||
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
||||
return base_params + model_specific_params
|
||||
|
||||
def _map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call
|
||||
|
||||
Args:
|
||||
non_default_params (dict): Non-default parameters to filter.
|
||||
optional_params (dict): Optional parameters to update.
|
||||
model (str): Model name for parameter support check.
|
||||
|
||||
Returns:
|
||||
dict: Updated optional_params with supported non-default parameters.
|
||||
"""
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return self._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def contains_pdf_url(self, content_item: ChatCompletionFileObjectFile) -> bool:
|
||||
potential_pdf_url_starts = ["https://", "http://", "www."]
|
||||
file_id = content_item.get("file_id")
|
||||
if file_id and any(
|
||||
file_id.startswith(start) for start in potential_pdf_url_starts
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_pdf_url(
|
||||
self, content_item: ChatCompletionFileObjectFile
|
||||
) -> ChatCompletionFileObjectFile:
|
||||
content_copy = content_item.copy()
|
||||
file_id = content_copy.get("file_id")
|
||||
if file_id is not None:
|
||||
base64_data = convert_url_to_base64(file_id)
|
||||
content_copy["file_data"] = base64_data
|
||||
content_copy["filename"] = "my_file.pdf"
|
||||
content_copy.pop("file_id")
|
||||
return content_copy
|
||||
|
||||
async def _async_handle_pdf_url(
|
||||
self, content_item: ChatCompletionFileObjectFile
|
||||
) -> ChatCompletionFileObjectFile:
|
||||
file_id = content_item.get("file_id")
|
||||
if file_id is not None: # check for file id being url done in _handle_pdf_url
|
||||
base64_data = await async_convert_url_to_base64(file_id)
|
||||
content_item["file_data"] = base64_data
|
||||
content_item["filename"] = "my_file.pdf"
|
||||
content_item.pop("file_id")
|
||||
return content_item
|
||||
|
||||
def _common_file_data_check(
|
||||
self, content_item: ChatCompletionFileObjectFile
|
||||
) -> ChatCompletionFileObjectFile:
|
||||
file_data = content_item.get("file_data")
|
||||
filename = content_item.get("filename")
|
||||
if file_data is not None and filename is None:
|
||||
content_item["filename"] = "my_file.pdf"
|
||||
return content_item
|
||||
|
||||
def _apply_common_transform_content_item(
|
||||
self,
|
||||
content_item: OpenAIMessageContentListBlock,
|
||||
) -> OpenAIMessageContentListBlock:
|
||||
litellm_specific_params = {"format"}
|
||||
if content_item.get("type") == "image_url":
|
||||
content_item = cast(ChatCompletionImageObject, content_item)
|
||||
if isinstance(content_item["image_url"], str):
|
||||
content_item["image_url"] = {
|
||||
"url": content_item["image_url"],
|
||||
}
|
||||
elif isinstance(content_item["image_url"], dict):
|
||||
new_image_url_obj = ChatCompletionImageUrlObject(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
for k, v in content_item["image_url"].items()
|
||||
if k not in litellm_specific_params
|
||||
}
|
||||
)
|
||||
content_item["image_url"] = new_image_url_obj
|
||||
elif content_item.get("type") == "file":
|
||||
content_item = cast(ChatCompletionFileObject, content_item)
|
||||
file_obj = content_item["file"]
|
||||
new_file_obj = ChatCompletionFileObjectFile(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
for k, v in file_obj.items()
|
||||
if k not in litellm_specific_params
|
||||
}
|
||||
)
|
||||
content_item["file"] = new_file_obj
|
||||
|
||||
return content_item
|
||||
|
||||
def _transform_content_item(
|
||||
self,
|
||||
content_item: OpenAIMessageContentListBlock,
|
||||
) -> OpenAIMessageContentListBlock:
|
||||
content_item = self._apply_common_transform_content_item(content_item)
|
||||
content_item_type = content_item.get("type")
|
||||
potential_file_obj = content_item.get("file")
|
||||
if content_item_type == "file" and potential_file_obj:
|
||||
file_obj = cast(ChatCompletionFileObjectFile, potential_file_obj)
|
||||
content_item_typed = cast(ChatCompletionFileObject, content_item)
|
||||
if self.contains_pdf_url(file_obj):
|
||||
file_obj = self._handle_pdf_url(file_obj)
|
||||
file_obj = self._common_file_data_check(file_obj)
|
||||
content_item_typed["file"] = file_obj
|
||||
content_item = content_item_typed
|
||||
return content_item
|
||||
|
||||
async def _async_transform_content_item(
|
||||
self, content_item: OpenAIMessageContentListBlock, is_async: bool = False
|
||||
) -> OpenAIMessageContentListBlock:
|
||||
content_item = self._apply_common_transform_content_item(content_item)
|
||||
content_item_type = content_item.get("type")
|
||||
potential_file_obj = content_item.get("file")
|
||||
if content_item_type == "file" and potential_file_obj:
|
||||
file_obj = cast(ChatCompletionFileObjectFile, potential_file_obj)
|
||||
content_item_typed = cast(ChatCompletionFileObject, content_item)
|
||||
if self.contains_pdf_url(file_obj):
|
||||
file_obj = await self._async_handle_pdf_url(file_obj)
|
||||
file_obj = self._common_file_data_check(file_obj)
|
||||
content_item_typed["file"] = file_obj
|
||||
content_item = content_item_typed
|
||||
return content_item
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
is_async: Literal[False] = False,
|
||||
) -> List[AllMessageValues]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: bool = False
|
||||
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
|
||||
"""OpenAI no longer supports image_url as a string, so we need to convert it to a dict"""
|
||||
|
||||
async def _async_transform():
|
||||
for message in messages:
|
||||
message_content = message.get("content")
|
||||
message_role = message.get("role")
|
||||
|
||||
if (
|
||||
message_role == "user"
|
||||
and message_content
|
||||
and isinstance(message_content, list)
|
||||
):
|
||||
message_content_types = cast(
|
||||
List[OpenAIMessageContentListBlock], message_content
|
||||
)
|
||||
for i, content_item in enumerate(message_content_types):
|
||||
message_content_types[
|
||||
i
|
||||
] = await self._async_transform_content_item(
|
||||
cast(OpenAIMessageContentListBlock, content_item),
|
||||
)
|
||||
return messages
|
||||
|
||||
if is_async:
|
||||
return _async_transform()
|
||||
else:
|
||||
for message in messages:
|
||||
message_content = message.get("content")
|
||||
message_role = message.get("role")
|
||||
if (
|
||||
message_role == "user"
|
||||
and message_content
|
||||
and isinstance(message_content, list)
|
||||
):
|
||||
message_content_types = cast(
|
||||
List[OpenAIMessageContentListBlock], message_content
|
||||
)
|
||||
for i, content_item in enumerate(message_content):
|
||||
message_content_types[i] = self._transform_content_item(
|
||||
cast(OpenAIMessageContentListBlock, content_item)
|
||||
)
|
||||
return messages
|
||||
|
||||
def remove_cache_control_flag_from_messages_and_tools(
|
||||
self,
|
||||
model: str, # allows overrides to selectively run this
|
||||
messages: List[AllMessageValues],
|
||||
tools: Optional[List["ChatCompletionToolParam"]] = None,
|
||||
) -> Tuple[List[AllMessageValues], Optional[List["ChatCompletionToolParam"]]]:
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
filter_value_from_dict,
|
||||
)
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
messages[i] = cast(
|
||||
AllMessageValues, filter_value_from_dict(message, "cache_control") # type: ignore
|
||||
)
|
||||
if tools is not None:
|
||||
for i, tool in enumerate(tools):
|
||||
tools[i] = cast(
|
||||
ChatCompletionToolParam,
|
||||
filter_value_from_dict(tool, "cache_control"), # type: ignore
|
||||
)
|
||||
return messages, tools
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the overall request to be sent to the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed request. Sent as the body of the API call.
|
||||
"""
|
||||
messages = self._transform_messages(messages=messages, model=model)
|
||||
messages, tools = self.remove_cache_control_flag_from_messages_and_tools(
|
||||
model=model, messages=messages, tools=optional_params.get("tools", [])
|
||||
)
|
||||
if tools is not None and len(tools) > 0:
|
||||
optional_params["tools"] = tools
|
||||
|
||||
optional_params.pop("max_retries", None)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
async def async_transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
transformed_messages = await self._transform_messages(
|
||||
messages=messages, model=model, is_async=True
|
||||
)
|
||||
(
|
||||
transformed_messages,
|
||||
tools,
|
||||
) = self.remove_cache_control_flag_from_messages_and_tools(
|
||||
model=model,
|
||||
messages=transformed_messages,
|
||||
tools=optional_params.get("tools", []),
|
||||
)
|
||||
if tools is not None and len(tools) > 0:
|
||||
optional_params["tools"] = tools
|
||||
if self.__class__._is_base_class:
|
||||
return {
|
||||
"model": model,
|
||||
"messages": transformed_messages,
|
||||
**optional_params,
|
||||
}
|
||||
else:
|
||||
## allow for any object specific behaviour to be handled
|
||||
return self.transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
def _passed_in_tools(self, optional_params: dict) -> bool:
|
||||
return optional_params.get("tools", None) is not None
|
||||
|
||||
def _check_and_fix_if_content_is_tool_call(
|
||||
self, content: str, optional_params: dict
|
||||
) -> Optional[ChatCompletionMessageToolCall]:
|
||||
"""
|
||||
Check if the content is a tool call
|
||||
"""
|
||||
import json
|
||||
|
||||
if not self._passed_in_tools(optional_params):
|
||||
return None
|
||||
tool_call_names = get_tool_call_names(optional_params.get("tools", []))
|
||||
try:
|
||||
json_content = json.loads(content)
|
||||
if (
|
||||
json_content.get("type") == "function"
|
||||
and json_content.get("name") in tool_call_names
|
||||
):
|
||||
return ChatCompletionMessageToolCall(
|
||||
function=Function(
|
||||
name=json_content.get("name"),
|
||||
arguments=json_content.get("arguments"),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _get_finish_reason(self, message: Message, received_finish_reason: str) -> str:
|
||||
if message.tool_calls is not None:
|
||||
return "tool_calls"
|
||||
else:
|
||||
return received_finish_reason
|
||||
|
||||
def _transform_choices(
|
||||
self,
|
||||
choices: List[OpenAIChatCompletionChoices],
|
||||
json_mode: Optional[bool] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> List[Choices]:
|
||||
transformed_choices = []
|
||||
|
||||
for choice in choices:
|
||||
## HANDLE JSON MODE - anthropic returns single function call]
|
||||
tool_calls = choice["message"].get("tool_calls", None)
|
||||
new_tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
message_content = choice["message"].get("content", None)
|
||||
if tool_calls is not None:
|
||||
_openai_tool_calls = []
|
||||
for _tc in tool_calls:
|
||||
_openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore
|
||||
_openai_tool_calls.append(_openai_tc)
|
||||
fixed_tool_calls = _handle_invalid_parallel_tool_calls(
|
||||
_openai_tool_calls
|
||||
)
|
||||
|
||||
if fixed_tool_calls is not None:
|
||||
new_tool_calls = fixed_tool_calls
|
||||
elif (
|
||||
optional_params is not None
|
||||
and message_content
|
||||
and isinstance(message_content, str)
|
||||
):
|
||||
new_tool_call = self._check_and_fix_if_content_is_tool_call(
|
||||
message_content, optional_params
|
||||
)
|
||||
if new_tool_call is not None:
|
||||
choice["message"]["content"] = None # remove the content
|
||||
new_tool_calls = [new_tool_call]
|
||||
|
||||
translated_message: Optional[Message] = None
|
||||
finish_reason: Optional[str] = None
|
||||
if new_tool_calls and _should_convert_tool_call_to_json_mode(
|
||||
tool_calls=new_tool_calls,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
):
|
||||
# to support response_format on claude models
|
||||
json_mode_content_str: Optional[str] = (
|
||||
str(new_tool_calls[0]["function"].get("arguments", "")) or None
|
||||
)
|
||||
if json_mode_content_str is not None:
|
||||
translated_message = Message(content=json_mode_content_str)
|
||||
finish_reason = "stop"
|
||||
|
||||
if translated_message is None:
|
||||
## get the reasoning content
|
||||
(
|
||||
reasoning_content,
|
||||
content_str,
|
||||
) = _extract_reasoning_content(cast(dict, choice["message"]))
|
||||
|
||||
translated_message = Message(
|
||||
role="assistant",
|
||||
content=content_str,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=None,
|
||||
tool_calls=new_tool_calls,
|
||||
)
|
||||
|
||||
if finish_reason is None:
|
||||
finish_reason = choice["finish_reason"]
|
||||
|
||||
translated_choice = Choices(
|
||||
finish_reason=finish_reason,
|
||||
index=choice["index"],
|
||||
message=translated_message,
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
translated_choice.finish_reason = map_finish_reason(
|
||||
self._get_finish_reason(translated_message, choice["finish_reason"])
|
||||
)
|
||||
transformed_choices.append(translated_choice)
|
||||
|
||||
return transformed_choices
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the response from the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed response.
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
final_response_obj = convert_to_model_response_object(
|
||||
response_object=completion_response,
|
||||
model_response_object=model_response,
|
||||
hidden_params={"headers": raw_response_headers},
|
||||
_response_headers=raw_response_headers,
|
||||
)
|
||||
|
||||
return cast(ModelResponse, final_response_obj)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers),
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the API call.
|
||||
|
||||
Returns:
|
||||
str: The complete URL for the API call.
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
endpoint = "chat/completions"
|
||||
|
||||
# Remove trailing slash from api_base if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Check if endpoint is already in the api_base
|
||||
if endpoint in api_base:
|
||||
return api_base
|
||||
|
||||
return f"{api_base}/{endpoint}"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ensure Content-Type is set to application/json
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
|
||||
"""
|
||||
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("OPENAI_API_KEY")
|
||||
|
||||
# Strip api_base to just the base URL (scheme + host + port)
|
||||
parsed_url = httpx.URL(api_base)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.host}"
|
||||
if parsed_url.port:
|
||||
base_url += f":{parsed_url.port}"
|
||||
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{base_url}/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get models: {response.text}")
|
||||
|
||||
models = response.json()["data"]
|
||||
return [model["id"] for model in models]
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_token_counter(self) -> Optional["BaseTokenCounter"]:
|
||||
from litellm.llms.openai.responses.count_tokens.token_counter import (
|
||||
OpenAITokenCounter,
|
||||
)
|
||||
|
||||
return OpenAITokenCounter()
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return OpenAIChatCompletionStreamingHandler(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
|
||||
def _map_reasoning_to_reasoning_content(self, choices: list) -> list:
|
||||
"""
|
||||
Map 'reasoning' field to 'reasoning_content' field in delta.
|
||||
|
||||
Some OpenAI-compatible providers (e.g., GLM-5, hosted_vllm) return
|
||||
delta.reasoning, but LiteLLM expects delta.reasoning_content.
|
||||
|
||||
Args:
|
||||
choices: List of choice objects from the streaming chunk
|
||||
|
||||
Returns:
|
||||
List of choices with reasoning field mapped to reasoning_content
|
||||
"""
|
||||
for choice in choices:
|
||||
delta = choice.get("delta", {})
|
||||
if "reasoning" in delta:
|
||||
delta["reasoning_content"] = delta.pop("reasoning")
|
||||
return choices
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
try:
|
||||
choices = chunk.get("choices", [])
|
||||
choices = self._map_reasoning_to_reasoning_content(choices)
|
||||
|
||||
kwargs = {
|
||||
"id": chunk["id"],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": chunk.get("created"),
|
||||
"model": chunk.get("model"),
|
||||
"choices": choices,
|
||||
}
|
||||
if "usage" in chunk and chunk["usage"] is not None:
|
||||
kwargs["usage"] = chunk["usage"]
|
||||
return ModelResponseStream(**kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,3 @@
|
||||
Translation of OpenAI `/chat/completions` input and output to a custom guardrail.
|
||||
|
||||
This enables guardrails to be applied to OpenAI `/chat/completions` requests and responses.
|
||||
@@ -0,0 +1,12 @@
|
||||
"""OpenAI Chat Completions message handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.chat.guardrail_translation.handler import (
|
||||
OpenAIChatCompletionsHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.completion: OpenAIChatCompletionsHandler,
|
||||
CallTypes.acompletion: OpenAIChatCompletionsHandler,
|
||||
}
|
||||
__all__ = ["guardrail_translation_mappings"]
|
||||
@@ -0,0 +1,808 @@
|
||||
"""
|
||||
OpenAI Chat Completions Message Handler for Unified Guardrails
|
||||
|
||||
This module provides a class-based handler for OpenAI-format chat completions.
|
||||
The class methods can be overridden for custom behavior.
|
||||
|
||||
Pattern Overview:
|
||||
-----------------
|
||||
1. Extract text content from messages/responses (both string and list formats)
|
||||
2. Create async tasks to apply guardrails to each text segment
|
||||
3. Track mappings to know where each response belongs
|
||||
4. Apply guardrail responses back to the original structure
|
||||
|
||||
This pattern can be replicated for other message formats (e.g., Anthropic).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.main import stream_chunk_builder
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
GenericGuardrailAPIInputs,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
||||
|
||||
class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI chat completions messages with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input messages (pre-call hook)
|
||||
2. Process output responses (post-call hook)
|
||||
|
||||
Methods can be overridden to customize behavior for different message formats.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input messages by applying guardrails to text content.
|
||||
"""
|
||||
messages = data.get("messages")
|
||||
if messages is None:
|
||||
return data
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
tool_calls_to_check: List[ChatCompletionToolParam] = []
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
tool_call_task_mappings: List[Tuple[int, int]] = []
|
||||
# text_task_mappings: Track (message_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
# tool_call_task_mappings: Track (message_index, tool_call_index) for each tool call
|
||||
|
||||
# Step 1: Extract all text content, images, and tool calls
|
||||
for msg_idx, message in enumerate(messages):
|
||||
self._extract_inputs(
|
||||
message=message,
|
||||
msg_idx=msg_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
text_task_mappings=text_task_mappings,
|
||||
tool_call_task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts and tool calls in batch
|
||||
if texts_to_check or tool_calls_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check # type: ignore
|
||||
if messages:
|
||||
inputs[
|
||||
"structured_messages"
|
||||
] = messages # pass the openai /chat/completions messages to the guardrail, as-is
|
||||
# Pass tools (function definitions) to the guardrail
|
||||
tools = data.get("tools")
|
||||
if tools:
|
||||
inputs["tools"] = tools
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
guardrailed_tool_calls = guardrailed_inputs.get("tool_calls", [])
|
||||
guardrailed_tools = guardrailed_inputs.get("tools")
|
||||
if guardrailed_tools is not None:
|
||||
data["tools"] = guardrailed_tools
|
||||
|
||||
# Step 3: Map guardrail responses back to original message structure
|
||||
if guardrailed_texts and texts_to_check:
|
||||
await self._apply_guardrail_responses_to_input_texts(
|
||||
messages=messages,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=text_task_mappings,
|
||||
)
|
||||
|
||||
# Step 4: Apply guardrailed tool calls back to messages
|
||||
if guardrailed_tool_calls:
|
||||
# Note: The guardrail may modify tool_calls_to_check in place
|
||||
# or we may need to handle returned tool calls differently
|
||||
await self._apply_guardrail_responses_to_input_tool_calls(
|
||||
messages=messages,
|
||||
tool_calls=guardrailed_tool_calls, # type: ignore
|
||||
task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processed input messages: %s", messages
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""Extract tool names from OpenAI chat completions request (tools[].function.name, functions[].name)."""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if isinstance(tool, dict) and tool.get("type") == "function":
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.append(str(fn["name"]))
|
||||
for fn in data.get("functions") or []:
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.append(str(fn["name"]))
|
||||
return names
|
||||
|
||||
def _extract_inputs(
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
msg_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
tool_calls_to_check: List[ChatCompletionToolParam],
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]],
|
||||
tool_call_task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content, images, and tool calls from a message.
|
||||
|
||||
Override this method to customize text/image/tool call extraction logic.
|
||||
"""
|
||||
content = message.get("content", None)
|
||||
if content is not None:
|
||||
if isinstance(content, str):
|
||||
# Simple string content
|
||||
texts_to_check.append(content)
|
||||
text_task_mappings.append((msg_idx, None))
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content (e.g., multimodal with text and images)
|
||||
for content_idx, content_item in enumerate(content):
|
||||
# Extract text
|
||||
text_str = content_item.get("text", None)
|
||||
if text_str is not None:
|
||||
texts_to_check.append(text_str)
|
||||
text_task_mappings.append((msg_idx, int(content_idx)))
|
||||
|
||||
# Extract images (image_url)
|
||||
if content_item.get("type") == "image_url":
|
||||
image_url = content_item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url")
|
||||
if url:
|
||||
images_to_check.append(url)
|
||||
elif isinstance(image_url, str):
|
||||
images_to_check.append(image_url)
|
||||
|
||||
# Extract tool calls (typically in assistant messages)
|
||||
tool_calls = message.get("tool_calls", None)
|
||||
if tool_calls is not None and isinstance(tool_calls, list):
|
||||
for tool_call_idx, tool_call in enumerate(tool_calls):
|
||||
if isinstance(tool_call, dict):
|
||||
# Add the full tool call object to the list
|
||||
tool_calls_to_check.append(cast(ChatCompletionToolParam, tool_call))
|
||||
tool_call_task_mappings.append((msg_idx, int(tool_call_idx)))
|
||||
|
||||
async def _apply_guardrail_responses_to_input_texts(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to input message text content.
|
||||
|
||||
Override this method to customize how text responses are applied.
|
||||
"""
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
msg_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
|
||||
# Handle content
|
||||
content = messages[msg_idx].get("content", None)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str) and content_idx_optional is None:
|
||||
# Replace string content with guardrail response
|
||||
messages[msg_idx]["content"] = guardrail_response
|
||||
|
||||
elif isinstance(content, list) and content_idx_optional is not None:
|
||||
# Replace specific text item in list content
|
||||
messages[msg_idx]["content"][content_idx_optional][
|
||||
"text"
|
||||
] = guardrail_response
|
||||
|
||||
async def _apply_guardrail_responses_to_input_tool_calls(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrailed tool calls back to input messages.
|
||||
|
||||
The guardrail may have modified the tool_calls list in place,
|
||||
so we apply the modified tool calls back to the original messages.
|
||||
|
||||
Override this method to customize how tool call responses are applied.
|
||||
"""
|
||||
for task_idx, (msg_idx, tool_call_idx) in enumerate(task_mappings):
|
||||
if task_idx < len(tool_calls):
|
||||
guardrailed_tool_call = tool_calls[task_idx]
|
||||
message_tool_calls = messages[msg_idx].get("tool_calls", None)
|
||||
if message_tool_calls is not None and isinstance(
|
||||
message_tool_calls, list
|
||||
):
|
||||
if tool_call_idx < len(message_tool_calls):
|
||||
# Replace the tool call with the guardrailed version
|
||||
message_tool_calls[tool_call_idx] = guardrailed_tool_call
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrail applied to content
|
||||
|
||||
Response Format Support:
|
||||
- String content: choice.message.content = "text here"
|
||||
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
|
||||
"""
|
||||
|
||||
# Step 0: Check if response has any text content to process
|
||||
if not self._has_text_content(response):
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Chat Completions: No text content in response, skipping guardrail"
|
||||
)
|
||||
return response
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
tool_calls_to_check: List[Dict[str, Any]] = []
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
tool_call_task_mappings: List[Tuple[int, int]] = []
|
||||
# text_task_mappings: Track (choice_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
# tool_call_task_mappings: Track (choice_index, tool_call_index) for each tool call
|
||||
|
||||
# Step 1: Extract all text content, images, and tool calls from response choices
|
||||
for choice_idx, choice in enumerate(response.choices):
|
||||
self._extract_output_text_images_and_tool_calls(
|
||||
choice=choice,
|
||||
choice_idx=choice_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
text_task_mappings=text_task_mappings,
|
||||
tool_call_task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts and tool calls in batch
|
||||
if texts_to_check or tool_calls_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check # type: ignore
|
||||
# Include model information from the response if available
|
||||
if hasattr(response, "model") and response.model:
|
||||
inputs["model"] = response.model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Map guardrail responses back to original response structure
|
||||
if guardrailed_texts and texts_to_check:
|
||||
await self._apply_guardrail_responses_to_output_texts(
|
||||
response=response,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=text_task_mappings,
|
||||
)
|
||||
|
||||
# Step 4: Apply guardrailed tool calls back to response
|
||||
if tool_calls_to_check:
|
||||
await self._apply_guardrail_responses_to_output_tool_calls(
|
||||
response=response,
|
||||
tool_calls=tool_calls_to_check,
|
||||
task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processed output response: %s", response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List["ModelResponseStream"],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> List["ModelResponseStream"]:
|
||||
"""
|
||||
Process output streaming responses by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
responses_so_far: List of LiteLLM ModelResponseStream objects
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified list of responses with guardrail applied to content
|
||||
|
||||
Response Format Support:
|
||||
- String content: choice.message.content = "text here"
|
||||
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
|
||||
"""
|
||||
# check if the stream has ended
|
||||
has_stream_ended = False
|
||||
for chunk in responses_so_far:
|
||||
if chunk.choices and chunk.choices[0].finish_reason is not None:
|
||||
has_stream_ended = True
|
||||
break
|
||||
|
||||
if has_stream_ended:
|
||||
# convert to model response
|
||||
model_response = cast(
|
||||
ModelResponse,
|
||||
stream_chunk_builder(
|
||||
chunks=responses_so_far, logging_obj=litellm_logging_obj
|
||||
),
|
||||
)
|
||||
# run process_output_response
|
||||
await self.process_output_response(
|
||||
response=model_response,
|
||||
guardrail_to_apply=guardrail_to_apply,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
return responses_so_far
|
||||
|
||||
# Step 0: Check if any response has text content to process
|
||||
has_any_text_content = False
|
||||
for response in responses_so_far:
|
||||
if self._has_text_content(response):
|
||||
has_any_text_content = True
|
||||
break
|
||||
|
||||
if not has_any_text_content:
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Chat Completions: No text content in streaming responses, skipping guardrail"
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
# Step 1: Combine all streaming chunks into complete text per choice
|
||||
# For streaming, we need to concatenate all delta.content across all chunks
|
||||
# Key: (choice_idx, content_idx), Value: combined text
|
||||
combined_texts = self._combine_streaming_texts(responses_so_far)
|
||||
|
||||
# Step 2: Create lists for guardrail processing
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (choice_index, content_index) for each combined text
|
||||
|
||||
for (map_choice_idx, map_content_idx), combined_text in combined_texts.items():
|
||||
texts_to_check.append(combined_text)
|
||||
task_mappings.append((map_choice_idx, map_content_idx))
|
||||
|
||||
# Step 3: Apply guardrail to all combined texts in batch
|
||||
if texts_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"responses": responses_so_far}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
# Include model information from the first response if available
|
||||
if (
|
||||
responses_so_far
|
||||
and hasattr(responses_so_far[0], "model")
|
||||
and responses_so_far[0].model
|
||||
):
|
||||
inputs["model"] = responses_so_far[0].model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 4: Apply guardrailed text back to all streaming chunks
|
||||
# For each choice, replace the combined text across all chunks
|
||||
await self._apply_guardrail_responses_to_output_streaming(
|
||||
responses=responses_so_far,
|
||||
guardrailed_texts=guardrailed_texts,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processed output streaming responses: %s",
|
||||
responses_so_far,
|
||||
)
|
||||
|
||||
return responses_so_far
|
||||
|
||||
def _combine_streaming_texts(
|
||||
self, responses_so_far: List["ModelResponseStream"]
|
||||
) -> Dict[Tuple[int, Optional[int]], str]:
|
||||
"""
|
||||
Combine all streaming chunks into complete text per choice.
|
||||
|
||||
For streaming, we need to concatenate all delta.content across all chunks.
|
||||
|
||||
Args:
|
||||
responses_so_far: List of LiteLLM ModelResponseStream objects
|
||||
|
||||
Returns:
|
||||
Dict mapping (choice_idx, content_idx) to combined text string
|
||||
"""
|
||||
combined_texts: Dict[Tuple[int, Optional[int]], str] = {}
|
||||
|
||||
for response_idx, response in enumerate(responses_so_far):
|
||||
for choice_idx, choice in enumerate(response.choices):
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
else:
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str):
|
||||
# String content - accumulate for this choice
|
||||
str_key: Tuple[int, Optional[int]] = (choice_idx, None)
|
||||
if str_key not in combined_texts:
|
||||
combined_texts[str_key] = ""
|
||||
combined_texts[str_key] += content
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content - accumulate for each content item
|
||||
for content_idx, content_item in enumerate(content):
|
||||
text_str = content_item.get("text")
|
||||
if text_str:
|
||||
list_key: Tuple[int, Optional[int]] = (
|
||||
choice_idx,
|
||||
content_idx,
|
||||
)
|
||||
if list_key not in combined_texts:
|
||||
combined_texts[list_key] = ""
|
||||
combined_texts[list_key] += text_str
|
||||
|
||||
return combined_texts
|
||||
|
||||
def _has_text_content(
|
||||
self, response: Union["ModelResponse", "ModelResponseStream"]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if response has any text content or tool calls to process.
|
||||
|
||||
Override this method to customize text content detection.
|
||||
"""
|
||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||
|
||||
if isinstance(response, ModelResponse):
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
# Check for text content
|
||||
if choice.message.content and isinstance(
|
||||
choice.message.content, str
|
||||
):
|
||||
return True
|
||||
# Check for tool calls
|
||||
if choice.message.tool_calls and isinstance(
|
||||
choice.message.tool_calls, list
|
||||
):
|
||||
if len(choice.message.tool_calls) > 0:
|
||||
return True
|
||||
elif isinstance(response, ModelResponseStream):
|
||||
for streaming_choice in response.choices:
|
||||
if isinstance(streaming_choice, litellm.StreamingChoices):
|
||||
# Check for text content
|
||||
if streaming_choice.delta.content and isinstance(
|
||||
streaming_choice.delta.content, str
|
||||
):
|
||||
return True
|
||||
# Check for tool calls
|
||||
if streaming_choice.delta.tool_calls and isinstance(
|
||||
streaming_choice.delta.tool_calls, list
|
||||
):
|
||||
if len(streaming_choice.delta.tool_calls) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_output_text_images_and_tool_calls(
|
||||
self,
|
||||
choice: Union[Choices, StreamingChoices],
|
||||
choice_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
tool_calls_to_check: List[Dict[str, Any]],
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]],
|
||||
tool_call_task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content, images, and tool calls from a response choice.
|
||||
|
||||
Override this method to customize text/image/tool call extraction logic.
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processing choice: %s", choice
|
||||
)
|
||||
|
||||
# Determine content source and tool calls based on choice type
|
||||
content = None
|
||||
tool_calls: Optional[List[Any]] = None
|
||||
if isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
tool_calls = choice.message.tool_calls
|
||||
elif isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
tool_calls = choice.delta.tool_calls
|
||||
else:
|
||||
# Unknown choice type, skip processing
|
||||
return
|
||||
|
||||
# Process content if it exists
|
||||
if content and isinstance(content, str):
|
||||
# Simple string content
|
||||
texts_to_check.append(content)
|
||||
text_task_mappings.append((choice_idx, None))
|
||||
|
||||
elif content and isinstance(content, list):
|
||||
# List content (e.g., multimodal response)
|
||||
for content_idx, content_item in enumerate(content):
|
||||
# Extract text
|
||||
content_text = content_item.get("text")
|
||||
if content_text:
|
||||
texts_to_check.append(content_text)
|
||||
text_task_mappings.append((choice_idx, int(content_idx)))
|
||||
|
||||
# Extract images
|
||||
if content_item.get("type") == "image_url":
|
||||
image_url = content_item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url")
|
||||
if url:
|
||||
images_to_check.append(url)
|
||||
|
||||
# Process tool calls if they exist
|
||||
if tool_calls is not None and isinstance(tool_calls, list):
|
||||
for tool_call_idx, tool_call in enumerate(tool_calls):
|
||||
# Convert tool call to dict format for guardrail processing
|
||||
tool_call_dict = self._convert_tool_call_to_dict(tool_call)
|
||||
if tool_call_dict:
|
||||
tool_calls_to_check.append(tool_call_dict)
|
||||
tool_call_task_mappings.append((choice_idx, int(tool_call_idx)))
|
||||
|
||||
def _convert_tool_call_to_dict(
|
||||
self, tool_call: Union[Dict[str, Any], Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a tool call object to dictionary format.
|
||||
|
||||
Tool calls can be either dict or object depending on the type.
|
||||
"""
|
||||
if isinstance(tool_call, dict):
|
||||
return tool_call
|
||||
elif hasattr(tool_call, "id") and hasattr(tool_call, "function"):
|
||||
# Convert object to dict
|
||||
function = tool_call.function
|
||||
function_dict = {}
|
||||
if hasattr(function, "name"):
|
||||
function_dict["name"] = function.name
|
||||
if hasattr(function, "arguments"):
|
||||
function_dict["arguments"] = function.arguments
|
||||
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id if hasattr(tool_call, "id") else None,
|
||||
"type": tool_call.type if hasattr(tool_call, "type") else "function",
|
||||
"function": function_dict,
|
||||
}
|
||||
return tool_call_dict
|
||||
return None
|
||||
|
||||
async def _apply_guardrail_responses_to_output_texts(
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail text responses back to output response.
|
||||
|
||||
Override this method to customize how text responses are applied.
|
||||
"""
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
choice_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
|
||||
choice = cast(Choices, response.choices[choice_idx])
|
||||
|
||||
# Handle content
|
||||
content = choice.message.content
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str) and content_idx_optional is None:
|
||||
# Replace string content with guardrail response
|
||||
choice.message.content = guardrail_response
|
||||
|
||||
elif isinstance(content, list) and content_idx_optional is not None:
|
||||
# Replace specific text item in list content
|
||||
choice.message.content[content_idx_optional]["text"] = guardrail_response # type: ignore
|
||||
|
||||
async def _apply_guardrail_responses_to_output_tool_calls(
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrailed tool calls back to output response.
|
||||
|
||||
The guardrail may have modified the tool_calls list in place,
|
||||
so we apply the modified tool calls back to the original response.
|
||||
|
||||
Override this method to customize how tool call responses are applied.
|
||||
"""
|
||||
for task_idx, (choice_idx, tool_call_idx) in enumerate(task_mappings):
|
||||
if task_idx < len(tool_calls):
|
||||
guardrailed_tool_call = tool_calls[task_idx]
|
||||
choice = cast(Choices, response.choices[choice_idx])
|
||||
choice_tool_calls = choice.message.tool_calls
|
||||
|
||||
if choice_tool_calls is not None and isinstance(
|
||||
choice_tool_calls, list
|
||||
):
|
||||
if tool_call_idx < len(choice_tool_calls):
|
||||
# Update the tool call with guardrailed version
|
||||
existing_tool_call = choice_tool_calls[tool_call_idx]
|
||||
# Update object attributes (output responses always have typed objects)
|
||||
if "function" in guardrailed_tool_call:
|
||||
func_dict = guardrailed_tool_call["function"]
|
||||
if "arguments" in func_dict:
|
||||
existing_tool_call.function.arguments = func_dict[
|
||||
"arguments"
|
||||
]
|
||||
if "name" in func_dict:
|
||||
existing_tool_call.function.name = func_dict["name"]
|
||||
|
||||
async def _apply_guardrail_responses_to_output_streaming(
|
||||
self,
|
||||
responses: List["ModelResponseStream"],
|
||||
guardrailed_texts: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to output streaming responses.
|
||||
|
||||
For streaming responses, the guardrailed text (which is the combined text from all chunks)
|
||||
is placed in the first chunk, and subsequent chunks are cleared.
|
||||
|
||||
Args:
|
||||
responses: List of ModelResponseStream objects to modify
|
||||
guardrailed_texts: List of guardrailed text responses (combined from all chunks)
|
||||
task_mappings: List of tuples (choice_idx, content_idx)
|
||||
|
||||
Override this method to customize how responses are applied to streaming responses.
|
||||
"""
|
||||
# Build a mapping of what guardrailed text to use for each (choice_idx, content_idx)
|
||||
guardrail_map: Dict[Tuple[int, Optional[int]], str] = {}
|
||||
for task_idx, guardrail_response in enumerate(guardrailed_texts):
|
||||
mapping = task_mappings[task_idx]
|
||||
choice_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
guardrail_map[(choice_idx, content_idx_optional)] = guardrail_response
|
||||
|
||||
# Track which choices we've already set the guardrailed text for
|
||||
# Key: (choice_idx, content_idx), Value: boolean (True if already set)
|
||||
already_set: Dict[Tuple[int, Optional[int]], bool] = {}
|
||||
|
||||
# Iterate through all responses and update content
|
||||
for response_idx, response in enumerate(responses):
|
||||
for choice_idx_in_response, choice in enumerate(response.choices):
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
else:
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str):
|
||||
# String content
|
||||
str_key: Tuple[int, Optional[int]] = (choice_idx_in_response, None)
|
||||
if str_key in guardrail_map:
|
||||
if str_key not in already_set:
|
||||
# First chunk - set the complete guardrailed text
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
choice.delta.content = guardrail_map[str_key]
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
choice.message.content = guardrail_map[str_key]
|
||||
already_set[str_key] = True
|
||||
else:
|
||||
# Subsequent chunks - clear the content
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
choice.delta.content = ""
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
choice.message.content = ""
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content - handle each content item
|
||||
for content_idx, content_item in enumerate(content):
|
||||
if "text" in content_item:
|
||||
list_key: Tuple[int, Optional[int]] = (
|
||||
choice_idx_in_response,
|
||||
content_idx,
|
||||
)
|
||||
if list_key in guardrail_map:
|
||||
if list_key not in already_set:
|
||||
# First chunk - set the complete guardrailed text
|
||||
content_item["text"] = guardrail_map[list_key]
|
||||
already_set[list_key] = True
|
||||
else:
|
||||
# Subsequent chunks - clear the text
|
||||
content_item["text"] = ""
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
LLM Calling done in `openai/openai.py`
|
||||
"""
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Support for o1/o3 model family
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.utils import (
|
||||
supports_function_calling,
|
||||
supports_parallel_function_calling,
|
||||
supports_response_schema,
|
||||
supports_system_messages,
|
||||
)
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/reasoning
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
O-series models support `developer` role.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the given model
|
||||
|
||||
"""
|
||||
|
||||
all_openai_params = super().get_supported_openai_params(model=model)
|
||||
non_supported_params = [
|
||||
"logprobs",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"top_logprobs",
|
||||
]
|
||||
|
||||
o_series_only_param = ["reasoning_effort"]
|
||||
|
||||
all_openai_params.extend(o_series_only_param)
|
||||
|
||||
try:
|
||||
model, custom_llm_provider, api_base, api_key = get_llm_provider(
|
||||
model=model
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
f"Unable to infer model provider for model={model}, defaulting to openai for o1 supported param check"
|
||||
)
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
_supports_function_calling = supports_function_calling(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
_supports_response_schema = supports_response_schema(model, custom_llm_provider)
|
||||
_supports_parallel_tool_calls = supports_parallel_function_calling(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
|
||||
if not _supports_function_calling:
|
||||
non_supported_params.append("tools")
|
||||
non_supported_params.append("tool_choice")
|
||||
non_supported_params.append("function_call")
|
||||
non_supported_params.append("functions")
|
||||
|
||||
if not _supports_parallel_tool_calls:
|
||||
non_supported_params.append("parallel_tool_calls")
|
||||
|
||||
if not _supports_response_schema:
|
||||
non_supported_params.append("response_format")
|
||||
|
||||
return [
|
||||
param for param in all_openai_params if param not in non_supported_params
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
if "temperature" in non_default_params:
|
||||
temperature_value: Optional[float] = non_default_params.pop("temperature")
|
||||
if temperature_value is not None:
|
||||
if temperature_value == 1:
|
||||
optional_params["temperature"] = temperature_value
|
||||
else:
|
||||
## UNSUPPORTED TOOL CHOICE VALUE
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="O-series models don't support temperature={}. Only temperature=1 is supported. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||
temperature_value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
return super()._map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def is_model_o_series_model(self, model: str) -> bool:
|
||||
model = model.split("/")[-1] # could be "openai/o3" or "o3"
|
||||
return (
|
||||
len(model) > 1
|
||||
and model[0] == "o"
|
||||
and model[1].isdigit()
|
||||
and model in litellm.open_ai_chat_completion_models
|
||||
)
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
is_async: Literal[False] = False,
|
||||
) -> List[AllMessageValues]:
|
||||
...
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: bool = False
|
||||
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
|
||||
"""
|
||||
Handles limitations of O-1 model family.
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
"""
|
||||
_supports_system_messages = supports_system_messages(model, "openai")
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "system" and not _supports_system_messages:
|
||||
new_message = ChatCompletionUserMessage(
|
||||
content=message["content"], role="user"
|
||||
)
|
||||
messages[i] = new_message # Replace the old message with the new one
|
||||
|
||||
if is_async:
|
||||
return super()._transform_messages(
|
||||
messages, model, is_async=cast(Literal[True], True)
|
||||
)
|
||||
else:
|
||||
return super()._transform_messages(
|
||||
messages, model, is_async=cast(Literal[False], False)
|
||||
)
|
||||
Reference in New Issue
Block a user