1113 lines
41 KiB
Python
1113 lines
41 KiB
Python
import re
|
|
from copy import deepcopy
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
|
from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
|
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.llms.vertex_ai import PartType, Schema
|
|
from litellm.types.utils import TokenCountResponse
|
|
from litellm.utils import supports_response_schema, supports_system_messages
|
|
|
|
|
|
class VertexAIError(BaseLLMException):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
message: str,
|
|
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
|
):
|
|
super().__init__(message=message, status_code=status_code, headers=headers)
|
|
|
|
|
|
class VertexAIModelRoute(str, Enum):
|
|
"""Enum for Vertex AI model routing"""
|
|
|
|
PARTNER_MODELS = "partner_models"
|
|
GEMINI = "gemini"
|
|
GEMMA = "gemma"
|
|
BGE = "bge"
|
|
MODEL_GARDEN = "model_garden"
|
|
NON_GEMINI = "non_gemini"
|
|
OPENAI_COMPATIBLE = "openai"
|
|
AGENT_ENGINE = "agent_engine"
|
|
|
|
|
|
VERTEX_AI_MODEL_ROUTES = [f"{route.value}/" for route in VertexAIModelRoute]
|
|
|
|
|
|
def get_vertex_ai_model_route(
|
|
model: str, litellm_params: Optional[dict] = None
|
|
) -> VertexAIModelRoute:
|
|
"""
|
|
Determine which handler to use for a Vertex AI model based on the model name.
|
|
|
|
Args:
|
|
model: The model name (e.g., "llama3-405b", "gemini-pro", "gemma/gemma-3-12b-it", "openai/gpt-oss-120b")
|
|
litellm_params: Optional litellm parameters dict that may contain base_model for routing
|
|
|
|
Returns:
|
|
VertexAIModelRoute: The route enum indicating which handler should be used
|
|
|
|
Examples:
|
|
>>> get_vertex_ai_model_route("llama3-405b")
|
|
VertexAIModelRoute.PARTNER_MODELS
|
|
|
|
>>> get_vertex_ai_model_route("gemini-pro")
|
|
VertexAIModelRoute.GEMINI
|
|
|
|
>>> get_vertex_ai_model_route("gemma/gemma-3-12b-it")
|
|
VertexAIModelRoute.GEMMA
|
|
|
|
>>> get_vertex_ai_model_route("openai/gpt-oss-120b")
|
|
VertexAIModelRoute.MODEL_GARDEN
|
|
|
|
>>> get_vertex_ai_model_route("1234567890", {"api_base": "http://10.96.32.8"})
|
|
VertexAIModelRoute.GEMINI # Numeric endpoints with api_base use HTTP path
|
|
"""
|
|
from litellm.llms.vertex_ai.vertex_ai_partner_models.main import (
|
|
VertexAIPartnerModels,
|
|
)
|
|
|
|
# Check base_model in litellm_params for gemini override
|
|
if litellm_params and litellm_params.get("base_model") is not None:
|
|
if "gemini" in litellm_params["base_model"]:
|
|
return VertexAIModelRoute.GEMINI
|
|
|
|
# Check for agent_engine models (Reasoning Engines)
|
|
if "agent_engine/" in model:
|
|
return VertexAIModelRoute.AGENT_ENGINE
|
|
|
|
# Check if numeric endpoint ID with custom api_base (PSC endpoint)
|
|
# Route to GEMINI (HTTP path) to support PSC endpoints properly
|
|
if model.isdigit() and litellm_params and litellm_params.get("api_base"):
|
|
return VertexAIModelRoute.GEMINI
|
|
|
|
# Check for partner models (llama, mistral, claude, etc.)
|
|
if VertexAIPartnerModels.is_vertex_partner_model(model=model):
|
|
return VertexAIModelRoute.PARTNER_MODELS
|
|
|
|
# Check for BGE models
|
|
if "bge/" in model or "bge" in model.lower():
|
|
return VertexAIModelRoute.BGE
|
|
|
|
# Check for gemma models
|
|
if "gemma/" in model:
|
|
return VertexAIModelRoute.GEMMA
|
|
|
|
# Check for model garden openai models
|
|
if "openai" in model:
|
|
return VertexAIModelRoute.MODEL_GARDEN
|
|
|
|
# Check for gemini models
|
|
if "gemini" in model:
|
|
return VertexAIModelRoute.GEMINI
|
|
|
|
# Default to non-gemini (legacy vertex models like chat-bison, text-bison, etc.)
|
|
return VertexAIModelRoute.NON_GEMINI
|
|
|
|
|
|
def get_supports_system_message(
|
|
model: str, custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"]
|
|
) -> bool:
|
|
try:
|
|
_custom_llm_provider = custom_llm_provider
|
|
if custom_llm_provider == "vertex_ai_beta":
|
|
_custom_llm_provider = "vertex_ai"
|
|
supports_system_message = supports_system_messages(
|
|
model=model, custom_llm_provider=_custom_llm_provider
|
|
)
|
|
|
|
# Vertex Models called in the `/gemini` request/response format also support system messages
|
|
if litellm.VertexGeminiConfig._is_model_gemini_spec_model(model):
|
|
supports_system_message = True
|
|
except Exception as e:
|
|
verbose_logger.warning(
|
|
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
|
str(e)
|
|
)
|
|
)
|
|
supports_system_message = False
|
|
|
|
return supports_system_message
|
|
|
|
|
|
def get_supports_response_schema(
|
|
model: str, custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"]
|
|
) -> bool:
|
|
_custom_llm_provider = custom_llm_provider
|
|
if custom_llm_provider == "vertex_ai_beta":
|
|
_custom_llm_provider = "vertex_ai"
|
|
|
|
_supports_response_schema = supports_response_schema(
|
|
model=model, custom_llm_provider=_custom_llm_provider
|
|
)
|
|
|
|
return _supports_response_schema
|
|
|
|
|
|
def supports_response_json_schema(model: str) -> bool:
|
|
"""
|
|
Check if the model supports responseJsonSchema (JSON Schema format).
|
|
|
|
responseJsonSchema is supported by Gemini 2.0+ models and uses standard
|
|
JSON Schema format with lowercase types (string, object, etc.) instead of
|
|
the OpenAPI-style responseSchema with uppercase types (STRING, OBJECT, etc.).
|
|
|
|
Benefits of responseJsonSchema:
|
|
- Supports additionalProperties for stricter schema validation
|
|
- Uses standard JSON Schema format (no type conversion needed)
|
|
- Better compatibility with Pydantic's model_json_schema()
|
|
|
|
Args:
|
|
model: The model name (e.g., "gemini-2.0-flash", "gemini-2.5-pro")
|
|
|
|
Returns:
|
|
True if the model supports responseJsonSchema, False otherwise
|
|
"""
|
|
model_lower = model.lower()
|
|
|
|
# Gemini 2.0+ and 2.5+ models support responseJsonSchema
|
|
# Pattern matches: gemini-2.0-*, gemini-2.5-*, gemini-3-*, etc.
|
|
gemini_2_plus_pattern = re.compile(r"gemini-([2-9]|[1-9]\d+)\.")
|
|
|
|
return bool(gemini_2_plus_pattern.search(model_lower))
|
|
|
|
|
|
from typing import Literal, Optional
|
|
|
|
all_gemini_url_modes = Literal[
|
|
"chat", "embedding", "batch_embedding", "image_generation", "count_tokens"
|
|
]
|
|
|
|
|
|
def get_vertex_base_model_name(model: str) -> str:
|
|
"""
|
|
Strip routing prefixes from model name for PSC/endpoint URL construction.
|
|
|
|
Patterns like "bge/", "gemma/", "openai/" are used for internal routing but
|
|
should not appear in the actual endpoint URL. Routing prefixes are derived
|
|
from VertexAIModelRoute enum values.
|
|
|
|
Args:
|
|
model: The model name with potential prefix (e.g., "bge/123456", "gemma/gemma-3-12b-it")
|
|
|
|
Returns:
|
|
str: The model name without routing prefix (e.g., "123456", "gemma-3-12b-it")
|
|
|
|
Examples:
|
|
>>> get_vertex_base_model_name("bge/378943383978115072")
|
|
"378943383978115072"
|
|
|
|
>>> get_vertex_base_model_name("gemma/gemma-3-12b-it")
|
|
"gemma-3-12b-it"
|
|
|
|
>>> get_vertex_base_model_name("openai/gpt-oss-120b")
|
|
"gpt-oss-120b"
|
|
|
|
>>> get_vertex_base_model_name("1234567890")
|
|
"1234567890"
|
|
"""
|
|
# Derive routing prefixes from VertexAIModelRoute enum
|
|
# Map specific routes to their prefixes (some routes like PARTNER_MODELS, GEMINI don't have prefixes)
|
|
for route in VERTEX_AI_MODEL_ROUTES:
|
|
if model.startswith(route):
|
|
return model.replace(route, "", 1)
|
|
|
|
return model
|
|
|
|
|
|
def get_vertex_base_url(
|
|
vertex_location: Optional[str],
|
|
) -> str:
|
|
"""
|
|
Get the base URL for Vertex AI API calls.
|
|
"""
|
|
if vertex_location == "global":
|
|
return "https://aiplatform.googleapis.com"
|
|
else:
|
|
return f"https://{vertex_location}-aiplatform.googleapis.com"
|
|
|
|
|
|
def _get_embedding_url(
|
|
model: str,
|
|
vertex_project: Optional[str],
|
|
vertex_location: Optional[str],
|
|
vertex_api_version: Literal["v1", "v1beta1"],
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Get URL for embedding models.
|
|
|
|
Handles special patterns:
|
|
- bge/endpoint_id -> strips to endpoint_id for endpoints/ routing
|
|
- numeric model -> routes to endpoints/
|
|
- regular model -> routes to publishers/google/models/
|
|
- models with uses_embed_content flag -> use embedContent endpoint instead of predict
|
|
"""
|
|
original_model = model
|
|
model = get_vertex_base_model_name(model=model)
|
|
|
|
try:
|
|
model_info = litellm.get_model_info(
|
|
model=original_model,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
uses_embed_content = model_info.get("uses_embed_content", False)
|
|
except Exception:
|
|
uses_embed_content = False
|
|
|
|
endpoint = "embedContent" if uses_embed_content else "predict"
|
|
|
|
base_url = get_vertex_base_url(vertex_location)
|
|
|
|
if model.isdigit():
|
|
url = f"{base_url}/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
|
else:
|
|
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
|
|
|
return url, endpoint
|
|
|
|
|
|
def _get_vertex_url(
|
|
mode: all_gemini_url_modes,
|
|
model: str,
|
|
stream: Optional[bool],
|
|
vertex_project: Optional[str],
|
|
vertex_location: Optional[str],
|
|
vertex_api_version: Literal["v1", "v1beta1"],
|
|
) -> Tuple[str, str]:
|
|
url: Optional[str] = None
|
|
endpoint: Optional[str] = None
|
|
|
|
model = litellm.VertexGeminiConfig.get_model_for_vertex_ai_url(model=model)
|
|
|
|
if mode == "chat":
|
|
### SET RUNTIME ENDPOINT ###
|
|
endpoint = "generateContent"
|
|
base_url = get_vertex_base_url(vertex_location)
|
|
|
|
if stream is True:
|
|
endpoint = "streamGenerateContent"
|
|
|
|
# if model is only numeric chars then it's a fine tuned gemini model
|
|
# model = 4965075652664360960
|
|
# send to this url: url = f"{base_url}/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
|
if model.isdigit():
|
|
# It's a fine-tuned Gemini model - use endpoints/ path
|
|
url = f"{base_url}/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
|
else:
|
|
# Regular model - use publishers/google/models/ path
|
|
url = f"{base_url}/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
|
|
|
if stream is True:
|
|
url += "?alt=sse"
|
|
elif mode == "embedding":
|
|
return _get_embedding_url(
|
|
model=model,
|
|
vertex_project=vertex_project,
|
|
vertex_location=vertex_location,
|
|
vertex_api_version=vertex_api_version,
|
|
)
|
|
elif mode == "image_generation":
|
|
endpoint = "predict"
|
|
base_url = get_vertex_base_url(vertex_location)
|
|
if model.isdigit():
|
|
# Numeric model -> custom endpoint
|
|
url = f"{base_url}/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
|
else:
|
|
# Regular model -> publisher model
|
|
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
|
elif mode == "count_tokens":
|
|
endpoint = "countTokens"
|
|
base_url = get_vertex_base_url(vertex_location)
|
|
url = f"{base_url}/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
|
if not url or not endpoint:
|
|
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
|
return url, endpoint
|
|
|
|
|
|
def _get_gemini_url(
|
|
mode: all_gemini_url_modes,
|
|
model: str,
|
|
stream: Optional[bool],
|
|
gemini_api_key: Optional[str],
|
|
) -> Tuple[str, str]:
|
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
|
VertexGeminiConfig,
|
|
)
|
|
|
|
_gemini_model_name = "models/{}".format(model)
|
|
api_version = (
|
|
"v1alpha" if VertexGeminiConfig._is_gemini_3_or_newer(model) else "v1beta"
|
|
)
|
|
|
|
if mode == "chat":
|
|
endpoint = "generateContent"
|
|
if stream is True:
|
|
endpoint = "streamGenerateContent"
|
|
url = "https://generativelanguage.googleapis.com/{}/{}:{}?key={}&alt=sse".format(
|
|
api_version, _gemini_model_name, endpoint, gemini_api_key
|
|
)
|
|
else:
|
|
url = "https://generativelanguage.googleapis.com/{}/{}:{}?key={}".format(
|
|
api_version, _gemini_model_name, endpoint, gemini_api_key
|
|
)
|
|
elif mode == "embedding":
|
|
endpoint = "embedContent"
|
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
|
_gemini_model_name, endpoint, gemini_api_key
|
|
)
|
|
elif mode == "batch_embedding":
|
|
endpoint = "batchEmbedContents"
|
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
|
_gemini_model_name, endpoint, gemini_api_key
|
|
)
|
|
elif mode == "count_tokens":
|
|
endpoint = "countTokens"
|
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
|
_gemini_model_name, endpoint, gemini_api_key
|
|
)
|
|
elif mode == "image_generation":
|
|
raise ValueError(
|
|
"LiteLLM's `gemini/` route does not support image generation yet. Let us know if you need this feature by opening an issue at https://github.com/BerriAI/litellm/issues"
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported mode: {mode}")
|
|
|
|
return url, endpoint
|
|
|
|
|
|
def _check_text_in_content(parts: List[PartType]) -> bool:
|
|
"""
|
|
check that user_content has 'text' parameter.
|
|
- Known Vertex Error: Unable to submit request because it must have a text parameter.
|
|
- 'text' param needs to be present (empty strings are valid)
|
|
- Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
|
|
"""
|
|
has_text_param = False
|
|
for part in parts:
|
|
if "text" in part and part.get("text") is not None:
|
|
has_text_param = True
|
|
|
|
return has_text_param
|
|
|
|
|
|
def _fix_enum_empty_strings(schema, depth=0):
|
|
"""Fix empty strings in enum values by replacing them with None. Gemini doesn't accept empty strings in enums."""
|
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
|
raise ValueError(
|
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema."
|
|
)
|
|
|
|
if "enum" in schema and isinstance(schema["enum"], list):
|
|
schema["enum"] = [None if value == "" else value for value in schema["enum"]]
|
|
|
|
# Reuse existing recursion pattern from convert_anyof_null_to_nullable
|
|
properties = schema.get("properties", None)
|
|
if properties is not None:
|
|
for _, value in properties.items():
|
|
_fix_enum_empty_strings(value, depth=depth + 1)
|
|
|
|
items = schema.get("items", None)
|
|
if items is not None:
|
|
_fix_enum_empty_strings(items, depth=depth + 1)
|
|
|
|
|
|
def _fix_enum_types(schema, depth=0):
|
|
"""Remove `enum` fields when the schema type is not string.
|
|
|
|
Gemini / Vertex APIs only allow enums for string-typed fields. When an enum
|
|
is present on a non-string typed property (or when `anyOf` types do not
|
|
include a string type), remove the enum to avoid provider validation errors.
|
|
"""
|
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
|
raise ValueError(
|
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema."
|
|
)
|
|
|
|
if not isinstance(schema, dict):
|
|
return
|
|
|
|
# If enum exists but type is not string (and anyOf doesn't include string), drop enum
|
|
if "enum" in schema and isinstance(schema["enum"], list):
|
|
schema_type = schema.get("type")
|
|
keep_enum = False
|
|
if isinstance(schema_type, str) and schema_type.lower() == "string":
|
|
keep_enum = True
|
|
else:
|
|
anyof = schema.get("anyOf")
|
|
if isinstance(anyof, list):
|
|
for item in anyof:
|
|
if isinstance(item, dict):
|
|
item_type = item.get("type")
|
|
if isinstance(item_type, str) and item_type.lower() == "string":
|
|
keep_enum = True
|
|
break
|
|
|
|
if not keep_enum:
|
|
schema.pop("enum", None)
|
|
|
|
# Recurse into nested structures
|
|
properties = schema.get("properties", None)
|
|
if properties is not None:
|
|
for _, value in properties.items():
|
|
_fix_enum_types(value, depth=depth + 1)
|
|
|
|
items = schema.get("items", None)
|
|
if items is not None:
|
|
_fix_enum_types(items, depth=depth + 1)
|
|
|
|
anyof = schema.get("anyOf", None)
|
|
if anyof is not None and isinstance(anyof, list):
|
|
for item in anyof:
|
|
if isinstance(item, dict):
|
|
_fix_enum_types(item, depth=depth + 1)
|
|
|
|
|
|
def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False):
|
|
"""
|
|
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
|
|
|
Updates the input parameters, removing extraneous fields, adjusting types, unwinding $defs, and adding propertyOrdering if specified, returning the updated parameters.
|
|
|
|
Parameters:
|
|
parameters: dict - the json schema to build from
|
|
add_property_ordering: bool - whether to add propertyOrdering to the schema. This is only applicable to schemas for structured outputs. See
|
|
set_schema_property_ordering for more details.
|
|
Returns:
|
|
parameters: dict - the input parameters, modified in place
|
|
"""
|
|
# Get valid fields from Schema TypedDict
|
|
valid_schema_fields = set(get_type_hints(Schema).keys())
|
|
|
|
defs = parameters.pop("$defs", {})
|
|
# Expand $ref references in parameters using the definitions
|
|
# Note: We don't pre-flatten defs as that causes exponential memory growth
|
|
# with circular references (see issue #19098). unpack_defs handles nested
|
|
# refs recursively and correctly detects/skips circular references.
|
|
unpack_defs(parameters, defs)
|
|
|
|
# 5. Nullable fields:
|
|
# * https://github.com/pydantic/pydantic/issues/1270
|
|
# * https://stackoverflow.com/a/58841311
|
|
# * https://github.com/pydantic/pydantic/discussions/4872
|
|
convert_anyof_null_to_nullable(parameters)
|
|
|
|
_convert_schema_types(parameters)
|
|
|
|
# Handle empty strings in enum values - Gemini doesn't accept empty strings in enums
|
|
_fix_enum_empty_strings(parameters)
|
|
|
|
# Remove enums for non-string typed fields (Gemini requires enum only on strings)
|
|
_fix_enum_types(parameters)
|
|
|
|
# Handle empty items objects
|
|
process_items(parameters)
|
|
add_object_type(parameters)
|
|
# Postprocessing
|
|
# Filter out fields that don't exist in Schema
|
|
|
|
parameters = filter_schema_fields(parameters, valid_schema_fields)
|
|
|
|
if add_property_ordering:
|
|
set_schema_property_ordering(parameters)
|
|
|
|
return parameters
|
|
|
|
|
|
def _build_json_schema(parameters: dict) -> dict:
|
|
"""
|
|
Build a JSON Schema for use with Gemini's responseJsonSchema parameter.
|
|
|
|
Unlike _build_vertex_schema (used for responseSchema), this function:
|
|
- Does NOT convert types to uppercase (keeps standard JSON Schema format)
|
|
- Does NOT add propertyOrdering
|
|
- Does NOT filter fields (allows additionalProperties)
|
|
- Preserves $defs/$ref (Gemini 2.0+ supports JSON Schema references natively)
|
|
|
|
Parameters:
|
|
parameters: dict - the JSON schema to process
|
|
|
|
Returns:
|
|
dict - the processed schema in standard JSON Schema format
|
|
"""
|
|
# Gemini 2.0+ with responseJsonSchema accepts standard JSON Schema as-is,
|
|
# including $ref, $defs, anyOf, etc. No transformations needed — the
|
|
# OpenAPI-specific fixes (unpack_defs, add_object_type, convert_anyof, etc.)
|
|
# are only required for responseSchema (Gemini 1.5) and can break valid
|
|
# JSON Schema by adding conflicting fields to $ref nodes.
|
|
# See: https://blog.google/technology/developers/gemini-api-structured-outputs/
|
|
|
|
return parameters
|
|
|
|
|
|
def _filter_anyof_fields(schema_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
When anyof is present, only keep the anyof field and its contents - otherwise VertexAI will throw an error - https://github.com/BerriAI/litellm/issues/11164
|
|
Filter out other fields in the same dict.
|
|
|
|
E.g. {"anyOf": [{"type": "string"}, {"type": "null"}], "default": "test"} -> {"anyOf": [{"type": "string"}, {"type": "null"}]}
|
|
|
|
Case 2: If additional metadata is present, try to keep it
|
|
E.g. {"anyOf": [{"type": "string"}, {"type": "null"}], "default": "test", "title": "test"} -> {"anyOf": [{"type": "string", "title": "test"}, {"type": "null", "title": "test"}]}
|
|
"""
|
|
title = schema_dict.get("title", None)
|
|
description = schema_dict.get("description", None)
|
|
|
|
if isinstance(schema_dict, dict) and schema_dict.get("anyOf"):
|
|
any_of = schema_dict["anyOf"]
|
|
if (
|
|
(title or description)
|
|
and isinstance(any_of, list)
|
|
and all(isinstance(item, dict) for item in any_of)
|
|
):
|
|
for item in any_of:
|
|
if title:
|
|
item["title"] = title
|
|
if description:
|
|
item["description"] = description
|
|
return {"anyOf": any_of}
|
|
return schema_dict
|
|
|
|
|
|
def process_items(schema, depth=0):
|
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
|
raise ValueError(
|
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
|
)
|
|
if isinstance(schema, dict):
|
|
if "items" in schema and schema["items"] == {}:
|
|
schema["items"] = {"type": "object"}
|
|
for key, value in schema.items():
|
|
if isinstance(value, dict):
|
|
process_items(value, depth + 1)
|
|
elif isinstance(value, list):
|
|
for item in value:
|
|
if isinstance(item, dict):
|
|
process_items(item, depth + 1)
|
|
|
|
|
|
def set_schema_property_ordering(
|
|
schema: Dict[str, Any], depth: int = 0
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
vertex ai and generativeai apis order output of fields alphabetically, unless you specify the order.
|
|
python dicts retain order, so we just use that. Note that this field only applies to structured outputs, and not tools.
|
|
Function tools are not afflicted by the same alphabetical ordering issue, (the order of keys returned seems to be arbitrary, up to the model)
|
|
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.cachedContents#Schema.FIELDS.property_ordering
|
|
|
|
Args:
|
|
schema: The schema dictionary to process
|
|
depth: Current recursion depth to prevent infinite loops
|
|
"""
|
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
|
raise ValueError(
|
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
|
)
|
|
|
|
if "properties" in schema and isinstance(schema["properties"], dict):
|
|
# retain propertyOrdering as an escape hatch if user already specifies it
|
|
if "propertyOrdering" not in schema:
|
|
schema["propertyOrdering"] = [k for k, v in schema["properties"].items()]
|
|
for k, v in schema["properties"].items():
|
|
set_schema_property_ordering(v, depth + 1)
|
|
if "items" in schema:
|
|
set_schema_property_ordering(schema["items"], depth + 1)
|
|
return schema
|
|
|
|
|
|
def filter_schema_fields(
|
|
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Recursively filter a schema dictionary to keep only valid fields.
|
|
"""
|
|
if processed is None:
|
|
processed = set()
|
|
|
|
# Handle circular references
|
|
schema_id = id(schema_dict)
|
|
if schema_id in processed:
|
|
return schema_dict
|
|
processed.add(schema_id)
|
|
|
|
if not isinstance(schema_dict, dict):
|
|
return schema_dict
|
|
|
|
result = {}
|
|
schema_dict = _filter_anyof_fields(schema_dict)
|
|
for key, value in schema_dict.items():
|
|
if key not in valid_fields:
|
|
continue
|
|
|
|
if key == "properties" and isinstance(value, dict):
|
|
result[key] = {
|
|
k: filter_schema_fields(v, valid_fields, processed)
|
|
for k, v in value.items()
|
|
}
|
|
elif key == "format":
|
|
if value in {"enum", "date-time"}:
|
|
result[key] = value
|
|
else:
|
|
continue
|
|
elif key == "items" and isinstance(value, dict):
|
|
result[key] = filter_schema_fields(value, valid_fields, processed)
|
|
elif key == "anyOf" and isinstance(value, list):
|
|
result[key] = [
|
|
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
|
|
]
|
|
else:
|
|
result[key] = value
|
|
|
|
return result
|
|
|
|
|
|
def convert_anyof_null_to_nullable(schema, depth=0):
|
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
|
raise ValueError(
|
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
|
)
|
|
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
|
|
anyof = schema.get("anyOf", None)
|
|
if anyof is not None:
|
|
contains_null = False
|
|
for atype in anyof:
|
|
if isinstance(atype, dict) and atype.get("type") == "null":
|
|
# remove null type
|
|
anyof.remove(atype)
|
|
contains_null = True
|
|
elif "type" not in atype and len(atype) == 0:
|
|
# Handle empty object case
|
|
atype["type"] = "object"
|
|
|
|
if len(anyof) == 0:
|
|
# Edge case: response schema with only null type present is invalid in Vertex AI
|
|
raise ValueError(
|
|
"Invalid input: AnyOf schema with only null type is not supported. "
|
|
"Please provide a non-null type."
|
|
)
|
|
|
|
if contains_null:
|
|
# set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python
|
|
for atype in anyof:
|
|
# Remove items field if type is array and items is empty
|
|
if (
|
|
atype.get("type") == "array"
|
|
and "items" in atype
|
|
and not atype["items"]
|
|
):
|
|
atype.pop("items")
|
|
atype["nullable"] = True
|
|
|
|
properties = schema.get("properties", None)
|
|
if properties is not None:
|
|
for name, value in properties.items():
|
|
convert_anyof_null_to_nullable(value, depth=depth + 1)
|
|
|
|
items = schema.get("items", None)
|
|
if items is not None:
|
|
convert_anyof_null_to_nullable(items, depth=depth + 1)
|
|
|
|
|
|
def add_object_type(schema):
|
|
# Gemini requires all function parameters to be type OBJECT
|
|
# Handle case where schema has no properties and no type (e.g. tools with no arguments)
|
|
if (
|
|
"type" not in schema
|
|
and "anyOf" not in schema
|
|
and "oneOf" not in schema
|
|
and "allOf" not in schema
|
|
):
|
|
schema["type"] = "object"
|
|
|
|
properties = schema.get("properties", None)
|
|
if properties is not None:
|
|
if "required" in schema and schema["required"] is None:
|
|
schema.pop("required", None)
|
|
# Gemini doesn't accept empty properties for object types
|
|
# If properties is empty, remove it but keep type as object
|
|
if not properties:
|
|
schema.pop("properties", None)
|
|
schema.pop("required", None)
|
|
schema["type"] = "object"
|
|
else:
|
|
schema["type"] = "object"
|
|
for name, value in properties.items():
|
|
add_object_type(value)
|
|
|
|
items = schema.get("items", None)
|
|
if items is not None:
|
|
add_object_type(items)
|
|
|
|
for key in ["anyOf", "oneOf", "allOf"]:
|
|
values = schema.get(key, None)
|
|
if values is not None and isinstance(values, list):
|
|
for value in values:
|
|
if isinstance(value, dict):
|
|
add_object_type(value)
|
|
|
|
|
|
def strip_field(schema, field_name: str):
|
|
schema.pop(field_name, None)
|
|
|
|
properties = schema.get("properties", None)
|
|
if properties is not None:
|
|
for name, value in properties.items():
|
|
strip_field(value, field_name)
|
|
|
|
items = schema.get("items", None)
|
|
if items is not None:
|
|
strip_field(items, field_name)
|
|
|
|
|
|
def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int:
|
|
"""
|
|
Converts a Vertex AI datetime string to an OpenAI datetime integer
|
|
|
|
vertex_datetime: str = "2024-12-04T21:53:12.120184Z"
|
|
returns: int = 1722729192
|
|
"""
|
|
from datetime import datetime
|
|
|
|
# Parse the ISO format string to datetime object
|
|
dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
|
# Convert to Unix timestamp (seconds since epoch)
|
|
return int(dt.timestamp())
|
|
|
|
|
|
def _convert_schema_types(schema, depth=0):
|
|
"""
|
|
Convert type arrays and lowercase types for Vertex AI compatibility.
|
|
|
|
Transforms OpenAI-style schemas to Vertex AI format by converting type arrays
|
|
like ["string", "number"] to anyOf format and converting all types to uppercase.
|
|
"""
|
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
|
raise ValueError(
|
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
|
)
|
|
|
|
if not isinstance(schema, dict):
|
|
return
|
|
|
|
# Handle type field
|
|
if "type" in schema:
|
|
type_val = schema["type"]
|
|
if isinstance(type_val, list) and len(type_val) > 1:
|
|
# Convert type arrays to anyOf format
|
|
# Fields that are specific to object/array types and should move into anyOf
|
|
type_specific_fields = {
|
|
"properties",
|
|
"required",
|
|
"additionalProperties",
|
|
"items",
|
|
"minItems",
|
|
"maxItems",
|
|
"minProperties",
|
|
"maxProperties",
|
|
}
|
|
|
|
any_of: List[Dict[str, Any]] = []
|
|
for t in type_val:
|
|
if not isinstance(t, str):
|
|
continue
|
|
if t == "null":
|
|
# Keep null entry minimal so we can strip it later.
|
|
any_of.append({"type": "null"})
|
|
continue
|
|
|
|
# For object/array types, include type-specific fields
|
|
if t in ("object", "array"):
|
|
item_schema = {"type": t}
|
|
# Move type-specific fields into this anyOf item
|
|
for field in type_specific_fields:
|
|
if field in schema:
|
|
item_schema[field] = deepcopy(schema[field])
|
|
any_of.append(item_schema)
|
|
else:
|
|
# For primitive types, only include the type
|
|
any_of.append({"type": t})
|
|
|
|
# Remove type-specific fields from parent if we moved them into anyOf
|
|
has_object_or_array = any(
|
|
t in ("object", "array") for t in type_val if isinstance(t, str)
|
|
)
|
|
if has_object_or_array:
|
|
for field in type_specific_fields:
|
|
schema.pop(field, None)
|
|
|
|
schema["anyOf"] = any_of
|
|
schema.pop("type")
|
|
elif isinstance(type_val, list) and len(type_val) == 1:
|
|
schema["type"] = type_val[0]
|
|
elif isinstance(type_val, str):
|
|
schema["type"] = type_val
|
|
|
|
# Recursively process nested properties, items, and anyOf
|
|
for key in ["properties", "items", "anyOf"]:
|
|
if key in schema:
|
|
value = schema[key]
|
|
if key == "properties" and isinstance(value, dict):
|
|
for prop_schema in value.values():
|
|
_convert_schema_types(prop_schema, depth + 1)
|
|
elif key == "items":
|
|
_convert_schema_types(value, depth + 1)
|
|
elif key == "anyOf" and isinstance(value, list):
|
|
for anyof_schema in value:
|
|
_convert_schema_types(anyof_schema, depth + 1)
|
|
|
|
|
|
def get_vertex_project_id_from_url(url: str) -> Optional[str]:
|
|
"""
|
|
Get the vertex project id from the url
|
|
|
|
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
|
"""
|
|
match = re.search(r"/projects/([^/]+)", url)
|
|
return match.group(1) if match else None
|
|
|
|
|
|
def get_vertex_location_from_url(url: str) -> Optional[str]:
|
|
"""
|
|
Get the vertex location from the url
|
|
|
|
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
|
"""
|
|
match = re.search(r"/locations/([^/]+)", url)
|
|
return match.group(1) if match else None
|
|
|
|
|
|
def get_vertex_model_id_from_url(url: str) -> Optional[str]:
|
|
"""
|
|
Get the vertex model id from the url
|
|
|
|
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
|
"""
|
|
match = re.search(r"/models/([^:]+)", url)
|
|
return match.group(1) if match else None
|
|
|
|
|
|
def replace_project_and_location_in_route(
|
|
requested_route: str, vertex_project: str, vertex_location: str
|
|
) -> str:
|
|
"""
|
|
Replace project and location values in the route with the provided values
|
|
"""
|
|
# Replace project and location values while keeping route structure
|
|
modified_route = re.sub(
|
|
r"/projects/[^/]+/locations/[^/]+/",
|
|
f"/projects/{vertex_project}/locations/{vertex_location}/",
|
|
requested_route,
|
|
)
|
|
return modified_route
|
|
|
|
|
|
def construct_target_url(
|
|
base_url: str,
|
|
requested_route: str,
|
|
vertex_location: Optional[str],
|
|
vertex_project: Optional[str],
|
|
) -> httpx.URL:
|
|
"""
|
|
Allow user to specify their own project id / location.
|
|
|
|
If missing, use defaults
|
|
|
|
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
|
|
|
|
Constructed Url:
|
|
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
|
|
"""
|
|
|
|
new_base_url = httpx.URL(base_url)
|
|
if "locations" in requested_route: # contains the target project id + location
|
|
if vertex_project and vertex_location:
|
|
requested_route = replace_project_and_location_in_route(
|
|
requested_route, vertex_project, vertex_location
|
|
)
|
|
return new_base_url.copy_with(path=requested_route)
|
|
|
|
"""
|
|
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
|
- Add default project id
|
|
- Add default location
|
|
"""
|
|
vertex_version: Literal["v1", "v1beta1"] = "v1"
|
|
if "cachedContent" in requested_route:
|
|
vertex_version = "v1beta1"
|
|
|
|
# Check if the requested route starts with a version
|
|
# e.g. /v1beta1/publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
|
if requested_route.startswith("/v1/"):
|
|
vertex_version = "v1"
|
|
requested_route = requested_route.replace("/v1/", "/", 1)
|
|
elif requested_route.startswith("/v1beta1/"):
|
|
vertex_version = "v1beta1"
|
|
requested_route = requested_route.replace("/v1beta1/", "/", 1)
|
|
|
|
base_requested_route = "{}/projects/{}/locations/{}".format(
|
|
vertex_version, vertex_project, vertex_location
|
|
)
|
|
|
|
updated_requested_route = "/" + base_requested_route + requested_route
|
|
|
|
updated_url = new_base_url.copy_with(path=updated_requested_route)
|
|
return updated_url
|
|
|
|
|
|
class VertexAIModelInfo(BaseLLMModelInfo):
|
|
def get_token_counter(self) -> Optional[BaseTokenCounter]:
|
|
"""
|
|
Factory method to create a token counter for this provider.
|
|
|
|
Returns:
|
|
Optional TokenCounterInterface implementation for this provider,
|
|
or None if token counting is not supported.
|
|
"""
|
|
return VertexAITokenCounter()
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
raise NotImplementedError("Vertex AI models are not supported yet")
|
|
|
|
def get_models(
|
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
|
) -> List[str]:
|
|
"""
|
|
Returns a list of models supported by this provider.
|
|
"""
|
|
raise NotImplementedError("Vertex AI models are not supported yet")
|
|
|
|
@staticmethod
|
|
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
|
raise NotImplementedError("Vertex AI models are not supported yet")
|
|
|
|
@staticmethod
|
|
def get_api_base(
|
|
api_base: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
raise NotImplementedError("Vertex AI models are not supported yet")
|
|
|
|
@staticmethod
|
|
def get_base_model(model: str) -> Optional[str]:
|
|
"""
|
|
Returns the base model name from the given model name.
|
|
|
|
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
|
|
This function will return `anthropic.claude-3-opus-20240229-v1:0`
|
|
"""
|
|
raise NotImplementedError("Vertex AI models are not supported yet")
|
|
|
|
|
|
class VertexAITokenCounter(BaseTokenCounter):
|
|
"""Token counter implementation for Google AI Studio provider."""
|
|
|
|
def should_use_token_counting_api(
|
|
self,
|
|
custom_llm_provider: Optional[str] = None,
|
|
) -> bool:
|
|
from litellm.types.utils import LlmProviders
|
|
|
|
return custom_llm_provider == LlmProviders.VERTEX_AI.value
|
|
|
|
async def count_tokens(
|
|
self,
|
|
model_to_use: str,
|
|
messages: Optional[List[Dict[str, Any]]],
|
|
contents: Optional[List[Dict[str, Any]]],
|
|
deployment: Optional[Dict[str, Any]] = None,
|
|
request_model: str = "",
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
system: Optional[Any] = None,
|
|
) -> Optional[TokenCountResponse]:
|
|
import copy
|
|
|
|
from litellm.llms.vertex_ai.vertex_ai_partner_models.main import (
|
|
VertexAIPartnerModels,
|
|
)
|
|
|
|
deployment = deployment or {}
|
|
count_tokens_params_request = copy.deepcopy(
|
|
deployment.get("litellm_params", {})
|
|
)
|
|
|
|
# Check if this is a partner model (Claude, Mistral, etc.)
|
|
if VertexAIPartnerModels.is_vertex_partner_model(model_to_use):
|
|
# Use partner models token counter
|
|
partner_models_handler = VertexAIPartnerModels()
|
|
|
|
# Extract vertex-specific params from litellm_params
|
|
vertex_project = count_tokens_params_request.get(
|
|
"vertex_project"
|
|
) or count_tokens_params_request.get("vertex_ai_project")
|
|
|
|
vertex_location = count_tokens_params_request.get(
|
|
"vertex_location"
|
|
) or count_tokens_params_request.get("vertex_ai_location")
|
|
|
|
# Count tokens not available on global location: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/count-tokens
|
|
vertex_location = (
|
|
count_tokens_params_request.get("vertex_count_tokens_location")
|
|
or vertex_location
|
|
)
|
|
|
|
vertex_credentials = count_tokens_params_request.get(
|
|
"vertex_credentials"
|
|
) or count_tokens_params_request.get("vertex_ai_credentials")
|
|
|
|
result = await partner_models_handler.count_tokens(
|
|
model=model_to_use,
|
|
messages=messages or [],
|
|
litellm_params=count_tokens_params_request,
|
|
vertex_project=vertex_project,
|
|
vertex_location=vertex_location,
|
|
vertex_credentials=vertex_credentials,
|
|
)
|
|
|
|
if result is not None:
|
|
return TokenCountResponse(
|
|
total_tokens=result.get("input_tokens", 0),
|
|
request_model=request_model,
|
|
model_used=model_to_use,
|
|
tokenizer_type=result.get("tokenizer_used", ""),
|
|
original_response=result,
|
|
)
|
|
else:
|
|
# Use standard Vertex AI (Gemini) token counter
|
|
from litellm.llms.vertex_ai.count_tokens.handler import VertexAITokenCounter
|
|
|
|
count_tokens_params = {
|
|
"model": model_to_use,
|
|
"contents": contents,
|
|
}
|
|
count_tokens_params_request.update(count_tokens_params)
|
|
result = await VertexAITokenCounter().acount_tokens(
|
|
**count_tokens_params_request,
|
|
)
|
|
|
|
if result is not None:
|
|
return TokenCountResponse(
|
|
total_tokens=result.get("totalTokens", 0),
|
|
request_model=request_model,
|
|
model_used=model_to_use,
|
|
tokenizer_type=result.get("tokenizer_used", ""),
|
|
original_response=result,
|
|
)
|
|
|
|
return None
|