chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
"""
Vertex AI Agent Engine (Reasoning Engines) Provider
Supports Vertex AI Reasoning Engines via the :query and :streamQuery endpoints.
"""
from litellm.llms.vertex_ai.agent_engine.transformation import (
VertexAgentEngineConfig,
VertexAgentEngineError,
)
__all__ = ["VertexAgentEngineConfig", "VertexAgentEngineError"]

View File

@@ -0,0 +1,90 @@
"""
SSE Stream Iterator for Vertex AI Agent Engine.
Handles Server-Sent Events (SSE) streaming responses from Vertex AI Reasoning Engines.
"""
from typing import Any, Union
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.llms.openai import ChatCompletionUsageBlock
from litellm.types.utils import (
Delta,
GenericStreamingChunk,
ModelResponseStream,
StreamingChoices,
)
class VertexAgentEngineResponseIterator(BaseModelResponseIterator):
"""
Iterator for Vertex Agent Engine SSE streaming responses.
Uses BaseModelResponseIterator which handles sync/async iteration.
We just need to implement chunk_parser to parse Vertex Agent Engine response format.
"""
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
def chunk_parser(
self, chunk: dict
) -> Union[GenericStreamingChunk, ModelResponseStream]:
"""
Parse a Vertex Agent Engine response chunk into ModelResponseStream.
Vertex Agent Engine response format:
{
"content": {
"parts": [{"text": "..."}],
"role": "model"
},
"finish_reason": "STOP",
"usage_metadata": {
"prompt_token_count": 100,
"candidates_token_count": 50,
"total_token_count": 150
}
}
"""
# Extract text from content.parts
text = None
content = chunk.get("content", {})
parts = content.get("parts", [])
for part in parts:
if isinstance(part, dict) and "text" in part:
text = part["text"]
break
# Extract finish_reason
finish_reason = None
raw_finish_reason = chunk.get("finish_reason")
if raw_finish_reason == "STOP":
finish_reason = "stop"
elif raw_finish_reason:
finish_reason = raw_finish_reason.lower()
# Extract usage from usage_metadata
usage = None
usage_metadata = chunk.get("usage_metadata", {})
if usage_metadata:
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_metadata.get("prompt_token_count", 0),
completion_tokens=usage_metadata.get("candidates_token_count", 0),
total_tokens=usage_metadata.get("total_token_count", 0),
)
# Return ModelResponseStream (OpenAI-compatible chunk)
return ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=finish_reason,
index=0,
delta=Delta(
content=text,
role="assistant" if text else None,
),
)
],
usage=usage,
)

View File

@@ -0,0 +1,517 @@
"""
Transformation for Vertex AI Agent Engine (Reasoning Engines)
Handles the transformation between LiteLLM's OpenAI-compatible format and
Vertex AI Reasoning Engine's API format.
API Reference:
- :query endpoint - for session management (create, get, list, delete)
- :streamQuery endpoint - for actual queries (stream_query method)
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.vertex_ai.agent_engine.sse_iterator import (
VertexAgentEngineResponseIterator,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
HTTPHandler = Any
AsyncHTTPHandler = Any
CustomStreamWrapper = Any
class VertexAgentEngineError(BaseLLMException):
"""Exception for Vertex Agent Engine errors."""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(message=message, status_code=status_code)
class VertexAgentEngineConfig(BaseConfig, VertexBase):
"""
Configuration for Vertex AI Agent Engine (Reasoning Engines).
Model format: vertex_ai/agent_engine/<resource_id>
Where resource_id is the numeric ID of the reasoning engine.
"""
def __init__(self, **kwargs):
BaseConfig.__init__(self, **kwargs)
VertexBase.__init__(self)
def get_supported_openai_params(self, model: str) -> List[str]:
"""Vertex Agent Engine has limited OpenAI compatible params."""
return ["user"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""Map OpenAI params to Agent Engine params."""
# Map 'user' to 'user_id' for session management
if "user" in non_default_params:
optional_params["user_id"] = non_default_params["user"]
return optional_params
def _parse_model_string(self, model: str) -> Tuple[str, str]:
"""
Parse model string to extract resource ID.
Model format: agent_engine/<project_number>/<location>/<engine_id>
Or: agent_engine/<engine_id> (uses default project/location)
Returns: (resource_path, engine_id)
"""
# Remove 'agent_engine/' prefix if present
if model.startswith("agent_engine/"):
model = model[len("agent_engine/") :]
# Check if it's a full resource path
if model.startswith("projects/"):
# Full path: projects/123/locations/us-central1/reasoningEngines/456
return model, model.split("/")[-1]
# Just the engine ID
return model, model
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the request.
For Vertex Agent Engine:
- Non-streaming: :query endpoint (for session management)
- Streaming: :streamQuery endpoint (for actual queries)
"""
resource_path, engine_id = self._parse_model_string(model)
# Get project and location from litellm_params or environment
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
vertex_location = (
self.safe_get_vertex_ai_location(litellm_params) or "us-central1"
)
# Build the full resource path if only engine_id was provided
if not resource_path.startswith("projects/"):
if not vertex_project:
raise ValueError(
"vertex_project is required for Vertex Agent Engine. "
"Set via litellm_params['vertex_project'] or VERTEXAI_PROJECT env var."
)
resource_path = f"projects/{vertex_project}/locations/{vertex_location}/reasoningEngines/{engine_id}"
base_url = get_vertex_base_url(vertex_location)
# Always use :streamQuery endpoint for actual queries
# The :query endpoint only supports session management methods
# (create_session, get_session, list_sessions, delete_session, etc.)
endpoint = f"{base_url}/v1beta1/{resource_path}:streamQuery"
verbose_logger.debug(f"Vertex Agent Engine URL: {endpoint}")
return endpoint
def _get_auth_headers(
self,
optional_params: dict,
litellm_params: dict,
) -> Dict[str, str]:
"""Get authentication headers using Google Cloud credentials."""
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
# Get access token using VertexBase
access_token, project_id = self.get_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
verbose_logger.debug(
f"Vertex Agent Engine: Authenticated for project {project_id}"
)
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_user_id(self, optional_params: dict) -> str:
"""Get or generate user ID for session management."""
user_id = optional_params.get("user_id") or optional_params.get("user")
if user_id:
return user_id
# Generate a user ID
return f"litellm-user-{str(uuid.uuid4())[:8]}"
def _get_session_id(self, optional_params: dict) -> Optional[str]:
"""Get session ID if provided."""
return optional_params.get("session_id")
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request to Vertex Agent Engine format.
The API expects:
{
"class_method": "stream_query",
"input": {
"message": "...",
"user_id": "...",
"session_id": "..." (optional)
}
}
"""
# Use the last message content as the prompt
prompt = convert_content_list_to_str(messages[-1])
# Get user_id and session_id
user_id = self._get_user_id(optional_params)
session_id = self._get_session_id(optional_params)
# Build the input
input_data: Dict[str, Any] = {
"message": prompt,
"user_id": user_id,
}
if session_id:
input_data["session_id"] = session_id
# Build the request payload
# Note: stream_query is used for both streaming and non-streaming
# The difference is the endpoint (:streamQuery vs :query)
payload = {
"class_method": "stream_query",
"input": input_data,
}
verbose_logger.debug(f"Vertex Agent Engine payload: {payload}")
return payload
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""Validate environment and set up authentication headers."""
auth_headers = self._get_auth_headers(optional_params, litellm_params)
headers.update(auth_headers)
return headers
def _extract_text_from_response(self, response_data: dict) -> str:
"""Extract text content from the response."""
# Try to get from content.parts
content = response_data.get("content", {})
parts = content.get("parts", [])
for part in parts:
if "text" in part:
return part["text"]
# Try actions.state_delta
actions = response_data.get("actions", {})
state_delta = actions.get("state_delta", {})
for key, value in state_delta.items():
if isinstance(value, str) and value:
return value
return ""
def _calculate_usage(
self, model: str, messages: List[AllMessageValues], content: str
) -> Optional[Usage]:
"""Calculate token usage using LiteLLM's token counter."""
try:
from litellm.utils import token_counter
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
completion_tokens = token_counter(
model="gpt-3.5-turbo", text=content, count_response_tokens=True
)
total_tokens = prompt_tokens + completion_tokens
return Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
except Exception as e:
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
return None
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform Vertex Agent Engine response to LiteLLM ModelResponse format.
The response is a streaming SSE format even for non-streaming requests.
We need to collect all the chunks and extract the final response.
"""
try:
content_type = raw_response.headers.get("content-type", "").lower()
verbose_logger.debug(
f"Vertex Agent Engine response Content-Type: {content_type}"
)
# Parse the SSE response
response_text = raw_response.text
verbose_logger.debug(f"Response (first 500 chars): {response_text[:500]}")
# Extract content from SSE stream
content = ""
for line in response_text.strip().split("\n"):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
if isinstance(data, dict):
text = self._extract_text_from_response(data)
if text:
content = text # Use the last non-empty text
except json.JSONDecodeError:
continue
# Create the message
message = Message(content=content, role="assistant")
# Create choices
choice = Choices(finish_reason="stop", index=0, message=message)
# Update model response
model_response.choices = [choice]
model_response.model = model
# Calculate usage
calculated_usage = self._calculate_usage(model, messages, content)
if calculated_usage:
setattr(model_response, "usage", calculated_usage)
return model_response
except Exception as e:
verbose_logger.error(
f"Error processing Vertex Agent Engine response: {str(e)}"
)
raise VertexAgentEngineError(
message=f"Error processing response: {str(e)}",
status_code=raw_response.status_code,
)
def get_streaming_response(
self,
model: str,
raw_response: httpx.Response,
) -> VertexAgentEngineResponseIterator:
"""Return a streaming iterator for SSE responses."""
return VertexAgentEngineResponseIterator(
streaming_response=raw_response.iter_lines(),
sync_stream=True,
)
def get_sync_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, "AsyncHTTPHandler"]] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> "CustomStreamWrapper":
"""Get a CustomStreamWrapper for synchronous streaming."""
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
_get_httpx_client,
)
from litellm.utils import CustomStreamWrapper
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client(params={})
# Avoid logging sensitive api_base directly
verbose_logger.debug("Making sync streaming request to Vertex AI endpoint.")
# Make streaming request
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=True,
logging_obj=logging_obj,
)
if response.status_code != 200:
raise VertexAgentEngineError(
status_code=response.status_code, message=str(response.read())
)
# Create iterator for SSE stream
completion_stream = self.get_streaming_response(
model=model, raw_response=response
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return streaming_response
async def get_async_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional["AsyncHTTPHandler"] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> "CustomStreamWrapper":
"""Get a CustomStreamWrapper for asynchronous streaming."""
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper
if client is None or not isinstance(client, AsyncHTTPHandler):
client = get_async_httpx_client(
llm_provider=cast(Any, "vertex_ai"), params={}
)
# Avoid logging sensitive api_base directly
verbose_logger.debug("Making async streaming request to Vertex AI endpoint.")
# Make async streaming request
response = await client.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=True,
logging_obj=logging_obj,
)
if response.status_code != 200:
raise VertexAgentEngineError(
status_code=response.status_code, message=str(await response.aread())
)
# Create iterator for SSE stream (async)
completion_stream = VertexAgentEngineResponseIterator(
streaming_response=response.aiter_lines(),
sync_stream=False,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return streaming_response
@property
def has_custom_stream_wrapper(self) -> bool:
"""Indicates that this config has custom streaming support."""
return True
@property
def supports_stream_param_in_request_body(self) -> bool:
"""Agent Engine does not allow passing `stream` in the request body."""
return False
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return VertexAgentEngineError(status_code=status_code, message=error_message)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""Agent Engine always returns SSE streams, so we use real streaming."""
return False

View File

@@ -0,0 +1,52 @@
"""
Custom AWS Security Credentials Supplier for Vertex AI WIF.
Wraps boto3/botocore credentials so that google-auth can use them
for the AWS-to-GCP Workload Identity Federation token exchange
without hitting the EC2 instance metadata service.
Requires google-auth >= 2.29.0.
"""
from typing import Callable
from google.auth import aws
class AwsCredentialsSupplier(aws.AwsSecurityCredentialsSupplier):
"""
Supplies AWS credentials to google-auth's aws.Credentials for WIF
token exchange.
This bypasses the default metadata-based credential retrieval,
allowing WIF to work in environments where EC2 metadata is blocked.
Accepts a credentials_provider callable that is invoked on every
get_aws_security_credentials() call, so that refreshed/rotated
credentials are picked up automatically (important for temporary
STS tokens).
"""
def __init__(self, credentials_provider: Callable, aws_region: str):
"""
Args:
credentials_provider: A zero-arg callable that returns a
botocore.credentials.Credentials object (with access_key,
secret_key, and token attributes).
aws_region: The AWS region string (e.g. "us-east-1").
"""
self._credentials_provider = credentials_provider
self._region = aws_region
def get_aws_security_credentials(self, context, request):
"""Return current AWS credentials for the GCP token exchange."""
current = self._credentials_provider()
return aws.AwsSecurityCredentials(
access_key_id=current.access_key,
secret_access_key=current.secret_key,
session_token=current.token,
)
def get_aws_region(self, context, request):
"""Return the AWS region for credential verification."""
return self._region

View File

@@ -0,0 +1,6 @@
# Vertex AI Batch Prediction Jobs
Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini

View File

@@ -0,0 +1,378 @@
import json
from typing import Any, Coroutine, Dict, Optional, Union
import httpx
import litellm
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.openai import CreateBatchRequest
from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES,
VertexAIBatchPredictionJob,
)
from litellm.types.utils import LiteLLMBatch
from .transformation import VertexAIBatchTransformation
class VertexAIBatchPrediction(VertexLLM):
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gcs_bucket_name = gcs_bucket_name
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_api_base = self.create_vertex_batch_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=None,
auth_header=None,
url=default_api_base,
model=None,
vertex_project=vertex_project or project_id,
vertex_location=vertex_location or "us-central1",
vertex_api_version="v1",
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
vertex_batch_request: VertexAIBatchPredictionJob = VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
request=create_batch_data
)
if _is_async is True:
return self._async_create_batch(
vertex_batch_request=vertex_batch_request,
api_base=api_base,
headers=headers,
)
response = sync_handler.post(
url=api_base,
headers=headers,
data=json.dumps(vertex_batch_request),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
async def _async_create_batch(
self,
vertex_batch_request: VertexAIBatchPredictionJob,
api_base: str,
headers: Dict[str, str],
) -> LiteLLMBatch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
try:
response = await client.post(
url=api_base,
headers=headers,
data=json.dumps(vertex_batch_request),
)
except httpx.HTTPStatusError as e:
error_body = e.response.text
litellm.verbose_logger.error(
"Vertex AI batch create failed: status=%s, body=%s",
e.response.status_code,
error_body[:1000],
)
raise
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
def create_vertex_batch_url(
self,
vertex_location: str,
vertex_project: str,
) -> str:
"""Return the base url for the vertex garden models"""
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
def retrieve_batch(
self,
_is_async: bool,
batch_id: str,
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
logging_obj: Optional[Any] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_api_base = self.create_vertex_batch_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
# Append batch_id to the URL
default_api_base = f"{default_api_base}/{batch_id}"
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=None,
auth_header=None,
url=default_api_base,
model=None,
vertex_project=vertex_project or project_id,
vertex_location=vertex_location or "us-central1",
vertex_api_version="v1",
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
if _is_async is True:
return self._async_retrieve_batch(
api_base=api_base,
headers=headers,
logging_obj=logging_obj,
)
# Log the request using logging_obj if available
if logging_obj is not None:
from litellm.litellm_core_utils.litellm_logging import Logging
if isinstance(logging_obj, Logging):
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": {},
"api_base": api_base,
"headers": headers,
"request_str": (
f"\nGET Request Sent from LiteLLM:\n"
f"curl -X GET \\\n"
f"{api_base} \\\n"
f"-H 'Authorization: Bearer ***REDACTED***' \\\n"
f"-H 'Content-Type: application/json; charset=utf-8'\n"
),
},
)
response = sync_handler.get(
url=api_base,
headers=headers,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
async def _async_retrieve_batch(
self,
api_base: str,
headers: Dict[str, str],
logging_obj: Optional[Any] = None,
) -> LiteLLMBatch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
# Log the request using logging_obj if available
if logging_obj is not None:
from litellm.litellm_core_utils.litellm_logging import Logging
if isinstance(logging_obj, Logging):
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": {},
"api_base": api_base,
"headers": headers,
"request_str": (
f"\nGET Request Sent from LiteLLM:\n"
f"curl -X GET \\\n"
f"{api_base} \\\n"
f"-H 'Authorization: Bearer ***REDACTED***' \\\n"
f"-H 'Content-Type: application/json; charset=utf-8'\n"
),
},
)
response = await client.get(
url=api_base,
headers=headers,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
def list_batches(
self,
_is_async: bool,
after: Optional[str],
limit: Optional[int],
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
):
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_api_base = self.create_vertex_batch_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=None,
auth_header=None,
url=default_api_base,
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
params: Dict[str, Any] = {}
if limit is not None:
params["pageSize"] = str(limit)
if after is not None:
params["pageToken"] = after
if _is_async is True:
return self._async_list_batches(
api_base=api_base,
headers=headers,
params=params,
)
response = sync_handler.get(
url=api_base,
headers=headers,
params=params,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
response=_json_response
)
return vertex_batch_response
async def _async_list_batches(
self,
api_base: str,
headers: Dict[str, str],
params: Dict[str, Any],
):
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
response = await client.get(
url=api_base,
headers=headers,
params=params,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
response=_json_response
)
return vertex_batch_response

View File

@@ -0,0 +1,227 @@
from litellm._uuid import uuid
from typing import Any, Dict
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
from litellm.types.llms.vertex_ai import *
from litellm.types.utils import LiteLLMBatch
class VertexAIBatchTransformation:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
"""
@classmethod
def transform_openai_batch_request_to_vertex_ai_batch_request(
cls,
request: CreateBatchRequest,
) -> VertexAIBatchPredictionJob:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
"""
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
input_file_id = request.get("input_file_id")
if input_file_id is None:
raise ValueError("input_file_id is required, but not provided")
input_config: InputConfig = InputConfig(
gcsSource=GcsSource(uris=[input_file_id]), instancesFormat="jsonl"
)
model: str = cls._get_model_from_gcs_file(input_file_id)
output_config: OutputConfig = OutputConfig(
predictionsFormat="jsonl",
gcsDestination=GcsDestination(
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
),
)
return VertexAIBatchPredictionJob(
inputConfig=input_config,
outputConfig=output_config,
model=model,
displayName=request_display_name,
)
@classmethod
def transform_vertex_ai_batch_response_to_openai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> LiteLLMBatch:
return LiteLLMBatch(
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
completion_window="24hrs",
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response.get("createTime", "")
),
endpoint="",
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
response
),
object="batch",
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
response
),
)
@classmethod
def transform_vertex_ai_batch_list_response_to_openai_list_response(
cls, response: Dict[str, Any]
) -> Dict[str, Any]:
"""
Transforms Vertex AI batch list response into OpenAI-compatible list response.
"""
batch_jobs = response.get("batchPredictionJobs", []) or []
data = [
cls.transform_vertex_ai_batch_response_to_openai_batch_response(job)
for job in batch_jobs
]
first_id = data[0].id if len(data) > 0 else None
last_id = data[-1].id if len(data) > 0 else None
next_page_token = response.get("nextPageToken")
return {
"object": "list",
"data": data,
"first_id": first_id,
"last_id": last_id,
"has_more": bool(next_page_token),
"next_page_token": next_page_token,
}
@classmethod
def _get_batch_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the batch id from the Vertex AI Batch response safely
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
returns: `3814889423749775360`
"""
_name = response.get("name", "")
if not _name:
return ""
# Split by '/' and get the last part if it exists
parts = _name.split("/")
return parts[-1] if parts else _name
@classmethod
def _get_input_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the input file id from the Vertex AI Batch response
"""
input_file_id: str = ""
input_config = response.get("inputConfig")
if input_config is None:
return input_file_id
gcs_source = input_config.get("gcsSource")
if gcs_source is None:
return input_file_id
uris = gcs_source.get("uris", "")
if len(uris) == 0:
return input_file_id
return uris[0]
@classmethod
def _get_output_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the output file id from the Vertex AI Batch response
"""
output_file_id: str = (
response.get("outputInfo", OutputInfo()).get("gcsOutputDirectory", "")
+ "/predictions.jsonl"
)
if output_file_id != "/predictions.jsonl":
return output_file_id
output_config = response.get("outputConfig")
if output_config is None:
return output_file_id
gcs_destination = output_config.get("gcsDestination")
if gcs_destination is None:
return output_file_id
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
return output_uri_prefix
@classmethod
def _get_batch_job_status_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> BatchJobStatus:
"""
Gets the batch job status from the Vertex AI Batch response
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
"""
state_mapping: Dict[str, BatchJobStatus] = {
"JOB_STATE_UNSPECIFIED": "failed",
"JOB_STATE_QUEUED": "validating",
"JOB_STATE_PENDING": "validating",
"JOB_STATE_RUNNING": "in_progress",
"JOB_STATE_SUCCEEDED": "completed",
"JOB_STATE_FAILED": "failed",
"JOB_STATE_CANCELLING": "cancelling",
"JOB_STATE_CANCELLED": "cancelled",
"JOB_STATE_PAUSED": "in_progress",
"JOB_STATE_EXPIRED": "expired",
"JOB_STATE_UPDATING": "in_progress",
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
}
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
return state_mapping[vertex_state]
@classmethod
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
"""
Gets the gcs uri prefix from the input file id
Example:
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket"
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket/batches"
"""
# Split the path and remove the filename
path_parts = input_file_id.rsplit("/", 1)
return path_parts[0]
@classmethod
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
"""
Extracts the model from the gcs file uri
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
Why?
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
returns: "publishers/google/models/gemini-1.5-flash-001"
"""
from urllib.parse import unquote
decoded_uri = unquote(gcs_file_uri)
model_path = decoded_uri.split("publishers/")[1]
parts = model_path.split("/")
model = f"publishers/{'/'.join(parts[:3])}"
return model

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,198 @@
"""
Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
import re
from typing import List, Optional, Tuple, Literal
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody
from litellm.utils import is_cached_message
from ..common_utils import get_supports_system_message
from ..gemini.transformation import (
_gemini_convert_messages_with_history,
_transform_system_message,
)
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.
Args:
filtered_messages: List of tuples containing (index, message) pairs
Returns:
int: The array index where the first continuous sequence ends
"""
if not filtered_messages:
return -1
if len(filtered_messages) == 1:
return 0
current_value = filtered_messages[0][0]
# Search forward through the array indices
for i in range(1, len(filtered_messages)):
if filtered_messages[i][0] != current_value + 1:
return i - 1
current_value = filtered_messages[i][0]
# If we made it through the whole list, return the last index
return len(filtered_messages) - 1
def extract_ttl_from_cached_messages(messages: List[AllMessageValues]) -> Optional[str]:
"""
Extract TTL from cached messages. Returns the first valid TTL found.
Args:
messages: List of messages to extract TTL from
Returns:
Optional[str]: TTL string in format "3600s" or None if not found/invalid
"""
for message in messages:
if not is_cached_message(message):
continue
content = message.get("content")
if not content or isinstance(content, str):
continue
for content_item in content:
# Type check to ensure content_item is a dictionary before calling .get()
if not isinstance(content_item, dict):
continue
cache_control = content_item.get("cache_control")
if not cache_control or not isinstance(cache_control, dict):
continue
if cache_control.get("type") != "ephemeral":
continue
ttl = cache_control.get("ttl")
if ttl and _is_valid_ttl_format(ttl):
return str(ttl)
return None
def _is_valid_ttl_format(ttl: str) -> bool:
"""
Validate TTL format. Should be a string ending with 's' for seconds.
Examples: "3600s", "7200s", "1.5s"
Args:
ttl: TTL string to validate
Returns:
bool: True if valid format, False otherwise
"""
if not isinstance(ttl, str):
return False
# TTL should end with 's' and contain a valid number before it
pattern = r"^([0-9]*\.?[0-9]+)s$"
match = re.match(pattern, ttl)
if not match:
return False
try:
# Ensure the numeric part is valid and positive
numeric_part = float(match.group(1))
return numeric_part > 0
except ValueError:
return False
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
"""
Returns separated cached and non-cached messages.
Args:
messages: List of messages to be separated.
Returns:
Tuple containing:
- cached_messages: List of cached messages.
- non_cached_messages: List of non-cached messages.
"""
cached_messages: List[AllMessageValues] = []
non_cached_messages: List[AllMessageValues] = []
# Extract cached messages and their indices
filtered_messages: List[Tuple[int, AllMessageValues]] = []
for idx, message in enumerate(messages):
if is_cached_message(message=message):
filtered_messages.append((idx, message))
# Validate only one block of continuous cached messages
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
# Separate messages based on the block of cached messages
if filtered_messages and last_continuous_block_idx is not None:
first_cached_idx = filtered_messages[0][0]
last_cached_idx = filtered_messages[last_continuous_block_idx][0]
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
non_cached_messages = (
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
)
else:
non_cached_messages = messages
return cached_messages, non_cached_messages
def transform_openai_messages_to_gemini_context_caching(
model: str,
messages: List[AllMessageValues],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
cache_key: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
) -> CachedContentRequestBody:
# Extract TTL from cached messages BEFORE system message transformation
ttl = extract_ttl_from_cached_messages(messages)
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider=custom_llm_provider
)
transformed_system_messages, new_messages = _transform_system_message(
supports_system_message=supports_system_message, messages=messages
)
transformed_messages = _gemini_convert_messages_with_history(
messages=new_messages, model=model
)
model_name = "models/{}".format(model)
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
model_name = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/{model_name}"
data = CachedContentRequestBody(
contents=transformed_messages,
model=model_name,
displayName=cache_key,
)
# Add TTL if present and valid
if ttl:
data["ttl"] = ttl
if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages
return data

View File

@@ -0,0 +1,578 @@
from typing import List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm.caching.caching import Cache, LiteLLMCacheType
from litellm.constants import MINIMUM_PROMPT_CACHE_TOKEN_COUNT
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm._logging import verbose_logger
from litellm.llms.openai.openai import AllMessageValues
from litellm.utils import is_prompt_caching_valid_prompt
from litellm.types.llms.vertex_ai import (
CachedContentListAllResponseBody,
VertexAICachedContentResponseObject,
)
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
from .transformation import (
separate_cached_messages,
transform_openai_messages_to_gemini_context_caching,
)
local_cache_obj = Cache(
type=LiteLLMCacheType.LOCAL
) # only used for calling 'get_cache_key' function
MAX_PAGINATION_PAGES = 100 # Reasonable upper bound for pagination
class ContextCachingEndpoints(VertexBase):
"""
Covers context caching endpoints for Vertex AI + Google AI Studio
v0: covers Google AI Studio
"""
def __init__(self) -> None:
pass
def _get_token_and_url_context_caching(
self,
gemini_api_key: Optional[str],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
auth_header = None
endpoint = "cachedContents"
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
endpoint, gemini_api_key
)
elif custom_llm_provider == "vertex_ai":
auth_header = vertex_auth_header
endpoint = "cachedContents"
if vertex_location == "global":
url = f"https://aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
else:
auth_header = vertex_auth_header
endpoint = "cachedContents"
if vertex_location == "global":
url = f"https://aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
return self._check_custom_proxy(
api_base=api_base,
custom_llm_provider=custom_llm_provider,
gemini_api_key=gemini_api_key,
endpoint=endpoint,
stream=None,
auth_header=auth_header,
url=url,
model=None,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version="v1beta1"
if custom_llm_provider == "vertex_ai_beta"
else "v1",
)
def check_cache(
self,
cache_key: str,
client: HTTPHandler,
headers: dict,
api_key: str,
api_base: Optional[str],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> Optional[str]:
"""
Checks if content already cached.
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
Returns
- cached_content_name - str - cached content name stored on google. (if found.)
OR
- None
"""
_, base_url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
page_token: Optional[str] = None
# Iterate through all pages
for _ in range(MAX_PAGINATION_PAGES):
# Build URL with pagination token if present
if page_token:
separator = "&" if "?" in base_url else "?"
url = f"{base_url}{separator}pageToken={page_token}"
else:
url = base_url
try:
## LOGGING
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": {},
"api_base": url,
"headers": headers,
},
)
resp = client.get(url=url, headers=headers)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
return None
raise VertexAIError(
status_code=e.response.status_code, message=e.response.text
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
raw_response = resp.json()
logging_obj.post_call(original_response=raw_response)
if "cachedContents" not in raw_response:
return None
all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
# Check current page for matching cache_key
for cached_item in all_cached_items["cachedContents"]:
display_name = cached_item.get("displayName")
if display_name is not None and display_name == cache_key:
return cached_item.get("name")
# Check if there are more pages
page_token = all_cached_items.get("nextPageToken")
if not page_token:
# No more pages, cache not found
break
return None
async def async_check_cache(
self,
cache_key: str,
client: AsyncHTTPHandler,
headers: dict,
api_key: str,
api_base: Optional[str],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> Optional[str]:
"""
Checks if content already cached.
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
Returns
- cached_content_name - str - cached content name stored on google. (if found.)
OR
- None
"""
_, base_url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
page_token: Optional[str] = None
# Iterate through all pages
for _ in range(MAX_PAGINATION_PAGES):
# Build URL with pagination token if present
if page_token:
separator = "&" if "?" in base_url else "?"
url = f"{base_url}{separator}pageToken={page_token}"
else:
url = base_url
try:
## LOGGING
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": {},
"api_base": url,
"headers": headers,
},
)
resp = await client.get(url=url, headers=headers)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
return None
raise VertexAIError(
status_code=e.response.status_code, message=e.response.text
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
raw_response = resp.json()
logging_obj.post_call(original_response=raw_response)
if "cachedContents" not in raw_response:
return None
all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
# Check current page for matching cache_key
for cached_item in all_cached_items["cachedContents"]:
display_name = cached_item.get("displayName")
if display_name is not None and display_name == cache_key:
return cached_item.get("name")
# Check if there are more pages
page_token = all_cached_items.get("nextPageToken")
if not page_token:
# No more pages, cache not found
break
return None
def check_and_create_cache(
self,
messages: List[AllMessageValues], # receives openai format messages
optional_params: dict, # cache the tools if present, in case cache content exists in messages
api_key: str,
api_base: Optional[str],
model: str,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
extra_headers: Optional[dict] = None,
cached_content: Optional[str] = None,
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
"""
Receives
- messages: List of dict - messages in the openai format
Returns
- messages - List[dict] - filtered list of messages in the openai format.
- cached_content - str - the cache content id, to be passed in the gemini request body
Follows - https://ai.google.dev/api/caching#request-body
"""
if cached_content is not None:
return messages, optional_params, cached_content
cached_messages, non_cached_messages = separate_cached_messages(
messages=messages
)
if len(cached_messages) == 0:
return messages, optional_params, None
# Gemini requires a minimum of 1024 tokens for context caching.
# Skip caching if the cached content is too small to avoid API errors.
if not is_prompt_caching_valid_prompt(
model=model,
messages=cached_messages,
custom_llm_provider=custom_llm_provider,
):
verbose_logger.debug(
"Vertex AI context caching: cached content is below minimum token "
"count (%d). Skipping context caching.",
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
)
return messages, optional_params, None
tools = optional_params.pop("tools", None)
## AUTHORIZATION ##
token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
headers = {
"Content-Type": "application/json",
}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
if extra_headers is not None:
headers.update(extra_headers)
if client is None or not isinstance(client, HTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
## CHECK IF CACHED ALREADY
generated_cache_key = local_cache_obj.get_cache_key(
messages=cached_messages, tools=tools, model=model
)
google_cache_name = self.check_cache(
cache_key=generated_cache_key,
client=client,
headers=headers,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
if google_cache_name:
return non_cached_messages, optional_params, google_cache_name
## TRANSFORM REQUEST
cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching(
model=model,
messages=cached_messages,
cache_key=generated_cache_key,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
)
)
cached_content_request_body["tools"] = tools
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": cached_content_request_body,
"api_base": url,
"headers": headers,
},
)
try:
response = client.post(
url=url, headers=headers, json=cached_content_request_body # type: ignore
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
raw_response_cached = response.json()
cached_content_response_obj = VertexAICachedContentResponseObject(
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
)
return (
non_cached_messages,
optional_params,
cached_content_response_obj["name"],
)
async def async_check_and_create_cache(
self,
messages: List[AllMessageValues], # receives openai format messages
optional_params: dict, # cache the tools if present, in case cache content exists in messages
api_key: str,
api_base: Optional[str],
model: str,
client: Optional[AsyncHTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
extra_headers: Optional[dict] = None,
cached_content: Optional[str] = None,
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
"""
Receives
- messages: List of dict - messages in the openai format
Returns
- messages - List[dict] - filtered list of messages in the openai format.
- cached_content - str - the cache content id, to be passed in the gemini request body
Follows - https://ai.google.dev/api/caching#request-body
"""
if cached_content is not None:
return messages, optional_params, cached_content
cached_messages, non_cached_messages = separate_cached_messages(
messages=messages
)
if len(cached_messages) == 0:
return messages, optional_params, None
# Gemini requires a minimum of 1024 tokens for context caching.
# Skip caching if the cached content is too small to avoid API errors.
if not is_prompt_caching_valid_prompt(
model=model,
messages=cached_messages,
custom_llm_provider=custom_llm_provider,
):
verbose_logger.debug(
"Vertex AI context caching: cached content is below minimum token "
"count (%d). Skipping context caching.",
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
)
return messages, optional_params, None
tools = optional_params.pop("tools", None)
## AUTHORIZATION ##
token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
headers = {
"Content-Type": "application/json",
}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
if extra_headers is not None:
headers.update(extra_headers)
if client is None or not isinstance(client, AsyncHTTPHandler):
client = get_async_httpx_client(
params={"timeout": timeout}, llm_provider=litellm.LlmProviders.VERTEX_AI
)
else:
client = client
## CHECK IF CACHED ALREADY
generated_cache_key = local_cache_obj.get_cache_key(
messages=cached_messages, tools=tools, model=model
)
google_cache_name = await self.async_check_cache(
cache_key=generated_cache_key,
client=client,
headers=headers,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
if google_cache_name:
return non_cached_messages, optional_params, google_cache_name
## TRANSFORM REQUEST
cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching(
model=model,
messages=cached_messages,
cache_key=generated_cache_key,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
)
)
cached_content_request_body["tools"] = tools
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": cached_content_request_body,
"api_base": url,
"headers": headers,
},
)
try:
response = await client.post(
url=url, headers=headers, json=cached_content_request_body # type: ignore
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
raw_response_cached = response.json()
cached_content_response_obj = VertexAICachedContentResponseObject(
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
)
return (
non_cached_messages,
optional_params,
cached_content_response_obj["name"],
)
def get_cache(self):
pass
async def async_get_cache(self):
pass

View File

@@ -0,0 +1,273 @@
# What is this?
## Cost calculation for Google AI Studio / Vertex AI models
from typing import Literal, Optional, Tuple, Union
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import (
_is_above_128k,
generic_cost_per_token,
)
from litellm.types.utils import ModelInfo, Usage
"""
Gemini pricing covers:
- token
- image
- audio
- video
"""
"""
Vertex AI -> character based pricing
Google AI Studio -> token based pricing
"""
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro", "gemini-2"]
def cost_router(
model: str,
custom_llm_provider: str,
call_type: Union[Literal["embedding", "aembedding"], str],
) -> Literal["cost_per_character", "cost_per_token"]:
"""
Route the cost calc to the right place, based on model/call_type/etc.
Returns
- str, the specific google cost calc function it should route to.
"""
if custom_llm_provider == "vertex_ai" and (
"claude" in model
or "llama" in model
or "mistral" in model
or "jamba" in model
or "codestral" in model
or "gemma" in model
):
return "cost_per_token"
elif custom_llm_provider == "vertex_ai" and (
call_type == "embedding" or call_type == "aembedding"
):
return "cost_per_token"
elif custom_llm_provider == "vertex_ai" and ("gemini-2" in model):
return "cost_per_token"
return "cost_per_character"
def cost_per_character(
model: str,
custom_llm_provider: str,
usage: Usage,
prompt_characters: Optional[float] = None,
completion_characters: Optional[float] = None,
) -> Tuple[float, float]:
"""
Calculates the cost per character for a given VertexAI model, input messages, and response object.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, "vertex_ai-*"
- prompt_characters: float, the number of input characters
- completion_characters: float, the number of output characters
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
Raises:
Exception if model requires >128k pricing, but model cost not mapped
"""
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
## GET MODEL INFO
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
## CALCULATE INPUT COST
if prompt_characters is None:
prompt_cost, _ = cost_per_token(
model=model,
custom_llm_provider=custom_llm_provider,
usage=usage,
)
else:
try:
if (
_is_above_128k(tokens=prompt_characters * 4) # 1 token = 4 char
and model not in models_without_dynamic_pricing
):
## check if character pricing, else default to token pricing
assert (
"input_cost_per_character_above_128k_tokens" in model_info
and model_info["input_cost_per_character_above_128k_tokens"]
is not None
), "model info for model={} does not have 'input_cost_per_character_above_128k_tokens'-pricing for > 128k tokens\nmodel_info={}".format(
model, model_info
)
prompt_cost = (
prompt_characters
* model_info["input_cost_per_character_above_128k_tokens"]
)
else:
assert (
"input_cost_per_character" in model_info
and model_info["input_cost_per_character"] is not None
), "model info for model={} does not have 'input_cost_per_character'-pricing\nmodel_info={}".format(
model, model_info
)
prompt_cost = prompt_characters * model_info["input_cost_per_character"]
except Exception as e:
verbose_logger.debug(
"litellm.litellm_core_utils.llm_cost_calc.google.py::cost_per_character(): Exception occured - {}\nDefaulting to None".format(
str(e)
)
)
prompt_cost, _ = cost_per_token(
model=model,
custom_llm_provider=custom_llm_provider,
usage=usage,
)
## CALCULATE OUTPUT COST
if completion_characters is None:
_, completion_cost = cost_per_token(
model=model,
custom_llm_provider=custom_llm_provider,
usage=usage,
)
else:
completion_tokens = usage.completion_tokens
try:
if (
_is_above_128k(tokens=completion_characters * 4) # 1 token = 4 char
and model not in models_without_dynamic_pricing
):
assert (
"output_cost_per_character_above_128k_tokens" in model_info
and model_info["output_cost_per_character_above_128k_tokens"]
is not None
), "model info for model={} does not have 'output_cost_per_character_above_128k_tokens' pricing\nmodel_info={}".format(
model, model_info
)
completion_cost = (
completion_tokens
* model_info["output_cost_per_character_above_128k_tokens"]
)
else:
assert (
"output_cost_per_character" in model_info
and model_info["output_cost_per_character"] is not None
), "model info for model={} does not have 'output_cost_per_character'-pricing\nmodel_info={}".format(
model, model_info
)
completion_cost = (
completion_characters * model_info["output_cost_per_character"]
)
except Exception as e:
verbose_logger.debug(
"litellm.litellm_core_utils.llm_cost_calc.google.py::cost_per_character(): Exception occured - {}\nDefaulting to None".format(
str(e)
)
)
_, completion_cost = cost_per_token(
model=model,
custom_llm_provider=custom_llm_provider,
usage=usage,
)
return prompt_cost, completion_cost
def _handle_128k_pricing(
model_info: ModelInfo,
usage: Usage,
) -> Tuple[float, float]:
## CALCULATE INPUT COST
input_cost_per_token_above_128k_tokens = model_info.get(
"input_cost_per_token_above_128k_tokens"
)
output_cost_per_token_above_128k_tokens = model_info.get(
"output_cost_per_token_above_128k_tokens"
)
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
if (
_is_above_128k(tokens=prompt_tokens)
and input_cost_per_token_above_128k_tokens is not None
):
prompt_cost = prompt_tokens * input_cost_per_token_above_128k_tokens
else:
prompt_cost = prompt_tokens * model_info["input_cost_per_token"]
## CALCULATE OUTPUT COST
output_cost_per_token_above_128k_tokens = model_info.get(
"output_cost_per_token_above_128k_tokens"
)
if (
_is_above_128k(tokens=completion_tokens)
and output_cost_per_token_above_128k_tokens is not None
):
completion_cost = completion_tokens * output_cost_per_token_above_128k_tokens
else:
completion_cost = completion_tokens * model_info["output_cost_per_token"]
return prompt_cost, completion_cost
def cost_per_token(
model: str,
custom_llm_provider: str,
usage: Usage,
service_tier: Optional[str] = None,
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, either "vertex_ai-*" or "gemini"
- prompt_tokens: float, the number of input tokens
- completion_tokens: float, the number of output tokens
- service_tier: optional tier derived from Gemini trafficType
("priority" for ON_DEMAND_PRIORITY, "flex" for FLEX/batch).
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
Raises:
Exception if model requires >128k pricing, but model cost not mapped
"""
## GET MODEL INFO
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
## HANDLE 128k+ PRICING
input_cost_per_token_above_128k_tokens = model_info.get(
"input_cost_per_token_above_128k_tokens"
)
output_cost_per_token_above_128k_tokens = model_info.get(
"output_cost_per_token_above_128k_tokens"
)
if (
input_cost_per_token_above_128k_tokens is not None
or output_cost_per_token_above_128k_tokens is not None
):
return _handle_128k_pricing(
model_info=model_info,
usage=usage,
)
return generic_cost_per_token(
model=model,
custom_llm_provider=custom_llm_provider,
usage=usage,
service_tier=service_tier,
)

View File

@@ -0,0 +1,48 @@
from typing import Any, Dict, Optional, Tuple
from litellm.llms.gemini.count_tokens.handler import GoogleAIStudioTokenCounter
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
class VertexAITokenCounter(GoogleAIStudioTokenCounter, VertexBase):
async def validate_environment(
self,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
model: str = "",
litellm_params: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, Any], str]:
"""
Returns a Tuple of headers and url for the Vertex AI countTokens endpoint.
"""
litellm_params = litellm_params or {}
vertex_credentials = self.get_vertex_ai_credentials(
litellm_params=litellm_params
)
vertex_project = self.get_vertex_ai_project(litellm_params=litellm_params)
vertex_location = self.get_vertex_ai_location(litellm_params=litellm_params)
should_use_v1beta1_features = self.is_using_v1beta1_features(litellm_params)
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=None,
should_use_v1beta1_features=should_use_v1beta1_features,
mode="count_tokens",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
return headers, api_base

View File

@@ -0,0 +1,246 @@
import asyncio
import urllib.parse
from typing import Any, Coroutine, Optional, Tuple, Union
import httpx
from litellm import LlmProviders
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
GCSBucketBase,
GCSLoggingConfig,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import (
CreateFileRequest,
FileContentRequest,
HttpxBinaryResponseContent,
OpenAIFileObject,
)
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .transformation import VertexAIJsonlFilesTransformation
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
class VertexAIFilesHandler(GCSBucketBase):
"""
Handles Calling VertexAI in OpenAI Files API format v1/files/*
This implementation uploads files on GCS Buckets
"""
def __init__(self):
super().__init__()
self.async_httpx_client = get_async_httpx_client(
llm_provider=LlmProviders.VERTEX_AI,
)
async def async_create_file(
self,
create_file_data: CreateFileRequest,
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> OpenAIFileObject:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs={}
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
(
logging_payload,
object_name,
) = vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content(
openai_file_content=create_file_data.get("file")
)
gcs_upload_response = await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
)
return vertex_ai_files_transformation.transform_gcs_bucket_response_to_openai_file_object(
create_file_data=create_file_data,
gcs_upload_response=gcs_upload_response,
)
def create_file(
self,
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
"""
Creates a file on VertexAI GCS Bucket
Only supported for Async litellm.acreate_file
"""
if _is_async:
return self.async_create_file(
create_file_data=create_file_data,
api_base=api_base,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
else:
return asyncio.run(
self.async_create_file(
create_file_data=create_file_data,
api_base=api_base,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
)
def _extract_bucket_and_object_from_file_id(self, file_id: str) -> Tuple[str, str]:
"""
Extract bucket name and object path from URL-encoded file_id.
Expected format: gs%3A%2F%2Fbucket-name%2Fpath%2Fto%2Ffile
Which decodes to: gs://bucket-name/path/to/file
Returns:
tuple: (bucket_name, url_encoded_object_path)
- bucket_name: "bucket-name"
- url_encoded_object_path: "path%2Fto%2Ffile"
"""
decoded_path = urllib.parse.unquote(file_id)
if decoded_path.startswith("gs://"):
full_path = decoded_path[5:] # Remove 'gs://' prefix
else:
full_path = decoded_path
if "/" in full_path:
bucket_name, object_path = full_path.split("/", 1)
else:
bucket_name = full_path
object_path = ""
encoded_object_path = urllib.parse.quote(object_path, safe="")
return bucket_name, encoded_object_path
async def afile_content(
self,
file_content_request: FileContentRequest,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> HttpxBinaryResponseContent:
"""
Download file content from GCS bucket for VertexAI files.
Args:
file_content_request: Contains file_id (URL-encoded GCS path)
vertex_credentials: VertexAI credentials
vertex_project: VertexAI project ID
vertex_location: VertexAI location
timeout: Request timeout
max_retries: Max retry attempts
Returns:
HttpxBinaryResponseContent: Binary content wrapped in compatible response format
"""
file_id = file_content_request.get("file_id")
if not file_id:
raise ValueError("file_id is required in file_content_request")
bucket_name, encoded_object_path = self._extract_bucket_and_object_from_file_id(
file_id
)
download_kwargs = {
"standard_callback_dynamic_params": {"gcs_bucket_name": bucket_name}
}
file_content = await self.download_gcs_object(
object_name=encoded_object_path, **download_kwargs
)
if file_content is None:
decoded_path = urllib.parse.unquote(file_id)
raise ValueError(f"Failed to download file from GCS: {decoded_path}")
decoded_path = urllib.parse.unquote(file_id)
mock_response = httpx.Response(
status_code=200,
content=file_content,
headers={"content-type": "application/octet-stream"},
request=httpx.Request(method="GET", url=decoded_path),
)
return HttpxBinaryResponseContent(response=mock_response)
def file_content(
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
"""
Download file content from GCS bucket for VertexAI files.
Supports both sync and async operations.
Args:
_is_async: Whether to run asynchronously
file_content_request: Contains file_id (URL-encoded GCS path)
api_base: API base (unused for GCS operations)
vertex_credentials: VertexAI credentials
vertex_project: VertexAI project ID
vertex_location: VertexAI location
timeout: Request timeout
max_retries: Max retry attempts
Returns:
HttpxBinaryResponseContent or Coroutine: Binary content wrapped in compatible response format
"""
if _is_async:
return self.afile_content(
file_content_request=file_content_request,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
else:
return asyncio.run(
self.afile_content(
file_content_request=file_content_request,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
)

View File

@@ -0,0 +1,607 @@
import json
import os
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from httpx import Headers, Response
from openai.types.file_deleted import FileDeleted
from litellm._uuid import uuid
from litellm.files.utils import FilesAPIUtils
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
)
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.llms.vertex_ai.gemini.transformation import _transform_request_body
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.llms.openai import (
AllMessageValues,
CreateFileRequest,
FileTypes,
HttpxBinaryResponseContent,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
PathLike,
)
from litellm.types.llms.vertex_ai import GcsBucketResponse
from litellm.types.utils import ExtractedFileData, LlmProviders
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
"""
Config for VertexAI Files
"""
def __init__(self):
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
super().__init__()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.VERTEX_AI
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if not api_key:
api_key, _ = self.get_access_token(
credentials=litellm_params.get("vertex_credentials"),
project_id=litellm_params.get("vertex_project"),
)
if not api_key:
raise ValueError("api_key is required")
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def _get_gcs_object_name_from_batch_jsonl(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique GCS object name for the VertexAI batch prediction job
named as: litellm-vertex-{model}-{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
if "publishers/google/models" not in _model:
_model = f"publishers/google/models/{_model}"
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
return object_name
def get_object_name(
self, extracted_file_data: ExtractedFileData, purpose: str
) -> str:
"""
Get the object name for the request
"""
extracted_file_data_content = extracted_file_data.get("content")
if extracted_file_data_content is None:
raise ValueError("file content is required")
if purpose == "batch":
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
if len(openai_jsonl_content) > 0:
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
## 2. If not jsonl, return the filename
filename = extracted_file_data.get("filename")
if filename:
return filename
## 3. If no file name, return timestamp
return str(int(time.time()))
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateFileRequest,
) -> str:
"""
Get the complete url for the request
"""
bucket_name = (
litellm_params.get("bucket_name")
or litellm_params.get("litellm_metadata", {}).pop("gcs_bucket_name", None)
or os.getenv("GCS_BUCKET_NAME")
)
if not bucket_name:
raise ValueError("GCS bucket_name is required")
file_data = data.get("file")
purpose = data.get("purpose")
if file_data is None:
raise ValueError("file is required")
if purpose is None:
raise ValueError("purpose is required")
extracted_file_data = extract_file_data(file_data)
object_name = self.get_object_name(extracted_file_data, purpose)
endpoint = (
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
)
api_base = api_base or "https://storage.googleapis.com"
if not api_base:
raise ValueError("api_base is required")
return f"{api_base}/{endpoint}"
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def _map_openai_to_vertex_params(
self,
openai_request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
wrapper to call VertexGeminiConfig.map_openai_params
"""
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
config = VertexGeminiConfig()
_model = openai_request_body.get("model", "")
vertex_params = config.map_openai_params(
model=_model,
non_default_params=openai_request_body,
optional_params={},
drop_params=False,
)
return vertex_params
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Transforms OpenAI JSONL content to VertexAI JSONL content
jsonl body for vertex is {"request": <request_body>}
Example Vertex jsonl
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
"""
vertex_jsonl_content = []
for _openai_jsonl_content in openai_jsonl_content:
openai_request_body = _openai_jsonl_content.get("body") or {}
vertex_request_body = _transform_request_body(
messages=openai_request_body.get("messages", []),
model=openai_request_body.get("model", ""),
optional_params=self._map_openai_to_vertex_params(openai_request_body),
custom_llm_provider="vertex_ai",
litellm_params={},
cached_content=None,
)
vertex_jsonl_content.append({"request": vertex_request_body})
return vertex_jsonl_content
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, dict]:
"""
2 Cases:
1. Handle basic file upload
2. Handle batch file upload (.jsonl)
"""
file_data = create_file_data.get("file")
if file_data is None:
raise ValueError("file is required")
extracted_file_data = extract_file_data(file_data)
extracted_file_data_content = extracted_file_data.get("content")
if extracted_file_data_content is None:
raise ValueError("file content is required")
if FilesAPIUtils.is_batch_jsonl_file(
create_file_data=create_file_data,
extracted_file_data=extracted_file_data,
):
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
vertex_jsonl_content = (
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
openai_jsonl_content
)
)
return "\n".join(json.dumps(item) for item in vertex_jsonl_content)
elif isinstance(extracted_file_data_content, bytes):
return extracted_file_data_content
else:
raise ValueError("Unsupported file content type")
def transform_create_file_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""
Transform VertexAI File upload response into OpenAI-style FileObject
"""
response_json = raw_response.json()
try:
response_object = GcsBucketResponse(**response_json) # type: ignore
except Exception as e:
raise VertexAIError(
status_code=raw_response.status_code,
message=f"Error reading GCS bucket response: {e}",
headers=raw_response.headers,
)
gcs_id = response_object.get("id", "")
# Remove the last numeric ID from the path
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return OpenAIFileObject(
purpose=response_object.get("purpose", "batch"),
id=f"gs://{gcs_id}",
filename=response_object.get("name", ""),
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response_object.get("timeCreated", "")
),
status="uploaded",
bytes=int(response_object.get("size", 0)),
object="file",
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
return VertexAIError(
status_code=status_code, message=error_message, headers=headers
)
def _parse_gcs_uri(self, file_id: str) -> Tuple[str, str]:
"""
Parse a GCS URI (gs://bucket/path/to/object) into (bucket, url-encoded-object-path).
Handles both raw and URL-encoded input.
"""
import urllib.parse
decoded = urllib.parse.unquote(file_id)
if decoded.startswith("gs://"):
full_path = decoded[5:]
else:
full_path = decoded
if "/" in full_path:
bucket_name, object_path = full_path.split("/", 1)
else:
bucket_name = full_path
object_path = ""
encoded_object = urllib.parse.quote(object_path, safe="")
return bucket_name, encoded_object
def transform_retrieve_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
bucket, encoded_object = self._parse_gcs_uri(file_id)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded_object}"
return url, {}
def transform_retrieve_file_response(
self,
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
response_json = raw_response.json()
gcs_id = response_json.get("id", "")
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return OpenAIFileObject(
id=f"gs://{gcs_id}",
bytes=int(response_json.get("size", 0)),
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response_json.get("timeCreated", "")
),
filename=response_json.get("name", ""),
object="file",
purpose=response_json.get("metadata", {}).get("purpose", "batch"),
status="processed",
status_details=None,
)
def transform_delete_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
bucket, encoded_object = self._parse_gcs_uri(file_id)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded_object}"
return url, {}
def transform_delete_file_response(
self,
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> FileDeleted:
file_id = "deleted"
if hasattr(raw_response, "request") and raw_response.request:
url = str(raw_response.request.url)
if "/b/" in url and "/o/" in url:
import urllib.parse
bucket_part = url.split("/b/")[-1].split("/o/")[0]
encoded_name = url.split("/o/")[-1].split("?")[0]
file_id = f"gs://{bucket_part}/{urllib.parse.unquote(encoded_name)}"
return FileDeleted(id=file_id, deleted=True, object="file")
def transform_list_files_request(
self,
purpose: Optional[str],
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
raise NotImplementedError("VertexAIFilesConfig does not support file listing")
def transform_list_files_response(
self,
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> List[OpenAIFileObject]:
raise NotImplementedError("VertexAIFilesConfig does not support file listing")
def transform_file_content_request(
self,
file_content_request,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
file_id = file_content_request.get("file_id", "")
bucket, encoded_object = self._parse_gcs_uri(file_id)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded_object}?alt=media"
return url, {}
def transform_file_content_response(
self,
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> HttpxBinaryResponseContent:
return HttpxBinaryResponseContent(response=raw_response)
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
"""
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
"""
def transform_openai_file_content_to_vertex_ai_file_content(
self, openai_file_content: Optional[FileTypes] = None
) -> Tuple[str, str]:
"""
Transforms OpenAI FileContentRequest to VertexAI FileContentRequest
"""
if openai_file_content is None:
raise ValueError("contents of file are None")
# Read the content of the file
file_content = self._get_content_from_openai_file(openai_file_content)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
vertex_jsonl_content = (
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
openai_jsonl_content
)
)
vertex_jsonl_string = "\n".join(
json.dumps(item) for item in vertex_jsonl_content
)
object_name = self._get_gcs_object_name(
openai_jsonl_content=openai_jsonl_content
)
return vertex_jsonl_string, object_name
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
):
"""
Transforms OpenAI JSONL content to VertexAI JSONL content
jsonl body for vertex is {"request": <request_body>}
Example Vertex jsonl
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
"""
vertex_jsonl_content = []
for _openai_jsonl_content in openai_jsonl_content:
openai_request_body = _openai_jsonl_content.get("body") or {}
vertex_request_body = _transform_request_body(
messages=openai_request_body.get("messages", []),
model=openai_request_body.get("model", ""),
optional_params=self._map_openai_to_vertex_params(openai_request_body),
custom_llm_provider="vertex_ai",
litellm_params={},
cached_content=None,
)
vertex_jsonl_content.append({"request": vertex_request_body})
return vertex_jsonl_content
def _get_gcs_object_name(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique GCS object name for the VertexAI batch prediction job
named as: litellm-vertex-{model}-{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
if "publishers/google/models" not in _model:
_model = f"publishers/google/models/{_model}"
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
return object_name
def _map_openai_to_vertex_params(
self,
openai_request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
wrapper to call VertexGeminiConfig.map_openai_params
"""
_model = openai_request_body.get("model", "")
vertex_params = self.map_openai_params(
model=_model,
non_default_params=openai_request_body,
optional_params={},
drop_params=False,
)
return vertex_params
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def transform_gcs_bucket_response_to_openai_file_object(
self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any]
) -> OpenAIFileObject:
"""
Transforms GCS Bucket upload file response to OpenAI FileObject
"""
gcs_id = gcs_upload_response.get("id", "")
# Remove the last numeric ID from the path
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return OpenAIFileObject(
purpose=create_file_data.get("purpose", "batch"),
id=f"gs://{gcs_id}",
filename=gcs_upload_response.get("name", ""),
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=gcs_upload_response.get("timeCreated", "")
),
status="uploaded",
bytes=gcs_upload_response.get("size", 0),
object="file",
)

View File

@@ -0,0 +1,374 @@
import json
import traceback
from datetime import datetime
from typing import Any, Coroutine, Literal, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters
from litellm.types.llms.openai import FineTuningJobCreate
from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES,
FineTuneHyperparameters,
FineTuneJobCreate,
FineTunesupervisedTuningSpec,
ResponseSupervisedTuningSpec,
ResponseTuningJob,
)
from litellm.types.utils import LiteLLMFineTuningJob
class VertexFineTuningAPI(VertexLLM):
"""
Vertex methods to support for batches
"""
def __init__(self) -> None:
super().__init__()
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": 600.0},
)
def convert_response_created_at(self, response: ResponseTuningJob):
try:
create_time_str = response.get("createTime", "") or ""
create_time_datetime = datetime.fromisoformat(
create_time_str.replace("Z", "+00:00")
)
# Convert to Unix timestamp (seconds since epoch)
created_at = int(create_time_datetime.timestamp())
return created_at
except Exception:
return 0
def convert_openai_request_to_vertex(
self,
create_fine_tuning_job_data: FineTuningJobCreate,
original_hyperparameters: dict = {},
kwargs: Optional[dict] = None,
) -> FineTuneJobCreate:
"""
convert request from OpenAI format to Vertex format
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
supervised_tuning_spec = FineTunesupervisedTuningSpec(
"""
supervised_tuning_spec = FineTunesupervisedTuningSpec(
training_dataset_uri=create_fine_tuning_job_data.training_file,
)
if create_fine_tuning_job_data.validation_file:
supervised_tuning_spec[
"validation_dataset"
] = create_fine_tuning_job_data.validation_file
_vertex_hyperparameters = (
self._transform_openai_hyperparameters_to_vertex_hyperparameters(
create_fine_tuning_job_data=create_fine_tuning_job_data,
kwargs=kwargs,
original_hyperparameters=original_hyperparameters,
)
)
if _vertex_hyperparameters and len(_vertex_hyperparameters) > 0:
supervised_tuning_spec["hyperParameters"] = _vertex_hyperparameters
fine_tune_job = FineTuneJobCreate(
baseModel=create_fine_tuning_job_data.model,
supervisedTuningSpec=supervised_tuning_spec,
tunedModelDisplayName=create_fine_tuning_job_data.suffix,
)
return fine_tune_job
def _transform_openai_hyperparameters_to_vertex_hyperparameters(
self,
create_fine_tuning_job_data: FineTuningJobCreate,
original_hyperparameters: dict = {},
kwargs: Optional[dict] = None,
) -> FineTuneHyperparameters:
_oai_hyperparameters = create_fine_tuning_job_data.hyperparameters
_vertex_hyperparameters = FineTuneHyperparameters()
if _oai_hyperparameters:
if _oai_hyperparameters.n_epochs:
_vertex_hyperparameters["epoch_count"] = int(
_oai_hyperparameters.n_epochs
)
if _oai_hyperparameters.learning_rate_multiplier:
_vertex_hyperparameters["learning_rate_multiplier"] = float(
_oai_hyperparameters.learning_rate_multiplier
)
_adapter_size = original_hyperparameters.get("adapter_size", None)
if _adapter_size:
_vertex_hyperparameters["adapter_size"] = _adapter_size
return _vertex_hyperparameters
def convert_vertex_response_to_open_ai_response(
self, response: ResponseTuningJob
) -> LiteLLMFineTuningJob:
status: Literal[
"validating_files", "queued", "running", "succeeded", "failed", "cancelled"
] = "queued"
if response["state"] == "JOB_STATE_PENDING":
status = "queued"
if response["state"] == "JOB_STATE_SUCCEEDED":
status = "succeeded"
if response["state"] == "JOB_STATE_FAILED":
status = "failed"
if response["state"] == "JOB_STATE_CANCELLED":
status = "cancelled"
if response["state"] == "JOB_STATE_RUNNING":
status = "running"
created_at = self.convert_response_created_at(response)
_supervisedTuningSpec: ResponseSupervisedTuningSpec = (
response.get("supervisedTuningSpec", None) or {}
)
training_uri: str = _supervisedTuningSpec.get("trainingDatasetUri", "") or ""
return LiteLLMFineTuningJob(
id=response.get("name", "") or "",
created_at=created_at,
fine_tuned_model=response.get("tunedModelDisplayName", ""),
finished_at=None,
hyperparameters=self._translate_vertex_response_hyperparameters(
vertex_hyper_parameters=_supervisedTuningSpec.get(
"hyperParameters", FineTuneHyperparameters()
)
or {}
),
model=response.get("baseModel", "") or "",
object="fine_tuning.job",
organization_id="",
result_files=[],
seed=0,
status=status,
trained_tokens=None,
training_file=training_uri,
validation_file=None,
estimated_finish=None,
integrations=[],
)
def _translate_vertex_response_hyperparameters(
self, vertex_hyper_parameters: FineTuneHyperparameters
) -> OpenAIFineTuningHyperparameters:
"""
translate vertex responsehyperparameters to openai hyperparameters
"""
_dict_remaining_hyperparameters: dict = dict(vertex_hyper_parameters)
return OpenAIFineTuningHyperparameters(
n_epochs=_dict_remaining_hyperparameters.pop("epoch_count", 0),
**_dict_remaining_hyperparameters,
)
async def acreate_fine_tuning_job(
self,
fine_tuning_url: str,
headers: dict,
request_data: FineTuneJobCreate,
):
try:
verbose_logger.debug(
"about to create fine tuning job: %s, request_data: %s",
fine_tuning_url,
json.dumps(request_data, indent=4),
)
if self.async_handler is None:
raise ValueError(
"VertexAI Fine Tuning - async_handler is not initialized"
)
response = await self.async_handler.post(
headers=headers,
url=fine_tuning_url,
json=request_data, # type: ignore
)
if response.status_code != 200:
raise Exception(
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
)
verbose_logger.debug(
"got response from creating fine tuning job: %s", response.json()
)
vertex_response = ResponseTuningJob( # type: ignore
**response.json(),
)
verbose_logger.debug("vertex_response %s", vertex_response)
open_ai_response = self.convert_vertex_response_to_open_ai_response(
vertex_response
)
return open_ai_response
except Exception as e:
verbose_logger.error("asyncerror creating fine tuning job %s", e)
trace_back_str = traceback.format_exc()
verbose_logger.error(trace_back_str)
raise e
def create_fine_tuning_job(
self,
_is_async: bool,
create_fine_tuning_job_data: FineTuningJobCreate,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
kwargs: Optional[dict] = None,
original_hyperparameters: Optional[dict] = {},
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = self._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base=api_base,
)
headers = {
"Authorization": f"Bearer {auth_header}",
"Content-Type": "application/json",
}
fine_tune_job = self.convert_openai_request_to_vertex(
create_fine_tuning_job_data=create_fine_tuning_job_data,
kwargs=kwargs,
original_hyperparameters=original_hyperparameters or {},
)
base_url = get_vertex_base_url(vertex_location)
fine_tuning_url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
if _is_async is True:
return self.acreate_fine_tuning_job( # type: ignore
fine_tuning_url=fine_tuning_url,
headers=headers,
request_data=fine_tune_job,
)
sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
verbose_logger.debug(
"about to create fine tuning job: %s, request_data: %s",
fine_tuning_url,
fine_tune_job,
)
response = sync_handler.post(
headers=headers,
url=fine_tuning_url,
json=fine_tune_job, # type: ignore
)
if response.status_code != 200:
raise Exception(
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
)
verbose_logger.debug(
"got response from creating fine tuning job: %s", response.json()
)
vertex_response = ResponseTuningJob( # type: ignore
**response.json(),
)
verbose_logger.debug("vertex_response %s", vertex_response)
open_ai_response = self.convert_vertex_response_to_open_ai_response(
vertex_response
)
return open_ai_response
async def pass_through_vertex_ai_POST_request(
self,
request_data: dict,
vertex_project: str,
vertex_location: str,
vertex_credentials: str,
request_route: str,
):
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = self._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
"Content-Type": "application/json",
}
base_url = get_vertex_base_url(vertex_location)
url = None
if request_route == "/tuningJobs":
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
elif "/tuningJobs/" in request_route and "cancel" in request_route:
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs{request_route}"
elif "generateContent" in request_route:
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
elif "predict" in request_route:
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
elif "/batchPredictionJobs" in request_route:
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
elif "countTokens" in request_route:
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
elif "cachedContents" in request_route:
_model = request_data.get("model")
if _model is not None and "/publishers/google/models/" not in _model:
request_data[
"model"
] = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}"
url = f"{base_url}/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
else:
raise ValueError(f"Unsupported Vertex AI request route: {request_route}")
if self.async_handler is None:
raise ValueError("VertexAI Fine Tuning - async_handler is not initialized")
response = await self.async_handler.post(
headers=headers,
url=url,
json=request_data, # type: ignore
)
if response.status_code != 200:
raise Exception(
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
)
response_json = response.json()
return response_json

View File

@@ -0,0 +1,45 @@
"""
Cost calculator for Vertex AI Gemini.
Used because there are differences in how Google AI Studio and Vertex AI Gemini handle web search requests.
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from litellm.types.utils import ModelInfo, Usage
def cost_per_web_search_request(usage: "Usage", model_info: "ModelInfo") -> float:
"""
Calculate the cost of a web search request for Vertex AI Gemini.
Vertex AI charges $35/1000 prompts, independent of the number of web search requests.
For a single call, this is $35e-3 USD.
Args:
usage: The usage object for the web search request.
model_info: The model info for the web search request.
Returns:
The cost of the web search request.
"""
from litellm.types.utils import PromptTokensDetailsWrapper
# check if usage object has web search requests
cost_per_llm_call_with_web_search = 35e-3
makes_web_search_request = False
if (
usage is not None
and usage.prompt_tokens_details is not None
and isinstance(usage.prompt_tokens_details, PromptTokensDetailsWrapper)
):
makes_web_search_request = True
# Calculate total cost
if makes_web_search_request:
return cost_per_llm_call_with_web_search
else:
return 0.0

View File

@@ -0,0 +1,888 @@
"""
Transformation logic from OpenAI format to Gemini format.
Why separate file? Make it easy to see how transformation works
"""
import json
import os
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union, cast
import httpx
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
_get_image_mime_type_from_url,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke,
convert_to_gemini_tool_call_result,
response_schema_prompt,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
is_gemini_1_5_accepted_file_type,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAudioObject,
ChatCompletionFileObject,
ChatCompletionImageObject,
ChatCompletionTextObject,
ChatCompletionUserMessage,
)
from litellm.types.llms.vertex_ai import *
from litellm.types.llms.vertex_ai import (
GenerationConfig,
PartType,
RequestBody,
SafetSettingsConfig,
SystemInstructions,
ToolConfig,
Tools,
)
from litellm.types.utils import GenericImageParsingChunk, LlmProviders
from ..common_utils import (
_check_text_in_content,
get_supports_response_schema,
get_supports_system_message,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
def _convert_detail_to_media_resolution_enum(
detail: Optional[str],
) -> Optional[Dict[str, str]]:
if detail == "low":
return {"level": "MEDIA_RESOLUTION_LOW"}
elif detail == "medium":
return {"level": "MEDIA_RESOLUTION_MEDIUM"}
elif detail == "high":
return {"level": "MEDIA_RESOLUTION_HIGH"}
elif detail == "ultra_high":
return {"level": "MEDIA_RESOLUTION_ULTRA_HIGH"}
return None
def _get_highest_media_resolution(
current: Optional[str], new_detail: Optional[str]
) -> Optional[str]:
"""
Compare two media resolution values and return the highest one.
Resolution hierarchy: ultra_high > high > medium > low > None
"""
resolution_priority = {"ultra_high": 4, "high": 3, "medium": 2, "low": 1}
current_priority = resolution_priority.get(current, 0) if current else 0
new_priority = resolution_priority.get(new_detail, 0) if new_detail else 0
if new_priority > current_priority:
return new_detail
return current
def _extract_max_media_resolution_from_messages(
messages: List[AllMessageValues],
) -> Optional[str]:
"""
Extract the highest media resolution (detail) from image content in messages.
This is used to set the global media_resolution in generation_config for
Gemini 2.x models which don't support per-part media resolution.
Args:
messages: List of messages in OpenAI format
Returns:
The highest detail level found ("high", "low", or None)
"""
max_resolution: Optional[str] = None
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
detail: Optional[str] = None
if item.get("type") == "image_url":
image_url = item.get("image_url")
if isinstance(image_url, dict):
detail = image_url.get("detail")
elif item.get("type") == "file":
file_obj = item.get("file")
if isinstance(file_obj, dict):
detail = file_obj.get("detail")
if detail:
max_resolution = _get_highest_media_resolution(
max_resolution, detail
)
return max_resolution
def _apply_gemini_3_metadata(
part: PartType,
model: Optional[str],
media_resolution_enum: Optional[Dict[str, str]],
video_metadata: Optional[Dict[str, Any]],
) -> PartType:
"""
Apply the unique media_resolution and video_metadata parameters of Gemini 3+
"""
if model is None:
return part
from .vertex_and_google_ai_studio_gemini import VertexGeminiConfig
if not VertexGeminiConfig._is_gemini_3_or_newer(model):
return part
part_dict = dict(part)
if media_resolution_enum is not None:
part_dict["media_resolution"] = media_resolution_enum
if video_metadata is not None:
gemini_video_metadata = {}
if "fps" in video_metadata:
gemini_video_metadata["fps"] = video_metadata["fps"]
if "start_offset" in video_metadata:
gemini_video_metadata["startOffset"] = video_metadata["start_offset"]
if "end_offset" in video_metadata:
gemini_video_metadata["endOffset"] = video_metadata["end_offset"]
if gemini_video_metadata:
part_dict["video_metadata"] = gemini_video_metadata
return cast(PartType, part_dict)
def _process_gemini_media(
image_url: str,
format: Optional[str] = None,
media_resolution_enum: Optional[Dict[str, str]] = None,
model: Optional[str] = None,
video_metadata: Optional[Dict[str, Any]] = None,
) -> PartType:
"""
Given a media URL (image, audio, or video), return the appropriate PartType for Gemini
By the way, actually video_metadata can only be used with videos; it cannot be used with images, audio, or files. However, I haven't made any special handling because vertex returns a parameter error.
Args:
image_url: The URL or base64 string of the media (image, audio, or video)
format: The MIME type of the media
media_resolution_enum: Media resolution level (for Gemini 3+)
model: The model name (to check version compatibility)
video_metadata: Video-specific metadata (fps, start_offset, end_offset)
"""
try:
# GCS URIs
if "gs://" in image_url:
# Figure out file type
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png"
if not format:
file_type = get_file_type_from_extension(extension)
# Validate the file type is supported by Gemini
if not is_gemini_1_5_accepted_file_type(file_type):
raise Exception(f"File type not supported by gemini - {file_type}")
mime_type = get_file_mime_type_for_file_type(file_type)
else:
mime_type = format
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
part: PartType = {"file_data": file_data}
return _apply_gemini_3_metadata(
part, model, media_resolution_enum, video_metadata
)
elif (
"https://" in image_url
and (image_type := format or _get_image_mime_type_from_url(image_url))
is not None
):
file_data = FileDataType(mime_type=image_type, file_uri=image_url)
part = {"file_data": file_data}
return _apply_gemini_3_metadata(
part, model, media_resolution_enum, video_metadata
)
elif "http://" in image_url or "https://" in image_url or "base64" in image_url:
image = convert_to_anthropic_image_obj(image_url, format=format)
_blob: BlobType = {"data": image["data"], "mime_type": image["media_type"]}
part = {"inline_data": cast(BlobType, _blob)}
return _apply_gemini_3_metadata(
part, model, media_resolution_enum, video_metadata
)
raise Exception("Invalid image received - {}".format(image_url))
except Exception as e:
raise e
def _snake_to_camel(snake_str: str) -> str:
"""Convert snake_case to camelCase"""
components = snake_str.split("_")
return components[0] + "".join(x.capitalize() for x in components[1:])
def _camel_to_snake(camel_str: str) -> str:
"""Convert camelCase to snake_case"""
import re
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
def _get_equivalent_key(key: str, available_keys: set) -> Optional[str]:
"""
Get the equivalent key from available keys, checking both camelCase and snake_case variants
"""
if key in available_keys:
return key
# Try camelCase version
camel_key = _snake_to_camel(key)
if camel_key in available_keys:
return camel_key
# Try snake_case version
snake_key = _camel_to_snake(key)
if snake_key in available_keys:
return snake_key
return None
def check_if_part_exists_in_parts(
parts: List[PartType], part: PartType, excluded_keys: List[str] = []
) -> bool:
"""
Check if a part exists in a list of parts
Handles both camelCase and snake_case key variations (e.g., function_call vs functionCall)
"""
keys_to_compare = set(part.keys()) - set(excluded_keys)
for p in parts:
p_keys = set(p.keys())
# Check if all keys in part have equivalent values in p
match_found = True
for key in keys_to_compare:
equivalent_key = _get_equivalent_key(key, p_keys)
if equivalent_key is None or p.get(equivalent_key, None) != part.get(
key, None
):
match_found = False
break
if match_found:
return True
return False
def _gemini_convert_messages_with_history( # noqa: PLR0915
messages: List[AllMessageValues],
model: Optional[str] = None,
) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
- Parts must be iterable
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
user_message_types = {"user", "system"}
contents: List[ContentType] = []
last_message_with_tool_calls = None
msg_i = 0
tool_call_responses = []
try:
while msg_i < len(messages):
user_content: List[PartType] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
_message_content = messages[msg_i].get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element_idx, element in enumerate(_message_content):
if (
element["type"] == "text"
and "text" in element
and len(element["text"]) > 0
):
element = cast(ChatCompletionTextObject, element)
_part = PartType(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
element = cast(ChatCompletionImageObject, element)
img_element = element
format: Optional[str] = None
media_resolution_enum: Optional[Dict[str, str]] = None
if isinstance(img_element["image_url"], dict):
image_url = img_element["image_url"]["url"]
format = img_element["image_url"].get("format")
detail = img_element["image_url"].get("detail")
media_resolution_enum = (
_convert_detail_to_media_resolution_enum(detail)
)
else:
image_url = img_element["image_url"]
_part = _process_gemini_media(
image_url=image_url,
format=format,
media_resolution_enum=media_resolution_enum,
model=model,
)
_parts.append(_part)
elif element["type"] == "input_audio":
audio_element = cast(ChatCompletionAudioObject, element)
audio_data = audio_element["input_audio"].get("data")
audio_format = audio_element["input_audio"].get("format")
if audio_data is not None and audio_format is not None:
audio_format_modified = (
"audio/" + audio_format
if audio_format.startswith("audio/") is False
else audio_format
) # Gemini expects audio/wav, audio/mp3, etc.
openai_image_str = (
convert_generic_image_chunk_to_openai_image_obj(
image_chunk=GenericImageParsingChunk(
type="base64",
media_type=audio_format_modified,
data=audio_data,
)
)
)
_part = _process_gemini_media(
image_url=openai_image_str,
format=audio_format_modified,
model=model,
)
_parts.append(_part)
elif element["type"] == "file":
file_element = cast(ChatCompletionFileObject, element)
file_id = file_element["file"].get("file_id")
format = file_element["file"].get("format")
file_data = file_element["file"].get("file_data")
detail = file_element["file"].get("detail")
video_metadata = file_element["file"].get("video_metadata")
passed_file = file_id or file_data
if passed_file is None:
raise Exception(
"Unknown file type. Please pass in a file_id or file_data"
)
# Convert detail to media_resolution_enum
media_resolution_enum = (
_convert_detail_to_media_resolution_enum(detail)
)
try:
_part = _process_gemini_media(
image_url=passed_file,
format=format,
model=model,
media_resolution_enum=media_resolution_enum,
video_metadata=video_metadata,
)
_parts.append(_part)
except Exception:
raise Exception(
"Unable to determine mime type for file_id: {}, set this explicitly using message[{}].content[{}].file.format".format(
file_id, msg_i, element_idx
)
)
user_content.extend(_parts)
elif _message_content is not None and isinstance(_message_content, str):
_part = PartType(text=_message_content)
user_content.append(_part)
msg_i += 1
if user_content:
"""
check that user_content has 'text' parameter.
- Known Vertex Error: Unable to submit request because it must have a text parameter.
- Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
"""
has_text_in_content = _check_text_in_content(user_content)
if has_text_in_content is False:
verbose_logger.warning(
"No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515"
)
user_content.append(
PartType(text=" ")
) # add a blank text, to ensure Gemini doesn't fail the request.
contents.append(ContentType(role="user", parts=user_content))
assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i], BaseModel):
msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
else:
msg_dict = messages[msg_i] # type: ignore
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
_message_content = assistant_msg.get("content", None)
reasoning_content = assistant_msg.get("reasoning_content", None)
thinking_blocks = assistant_msg.get("thinking_blocks")
if reasoning_content is not None:
assistant_content.append(
PartType(thought=True, text=reasoning_content)
)
if thinking_blocks is not None:
for block in thinking_blocks:
if block["type"] == "thinking":
block_thinking_str = block.get("thinking")
block_signature = block.get("signature")
if (
block_thinking_str is not None
and block_signature is not None
):
try:
assistant_content.append(
PartType(
thoughtSignature=block_signature,
**json.loads(block_thinking_str),
)
)
except Exception:
assistant_content.append(
PartType(
thoughtSignature=block_signature,
text=block_thinking_str,
)
)
if _message_content is not None and isinstance(_message_content, list):
_parts = []
for element in _message_content:
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"])
_parts.append(_part)
assistant_content.extend(_parts)
elif _message_content is not None and isinstance(_message_content, str):
assistant_text = _message_content
# Check if message has thought_signatures in provider_specific_fields
provider_specific_fields = assistant_msg.get(
"provider_specific_fields"
)
thought_signatures = None
if provider_specific_fields and isinstance(
provider_specific_fields, dict
):
thought_signatures = provider_specific_fields.get(
"thought_signatures"
)
# If we have thought signatures, add them to the part
if (
thought_signatures
and isinstance(thought_signatures, list)
and len(thought_signatures) > 0
):
# Use the first signature for the text part (Gemini expects one signature per part)
assistant_content.append(PartType(text=assistant_text, thoughtSignature=thought_signatures[0])) # type: ignore
else:
assistant_content.append(PartType(text=assistant_text)) # type: ignore
## HANDLE ASSISTANT IMAGES FIELD
# Process images field if present (for generated images from assistant)
assistant_images = assistant_msg.get("images")
if assistant_images is not None and isinstance(assistant_images, list):
for image_item in assistant_images:
if isinstance(image_item, dict):
image_url_obj = image_item.get("image_url")
if isinstance(image_url_obj, dict):
assistant_image_url = image_url_obj.get("url")
format = image_url_obj.get("format")
detail = image_url_obj.get("detail")
media_resolution_enum = (
_convert_detail_to_media_resolution_enum(detail)
)
if assistant_image_url:
_part = _process_gemini_media(
image_url=assistant_image_url,
format=format,
media_resolution_enum=media_resolution_enum,
model=model,
)
assistant_content.append(_part)
## HANDLE ASSISTANT FUNCTION CALL
if (
assistant_msg.get("tool_calls", []) is not None
or assistant_msg.get("function_call") is not None
): # support assistant tool invoke conversion
gemini_tool_call_parts = convert_to_gemini_tool_call_invoke(
assistant_msg, model=model
)
## check if gemini_tool_call already exists in assistant_content
for gemini_tool_call_part in gemini_tool_call_parts:
if not check_if_part_exists_in_parts(
assistant_content,
gemini_tool_call_part,
excluded_keys=["thoughtSignature"],
):
assistant_content.append(gemini_tool_call_part)
last_message_with_tool_calls = assistant_msg
msg_i += 1
if assistant_content:
contents.append(ContentType(role="model", parts=assistant_content))
## APPEND TOOL CALL MESSAGES ##
tool_call_message_roles = ["tool", "function"]
if (
msg_i < len(messages)
and messages[msg_i]["role"] in tool_call_message_roles
):
_part = convert_to_gemini_tool_call_result(
messages[msg_i], last_message_with_tool_calls # type: ignore
)
msg_i += 1
# Handle both single part and list of parts (for Computer Use with images)
if isinstance(_part, list):
tool_call_responses.extend(_part)
else:
tool_call_responses.append(_part)
if msg_i < len(messages) and (
messages[msg_i]["role"] not in tool_call_message_roles
):
if len(tool_call_responses) > 0:
contents.append(ContentType(role="user", parts=tool_call_responses))
tool_call_responses = []
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
if len(tool_call_responses) > 0:
contents.append(ContentType(role="user", parts=tool_call_responses))
if len(contents) == 0:
verbose_logger.warning(
"""
No contents in messages. Contents are required. See
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
If the original request did not comply to OpenAI API requirements it should have failed by now,
but LiteLLM does not check for missing messages.
Setting an empty content to prevent an 400 error.
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
"""
)
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
return contents
except Exception as e:
raise e
# Keys that LiteLLM consumes internally and must never be forwarded to the
_LITELLM_INTERNAL_EXTRA_BODY_KEYS: frozenset = frozenset({"cache", "tags"})
def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
"""Pop extra_body from optional_params and shallow-merge into data, deep-merging dict values."""
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
if extra_body is not None:
data_dict: dict = data # type: ignore[assignment]
for k, v in extra_body.items():
if k in _LITELLM_INTERNAL_EXTRA_BODY_KEYS:
continue
if (
k in data_dict
and isinstance(data_dict[k], dict)
and isinstance(v, dict)
):
data_dict[k].update(v)
else:
data_dict[k] = v
def _transform_request_body( # noqa: PLR0915
messages: List[AllMessageValues],
model: str,
optional_params: dict,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
litellm_params: dict,
cached_content: Optional[str],
) -> RequestBody:
"""
Common transformation logic across sync + async Gemini /generateContent calls.
"""
# Separate system prompt from rest of message
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider=custom_llm_provider
)
system_instructions, messages = _transform_system_message(
supports_system_message=supports_system_message, messages=messages
)
# Checks for 'response_schema' support - if passed in
if "response_schema" in optional_params:
supports_response_schema = get_supports_response_schema(
model=model, custom_llm_provider=custom_llm_provider
)
if supports_response_schema is False:
user_response_schema_message = response_schema_prompt(
model=model, response_schema=optional_params.get("response_schema") # type: ignore
)
messages.append({"role": "user", "content": user_response_schema_message})
optional_params.pop("response_schema")
# Check for any 'litellm_param_*' set during optional param mapping
remove_keys = []
for k, v in optional_params.items():
if k.startswith("litellm_param_"):
litellm_params.update({k: v})
remove_keys.append(k)
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
try:
if custom_llm_provider == "gemini":
content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
messages=messages, model=model
)
else:
content = litellm.VertexGeminiConfig()._transform_messages(
messages=messages, model=model
)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
"safety_settings", None
) # type: ignore
# Drop output_config as it's not supported by Vertex AI
optional_params.pop("output_config", None)
config_fields = GenerationConfig.__annotations__.keys()
# If the LiteLLM client sends Gemini-supported parameter "labels", add it
# as "labels" field to the request sent to the Gemini backend.
labels: Optional[dict[str, str]] = optional_params.pop("labels", None)
# If the LiteLLM client sends OpenAI-supported parameter "metadata", add it
# as "labels" field to the request sent to the Gemini backend.
if labels is None and "metadata" in litellm_params:
metadata = litellm_params["metadata"]
if metadata is not None and "requester_metadata" in metadata:
rm = metadata["requester_metadata"]
labels = {k: v for k, v in rm.items() if isinstance(v, str)}
filtered_params = {
k: v
for k, v in optional_params.items()
if _get_equivalent_key(k, set(config_fields))
}
generation_config: Optional[GenerationConfig] = GenerationConfig(
**filtered_params
)
# For Gemini 2.x models, add media_resolution to generation_config (global)
# Gemini 3+ supports per-part media_resolution, but 2.x only supports global
# Gemini 1.x does not support mediaResolution at all
if "gemini-2" in model:
max_media_resolution = _extract_max_media_resolution_from_messages(messages)
if max_media_resolution:
media_resolution_value = _convert_detail_to_media_resolution_enum(
max_media_resolution
)
if media_resolution_value and generation_config is not None:
generation_config["mediaResolution"] = media_resolution_value[
"level"
]
data = RequestBody(contents=content)
if system_instructions is not None:
data["system_instruction"] = system_instructions
if tools is not None:
data["tools"] = tools
if tool_choice is not None:
data["toolConfig"] = tool_choice
if safety_settings is not None:
data["safetySettings"] = safety_settings
if generation_config is not None and len(generation_config) > 0:
data["generationConfig"] = generation_config
if cached_content is not None:
data["cachedContent"] = cached_content
# Only add labels for Vertex AI endpoints (not Google GenAI/AI Studio) and only if non-empty
if labels and custom_llm_provider != LlmProviders.GEMINI:
data["labels"] = labels
_pop_and_merge_extra_body(data, optional_params)
except Exception as e:
raise e
return data
def sync_transform_request_body(
gemini_api_key: Optional[str],
messages: List[AllMessageValues],
api_base: Optional[str],
model: str,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[dict],
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
litellm_params: dict,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> RequestBody:
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
context_caching_endpoints = ContextCachingEndpoints()
(
messages,
optional_params,
cached_content,
) = context_caching_endpoints.check_and_create_cache(
messages=messages,
optional_params=optional_params,
api_key=gemini_api_key or "dummy",
api_base=api_base,
model=model,
client=client,
timeout=timeout,
extra_headers=extra_headers,
cached_content=optional_params.pop("cached_content", None),
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
return _transform_request_body(
messages=messages,
model=model,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
cached_content=cached_content,
optional_params=optional_params,
)
async def async_transform_request_body(
gemini_api_key: Optional[str],
messages: List[AllMessageValues],
api_base: Optional[str],
model: str,
client: Optional[AsyncHTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[dict],
optional_params: dict,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
litellm_params: dict,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> RequestBody:
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
context_caching_endpoints = ContextCachingEndpoints()
(
messages,
optional_params,
cached_content,
) = await context_caching_endpoints.async_check_and_create_cache(
messages=messages,
optional_params=optional_params,
api_key=gemini_api_key or "dummy",
api_base=api_base,
model=model,
client=client,
timeout=timeout,
extra_headers=extra_headers,
cached_content=optional_params.pop("cached_content", None),
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
return _transform_request_body(
messages=messages,
model=model,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
cached_content=cached_content,
optional_params=optional_params,
)
def _default_user_message_when_system_message_passed() -> ChatCompletionUserMessage:
"""
Returns a default user message when a "system" message is passed in gemini fails.
This adds a blank user message to the messages list, to ensure that gemini doesn't fail the request.
"""
return ChatCompletionUserMessage(content=".", role="user")
def _transform_system_message(
supports_system_message: bool, messages: List[AllMessageValues]
) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:
"""
Extracts the system message from the openai message list.
Converts the system message to Gemini format
Returns
- system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format.
- messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately)
"""
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[PartType] = []
if supports_system_message is True:
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block: Optional[PartType] = None
if isinstance(message["content"], str):
_system_content_block = PartType(text=message["content"])
elif isinstance(message["content"], list):
system_text = ""
for content in message["content"]:
system_text += content.get("text") or ""
_system_content_block = PartType(text=system_text)
if _system_content_block is not None:
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_content_blocks) > 0:
#########################################################
# If no messages are passed in, add a blank user message
# Relevant Issue - https://github.com/BerriAI/litellm/issues/13769
#########################################################
if len(messages) == 0:
messages.append(_default_user_message_when_system_message_passed())
#########################################################
return SystemInstructions(parts=system_content_blocks), messages
return None, messages

View File

@@ -0,0 +1,347 @@
"""
Google AI Studio /batchEmbedContents Embeddings Endpoint
"""
import json
from typing import Any, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from litellm.types.utils import EmbeddingResponse
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .batch_embed_content_transformation import (
_is_file_reference,
_is_multimodal_input,
process_embed_content_response,
process_response,
transform_openai_input_gemini_content,
transform_openai_input_gemini_embed_content,
)
class GoogleBatchEmbeddings(VertexLLM):
def _resolve_file_references(
self,
input: EmbeddingInput,
api_key: str,
sync_handler: HTTPHandler,
) -> Dict[str, Dict[str, str]]:
"""
Resolve Gemini file references (files/...) to get mime_type and uri.
Args:
input: EmbeddingInput that may contain file references
api_key: Gemini API key
sync_handler: HTTP client
Returns:
Dict mapping file name to {mime_type, uri}
"""
input_list = [input] if isinstance(input, str) else input
resolved_files: Dict[str, Dict[str, str]] = {}
for element in input_list:
if isinstance(element, str) and _is_file_reference(element):
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
headers = {"x-goog-api-key": api_key}
response = sync_handler.get(url=url, headers=headers)
if response.status_code != 200:
raise Exception(
f"Error fetching file {element}: {response.status_code} {response.text}"
)
file_data = response.json()
resolved_files[element] = {
"mime_type": file_data.get("mimeType", ""),
"uri": file_data.get("uri", element),
}
return resolved_files
async def _async_resolve_file_references(
self,
input: EmbeddingInput,
api_key: str,
async_handler: AsyncHTTPHandler,
) -> Dict[str, Dict[str, str]]:
"""
Async version of _resolve_file_references.
Args:
input: EmbeddingInput that may contain file references
api_key: Gemini API key
async_handler: Async HTTP client
Returns:
Dict mapping file name to {mime_type, uri}
"""
input_list = [input] if isinstance(input, str) else input
resolved_files: Dict[str, Dict[str, str]] = {}
for element in input_list:
if isinstance(element, str) and _is_file_reference(element):
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
headers = {"x-goog-api-key": api_key}
response = await async_handler.get(url=url, headers=headers)
if response.status_code != 200:
raise Exception(
f"Error fetching file {element}: {response.status_code} {response.text}"
)
file_data = response.json()
resolved_files[element] = {
"mime_type": file_data.get("mimeType", ""),
"uri": file_data.get("uri", element),
}
return resolved_files
def batch_embeddings(
self,
model: str,
input: EmbeddingInput,
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
logging_obj: Any,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding: Optional[bool] = False,
timeout=300,
client=None,
extra_headers: Optional[dict] = None,
) -> EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
is_multimodal = _is_multimodal_input(input)
use_embed_content = is_multimodal or (custom_llm_provider == "vertex_ai")
mode: Literal["embedding", "batch_embedding"]
if use_embed_content:
mode = "embedding"
else:
mode = "batch_embedding"
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode=mode,
)
headers = {
"Content-Type": "application/json; charset=utf-8",
}
if auth_header is not None:
if isinstance(auth_header, dict):
headers.update(auth_header)
else:
headers["Authorization"] = f"Bearer {auth_header}"
if extra_headers is not None:
headers.update(extra_headers)
if aembedding is True:
return self.async_batch_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=None,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
use_embed_content=use_embed_content,
api_key=api_key,
optional_params=optional_params,
logging_obj=logging_obj,
)
### TRANSFORMATION (sync path) ###
request_data: Any
if use_embed_content:
resolved_files = {}
if api_key:
resolved_files = self._resolve_file_references(
input=input, api_key=api_key, sync_handler=sync_handler
)
request_data = transform_openai_input_gemini_embed_content(
input=input,
model=model,
optional_params=optional_params,
resolved_files=resolved_files,
)
else:
request_data = transform_openai_input_gemini_content(
input=input, model=model, optional_params=optional_params
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
if use_embed_content:
return process_embed_content_response(
input=input,
model_response=model_response,
model=model,
response_json=_json_response,
)
else:
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
async def async_batch_embeddings(
self,
model: str,
api_base: Optional[str],
url: str,
data: Optional[Union[VertexAIBatchEmbeddingsRequestBody, dict]],
model_response: EmbeddingResponse,
input: EmbeddingInput,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
use_embed_content: bool = False,
api_key: Optional[str] = None,
optional_params: Optional[dict] = None,
logging_obj: Optional[Any] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler: AsyncHTTPHandler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
async_handler = client # type: ignore
### TRANSFORMATION (async path) ###
if use_embed_content:
resolved_files = {}
if api_key:
resolved_files = await self._async_resolve_file_references(
input=input, api_key=api_key, async_handler=async_handler
)
data = transform_openai_input_gemini_embed_content(
input=input,
model=model,
optional_params=optional_params or {},
resolved_files=resolved_files,
)
else:
data = transform_openai_input_gemini_content(
input=input, model=model, optional_params=optional_params or {}
)
## LOGGING
if logging_obj is not None:
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": url,
"headers": headers,
},
)
response = await async_handler.post(
url=url,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
if use_embed_content:
return process_embed_content_response(
input=input,
model_response=model_response,
model=model,
response_json=_json_response,
)
else:
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)

View File

@@ -0,0 +1,308 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
Why separate file? Make it easy to see how transformation works
"""
from typing import Dict, List, Optional, Tuple
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
BlobType,
ContentType,
EmbedContentRequest,
FileDataType,
PartType,
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
from litellm.utils import get_formatted_prompt, token_counter
SUPPORTED_EMBEDDING_MIME_TYPES = {
"image/png",
"image/jpeg",
"audio/mpeg",
"audio/wav",
"video/mp4",
"video/quicktime",
"application/pdf",
}
def _is_file_reference(s: str) -> bool:
"""Check if string is a Gemini file reference (files/...)."""
return isinstance(s, str) and s.startswith("files/")
def _is_gcs_url(s: str) -> bool:
"""Check if string is a GCS URL (gs://...)."""
return isinstance(s, str) and s.startswith("gs://")
def _infer_mime_type_from_gcs_url(gcs_url: str) -> str:
"""
Infer MIME type from GCS URL file extension.
Args:
gcs_url: GCS URL like gs://bucket/path/to/file.png
Returns:
str: Inferred MIME type
Raises:
ValueError: If file extension is not supported
"""
extension_to_mime = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".mp3": "audio/mpeg",
".wav": "audio/wav",
".mp4": "video/mp4",
".mov": "video/quicktime",
".pdf": "application/pdf",
}
gcs_url_lower = gcs_url.lower()
for ext, mime_type in extension_to_mime.items():
if gcs_url_lower.endswith(ext):
return mime_type
raise ValueError(
f"Unable to infer MIME type from GCS URL: {gcs_url}. "
f"Supported extensions: {', '.join(extension_to_mime.keys())}"
)
def _parse_data_url(data_url: str) -> Tuple[str, str]:
"""
Parse a data URL to extract the media type and base64 data.
Args:
data_url: Data URL in format: data:image/jpeg;base64,/9j/4AAQ...
Returns:
tuple: (media_type, base64_data)
media_type: e.g., "image/jpeg", "video/mp4", "audio/mpeg"
base64_data: The base64-encoded data without the prefix
Raises:
ValueError: If data URL format is invalid or MIME type is unsupported
"""
if not data_url.startswith("data:"):
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
if "," not in data_url:
raise ValueError(f"Invalid data URL format (missing comma): {data_url[:50]}...")
metadata, base64_data = data_url.split(",", 1)
metadata = metadata[5:]
if ";" in metadata:
media_type = metadata.split(";")[0]
else:
media_type = metadata
if media_type not in SUPPORTED_EMBEDDING_MIME_TYPES:
raise ValueError(
f"Unsupported MIME type for embedding: {media_type}. "
f"Supported types: {', '.join(sorted(SUPPORTED_EMBEDDING_MIME_TYPES))}"
)
return media_type, base64_data
def _is_multimodal_input(input: EmbeddingInput) -> bool:
"""
Check if the input contains multimodal data (data URIs, file references, or GCS URLs).
Args:
input: EmbeddingInput (str or List[str])
Returns:
bool: True if any element is a data URI, file reference, or GCS URL
"""
if isinstance(input, str):
input_list = [input]
else:
input_list = input
for element in input_list:
if isinstance(element, str):
if element.startswith("data:") and ";base64," in element:
return True
if _is_file_reference(element):
return True
if _is_gcs_url(element):
return True
return False
def transform_openai_input_gemini_content(
input: EmbeddingInput, model: str, optional_params: dict
) -> VertexAIBatchEmbeddingsRequestBody:
"""
The content to embed. Only the parts.text fields will be counted.
"""
gemini_model_name = "models/{}".format(model)
gemini_params = optional_params.copy()
if "dimensions" in gemini_params:
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
requests: List[EmbedContentRequest] = []
if isinstance(input, str):
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=input)]),
**gemini_params,
)
requests.append(request)
else:
for i in input:
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=i)]),
**gemini_params,
)
requests.append(request)
return VertexAIBatchEmbeddingsRequestBody(requests=requests)
def transform_openai_input_gemini_embed_content(
input: EmbeddingInput,
model: str,
optional_params: dict,
resolved_files: Optional[Dict[str, Dict[str, str]]] = None,
) -> dict:
"""
Transform OpenAI embedding input to Gemini embedContent format (multimodal).
Args:
input: EmbeddingInput (str or List[str]) with text, data URIs, or file references
model: Model name
optional_params: Additional parameters (taskType, outputDimensionality, etc.)
resolved_files: Dict mapping file names (files/abc) to {mime_type, uri}
Returns:
dict: Gemini embedContent request body with content.parts
"""
resolved_files = resolved_files or {}
gemini_params = optional_params.copy()
if "dimensions" in gemini_params:
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
input_list = [input] if isinstance(input, str) else input
parts: List[PartType] = []
for element in input_list:
if not isinstance(element, str):
raise ValueError(f"Unsupported input type: {type(element)}")
if element.startswith("data:") and ";base64," in element:
mime_type, base64_data = _parse_data_url(element)
blob: BlobType = {"mime_type": mime_type, "data": base64_data}
parts.append(PartType(inline_data=blob))
elif _is_gcs_url(element):
mime_type = _infer_mime_type_from_gcs_url(element)
file_data: FileDataType = {
"mime_type": mime_type,
"file_uri": element,
}
parts.append(PartType(file_data=file_data))
elif _is_file_reference(element):
if element not in resolved_files:
raise ValueError(f"File reference {element} not resolved")
file_info = resolved_files[element]
file_data_ref: FileDataType = {
"mime_type": file_info["mime_type"],
"file_uri": file_info["uri"],
}
parts.append(PartType(file_data=file_data_ref))
else:
parts.append(PartType(text=element))
request_body: dict = {
"content": ContentType(parts=parts),
**gemini_params,
}
return request_body
def process_embed_content_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
response_json: dict,
) -> EmbeddingResponse:
"""
Process Gemini embedContent response (single embedding for multimodal input).
Args:
input: Original input
model_response: EmbeddingResponse to populate
model: Model name
response_json: Raw JSON response from embedContent endpoint
Returns:
EmbeddingResponse with single embedding
"""
if "embedding" not in response_json:
raise ValueError(
f"embedContent response missing 'embedding' field: {response_json}"
)
embedding_data = response_json["embedding"]
openai_embedding = Embedding(
embedding=embedding_data["values"],
index=0,
object="embedding",
)
model_response.data = [openai_embedding]
model_response.model = model
if _is_multimodal_input(input):
prompt_tokens = 0
else:
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response
def process_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
_predictions: VertexAIBatchEmbeddingsResponseObject,
) -> EmbeddingResponse:
openai_embeddings: List[Embedding] = []
for embedding in _predictions["embeddings"]:
openai_embedding = Embedding(
embedding=embedding["values"],
index=0,
object="embedding",
)
openai_embeddings.append(openai_embedding)
model_response.data = openai_embeddings
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response

View File

@@ -0,0 +1,100 @@
"""
Transformation for Calling Google models in their native format.
"""
from typing import Any, Dict, Literal, Optional, Union
from litellm.llms.gemini.google_genai.transformation import GoogleGenAIConfig
from litellm.types.router import GenericLiteLLMParams
class VertexAIGoogleGenAIConfig(GoogleGenAIConfig):
"""
Configuration for calling Google models in their native format.
"""
HEADER_NAME = "Authorization"
BEARER_PREFIX = "Bearer"
@property
def custom_llm_provider(self) -> Literal["gemini", "vertex_ai"]:
return "vertex_ai"
def validate_environment(
self,
api_key: Optional[str],
headers: Optional[dict],
model: str,
litellm_params: Optional[Union[GenericLiteLLMParams, dict]],
) -> dict:
default_headers = {
"Content-Type": "application/json",
}
if api_key is not None:
default_headers[self.HEADER_NAME] = f"{self.BEARER_PREFIX} {api_key}"
if headers is not None:
default_headers.update(headers)
return default_headers
def _camel_to_snake(self, camel_str: str) -> str:
"""Convert camelCase to snake_case"""
import re
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
def map_generate_content_optional_params(
self,
generate_content_config_dict,
model: str,
):
"""
Map Google GenAI parameters to provider-specific format.
Args:
generate_content_optional_params: Optional parameters for generate content
model: The model name
Returns:
Mapped parameters for the provider
"""
_generate_content_config_dict: Dict = {}
for param, value in generate_content_config_dict.items():
camel_case_key = self._camel_to_snake(param)
_generate_content_config_dict[camel_case_key] = value
return _generate_content_config_dict
def transform_generate_content_request(
self,
model: str,
contents: Any,
tools: Optional[Any],
generate_content_config_dict: Dict,
system_instruction: Optional[Any] = None,
) -> dict:
"""
Transform the generate content request for Vertex AI.
Since Vertex AI natively supports Google GenAI format, we can pass most fields directly.
"""
# Build the request in Google GenAI format that Vertex AI expects
result = {
"model": model,
"contents": contents,
}
# Add tools if provided
if tools:
result["tools"] = tools
# Add systemInstruction if provided
if system_instruction:
result["systemInstruction"] = system_instruction
# Handle generationConfig - Vertex AI expects it in the same format
if generate_content_config_dict:
result["generationConfig"] = generate_content_config_dict
return result

View File

@@ -0,0 +1,42 @@
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.llms.vertex_ai.common_utils import (
VertexAIModelRoute,
get_vertex_ai_model_route,
)
from .cost_calculator import cost_calculator
from .vertex_gemini_transformation import VertexAIGeminiImageEditConfig
from .vertex_imagen_transformation import VertexAIImagenImageEditConfig
__all__ = [
"VertexAIGeminiImageEditConfig",
"VertexAIImagenImageEditConfig",
"get_vertex_ai_image_edit_config",
"cost_calculator",
]
def get_vertex_ai_image_edit_config(model: str) -> BaseImageEditConfig:
"""
Get the appropriate image edit config for a Vertex AI model.
Routes to the correct transformation class based on the model type:
- Gemini models use generateContent API (VertexAIGeminiImageEditConfig)
- Imagen models use predict API (VertexAIImagenImageEditConfig)
Args:
model: The model name (e.g., "gemini-2.5-flash", "imagegeneration@006")
Returns:
BaseImageEditConfig: The appropriate configuration class
"""
# Determine the model route
model_route = get_vertex_ai_model_route(model)
if model_route == VertexAIModelRoute.GEMINI:
# Gemini models use generateContent API
return VertexAIGeminiImageEditConfig()
else:
# Default to Imagen for other models (imagegeneration, etc.)
# This includes NON_GEMINI models like imagegeneration@006
return VertexAIImagenImageEditConfig()

View File

@@ -0,0 +1,34 @@
"""
Vertex AI Image Edit Cost Calculator
"""
from typing import Any
import litellm
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: Any,
) -> float:
"""
Vertex AI image edit cost calculator.
Mirrors image generation pricing: charge per returned image based on
model metadata (`output_cost_per_image`).
"""
model_info = litellm.get_model_info(
model=model,
custom_llm_provider="vertex_ai",
)
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
if not isinstance(image_response, ImageResponse):
raise ValueError(
f"image_response must be of type ImageResponse got type={type(image_response)}"
)
num_images = len(image_response.data or [])
return output_cost_per_image * num_images

View File

@@ -0,0 +1,298 @@
import base64
import json
import os
from io import BufferedReader, BytesIO
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx
from httpx._types import RequestFiles
import litellm
from litellm.images.utils import ImageEditRequestUtils
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexAIGeminiImageEditConfig(BaseImageEditConfig, VertexLLM):
"""
Vertex AI Gemini Image Edit Configuration
Uses generateContent API for Gemini models on Vertex AI
"""
SUPPORTED_PARAMS: List[str] = ["size"]
def __init__(self) -> None:
BaseImageEditConfig.__init__(self)
VertexLLM.__init__(self)
def get_supported_openai_params(self, model: str) -> List[str]:
return list(self.SUPPORTED_PARAMS)
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict[str, Any]:
supported_params = self.get_supported_openai_params(model)
filtered_params = {
key: value
for key, value in image_edit_optional_params.items()
if key in supported_params
}
mapped_params: Dict[str, Any] = {}
if "size" in filtered_params:
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
filtered_params["size"] # type: ignore[arg-type]
)
return mapped_params
def _resolve_vertex_project(self) -> Optional[str]:
return (
getattr(self, "_vertex_project", None)
or os.environ.get("VERTEXAI_PROJECT")
or getattr(litellm, "vertex_project", None)
or get_secret_str("VERTEXAI_PROJECT")
)
def _resolve_vertex_location(self) -> Optional[str]:
return (
getattr(self, "_vertex_location", None)
or os.environ.get("VERTEXAI_LOCATION")
or os.environ.get("VERTEX_LOCATION")
or getattr(litellm, "vertex_location", None)
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
def _resolve_vertex_credentials(self) -> Optional[str]:
return (
getattr(self, "_vertex_credentials", None)
or os.environ.get("VERTEXAI_CREDENTIALS")
or getattr(litellm, "vertex_credentials", None)
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
litellm_params: Optional[dict] = None,
api_base: Optional[str] = None,
) -> dict:
headers = headers or {}
litellm_params = litellm_params or {}
# If a custom api_base is provided, skip credential validation
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
_api_base = litellm_params.get("api_base") or api_base
if _api_base is not None:
return headers
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_credentials = (
self.safe_get_vertex_ai_credentials(litellm_params)
or self._resolve_vertex_credentials()
)
access_token, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return self.set_headers(access_token, headers)
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for Vertex AI Gemini generateContent API
"""
# Use the model name as provided, handling vertex_ai prefix
model_name = model
if model.startswith("vertex_ai/"):
model_name = model.replace("vertex_ai/", "")
# If a custom api_base is provided, use it directly
# This allows users to use proxies or mock endpoints
if api_base:
return api_base.rstrip("/")
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_location = (
self.safe_get_vertex_ai_location(litellm_params)
or self._resolve_vertex_location()
)
if not vertex_project or not vertex_location:
raise ValueError(
"vertex_project and vertex_location are required for Vertex AI"
)
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
def transform_image_edit_request( # type: ignore[override]
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict[str, Any],
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
inline_parts = self._prepare_inline_image_parts(image) if image else []
if not inline_parts:
raise ValueError("Vertex AI Gemini image edit requires at least one image.")
# Build parts list with image and prompt (if provided)
parts = inline_parts.copy()
if prompt is not None and prompt != "":
parts.append({"text": prompt})
# Correct format for Vertex AI Gemini image editing
contents = {"role": "USER", "parts": parts}
request_body: Dict[str, Any] = {"contents": contents}
# Generation config with proper structure for image editing
generation_config: Dict[str, Any] = {"response_modalities": ["IMAGE"]}
# Add image-specific configuration
image_config: Dict[str, Any] = {}
if "aspectRatio" in image_edit_optional_request_params:
image_config["aspect_ratio"] = image_edit_optional_request_params[
"aspectRatio"
]
if image_config:
generation_config["image_config"] = image_config
request_body["generationConfig"] = generation_config
payload: Any = json.dumps(request_body)
empty_files = cast(RequestFiles, [])
return cast(
Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files)
)
def transform_image_edit_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: Any,
) -> ImageResponse:
model_response = ImageResponse()
try:
response_json = raw_response.json()
except Exception as exc:
raise self.get_error_class(
error_message=f"Error transforming image edit response: {exc}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
candidates = response_json.get("candidates", [])
data_list: List[ImageObject] = []
for candidate in candidates:
content = candidate.get("content", {})
parts = content.get("parts", [])
for part in parts:
inline_data = part.get("inlineData")
if inline_data and inline_data.get("data"):
data_list.append(
ImageObject(
b64_json=inline_data["data"],
url=None,
)
)
model_response.data = cast(List[OpenAIImage], data_list)
return model_response
def _map_size_to_aspect_ratio(self, size: str) -> str:
"""Map OpenAI size format to Gemini aspect ratio format"""
aspect_ratio_map = {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1280x896": "4:3",
"896x1280": "3:4",
}
return aspect_ratio_map.get(size, "1:1")
def _prepare_inline_image_parts(
self, image: Union[FileTypes, List[FileTypes]]
) -> List[Dict[str, Any]]:
images: List[FileTypes]
if isinstance(image, list):
images = image
else:
images = [image]
inline_parts: List[Dict[str, Any]] = []
for img in images:
if img is None:
continue
mime_type = ImageEditRequestUtils.get_image_content_type(img)
image_bytes = self._read_all_bytes(img)
inline_parts.append(
{
"inlineData": {
"mimeType": mime_type,
"data": base64.b64encode(image_bytes).decode("utf-8"),
}
}
)
return inline_parts
def _read_all_bytes(self, image: FileTypes) -> bytes:
if isinstance(image, bytes):
return image
if isinstance(image, BytesIO):
current_pos = image.tell()
image.seek(0)
data = image.read()
image.seek(current_pos)
return data
if isinstance(image, BufferedReader):
current_pos = image.tell()
image.seek(0)
data = image.read()
image.seek(current_pos)
return data
raise ValueError("Unsupported image type for Vertex AI Gemini image edit.")

View File

@@ -0,0 +1,365 @@
import base64
import json
import os
from io import BufferedRandom, BufferedReader, BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx
from httpx._types import RequestFiles
import litellm
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexAIImagenImageEditConfig(BaseImageEditConfig, VertexLLM):
"""
Vertex AI Imagen Image Edit Configuration
Uses predict API for Imagen models on Vertex AI
"""
SUPPORTED_PARAMS: List[str] = ["n", "size", "mask"]
def __init__(self) -> None:
BaseImageEditConfig.__init__(self)
VertexLLM.__init__(self)
def get_supported_openai_params(self, model: str) -> List[str]:
return list(self.SUPPORTED_PARAMS)
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict[str, Any]:
supported_params = self.get_supported_openai_params(model)
filtered_params = {
key: value
for key, value in image_edit_optional_params.items()
if key in supported_params
}
mapped_params: Dict[str, Any] = {}
# Map OpenAI parameters to Imagen format
if "n" in filtered_params:
mapped_params["sampleCount"] = filtered_params["n"]
if "size" in filtered_params:
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
filtered_params["size"] # type: ignore[arg-type]
)
if "mask" in filtered_params:
mapped_params["mask"] = filtered_params["mask"]
return mapped_params
def _resolve_vertex_project(self) -> Optional[str]:
return (
getattr(self, "_vertex_project", None)
or os.environ.get("VERTEXAI_PROJECT")
or getattr(litellm, "vertex_project", None)
or get_secret_str("VERTEXAI_PROJECT")
)
def _resolve_vertex_location(self) -> Optional[str]:
return (
getattr(self, "_vertex_location", None)
or os.environ.get("VERTEXAI_LOCATION")
or os.environ.get("VERTEX_LOCATION")
or getattr(litellm, "vertex_location", None)
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
def _resolve_vertex_credentials(self) -> Optional[str]:
return (
getattr(self, "_vertex_credentials", None)
or os.environ.get("VERTEXAI_CREDENTIALS")
or getattr(litellm, "vertex_credentials", None)
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
headers = headers or {}
vertex_project = self._resolve_vertex_project()
vertex_credentials = self._resolve_vertex_credentials()
access_token, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return self.set_headers(access_token, headers)
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for Vertex AI Imagen predict API
"""
vertex_project = self._resolve_vertex_project()
vertex_location = self._resolve_vertex_location()
if not vertex_project or not vertex_location:
raise ValueError(
"vertex_project and vertex_location are required for Vertex AI"
)
# Use the model name as provided, handling vertex_ai prefix
model_name = model
if model.startswith("vertex_ai/"):
model_name = model.replace("vertex_ai/", "")
if api_base:
base_url = api_base.rstrip("/")
else:
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
def transform_image_edit_request( # type: ignore[override]
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict[str, Any],
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
# Prepare reference images in the correct Imagen format
if image is None:
raise ValueError(
"Vertex AI Imagen image edit requires at least one reference image."
)
reference_images = self._prepare_reference_images(
image, image_edit_optional_request_params
)
if not reference_images:
raise ValueError(
"Vertex AI Imagen image edit requires at least one reference image."
)
if prompt is None:
raise ValueError("Vertex AI Imagen image edit requires a prompt.")
# Correct Imagen instances format
instances = [{"prompt": prompt, "referenceImages": reference_images}]
# Extract OpenAI parameters and set sensible defaults for Vertex AI-specific parameters
sample_count = image_edit_optional_request_params.get("sampleCount", 1)
# Use sensible defaults for Vertex AI-specific parameters (not exposed to users)
edit_mode = "EDIT_MODE_INPAINT_INSERTION" # Default edit mode
base_steps = 50 # Default number of steps
# Imagen parameters with correct structure
parameters = {
"sampleCount": sample_count,
"editMode": edit_mode,
"editConfig": {"baseSteps": base_steps},
}
# Set default values for Vertex AI-specific parameters (not configurable by users via OpenAI API)
parameters["guidanceScale"] = 7.5 # Default guidance scale
parameters["seed"] = None # Let Vertex AI choose random seed
request_body: Dict[str, Any] = {
"instances": instances,
"parameters": parameters,
}
payload: Any = json.dumps(request_body)
empty_files = cast(RequestFiles, [])
return cast(
Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files)
)
def transform_image_edit_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: Any,
) -> ImageResponse:
model_response = ImageResponse()
try:
response_json = raw_response.json()
except Exception as exc:
raise self.get_error_class(
error_message=f"Error transforming image edit response: {exc}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
predictions = response_json.get("predictions", [])
data_list: List[ImageObject] = []
for prediction in predictions:
# Imagen returns images as bytesBase64Encoded
if "bytesBase64Encoded" in prediction:
data_list.append(
ImageObject(
b64_json=prediction["bytesBase64Encoded"],
url=None,
)
)
model_response.data = cast(List[OpenAIImage], data_list)
return model_response
def _map_size_to_aspect_ratio(self, size: str) -> str:
"""Map OpenAI size format to Imagen aspect ratio format"""
aspect_ratio_map = {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1280x896": "4:3",
"896x1280": "3:4",
}
return aspect_ratio_map.get(size, "1:1")
def _prepare_reference_images(
self,
image: Union[FileTypes, List[FileTypes]],
image_edit_optional_request_params: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""
Prepare reference images in the correct Imagen API format
"""
images: List[FileTypes]
if isinstance(image, list):
images = image
else:
images = [image]
reference_images: List[Dict[str, Any]] = []
for idx, img in enumerate(images):
if img is None:
continue
image_bytes = self._read_all_bytes(img)
base64_data = base64.b64encode(image_bytes).decode("utf-8")
# Create reference image structure
reference_image = {
"referenceType": "REFERENCE_TYPE_RAW",
"referenceId": idx + 1,
"referenceImage": {"bytesBase64Encoded": base64_data},
}
reference_images.append(reference_image)
# Handle mask image if provided (for inpainting)
mask_image = image_edit_optional_request_params.get("mask")
if mask_image is not None:
mask_bytes = self._read_all_bytes(mask_image)
mask_base64 = base64.b64encode(mask_bytes).decode("utf-8")
mask_reference = {
"referenceType": "REFERENCE_TYPE_MASK",
"referenceId": len(reference_images) + 1,
"referenceImage": {"bytesBase64Encoded": mask_base64},
"maskImageConfig": {
"maskMode": "MASK_MODE_USER_PROVIDED",
"dilation": 0.03, # Default dilation value (not configurable via OpenAI API)
},
}
reference_images.append(mask_reference)
return reference_images
def _read_all_bytes(
self, image: Any, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> bytes:
if depth > max_depth:
raise ValueError(
f"Max recursion depth {max_depth} reached while reading image bytes for Vertex AI Imagen image edit."
)
if isinstance(image, (list, tuple)):
for item in image:
if item is not None:
return self._read_all_bytes(
item, depth=depth + 1, max_depth=max_depth
)
raise ValueError("Unsupported image type for Vertex AI Imagen image edit.")
if isinstance(image, dict):
for key in ("data", "bytes", "content"):
if key in image and image[key] is not None:
value = image[key]
if isinstance(value, str):
try:
return base64.b64decode(value)
except Exception:
continue
return self._read_all_bytes(
value, depth=depth + 1, max_depth=max_depth
)
if "path" in image:
return self._read_all_bytes(
image["path"], depth=depth + 1, max_depth=max_depth
)
if isinstance(image, bytes):
return image
if isinstance(image, bytearray):
return bytes(image)
if isinstance(image, BytesIO):
current_pos = image.tell()
image.seek(0)
data = image.read()
image.seek(current_pos)
return data
if isinstance(image, (BufferedReader, BufferedRandom)):
stream_pos: Optional[int] = None
try:
stream_pos = image.tell()
except Exception:
stream_pos = None
if stream_pos is not None:
image.seek(0)
data = image.read()
if stream_pos is not None:
image.seek(stream_pos)
return data
if isinstance(image, (str, Path)):
path_obj = Path(image)
if not path_obj.exists():
raise ValueError(
f"Mask/image path does not exist for Vertex AI Imagen image edit: {path_obj}"
)
return path_obj.read_bytes()
if hasattr(image, "read"):
data = image.read()
if isinstance(data, str):
data = data.encode("utf-8")
return data
raise ValueError(
f"Unsupported image type for Vertex AI Imagen image edit. Got type={type(image)}"
)

View File

@@ -0,0 +1,42 @@
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.llms.vertex_ai.common_utils import (
VertexAIModelRoute,
get_vertex_ai_model_route,
)
from .vertex_gemini_transformation import VertexAIGeminiImageGenerationConfig
from .vertex_imagen_transformation import VertexAIImagenImageGenerationConfig
__all__ = [
"VertexAIGeminiImageGenerationConfig",
"VertexAIImagenImageGenerationConfig",
"get_vertex_ai_image_generation_config",
]
def get_vertex_ai_image_generation_config(model: str) -> BaseImageGenerationConfig:
"""
Get the appropriate image generation config for a Vertex AI model.
Routes to the correct transformation class based on the model type:
- Gemini image generation models use generateContent API (VertexAIGeminiImageGenerationConfig)
- Imagen models use predict API (VertexAIImagenImageGenerationConfig)
Args:
model: The model name (e.g., "gemini-2.5-flash-image", "imagegeneration@006")
Returns:
BaseImageGenerationConfig: The appropriate configuration class
"""
# Determine the model route
model_route = get_vertex_ai_model_route(model)
if model_route == VertexAIModelRoute.GEMINI:
# Gemini models use generateContent API
return VertexAIGeminiImageGenerationConfig()
else:
# Default to Imagen for other models (imagegeneration, etc.)
# This includes NON_GEMINI models like imagegeneration@006
return VertexAIImagenImageGenerationConfig()

View File

@@ -0,0 +1,36 @@
"""
Vertex AI Image Generation Cost Calculator
"""
import litellm
from litellm.litellm_core_utils.llm_cost_calc.utils import (
calculate_image_response_cost_from_usage,
)
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: ImageResponse,
) -> float:
"""
Vertex AI Image Generation Cost Calculator
"""
_model_info = litellm.get_model_info(
model=model,
custom_llm_provider="vertex_ai",
)
token_based_cost = calculate_image_response_cost_from_usage(
model=model,
image_response=image_response,
custom_llm_provider="vertex_ai",
)
if token_based_cost is not None:
return token_based_cost
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
num_images: int = 0
if image_response.data:
num_images = len(image_response.data)
return output_cost_per_image * num_images

View File

@@ -0,0 +1,282 @@
import json
from typing import Any, Dict, List, Optional
import httpx
from openai.types.image import Image
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from litellm.types.utils import ImageResponse
class VertexImageGeneration(VertexLLM):
def process_image_generation_response(
self,
json_response: Dict[str, Any],
model_response: ImageResponse,
model: Optional[str] = None,
) -> ImageResponse:
if "predictions" not in json_response:
raise litellm.InternalServerError(
message=f"image generation response does not contain 'predictions', got {json_response}",
llm_provider="vertex_ai",
model=model,
)
predictions = json_response["predictions"]
response_data: List[Image] = []
for prediction in predictions:
bytes_base64_encoded = prediction["bytesBase64Encoded"]
image_object = Image(b64_json=bytes_base64_encoded)
response_data.append(image_object)
model_response.data = response_data
return model_response
def transform_optional_params(self, optional_params: Optional[dict]) -> dict:
"""
Transform the optional params to the format expected by the Vertex AI API.
For example, "aspect_ratio" is transformed to "aspectRatio".
"""
default_params = {
"sampleCount": 1,
}
if optional_params is None:
return default_params
def snake_to_camel(snake_str: str) -> str:
"""Convert snake_case to camelCase"""
components = snake_str.split("_")
return components[0] + "".join(word.capitalize() for word in components[1:])
transformed_params = default_params.copy()
for key, value in optional_params.items():
if "_" in key:
camel_case_key = snake_to_camel(key)
transformed_params[camel_case_key] = value
else:
transformed_params[key] = value
return transformed_params
def image_generation(
self,
prompt: str,
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: ImageResponse,
logging_obj: Any,
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
aimg_generation=False,
extra_headers: Optional[dict] = None,
) -> ImageResponse:
if aimg_generation is True:
return self.aimage_generation( # type: ignore
prompt=prompt,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
model=model,
client=client,
optional_params=optional_params,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
# Transform optional params to camelCase format
optional_params = self.transform_optional_params(optional_params)
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
"headers": headers,
},
)
response = sync_handler.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
json_response = response.json()
return self.process_image_generation_response(
json_response, model_response, model
)
async def aimage_generation(
self,
prompt: str,
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: ImageResponse,
logging_obj: Any,
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
extra_headers: Optional[dict] = None,
):
response = None
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
self.async_handler = client # type: ignore
# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d {
"instances": [
{
"prompt": "a cat"
}
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
# Transform optional params to camelCase format
optional_params = self.transform_optional_params(optional_params)
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
"headers": headers,
},
)
response = await self.async_handler.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
json_response = response.json()
return self.process_image_generation_response(
json_response, model_response, model
)
def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
if "predictions" in json_response:
if "bytesBase64Encoded" in json_response["predictions"][0]:
return True
return False

View File

@@ -0,0 +1,327 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import httpx
import litellm
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ImageObject,
ImageResponse,
ImageUsage,
ImageUsageInputTokensDetails,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexAIGeminiImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
"""
Vertex AI Gemini Image Generation Configuration
Uses generateContent API for Gemini image generation models on Vertex AI
Supports models like gemini-2.5-flash-image, gemini-3-pro-image-preview, etc.
"""
def __init__(self) -> None:
BaseImageGenerationConfig.__init__(self)
VertexLLM.__init__(self)
def get_supported_openai_params(self, model: str) -> list:
"""
Gemini image generation supported parameters
Includes native Gemini imageConfig params (aspectRatio, imageSize)
in both camelCase and snake_case variants.
"""
return [
"n",
"size",
"aspectRatio",
"aspect_ratio",
"imageSize",
"image_size",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
mapped_params = {}
for k, v in non_default_params.items():
if k not in optional_params.keys():
if k in supported_params:
# Map OpenAI parameters to Gemini format
if k == "n":
mapped_params["candidate_count"] = v
elif k == "size":
# Map OpenAI size format to Gemini aspectRatio
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
elif k in ("aspectRatio", "aspect_ratio"):
mapped_params["aspectRatio"] = v
elif k in ("imageSize", "image_size"):
mapped_params["imageSize"] = v
else:
mapped_params[k] = v
return mapped_params
def _map_size_to_aspect_ratio(self, size: str) -> str:
"""
Map OpenAI size format to Gemini aspect ratio format
"""
aspect_ratio_map = {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1280x896": "4:3",
"896x1280": "3:4",
}
return aspect_ratio_map.get(size, "1:1")
def _resolve_vertex_project(self) -> Optional[str]:
return (
getattr(self, "_vertex_project", None)
or os.environ.get("VERTEXAI_PROJECT")
or getattr(litellm, "vertex_project", None)
or get_secret_str("VERTEXAI_PROJECT")
)
def _resolve_vertex_location(self) -> Optional[str]:
return (
getattr(self, "_vertex_location", None)
or os.environ.get("VERTEXAI_LOCATION")
or os.environ.get("VERTEX_LOCATION")
or getattr(litellm, "vertex_location", None)
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
def _resolve_vertex_credentials(self) -> Optional[str]:
return (
getattr(self, "_vertex_credentials", None)
or os.environ.get("VERTEXAI_CREDENTIALS")
or getattr(litellm, "vertex_credentials", None)
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Vertex AI Gemini generateContent API
"""
# Use the model name as provided, handling vertex_ai prefix
model_name = model
if model.startswith("vertex_ai/"):
model_name = model.replace("vertex_ai/", "")
# If a custom api_base is provided, use it directly
# This allows users to use proxies or mock endpoints
if api_base:
return api_base.rstrip("/")
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_location = (
self.safe_get_vertex_ai_location(litellm_params)
or self._resolve_vertex_location()
)
if not vertex_project or not vertex_location:
raise ValueError(
"vertex_project and vertex_location are required for Vertex AI"
)
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = headers or {}
# If a custom api_base is provided, skip credential validation
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
_api_base = litellm_params.get("api_base") or api_base
if _api_base is not None:
return headers
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_credentials = (
self.safe_get_vertex_ai_credentials(litellm_params)
or self._resolve_vertex_credentials()
)
access_token, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return self.set_headers(access_token, headers)
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the image generation request to Gemini format
Uses generateContent API with responseModalities: ["IMAGE"]
"""
# Prepare messages with the prompt
contents = [{"role": "user", "parts": [{"text": prompt}]}]
# Prepare generation config
generation_config: Dict[str, Any] = {"responseModalities": ["IMAGE"]}
# Handle image-specific config parameters
image_config: Dict[str, Any] = {}
# Map aspectRatio
if "aspectRatio" in optional_params:
image_config["aspectRatio"] = optional_params["aspectRatio"]
elif "aspect_ratio" in optional_params:
image_config["aspectRatio"] = optional_params["aspect_ratio"]
# Map imageSize (for Gemini 3 Pro)
if "imageSize" in optional_params:
image_config["imageSize"] = optional_params["imageSize"]
elif "image_size" in optional_params:
image_config["imageSize"] = optional_params["image_size"]
if image_config:
generation_config["imageConfig"] = image_config
# Handle candidate_count (n parameter)
if "candidate_count" in optional_params:
generation_config["candidateCount"] = optional_params["candidate_count"]
elif "n" in optional_params:
generation_config["candidateCount"] = optional_params["n"]
request_body: Dict[str, Any] = {
"contents": contents,
"generationConfig": generation_config,
}
return request_body
def _transform_image_usage(self, usage: dict) -> ImageUsage:
input_tokens_details = ImageUsageInputTokensDetails(
image_tokens=0,
text_tokens=0,
)
tokens_details = usage.get("promptTokensDetails", [])
for details in tokens_details:
if isinstance(details, dict) and (modality := details.get("modality")):
token_count = details.get("tokenCount", 0)
if modality == "TEXT":
input_tokens_details.text_tokens += token_count
elif modality == "IMAGE":
input_tokens_details.image_tokens += token_count
return ImageUsage(
input_tokens=usage.get("promptTokenCount", 0),
input_tokens_details=input_tokens_details,
output_tokens=usage.get("candidatesTokenCount", 0),
total_tokens=usage.get("totalTokenCount", 0),
)
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
"""
Transform Gemini image generation response to litellm ImageResponse format
"""
try:
response_data = raw_response.json()
except Exception as e:
raise self.get_error_class(
error_message=f"Error transforming image generation response: {e}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
if not model_response.data:
model_response.data = []
# Gemini image generation models return in candidates format
candidates = response_data.get("candidates", [])
for candidate in candidates:
content = candidate.get("content", {})
parts = content.get("parts", [])
for part in parts:
# Look for inlineData with image
if "inlineData" in part:
inline_data = part["inlineData"]
if "data" in inline_data:
thought_sig = part.get("thoughtSignature")
model_response.data.append(
ImageObject(
b64_json=inline_data["data"],
url=None,
provider_specific_fields={
"thought_signature": thought_sig
}
if thought_sig
else None,
)
)
if usage_metadata := response_data.get("usageMetadata", None):
model_response.usage = self._transform_image_usage(usage_metadata)
return model_response

View File

@@ -0,0 +1,256 @@
import os
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
import litellm
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageGenerationOptionalParams,
)
from litellm.types.utils import ImageObject, ImageResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexAIImagenImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
"""
Vertex AI Imagen Image Generation Configuration
Uses predict API for Imagen models on Vertex AI
Supports models like imagegeneration@006
"""
def __init__(self) -> None:
BaseImageGenerationConfig.__init__(self)
VertexLLM.__init__(self)
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
"""
Imagen API supported parameters
"""
return ["n", "size"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
mapped_params = {}
for k, v in non_default_params.items():
if k not in optional_params.keys():
if k in supported_params:
# Map OpenAI parameters to Imagen format
if k == "n":
mapped_params["sampleCount"] = v
elif k == "size":
# Map OpenAI size format to Imagen aspectRatio
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
else:
mapped_params[k] = v
return mapped_params
def _map_size_to_aspect_ratio(self, size: str) -> str:
"""
Map OpenAI size format to Imagen aspect ratio format
"""
aspect_ratio_map = {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1280x896": "4:3",
"896x1280": "3:4",
}
return aspect_ratio_map.get(size, "1:1")
def _resolve_vertex_project(self) -> Optional[str]:
return (
getattr(self, "_vertex_project", None)
or os.environ.get("VERTEXAI_PROJECT")
or getattr(litellm, "vertex_project", None)
or get_secret_str("VERTEXAI_PROJECT")
)
def _resolve_vertex_location(self) -> Optional[str]:
return (
getattr(self, "_vertex_location", None)
or os.environ.get("VERTEXAI_LOCATION")
or os.environ.get("VERTEX_LOCATION")
or getattr(litellm, "vertex_location", None)
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
def _resolve_vertex_credentials(self) -> Optional[str]:
return (
getattr(self, "_vertex_credentials", None)
or os.environ.get("VERTEXAI_CREDENTIALS")
or getattr(litellm, "vertex_credentials", None)
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Vertex AI Imagen predict API
"""
# Use the model name as provided, handling vertex_ai prefix
model_name = model
if model.startswith("vertex_ai/"):
model_name = model.replace("vertex_ai/", "")
# If a custom api_base is provided, use it directly
# This allows users to use proxies or mock endpoints
if api_base:
return api_base.rstrip("/")
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_location = (
self.safe_get_vertex_ai_location(litellm_params)
or self._resolve_vertex_location()
)
if not vertex_project or not vertex_location:
raise ValueError(
"vertex_project and vertex_location are required for Vertex AI"
)
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = headers or {}
# If a custom api_base is provided, skip credential validation
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
_api_base = litellm_params.get("api_base") or api_base
if _api_base is not None:
return headers
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_credentials = (
self.safe_get_vertex_ai_credentials(litellm_params)
or self._resolve_vertex_credentials()
)
access_token, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return self.set_headers(access_token, headers)
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the image generation request to Imagen format
Uses predict API with instances and parameters
"""
# Default parameters
default_params = {
"sampleCount": 1,
}
# Merge with optional params
parameters = {**default_params, **optional_params}
request_body = {
"instances": [{"prompt": prompt}],
"parameters": parameters,
}
return request_body
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
"""
Transform Imagen image generation response to litellm ImageResponse format
"""
try:
response_data = raw_response.json()
except Exception as e:
raise self.get_error_class(
error_message=f"Error transforming image generation response: {e}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
if not model_response.data:
model_response.data = []
# Imagen format - predictions with generated images
predictions = response_data.get("predictions", [])
for prediction in predictions:
# Imagen returns images as bytesBase64Encoded
if "bytesBase64Encoded" in prediction:
model_response.data.append(
ImageObject(
b64_json=prediction["bytesBase64Encoded"],
url=None,
)
)
return model_response

View File

@@ -0,0 +1,188 @@
import json
from typing import Literal, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
VertexLLM,
)
from litellm.types.utils import EmbeddingResponse
from .transformation import VertexAIMultimodalEmbeddingConfig
vertex_multimodal_embedding_handler = VertexAIMultimodalEmbeddingConfig()
class VertexMultimodalEmbedding(VertexLLM):
def __init__(self) -> None:
super().__init__()
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def multimodal_embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
headers: dict = {},
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding: Optional[bool] = False,
timeout=300,
client=None,
) -> EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
request_data = vertex_multimodal_embedding_handler.transform_embedding_request(
model, input, optional_params, headers
)
headers = vertex_multimodal_embedding_handler.validate_environment(
headers=headers,
model=model,
messages=[],
optional_params=optional_params,
api_key=auth_header,
api_base=api_base,
litellm_params=litellm_params,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_multimodal_embedding( # type: ignore
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
api_key=api_key,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
return vertex_multimodal_embedding_handler.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params,
)
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
optional_params: dict,
litellm_params: dict,
data: dict,
model_response: EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: LiteLLMLoggingObj,
headers={},
client: Optional[AsyncHTTPHandler] = None,
api_key: Optional[str] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return vertex_multimodal_embedding_handler.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
optional_params=optional_params,
litellm_params=litellm_params,
)

View File

@@ -0,0 +1,325 @@
from typing import List, Optional, Union, cast
from httpx import Headers, Response
from litellm.exceptions import InternalServerError
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.llms.vertex_ai import (
Instance,
InstanceImage,
InstanceVideo,
MultimodalPredictions,
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import (
Embedding,
EmbeddingResponse,
PromptTokensDetailsWrapper,
Usage,
)
from litellm.utils import _count_characters, is_base64_encoded
from ...base_llm.embedding.transformation import BaseEmbeddingConfig
from ..common_utils import VertexAIError
class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
def get_supported_openai_params(self, model: str) -> list:
return ["dimensions"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["outputDimensionality"] = value
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
headers.update(default_headers)
return headers
def _is_gcs_uri(self, input_str: str) -> bool:
"""Check if the input string is a GCS URI."""
return "gs://" in input_str
def _is_video(self, input_str: str) -> bool:
"""Check if the input string represents a video (mp4)."""
return "mp4" in input_str
def _is_media_input(self, input_str: str) -> bool:
"""Check if the input string is a media element (GCS URI or base64 image)."""
return self._is_gcs_uri(input_str) or is_base64_encoded(s=input_str)
def _create_image_instance(self, input_str: str) -> InstanceImage:
"""Create an InstanceImage from a GCS URI or base64 string."""
if self._is_gcs_uri(input_str):
return InstanceImage(gcsUri=input_str)
else:
return InstanceImage(
bytesBase64Encoded=(
input_str.split(",")[1] if "," in input_str else input_str
)
)
def _create_video_instance(self, input_str: str) -> InstanceVideo:
"""Create an InstanceVideo from a GCS URI."""
return InstanceVideo(gcsUri=input_str)
def _process_input_element(self, input_element: str) -> Instance:
"""
Process a single input element for multimodal embedding requests.
Detects if the input is a GCS URI, base64 encoded image, or plain text.
Args:
input_element (str): The input element to process.
Returns:
Instance: A dictionary representing the processed input element.
"""
if len(input_element) == 0:
return Instance(text=input_element)
elif self._is_gcs_uri(input_element):
if self._is_video(input_element):
return Instance(video=self._create_video_instance(input_element))
else:
return Instance(image=self._create_image_instance(input_element))
elif is_base64_encoded(s=input_element):
return Instance(image=self._create_image_instance(input_element))
else:
return Instance(text=input_element)
def _try_merge_text_with_media(
self, text_str: str, next_elem: Optional[str]
) -> tuple[Instance, bool]:
"""
Try to merge a text element with a following media element into a single instance.
Args:
text_str: The text string to potentially merge.
next_elem: The next element in the input list (may be media).
Returns:
A tuple of (Instance, consumed_next) where consumed_next indicates
if the next element was merged into this instance.
"""
instance_args: Instance = {"text": text_str}
if next_elem and isinstance(next_elem, str) and self._is_media_input(next_elem):
if self._is_gcs_uri(next_elem) and self._is_video(next_elem):
instance_args["video"] = self._create_video_instance(next_elem)
else:
instance_args["image"] = self._create_image_instance(next_elem)
return instance_args, True
return instance_args, False
def process_openai_embedding_input(
self, _input: Union[list, str]
) -> List[Instance]:
"""
Process the input for multimodal embedding requests.
Args:
_input (Union[list, str]): The input data to process.
Returns:
List[Instance]: List of Instance objects for the embedding request.
"""
_input_list = [_input] if not isinstance(_input, list) else _input
processed_instances: List[Instance] = []
i = 0
while i < len(_input_list):
current = _input_list[i]
next_elem = _input_list[i + 1] if i + 1 < len(_input_list) else None
if isinstance(current, str):
if self._is_media_input(current):
# Current element is media - process it standalone
processed_instances.append(self._process_input_element(current))
i += 1
else:
# Current element is text - try to merge with next media element
instance, consumed_next = self._try_merge_text_with_media(
text_str=current, next_elem=next_elem
)
processed_instances.append(instance)
i += 2 if consumed_next else 1
elif isinstance(current, dict):
processed_instances.append(Instance(**current))
i += 1
else:
raise ValueError(f"Unsupported input type: {type(current)}")
return processed_instances
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest(instances=[])
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
vertex_instances: List[Instance] = self.process_openai_embedding_input(
_input=input
)
request_data["instances"] = vertex_instances
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance = self._process_input_element(input)
request_data["instances"] = [vertex_request_instance]
return cast(dict, request_data)
def transform_embedding_response(
self,
model: str,
raw_response: Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
if raw_response.status_code != 200:
raise Exception(f"Error: {raw_response.status_code} {raw_response.text}")
_json_response = raw_response.json()
if "predictions" not in _json_response:
raise InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
)
model_response.model = model
model_response.usage = self.calculate_usage(
request_data=cast(VertexMultimodalEmbeddingRequest, request_data),
vertex_predictions=vertex_predictions,
)
return model_response
def calculate_usage(
self,
request_data: VertexMultimodalEmbeddingRequest,
vertex_predictions: MultimodalPredictions,
) -> Usage:
## Calculate text embeddings usage
prompt: Optional[str] = None
character_count: Optional[int] = None
for instance in request_data["instances"]:
text = instance.get("text")
if text:
if prompt is None:
prompt = text
else:
prompt += text
if prompt is not None:
character_count = _count_characters(prompt)
## Calculate image embeddings usage
image_count = 0
for instance in request_data["instances"]:
if instance.get("image"):
image_count += 1
## Calculate video embeddings usage
video_length_seconds = 0.0
for prediction in vertex_predictions["predictions"]:
video_embeddings = prediction.get("videoEmbeddings")
if video_embeddings:
for embedding in video_embeddings:
duration = embedding["endOffsetSec"] - embedding["startOffsetSec"]
video_length_seconds += duration
prompt_tokens_details = PromptTokensDetailsWrapper(
character_count=character_count,
image_count=image_count,
video_length_seconds=video_length_seconds,
)
return Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
prompt_tokens_details=prompt_tokens_details,
)
def transform_embedding_response_to_openai(
self, predictions: MultimodalPredictions
) -> List[Embedding]:
openai_embeddings: List[Embedding] = []
if "predictions" in predictions:
for idx, _prediction in enumerate(predictions["predictions"]):
if _prediction:
if "textEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["textEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "imageEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["imageEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "videoEmbeddings" in _prediction:
for video_embedding in _prediction["videoEmbeddings"]:
openai_embedding_object = Embedding(
embedding=video_embedding["embedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
return openai_embeddings
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return VertexAIError(
status_code=status_code, message=error_message, headers=headers
)

View File

@@ -0,0 +1,4 @@
"""Vertex AI OCR module."""
from .transformation import VertexAIOCRConfig
__all__ = ["VertexAIOCRConfig"]

View File

@@ -0,0 +1,41 @@
"""
Common utilities for Vertex AI OCR providers.
This module provides routing logic to determine which OCR configuration to use
based on the model name.
"""
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
def get_vertex_ai_ocr_config(model: str) -> Optional["BaseOCRConfig"]:
"""
Determine which Vertex AI OCR configuration to use based on the model name.
Vertex AI supports multiple OCR services:
- Vertex AI OCR: vertex_ai/<model>
Args:
model: The model name (e.g., "vertex_ai/ocr/<model>")
Returns:
OCR configuration instance for the specified model
Examples:
>>> get_vertex_ai_ocr_config("vertex_ai/deepseek-ai/deepseek-ocr-maas")
<VertexAIDeepSeekOCRConfig object>
>>> get_vertex_ai_ocr_config("vertex_ai/ocr/mistral-ocr-maas")
<VertexAIOCRConfig object>
"""
from litellm.llms.vertex_ai.ocr.deepseek_transformation import (
VertexAIDeepSeekOCRConfig,
)
from litellm.llms.vertex_ai.ocr.transformation import VertexAIOCRConfig
if "deepseek" in model:
return VertexAIDeepSeekOCRConfig()
return VertexAIOCRConfig()

View File

@@ -0,0 +1,394 @@
"""
Vertex AI DeepSeek OCR transformation implementation.
"""
import json
from typing import TYPE_CHECKING, Any, Dict, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.llms.base_llm.ocr.transformation import (
BaseOCRConfig,
DocumentType,
OCRPage,
OCRRequestData,
OCRResponse,
OCRUsageInfo,
)
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexAIDeepSeekOCRConfig(BaseOCRConfig):
"""
Vertex AI DeepSeek OCR transformation configuration.
Vertex AI DeepSeek OCR uses the chat completion API format through the openapi endpoint.
This transformation converts OCR requests to chat completion format and vice versa.
"""
def __init__(self) -> None:
super().__init__()
self.vertex_base = VertexBase()
def validate_environment(
self,
headers: Dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
litellm_params: Optional[dict] = None,
**kwargs,
) -> Dict:
"""
Validate environment and return headers for Vertex AI OCR.
Vertex AI uses Bearer token authentication with access token from credentials.
"""
# Extract Vertex AI parameters using safe helpers from VertexBase
# Use safe_get_* methods that don't mutate litellm_params dict
litellm_params = litellm_params or {}
vertex_project = VertexBase.safe_get_vertex_ai_project(
litellm_params=litellm_params
)
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
litellm_params=litellm_params
)
# Get access token from Vertex credentials
access_token, project_id = self.vertex_base.get_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
**headers,
}
return headers
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: Optional[dict] = None,
**kwargs,
) -> str:
"""
Get complete URL for Vertex AI DeepSeek OCR endpoint.
Vertex AI endpoint format:
https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/endpoints/openapi/chat/completions
Args:
api_base: Vertex AI API base URL (optional)
model: Model name (e.g., "deepseek-ai/deepseek-ocr-maas")
optional_params: Optional parameters
litellm_params: LiteLLM parameters containing vertex_project, vertex_location
Returns: Complete URL for Vertex AI OCR endpoint
"""
# Extract Vertex AI parameters using safe helpers from VertexBase
# Use safe_get_* methods that don't mutate litellm_params dict
litellm_params = litellm_params or {}
vertex_project = VertexBase.safe_get_vertex_ai_project(
litellm_params=litellm_params
)
vertex_location = VertexBase.safe_get_vertex_ai_location(
litellm_params=litellm_params
)
if vertex_project is None:
raise ValueError(
"Missing vertex_project - Set VERTEXAI_PROJECT environment variable or pass vertex_project parameter"
)
if vertex_location is None:
vertex_location = "us-central1"
# Get API base URL
if api_base is None:
api_base = "https://aiplatform.googleapis.com"
# Ensure no trailing slash
api_base = api_base.rstrip("/")
# Vertex AI DeepSeek OCR endpoint format
# Format: https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/endpoints/openapi/chat/completions
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi/chat/completions"
def transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request to chat completion format for Vertex AI DeepSeek OCR.
Converts OCR document format to chat completion messages format:
- Input: {"type": "image_url", "image_url": "gs://..."}
- Output: {"model": "deepseek-ai/deepseek-ocr-maas", "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": "gs://..."}]}]}
Args:
model: Model name (e.g., "deepseek-ai/deepseek-ocr-maas")
document: Document dict from user (Mistral OCR format)
optional_params: Already mapped optional parameters
headers: Request headers
**kwargs: Additional arguments
Returns:
OCRRequestData with JSON data in chat completion format
"""
verbose_logger.debug(
"Vertex AI DeepSeek OCR transform_ocr_request (sync) called"
)
if not isinstance(document, dict):
raise ValueError(f"Expected document dict, got {type(document)}")
# Extract document type and URL
doc_type = document.get("type")
image_url = None
document_url = None
if doc_type == "image_url":
image_url = document.get("image_url", "")
elif doc_type == "document_url":
document_url = document.get("document_url", "")
else:
raise ValueError(
f"Unsupported document type: {doc_type}. Expected 'image_url' or 'document_url'"
)
# Build chat completion message content
content_item = {}
if image_url:
content_item = {"type": "image_url", "image_url": image_url}
elif document_url:
# For document URLs, we use image_url type as well (Vertex AI supports both)
content_item = {"type": "image_url", "image_url": document_url}
# Build chat completion request
data = {
"model": "deepseek-ai/" + model,
"messages": [{"role": "user", "content": [content_item]}],
}
# Add optional parameters (stream, temperature, etc.)
# Filter out OCR-specific params that don't apply to chat completion
chat_completion_params = {}
for key, value in optional_params.items():
# Include common chat completion params
if key in ["stream", "temperature", "max_tokens", "top_p", "n", "stop"]:
chat_completion_params[key] = value
data.update(chat_completion_params)
verbose_logger.debug(
"Vertex AI DeepSeek OCR: Transformed request to chat completion format"
)
return OCRRequestData(data=data, files=None)
async def async_transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request to chat completion format for Vertex AI DeepSeek OCR (async).
Same as sync version - no async-specific logic needed.
Args:
model: Model name
document: Document dict from user
optional_params: Already mapped optional parameters
headers: Request headers
**kwargs: Additional arguments
Returns:
OCRRequestData with JSON data in chat completion format
"""
return self.transform_ocr_request(
model=model,
document=document,
optional_params=optional_params,
headers=headers,
**kwargs,
)
def transform_ocr_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> OCRResponse:
"""
Transform chat completion response to OCR format.
Vertex AI DeepSeek OCR returns chat completion format:
{
"id": "...",
"object": "chat.completion",
"choices": [{
"message": {
"role": "assistant",
"content": "<OCR result as JSON string or markdown>"
}
}],
"usage": {...}
}
We need to extract the content and convert it to OCRResponse format.
Args:
model: Model name
raw_response: Raw HTTP response from Vertex AI
logging_obj: Logging object
**kwargs: Additional arguments
Returns:
OCRResponse in standard format
"""
verbose_logger.debug("Vertex AI DeepSeek OCR transform_ocr_response called")
verbose_logger.debug(f"Raw response: {raw_response.text}")
try:
response_json = raw_response.json()
# Extract content from chat completion response
choices = response_json.get("choices", [])
if not choices:
raise ValueError("No choices in chat completion response")
message = choices[0].get("message", {})
content = message.get("content", "")
if not content:
raise ValueError("No content in chat completion response")
# Try to parse content as JSON (OCR result might be JSON string)
ocr_data = None
try:
# If content is a JSON string, parse it
if isinstance(content, str) and content.strip().startswith("{"):
ocr_data = json.loads(content)
elif isinstance(content, dict):
ocr_data = content
else:
# If content is markdown text, create a single page with the markdown
ocr_data = {
"pages": [{"index": 0, "markdown": content}],
"model": model,
"usage_info": response_json.get("usage", {}),
}
except json.JSONDecodeError:
# If JSON parsing fails, treat content as markdown
ocr_data = {
"pages": [{"index": 0, "markdown": content}],
"model": model,
"usage_info": response_json.get("usage", {}),
}
# Ensure we have the expected structure
if "pages" not in ocr_data:
# If OCR data doesn't have pages, wrap the content in a page
ocr_data = {
"pages": [
{
"index": 0,
"markdown": content
if isinstance(content, str)
else json.dumps(content),
}
],
"model": ocr_data.get("model", model),
"usage_info": ocr_data.get(
"usage_info", response_json.get("usage", {})
),
}
# Convert usage info if present
usage_info = None
if "usage_info" in ocr_data:
usage_dict = ocr_data["usage_info"]
if isinstance(usage_dict, dict):
usage_info = OCRUsageInfo(**usage_dict)
# Build OCRResponse
pages = []
for page_data in ocr_data.get("pages", []):
# Ensure page has required fields
if isinstance(page_data, dict):
page = OCRPage(
index=page_data.get("index", 0),
markdown=page_data.get("markdown", ""),
images=page_data.get("images"),
dimensions=page_data.get("dimensions"),
)
pages.append(page)
if not pages:
# Create a default page if none exist
pages = [
OCRPage(
index=0, markdown=content if isinstance(content, str) else ""
)
]
return OCRResponse(
pages=pages,
model=ocr_data.get("model", model),
document_annotation=ocr_data.get("document_annotation"),
usage_info=usage_info,
object="ocr",
)
except Exception as e:
verbose_logger.error(f"Error parsing Vertex AI DeepSeek OCR response: {e}")
raise e
async def async_transform_ocr_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> OCRResponse:
"""
Async transform chat completion response to OCR format.
Same as sync version - no async-specific logic needed.
Args:
model: Model name
raw_response: Raw HTTP response
logging_obj: Logging object
**kwargs: Additional arguments
Returns:
OCRResponse in standard format
"""
return self.transform_ocr_response(
model=model,
raw_response=raw_response,
logging_obj=logging_obj,
**kwargs,
)

View File

@@ -0,0 +1,301 @@
"""
Vertex AI Mistral OCR transformation implementation.
"""
from typing import Dict, Optional
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.image_handling import (
async_convert_url_to_base64,
convert_url_to_base64,
)
from litellm.llms.base_llm.ocr.transformation import DocumentType, OCRRequestData
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
class VertexAIOCRConfig(MistralOCRConfig):
"""
Vertex AI Mistral OCR transformation configuration.
Vertex AI uses Mistral's OCR API format through the Mistral publisher endpoint.
Inherits transformation logic from MistralOCRConfig since they use the same format.
Reference: Vertex AI Mistral OCR documentation
Important: Vertex AI OCR only supports base64 data URIs (data:image/..., data:application/pdf;base64,...).
Regular URLs are not supported.
"""
def __init__(self) -> None:
super().__init__()
self.vertex_base = VertexBase()
def validate_environment(
self,
headers: Dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
litellm_params: Optional[dict] = None,
**kwargs,
) -> Dict:
"""
Validate environment and return headers for Vertex AI OCR.
Vertex AI uses Bearer token authentication with access token from credentials.
"""
# Extract Vertex AI parameters using safe helpers from VertexBase
# Use safe_get_* methods that don't mutate litellm_params dict
litellm_params = litellm_params or {}
vertex_project = VertexBase.safe_get_vertex_ai_project(
litellm_params=litellm_params
)
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
litellm_params=litellm_params
)
# Get access token from Vertex credentials
access_token, project_id = self.vertex_base.get_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
**headers,
}
return headers
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: Optional[dict] = None,
**kwargs,
) -> str:
"""
Get complete URL for Vertex AI OCR endpoint.
Vertex AI endpoint format:
https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/mistralai/ocr
Args:
api_base: Vertex AI API base URL (optional)
model: Model name (not used in URL construction)
optional_params: Optional parameters
litellm_params: LiteLLM parameters containing vertex_project, vertex_location
Returns: Complete URL for Vertex AI OCR endpoint
"""
# Extract Vertex AI parameters using safe helpers from VertexBase
# Use safe_get_* methods that don't mutate litellm_params dict
litellm_params = litellm_params or {}
vertex_project = VertexBase.safe_get_vertex_ai_project(
litellm_params=litellm_params
)
vertex_location = VertexBase.safe_get_vertex_ai_location(
litellm_params=litellm_params
)
if vertex_project is None:
raise ValueError(
"Missing vertex_project - Set VERTEXAI_PROJECT environment variable or pass vertex_project parameter"
)
if vertex_location is None:
vertex_location = "us-central1"
# Get API base URL
if api_base is None:
api_base = get_vertex_base_url(vertex_location)
# Ensure no trailing slash
api_base = api_base.rstrip("/")
# Vertex AI OCR endpoint format for Mistral publisher
# Format: https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/mistralai/models/{model}:rawPredict
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
def _convert_url_to_data_uri_sync(self, url: str) -> str:
"""
Synchronously convert a URL to a base64 data URI.
Vertex AI OCR doesn't have internet access, so we need to fetch URLs
and convert them to base64 data URIs.
Args:
url: The URL to convert
Returns:
Base64 data URI string
"""
verbose_logger.debug(
f"Vertex AI OCR: Converting URL to base64 data URI (sync): {url}"
)
# Fetch and convert to base64 data URI
# convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
data_uri = convert_url_to_base64(url=url)
verbose_logger.debug(
f"Vertex AI OCR: Converted URL to data URI (length: {len(data_uri)})"
)
return data_uri
async def _convert_url_to_data_uri_async(self, url: str) -> str:
"""
Asynchronously convert a URL to a base64 data URI.
Vertex AI OCR doesn't have internet access, so we need to fetch URLs
and convert them to base64 data URIs.
Args:
url: The URL to convert
Returns:
Base64 data URI string
"""
verbose_logger.debug(
f"Vertex AI OCR: Converting URL to base64 data URI (async): {url}"
)
# Fetch and convert to base64 data URI asynchronously
# async_convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
data_uri = await async_convert_url_to_base64(url=url)
verbose_logger.debug(
f"Vertex AI OCR: Converted URL to data URI (length: {len(data_uri)})"
)
return data_uri
def transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request for Vertex AI, converting URLs to base64 data URIs (sync).
Vertex AI OCR doesn't have internet access, so we automatically fetch
any URLs and convert them to base64 data URIs synchronously.
Args:
model: Model name
document: Document dict from user
optional_params: Already mapped optional parameters
headers: Request headers
**kwargs: Additional arguments
Returns:
OCRRequestData with JSON data
"""
verbose_logger.debug("Vertex AI OCR transform_ocr_request (sync) called")
if not isinstance(document, dict):
raise ValueError(f"Expected document dict, got {type(document)}")
# Check if we need to convert URL to base64
doc_type = document.get("type")
transformed_document = document.copy()
if doc_type == "document_url":
document_url = document.get("document_url", "")
# If it's not already a data URI, convert it
if document_url and not document_url.startswith("data:"):
verbose_logger.debug(
"Vertex AI OCR: Converting document URL to base64 data URI (sync)"
)
data_uri = self._convert_url_to_data_uri_sync(url=document_url)
transformed_document["document_url"] = data_uri
elif doc_type == "image_url":
image_url = document.get("image_url", "")
# If it's not already a data URI, convert it
if image_url and not image_url.startswith("data:"):
verbose_logger.debug(
"Vertex AI OCR: Converting image URL to base64 data URI (sync)"
)
data_uri = self._convert_url_to_data_uri_sync(url=image_url)
transformed_document["image_url"] = data_uri
# Call parent's transform to build the request
return super().transform_ocr_request(
model=model,
document=transformed_document,
optional_params=optional_params,
headers=headers,
**kwargs,
)
async def async_transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request for Vertex AI, converting URLs to base64 data URIs (async).
Vertex AI OCR doesn't have internet access, so we automatically fetch
any URLs and convert them to base64 data URIs asynchronously.
Args:
model: Model name
document: Document dict from user
optional_params: Already mapped optional parameters
headers: Request headers
**kwargs: Additional arguments
Returns:
OCRRequestData with JSON data
"""
verbose_logger.debug(
f"Vertex AI OCR async_transform_ocr_request - model: {model}"
)
if not isinstance(document, dict):
raise ValueError(f"Expected document dict, got {type(document)}")
# Check if we need to convert URL to base64
doc_type = document.get("type")
transformed_document = document.copy()
if doc_type == "document_url":
document_url = document.get("document_url", "")
# If it's not already a data URI, convert it
if document_url and not document_url.startswith("data:"):
verbose_logger.debug(
"Vertex AI OCR: Converting document URL to base64 data URI (async)"
)
data_uri = await self._convert_url_to_data_uri_async(url=document_url)
transformed_document["document_url"] = data_uri
elif doc_type == "image_url":
image_url = document.get("image_url", "")
# If it's not already a data URI, convert it
if image_url and not image_url.startswith("data:"):
verbose_logger.debug(
"Vertex AI OCR: Converting image URL to base64 data URI (async)"
)
data_uri = await self._convert_url_to_data_uri_async(url=image_url)
transformed_document["image_url"] = data_uri
# Call parent's transform to build the request
return super().transform_ocr_request(
model=model,
document=transformed_document,
optional_params=optional_params,
headers=headers,
**kwargs,
)

View File

@@ -0,0 +1,13 @@
"""
Vertex AI RAG Engine module.
Handles RAG ingestion via Vertex AI RAG Engine API.
"""
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
__all__ = [
"VertexAIRAGIngestion",
"VertexAIRAGTransformation",
]

View File

@@ -0,0 +1,312 @@
"""
Vertex AI RAG Engine Ingestion implementation.
Uses:
- litellm.files.acreate_file for uploading files to GCS
- Vertex AI RAG Engine REST API for importing files into corpus (via httpx)
Key differences from OpenAI:
- Files must be uploaded to GCS first (via litellm.files.acreate_file)
- Embedding is handled internally using text-embedding-005 by default
- Chunking configured via unified chunking_strategy in ingest_options
"""
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from litellm import get_secret_str
from litellm._logging import verbose_logger
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGIngestOptions
def _get_str_or_none(value: Any) -> Optional[str]:
"""Cast config value to Optional[str]."""
return str(value) if value is not None else None
def _get_int(value: Any, default: int) -> int:
"""Cast config value to int with default."""
if value is None:
return default
return int(value)
class VertexAIRAGIngestion(BaseRAGIngestion):
"""
Vertex AI RAG Engine ingestion.
Uses litellm.files.acreate_file for GCS upload, then imports into RAG corpus.
Required config in vector_store:
- vector_store_id: RAG corpus ID (required)
Optional config in vector_store:
- vertex_project: GCP project ID (uses env VERTEXAI_PROJECT if not set)
- vertex_location: GCP region (default: us-central1)
- vertex_credentials: Path to credentials JSON (uses ADC if not set)
- wait_for_import: Wait for import to complete (default: True)
- import_timeout: Timeout in seconds (default: 600)
Chunking is configured via ingest_options["chunking_strategy"]:
- chunk_size: Maximum size of chunks (default: 1000)
- chunk_overlap: Overlap between chunks (default: 200)
Authentication:
- Uses Application Default Credentials (ADC)
- Run: gcloud auth application-default login
"""
def __init__(
self,
ingest_options: "RAGIngestOptions",
router: Optional["Router"] = None,
):
super().__init__(ingest_options=ingest_options, router=router)
# Get corpus ID (required for Vertex AI)
self.corpus_id = self.vector_store_config.get("vector_store_id")
if not self.corpus_id:
raise ValueError(
"vector_store_id (corpus ID) is required for Vertex AI RAG ingestion. "
"Please provide an existing RAG corpus ID."
)
# GCP config
self.vertex_project = self.vector_store_config.get(
"vertex_project"
) or get_secret_str("VERTEXAI_PROJECT")
self.vertex_location = (
self.vector_store_config.get("vertex_location")
or get_secret_str("VERTEXAI_LOCATION")
or "us-central1"
)
self.vertex_credentials = self.vector_store_config.get("vertex_credentials")
# GCS bucket for file uploads
self.gcs_bucket = self.vector_store_config.get("gcs_bucket") or os.environ.get(
"GCS_BUCKET_NAME"
)
if not self.gcs_bucket:
raise ValueError(
"gcs_bucket is required for Vertex AI RAG ingestion. "
"Set via vector_store config or GCS_BUCKET_NAME env var."
)
# Import settings
self.wait_for_import = self.vector_store_config.get("wait_for_import", True)
self.import_timeout = _get_int(
self.vector_store_config.get("import_timeout"), 600
)
# Validate required config
if not self.vertex_project:
raise ValueError(
"vertex_project is required for Vertex AI RAG ingestion. "
"Set via vector_store config or VERTEXAI_PROJECT env var."
)
def _get_corpus_name(self) -> str:
"""Get full corpus resource name."""
return f"projects/{self.vertex_project}/locations/{self.vertex_location}/ragCorpora/{self.corpus_id}"
async def _upload_file_to_gcs(
self,
file_content: bytes,
filename: str,
content_type: str,
) -> str:
"""
Upload file to GCS using litellm.files.acreate_file.
Returns:
GCS URI of the uploaded file (gs://bucket/path/file)
"""
import litellm
# Set GCS_BUCKET_NAME env var for litellm.files.create_file
# The handler uses this to determine where to upload
original_bucket = os.environ.get("GCS_BUCKET_NAME")
if self.gcs_bucket:
os.environ["GCS_BUCKET_NAME"] = self.gcs_bucket
try:
# Create file tuple for litellm.files.acreate_file
file_tuple = (filename, file_content, content_type)
verbose_logger.debug(
f"Uploading file to GCS via litellm.files.acreate_file: {filename} "
f"(bucket: {self.gcs_bucket})"
)
# Upload to GCS using LiteLLM's file upload
response = await litellm.acreate_file(
file=file_tuple,
purpose="assistants", # Purpose for file storage
custom_llm_provider="vertex_ai",
vertex_project=self.vertex_project,
vertex_location=self.vertex_location,
vertex_credentials=self.vertex_credentials,
)
# The response.id should be the GCS URI
gcs_uri = response.id
verbose_logger.info(f"Uploaded file to GCS: {gcs_uri}")
return gcs_uri
finally:
# Restore original env var
if original_bucket is not None:
os.environ["GCS_BUCKET_NAME"] = original_bucket
elif "GCS_BUCKET_NAME" in os.environ:
del os.environ["GCS_BUCKET_NAME"]
async def _import_file_to_corpus_via_sdk(
self,
gcs_uri: str,
) -> None:
"""
Import file into RAG corpus using the Vertex AI SDK.
The REST API endpoint for importRagFiles is not publicly available,
so we use the Python SDK.
"""
try:
from vertexai import init as vertexai_init
from vertexai import rag # type: ignore[import-not-found]
except ImportError:
raise ImportError(
"vertexai.rag module not found. Vertex AI RAG requires "
"google-cloud-aiplatform>=1.60.0. Install with: "
"pip install 'google-cloud-aiplatform>=1.60.0'"
)
# Initialize Vertex AI
vertexai_init(project=self.vertex_project, location=self.vertex_location)
# Get chunking config from ingest_options (unified interface)
transformation_config = self._build_transformation_config()
corpus_name = self._get_corpus_name()
verbose_logger.debug(f"Importing {gcs_uri} into corpus {self.corpus_id}")
if self.wait_for_import:
# Synchronous import - wait for completion
response = rag.import_files(
corpus_name=corpus_name,
paths=[gcs_uri],
transformation_config=transformation_config,
timeout=self.import_timeout,
)
verbose_logger.info(
f"Import complete: {response.imported_rag_files_count} files imported"
)
else:
# Async import - don't wait
_ = rag.import_files_async(
corpus_name=corpus_name,
paths=[gcs_uri],
transformation_config=transformation_config,
)
verbose_logger.info("Import started asynchronously")
def _build_transformation_config(self) -> Any:
"""
Build Vertex AI TransformationConfig from unified chunking_strategy.
Uses chunking_strategy from ingest_options (not vector_store).
"""
try:
from vertexai import rag # type: ignore[import-not-found]
except ImportError:
raise ImportError(
"vertexai.rag module not found. Vertex AI RAG requires "
"google-cloud-aiplatform>=1.60.0. Install with: "
"pip install 'google-cloud-aiplatform>=1.60.0'"
)
# Get chunking config from ingest_options using transformation class
from typing import cast
from litellm.types.rag import RAGChunkingStrategy
transformation = VertexAIRAGTransformation()
chunking_config = transformation.transform_chunking_strategy_to_vertex_format(
cast(Optional[RAGChunkingStrategy], self.chunking_strategy)
)
chunk_size = chunking_config["chunking_config"]["chunk_size"]
chunk_overlap = chunking_config["chunking_config"]["chunk_overlap"]
return rag.TransformationConfig(
chunking_config=rag.ChunkingConfig(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
),
)
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
Vertex AI handles embedding internally - skip this step.
Returns:
None (Vertex AI embeds when files are imported)
"""
return None
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store content in Vertex AI RAG corpus.
Vertex AI workflow:
1. Upload file to GCS via litellm.files.acreate_file
2. Import file into RAG corpus via SDK
3. (Optional) Wait for import to complete
Args:
file_content: Raw file bytes
filename: Name of the file
content_type: MIME type
chunks: Ignored - Vertex AI handles chunking
embeddings: Ignored - Vertex AI handles embedding
Returns:
Tuple of (corpus_id, gcs_uri)
"""
if not file_content or not filename:
verbose_logger.warning(
"No file content or filename provided for Vertex AI ingestion"
)
return _get_str_or_none(self.corpus_id), None
# Step 1: Upload file to GCS
gcs_uri = await self._upload_file_to_gcs(
file_content=file_content,
filename=filename,
content_type=content_type or "application/octet-stream",
)
# Step 2: Import file into RAG corpus
try:
await self._import_file_to_corpus_via_sdk(gcs_uri=gcs_uri)
except Exception as e:
verbose_logger.error(f"Failed to import file into RAG corpus: {e}")
raise RuntimeError(f"Failed to import file into RAG corpus: {e}") from e
return str(self.corpus_id), gcs_uri

View File

@@ -0,0 +1,153 @@
"""
Transformation utilities for Vertex AI RAG Engine.
Handles transforming LiteLLM's unified formats to Vertex AI RAG Engine API format.
"""
from typing import Any, Dict, Optional
from litellm._logging import verbose_logger
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.rag import RAGChunkingStrategy
class VertexAIRAGTransformation(VertexBase):
"""
Transformation class for Vertex AI RAG Engine API.
Handles:
- Converting unified chunking_strategy to Vertex AI format
- Building import request payloads
- Transforming responses
"""
def __init__(self):
super().__init__()
def get_import_rag_files_url(
self,
vertex_project: str,
vertex_location: str,
corpus_id: str,
) -> str:
"""
Get the URL for importing RAG files.
Note: The REST endpoint for importRagFiles may not be publicly available.
Vertex AI RAG Engine primarily uses gRPC-based SDK.
"""
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{corpus_id}:importRagFiles"
def get_retrieve_contexts_url(
self,
vertex_project: str,
vertex_location: str,
) -> str:
"""Get the URL for retrieving contexts (search)."""
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}:retrieveContexts"
def transform_chunking_strategy_to_vertex_format(
self,
chunking_strategy: Optional[RAGChunkingStrategy],
) -> Dict[str, Any]:
"""
Transform LiteLLM's unified chunking_strategy to Vertex AI RAG format.
LiteLLM format (RAGChunkingStrategy):
{
"chunk_size": 1000,
"chunk_overlap": 200,
"separators": ["\n\n", "\n", " ", ""]
}
Vertex AI RAG format (TransformationConfig):
{
"chunking_config": {
"chunk_size": 1000,
"chunk_overlap": 200
}
}
Note: Vertex AI doesn't support custom separators in the same way,
so we only transform chunk_size and chunk_overlap.
"""
if not chunking_strategy:
return {
"chunking_config": {
"chunk_size": DEFAULT_CHUNK_SIZE,
"chunk_overlap": DEFAULT_CHUNK_OVERLAP,
}
}
chunk_size = chunking_strategy.get("chunk_size", DEFAULT_CHUNK_SIZE)
chunk_overlap = chunking_strategy.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
# Log if separators are provided (not supported by Vertex AI)
if chunking_strategy.get("separators"):
verbose_logger.warning(
"Vertex AI RAG Engine does not support custom separators. "
"The 'separators' parameter will be ignored."
)
return {
"chunking_config": {
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
}
}
def build_import_rag_files_request(
self,
gcs_uri: str,
chunking_strategy: Optional[RAGChunkingStrategy] = None,
) -> Dict[str, Any]:
"""
Build the request payload for importing RAG files.
Args:
gcs_uri: GCS URI of the file to import (e.g., gs://bucket/path/file.txt)
chunking_strategy: LiteLLM unified chunking config
Returns:
Request payload dict for importRagFiles API
"""
transformation_config = self.transform_chunking_strategy_to_vertex_format(
chunking_strategy
)
return {
"import_rag_files_config": {
"gcs_source": {"uris": [gcs_uri]},
"rag_file_transformation_config": transformation_config,
}
}
def get_auth_headers(
self,
vertex_credentials: Optional[str] = None,
vertex_project: Optional[str] = None,
) -> Dict[str, str]:
"""
Get authentication headers for Vertex AI API calls.
Uses the base class method to get credentials.
"""
credentials = self.get_vertex_ai_credentials(
{"vertex_credentials": vertex_credentials}
)
project = vertex_project or self.get_vertex_ai_project({})
access_token, _ = self._ensure_access_token(
credentials=credentials,
project_id=project,
custom_llm_provider="vertex_ai",
)
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}

View File

@@ -0,0 +1,164 @@
"""
Vertex AI Realtime (BidiGenerateContent) config.
Extends GeminiRealtimeConfig but adapts the WSS URL and auth header for the
Vertex AI endpoint instead of Google AI Studio.
URL pattern:
wss://{location}-aiplatform.googleapis.com/ws/
google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent
Auth: OAuth2 Bearer token (not an API key).
"""
import json
from typing import List, Optional
from litellm.llms.gemini.realtime.transformation import GeminiRealtimeConfig
class VertexAIRealtimeConfig(GeminiRealtimeConfig):
"""
Realtime config for Vertex AI (BidiGenerateContent).
``access_token`` and ``project`` must be pre-resolved by the caller
(they require async I/O) and injected at construction time.
"""
def __init__(self, access_token: str, project: str, location: str) -> None:
self._access_token = access_token
self._project = project
self._location = location
# ------------------------------------------------------------------
# URL
# ------------------------------------------------------------------
def get_complete_url(
self,
api_base: Optional[str],
model: str,
api_key: Optional[str] = None, # noqa: ARG002
) -> str:
"""
Build the Vertex AI Live WSS endpoint URL.
If *api_base* is provided it overrides the default aiplatform host,
allowing enterprise / VPC-SC deployments to point at a custom gateway.
"""
if api_base:
# Allow callers to supply a fully-qualified wss:// base URL.
base = api_base.rstrip("/")
base = base.replace("https://", "wss://").replace("http://", "ws://")
return f"{base}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
location = self._location
if location == "global":
host = "aiplatform.googleapis.com"
else:
host = f"{location}-aiplatform.googleapis.com"
return f"wss://{host}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
# ------------------------------------------------------------------
# Auth headers
# ------------------------------------------------------------------
def validate_environment(
self,
headers: dict,
model: str, # noqa: ARG002
api_key: Optional[str] = None, # noqa: ARG002
) -> dict:
"""
Return headers with a Bearer token for Vertex AI.
``api_key`` is intentionally ignored — Vertex AI uses OAuth2 tokens,
not API keys. The token was resolved at config-construction time.
"""
headers = dict(headers)
headers["Authorization"] = f"Bearer {self._access_token}"
if self._project:
headers["x-goog-user-project"] = self._project
return headers
# ------------------------------------------------------------------
# Audio MIME type — Vertex AI needs the sample rate in the MIME string
# ------------------------------------------------------------------
def get_audio_mime_type(self, input_audio_format: str = "pcm16") -> str:
mime_types = {
"pcm16": "audio/pcm;rate=16000",
"g711_ulaw": "audio/pcmu",
"g711_alaw": "audio/pcma",
}
return mime_types.get(input_audio_format, "application/octet-stream")
# ------------------------------------------------------------------
# Session setup message
# ------------------------------------------------------------------
def session_configuration_request(self, model: str) -> str:
"""
Return the JSON setup message for Vertex AI Live.
Vertex AI requires the fully-qualified model path:
``projects/{project}/locations/{location}/publishers/google/models/{model}``
Also enables automatic activity detection (server VAD) and output
audio transcription so the proxy forwards transcript events.
"""
from litellm.types.llms.gemini import BidiGenerateContentSetup
from litellm.types.llms.vertex_ai import GeminiResponseModalities
response_modalities: list[GeminiResponseModalities] = ["AUDIO"]
full_model_path = (
f"projects/{self._project}"
f"/locations/{self._location}"
f"/publishers/google/models/{model}"
)
setup_config: BidiGenerateContentSetup = {
"model": full_model_path,
"generationConfig": {"responseModalities": response_modalities},
# Enable server-side VAD with sensible defaults for voice sessions.
"realtimeInputConfig": {
"automaticActivityDetection": {
"disabled": False,
"silenceDurationMs": 800,
}
},
# Return input transcript so guardrails can inspect user speech.
"inputAudioTranscription": {},
# Return output transcript so clients can read what the model said.
"outputAudioTranscription": {},
}
return json.dumps({"setup": setup_config})
# ------------------------------------------------------------------
# Request translation
# ------------------------------------------------------------------
def transform_realtime_request(
self,
message: str,
model: str,
session_configuration_request: Optional[str] = None,
) -> List[str]:
"""
Translate OpenAI realtime client messages to Vertex AI format.
``session.update`` is intentionally ignored (returns []) because
Vertex AI only accepts a single ``setup`` message at the start of
the connection — sending a second one causes a 1007 close error.
The initial setup (sent automatically before bidirectional_forward)
already includes AUDIO modality and server VAD, so there is nothing
more to configure.
"""
json_message = json.loads(message)
if json_message.get("type") == "session.update":
# Do not forward as a second setup — Vertex AI rejects it.
return []
return super().transform_realtime_request(
message, model, session_configuration_request
)

View File

@@ -0,0 +1,5 @@
"""
Vertex AI Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View File

@@ -0,0 +1,248 @@
"""
Translates from Cohere's `/v1/rerank` input format to Vertex AI Discovery Engine's `/rank` input format.
Why separate file? Make it easy to see how transformation works
"""
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import (
RerankResponse,
RerankResponseMeta,
RerankBilledUnits,
RerankResponseResult,
)
class VertexAIRerankConfig(BaseRerankConfig, VertexBase):
"""
Configuration for Vertex AI Discovery Engine Rerank API
Reference: https://cloud.google.com/generative-ai-app-builder/docs/ranking#rank_or_rerank_a_set_of_records_according_to_a_query
"""
def __init__(self) -> None:
super().__init__()
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: Optional[Dict] = None,
) -> str:
"""
Get the complete URL for the Vertex AI Discovery Engine ranking API
"""
# Try to get project ID from optional_params first (e.g., vertex_project parameter)
params = optional_params or {}
# Get credentials to extract project ID if needed
vertex_credentials = self.safe_get_vertex_ai_credentials(params.copy())
vertex_project = self.safe_get_vertex_ai_project(params.copy())
# Use _ensure_access_token to extract project_id from credentials
# This is the same method used in vertex embeddings
_, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
# Fallback to environment or litellm config
project_id = (
vertex_project
or get_secret_str("VERTEXAI_PROJECT")
or litellm.vertex_project
)
if not project_id:
raise ValueError(
"Vertex AI project ID is required. Please set 'VERTEXAI_PROJECT', 'litellm.vertex_project', or pass 'vertex_project' parameter"
)
return f"https://discoveryengine.googleapis.com/v1/projects/{project_id}/locations/global/rankingConfigs/default_ranking_config:rank"
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
optional_params: Optional[Dict] = None,
) -> dict:
"""
Validate and set up authentication for Vertex AI Discovery Engine API
"""
# Get credentials and project info from optional_params (which contains vertex_credentials, etc.)
litellm_params = optional_params.copy() if optional_params else {}
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
# Get access token using the base class method
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"X-Goog-User-Project": project_id,
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_request(
self,
model: str,
optional_rerank_params: Dict,
headers: dict,
) -> dict:
"""
Transform the request from Cohere format to Vertex AI Discovery Engine format
"""
if "query" not in optional_rerank_params:
raise ValueError("query is required for Vertex AI rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Vertex AI rerank")
query = optional_rerank_params["query"]
documents = optional_rerank_params["documents"]
top_n = optional_rerank_params.get("top_n", None)
return_documents = optional_rerank_params.get("return_documents", True)
# Convert documents to records format
records = []
for idx, document in enumerate(documents):
if isinstance(document, str):
content = document
title = " ".join(document.split()[:3]) # First 3 words as title
else:
# Handle dict format
content = document.get("text", str(document))
title = document.get("title", " ".join(content.split()[:3]))
records.append({"id": str(idx), "title": title, "content": content})
request_data = {"model": model, "query": query, "records": records}
if top_n is not None:
request_data["topN"] = top_n
# Map return_documents to ignoreRecordDetailsInResponse
# When return_documents is False, we want to ignore record details (return only IDs)
request_data["ignoreRecordDetailsInResponse"] = not return_documents
return request_data
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Vertex AI Discovery Engine response to Cohere format
"""
try:
raw_response_json = raw_response.json()
except Exception as e:
raise ValueError(f"Failed to parse response: {e}")
# Extract records from response
records = raw_response_json.get("records", [])
# Convert to Cohere format
results = []
for record in records:
# Handle both cases: with full details and with only IDs
if "score" in record:
# Full response with score and details
results.append(
{
"index": int(record["id"]),
"relevance_score": record.get("score", 0.0),
}
)
else:
# Response with only IDs (when ignoreRecordDetailsInResponse=true)
# We can't provide a relevance score, so we'll use a default
results.append(
{
"index": int(record["id"]),
"relevance_score": 1.0, # Default score when details are ignored
}
)
# Sort by relevance score (descending)
results.sort(key=lambda x: x["relevance_score"], reverse=True)
# Create response in Cohere format
# Convert results to proper RerankResponseResult objects
rerank_results = []
for result in results:
rerank_results.append(
RerankResponseResult(
index=result["index"], relevance_score=result["relevance_score"]
)
)
# Create meta object
meta = RerankResponseMeta(
billed_units=RerankBilledUnits(search_units=len(records))
)
return RerankResponse(
id=f"vertex_ai_rerank_{model}", results=rerank_results, meta=meta
)
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: dict,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> Dict:
"""
Map Cohere rerank params to Vertex AI format
"""
result = {
"query": query,
"documents": documents,
"top_n": top_n,
"return_documents": return_documents,
}
result.update(non_default_params)
return result

View File

@@ -0,0 +1,244 @@
from typing import Optional, Union
import httpx
from typing_extensions import TypedDict
import litellm
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.openai.openai import HttpxBinaryResponseContent
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
class VertexInput(TypedDict, total=False):
text: Optional[str]
ssml: Optional[str]
class VertexVoice(TypedDict, total=False):
languageCode: str
name: str
class VertexAudioConfig(TypedDict, total=False):
audioEncoding: str
speakingRate: str
class VertexTextToSpeechRequest(TypedDict, total=False):
input: VertexInput
voice: VertexVoice
audioConfig: Optional[VertexAudioConfig]
class VertexTextToSpeechAPI(VertexLLM):
"""
Vertex methods to support for batches
"""
def __init__(self) -> None:
super().__init__()
def audio_speech(
self,
logging_obj,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
model: str,
input: str,
voice: Optional[dict] = None,
_is_async: Optional[bool] = False,
optional_params: Optional[dict] = None,
kwargs: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
import base64
####### Authenticate with Vertex AI ########
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = self._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base=api_base,
)
headers = {
"Authorization": f"Bearer {auth_header}",
"x-goog-user-project": vertex_project,
"Content-Type": "application/json",
"charset": "UTF-8",
}
######### End of Authentication ###########
####### Build the request ################
# API Ref: https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize
kwargs = kwargs or {}
optional_params = optional_params or {}
vertex_input = VertexInput(text=input)
validate_vertex_input(vertex_input, kwargs, optional_params)
# required param
if voice is not None:
vertex_voice = VertexVoice(**voice)
elif "voice" in kwargs:
vertex_voice = VertexVoice(**kwargs["voice"])
else:
# use defaults to not fail the request
vertex_voice = VertexVoice(
languageCode="en-US",
name="en-US-Studio-O",
)
if "audioConfig" in kwargs:
vertex_audio_config = VertexAudioConfig(**kwargs["audioConfig"])
else:
# use defaults to not fail the request
vertex_audio_config = VertexAudioConfig(
audioEncoding="LINEAR16",
speakingRate="1",
)
request = VertexTextToSpeechRequest(
input=vertex_input,
voice=vertex_voice,
audioConfig=vertex_audio_config,
)
url = "https://texttospeech.googleapis.com/v1/text:synthesize"
########## End of building request ############
########## Log the request for debugging / logging ############
logging_obj.pre_call(
input=[],
api_key="",
additional_args={
"complete_input_dict": request,
"api_base": url,
"headers": headers,
},
)
########## End of logging ############
####### Send the request ###################
if _is_async is True:
return self.async_audio_speech( # type:ignore
logging_obj=logging_obj, url=url, headers=headers, request=request
)
sync_handler = _get_httpx_client()
response = sync_handler.post(
url=url,
headers=headers,
json=request, # type: ignore
)
if response.status_code != 200:
raise Exception(
f"Request failed with status code {response.status_code}, {response.text}"
)
############ Process the response ############
_json_response = response.json()
response_content = _json_response["audioContent"]
# Decode base64 to get binary content
binary_data = base64.b64decode(response_content)
# Create an httpx.Response object
response = httpx.Response(
status_code=200,
content=binary_data,
)
# Initialize the HttpxBinaryResponseContent instance
http_binary_response = HttpxBinaryResponseContent(response)
return http_binary_response
async def async_audio_speech(
self,
logging_obj,
url: str,
headers: dict,
request: VertexTextToSpeechRequest,
) -> HttpxBinaryResponseContent:
import base64
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI
)
response = await async_handler.post(
url=url,
headers=headers,
json=request, # type: ignore
)
if response.status_code != 200:
raise Exception(
f"Request did not return a 200 status code: {response.status_code}, {response.text}"
)
_json_response = response.json()
response_content = _json_response["audioContent"]
# Decode base64 to get binary content
binary_data = base64.b64decode(response_content)
# Create an httpx.Response object
response = httpx.Response(
status_code=200,
content=binary_data,
)
# Initialize the HttpxBinaryResponseContent instance
http_binary_response = HttpxBinaryResponseContent(response)
return http_binary_response
def validate_vertex_input(
input_data: VertexInput, kwargs: dict, optional_params: dict
) -> None:
# Remove None values
if input_data.get("text") is None:
input_data.pop("text", None)
if input_data.get("ssml") is None:
input_data.pop("ssml", None)
# Check if use_ssml is set
use_ssml = kwargs.get("use_ssml", optional_params.get("use_ssml", False))
if use_ssml:
if "text" in input_data:
input_data["ssml"] = input_data.pop("text")
elif "ssml" not in input_data:
raise ValueError("SSML input is required when use_ssml is True.")
else:
# LiteLLM will auto-detect if text is in ssml format
# check if "text" is an ssml - in this case we should pass it as ssml instead of text
if input_data:
_text = input_data.get("text", None) or ""
if "<speak>" in _text:
input_data["ssml"] = input_data.pop("text")
if not input_data:
raise ValueError("Either 'text' or 'ssml' must be provided.")
if "text" in input_data and "ssml" in input_data:
raise ValueError("Only one of 'text' or 'ssml' should be provided, not both.")

View File

@@ -0,0 +1,479 @@
"""
Vertex AI Text-to-Speech transformation
Maps OpenAI TTS spec to Google Cloud Text-to-Speech API
Reference: https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize
"""
import base64
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
import httpx
from litellm.llms.base_llm.text_to_speech.transformation import (
BaseTextToSpeechConfig,
TextToSpeechRequestData,
)
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from litellm.types.llms.vertex_ai_text_to_speech import (
VertexTextToSpeechAudioConfig,
VertexTextToSpeechInput,
VertexTextToSpeechVoice,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent
else:
LiteLLMLoggingObj = Any
HttpxBinaryResponseContent = Any
class VertexAITextToSpeechConfig(BaseTextToSpeechConfig, VertexBase):
"""
Configuration for Google Cloud/Vertex AI Text-to-Speech
Reference: https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize
"""
# Default values
DEFAULT_LANGUAGE_CODE = "en-US"
DEFAULT_VOICE_NAME = "en-US-Studio-O"
DEFAULT_AUDIO_ENCODING = "LINEAR16"
DEFAULT_SPEAKING_RATE = "1"
# API endpoint
TTS_API_URL = "https://texttospeech.googleapis.com/v1/text:synthesize"
# Voice name mappings from OpenAI voices to Google Cloud voices
# Users can pass either:
# 1. OpenAI voice names (alloy, echo, fable, onyx, nova, shimmer) - will be mapped
# 2. Google Cloud/Vertex AI voice names (en-US-Studio-O, en-US-Wavenet-D, etc.) - used directly
VOICE_MAPPINGS = {
"alloy": "en-US-Studio-O",
"echo": "en-US-Studio-M",
"fable": "en-GB-Studio-B",
"onyx": "en-US-Wavenet-D",
"nova": "en-US-Studio-O",
"shimmer": "en-US-Wavenet-F",
}
# Response format mappings from OpenAI to Google Cloud audio encoding
FORMAT_MAPPINGS = {
"mp3": "MP3",
"opus": "OGG_OPUS",
"aac": "MP3", # Google doesn't have AAC, use MP3
"flac": "FLAC",
"wav": "LINEAR16",
"pcm": "LINEAR16",
}
def __init__(self) -> None:
BaseTextToSpeechConfig.__init__(self)
VertexBase.__init__(self)
def _map_voice_to_vertex_format(
self,
voice: Optional[Union[str, Dict]],
) -> Tuple[Optional[str], Optional[Dict]]:
"""
Map voice to Vertex AI format.
Supports both:
1. OpenAI voice names (alloy, echo, fable, onyx, nova, shimmer) - will be mapped
2. Vertex AI voice names (en-US-Studio-O, en-US-Wavenet-D, etc.) - used directly
3. Dict with languageCode and name - used as-is
Returns:
Tuple of (voice_str, voice_dict) where:
- voice_str: Original string voice (for interface compatibility)
- voice_dict: Vertex AI format dict with languageCode and name
"""
if voice is None:
return None, None
if isinstance(voice, dict):
# Already in Vertex AI format
return None, voice
# voice is a string
voice_str = voice
# Map OpenAI voice if it's a known OpenAI voice, otherwise use directly
if voice in self.VOICE_MAPPINGS:
mapped_voice_name = self.VOICE_MAPPINGS[voice]
else:
# Assume it's already a Vertex AI voice name
mapped_voice_name = voice
# Extract language code from voice name (e.g., "en-US-Studio-O" -> "en-US")
parts = mapped_voice_name.split("-")
if len(parts) >= 2:
language_code = f"{parts[0]}-{parts[1]}"
else:
language_code = self.DEFAULT_LANGUAGE_CODE
voice_dict = {
"languageCode": language_code,
"name": mapped_voice_name,
}
return voice_str, voice_dict
def dispatch_text_to_speech(
self,
model: str,
input: str,
voice: Optional[Union[str, Dict]],
optional_params: Dict,
litellm_params_dict: Dict,
logging_obj: "LiteLLMLoggingObj",
timeout: Union[float, httpx.Timeout],
extra_headers: Optional[Dict[str, Any]],
base_llm_http_handler: Any,
aspeech: bool,
api_base: Optional[str],
api_key: Optional[str],
**kwargs: Any,
) -> Union[
"HttpxBinaryResponseContent",
Coroutine[Any, Any, "HttpxBinaryResponseContent"],
]:
"""
Dispatch method to handle Vertex AI TTS requests
This method encapsulates Vertex AI-specific credential resolution and parameter handling.
Voice mapping is handled in map_openai_params (similar to Azure AVA pattern).
Args:
base_llm_http_handler: The BaseLLMHTTPHandler instance from main.py
"""
# Resolve Vertex AI credentials using VertexBase helpers
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params_dict)
vertex_project = self.safe_get_vertex_ai_project(litellm_params_dict)
vertex_location = self.safe_get_vertex_ai_location(litellm_params_dict)
# Convert voice to string if it's a dict (extract name)
# Actual voice mapping happens in map_openai_params
voice_str: Optional[str] = None
if isinstance(voice, str):
voice_str = voice
elif isinstance(voice, dict):
# Extract voice name from dict if needed
voice_str = voice.get("name") if voice else None
# Store credentials in litellm_params for use in transform methods
litellm_params_dict.update(
{
"vertex_credentials": vertex_credentials,
"vertex_project": vertex_project,
"vertex_location": vertex_location,
"api_base": api_base,
}
)
# Call the text_to_speech_handler
response = base_llm_http_handler.text_to_speech_handler(
model=model,
input=input,
voice=voice_str,
text_to_speech_provider_config=self,
text_to_speech_optional_params=optional_params,
custom_llm_provider="vertex_ai",
litellm_params=litellm_params_dict,
logging_obj=logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=None,
_is_async=aspeech,
)
return response
def get_supported_openai_params(self, model: str) -> list:
"""
Vertex AI TTS supports these OpenAI parameters
Note: Vertex AI also supports additional parameters like audioConfig
which can be passed but are not part of the OpenAI spec
"""
return ["voice", "response_format", "speed"]
def map_openai_params(
self,
model: str,
optional_params: Dict,
voice: Optional[Union[str, Dict]] = None,
drop_params: bool = False,
kwargs: Dict = {},
) -> Tuple[Optional[str], Dict]:
"""
Map OpenAI parameters to Vertex AI TTS parameters
Voice handling (similar to Azure AVA):
- If voice is an OpenAI voice name (alloy, echo, etc.), it maps to a Vertex AI voice
- If voice is already a Vertex AI voice name (en-US-Studio-O, etc.), it's used directly
- If voice is a dict with languageCode and name, it's used as-is
Note: For Vertex AI, voice dict is stored in mapped_params["vertex_voice_dict"]
because the base class interface expects voice to be a string.
Returns:
Tuple of (mapped_voice_str, mapped_params)
"""
mapped_params: Dict[str, Any] = {}
##########################################################
# Map voice using helper
##########################################################
mapped_voice_str, voice_dict = self._map_voice_to_vertex_format(voice)
if voice_dict is not None:
mapped_params["vertex_voice_dict"] = voice_dict
# Map response format
if "response_format" in optional_params:
format_name = optional_params["response_format"]
if format_name in self.FORMAT_MAPPINGS:
mapped_params["audioEncoding"] = self.FORMAT_MAPPINGS[format_name]
else:
# Try to use it directly as Google Cloud format
mapped_params["audioEncoding"] = format_name
else:
# Default to LINEAR16
mapped_params["audioEncoding"] = self.DEFAULT_AUDIO_ENCODING
# Map speed (OpenAI: 0.25-4.0, Vertex AI: speakingRate 0.25-4.0)
if "speed" in optional_params:
speed = optional_params["speed"]
if speed is not None:
mapped_params["speakingRate"] = str(speed)
# Pass through Vertex AI-specific parameters from kwargs
if "audioConfig" in kwargs:
mapped_params["audioConfig"] = kwargs["audioConfig"]
if "use_ssml" in kwargs:
mapped_params["use_ssml"] = kwargs["use_ssml"]
return mapped_voice_str, mapped_params
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate Vertex AI environment and set up authentication headers
Note: Actual authentication is handled in transform_text_to_speech_request
because Vertex AI requires OAuth2 token refresh
"""
validated_headers = headers.copy()
# Content-Type for JSON
validated_headers["Content-Type"] = "application/json"
validated_headers["charset"] = "UTF-8"
return validated_headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for Vertex AI TTS request
Google Cloud TTS endpoint: https://texttospeech.googleapis.com/v1/text:synthesize
"""
if api_base:
return api_base
return self.TTS_API_URL
def _validate_vertex_input(
self,
input_data: VertexTextToSpeechInput,
optional_params: Dict,
) -> VertexTextToSpeechInput:
"""
Validate and transform input for Vertex AI TTS
Handles text vs SSML input detection and validation
"""
# Remove None values
if input_data.get("text") is None:
input_data.pop("text", None)
if input_data.get("ssml") is None:
input_data.pop("ssml", None)
# Check if use_ssml is set
use_ssml = optional_params.get("use_ssml", False)
if use_ssml:
if "text" in input_data:
input_data["ssml"] = input_data.pop("text")
elif "ssml" not in input_data:
raise ValueError("SSML input is required when use_ssml is True.")
else:
# LiteLLM will auto-detect if text is in ssml format
# check if "text" is an ssml - in this case we should pass it as ssml instead of text
if input_data:
_text = input_data.get("text", None) or ""
if "<speak>" in _text:
input_data["ssml"] = input_data.pop("text")
if not input_data:
raise ValueError("Either 'text' or 'ssml' must be provided.")
if "text" in input_data and "ssml" in input_data:
raise ValueError(
"Only one of 'text' or 'ssml' should be provided, not both."
)
return input_data
def transform_text_to_speech_request(
self,
model: str,
input: str,
voice: Optional[str],
optional_params: Dict,
litellm_params: Dict,
headers: dict,
) -> TextToSpeechRequestData:
"""
Transform OpenAI TTS request to Vertex AI TTS format
This method handles:
1. Authentication with Vertex AI
2. Building the request body
3. Setting up headers
Returns:
TextToSpeechRequestData: Contains dict_body and headers
"""
# Get Vertex AI credentials from litellm_params
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = litellm_params.get(
"vertex_credentials"
)
vertex_project: Optional[str] = litellm_params.get("vertex_project")
####### Authenticate with Vertex AI ########
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = self._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=litellm_params.get("vertex_location"),
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base=litellm_params.get("api_base"),
)
# Set authentication headers
headers["Authorization"] = f"Bearer {auth_header}"
headers["x-goog-user-project"] = vertex_project
####### Build the request ################
vertex_input = VertexTextToSpeechInput(text=input)
vertex_input = self._validate_vertex_input(vertex_input, optional_params)
# Build voice configuration
# Check for voice dict stored in:
# 1. litellm_params by dispatch method
# 2. optional_params by map_openai_params
voice_dict = litellm_params.get("vertex_voice_dict") or optional_params.get(
"vertex_voice_dict"
)
if voice_dict is not None and isinstance(voice_dict, dict):
vertex_voice = VertexTextToSpeechVoice(**voice_dict)
elif voice is not None and isinstance(voice, str):
# Handle string voice (shouldn't normally happen if dispatch was called)
parts = voice.split("-")
if len(parts) >= 2:
language_code = f"{parts[0]}-{parts[1]}"
else:
language_code = self.DEFAULT_LANGUAGE_CODE
vertex_voice = VertexTextToSpeechVoice(
languageCode=language_code,
name=voice,
)
else:
# Use defaults
vertex_voice = VertexTextToSpeechVoice(
languageCode=self.DEFAULT_LANGUAGE_CODE,
name=self.DEFAULT_VOICE_NAME,
)
# Build audio configuration
audio_encoding = optional_params.get(
"audioEncoding", self.DEFAULT_AUDIO_ENCODING
)
speaking_rate = optional_params.get("speakingRate", self.DEFAULT_SPEAKING_RATE)
# Check for full audioConfig in optional_params
if "audioConfig" in optional_params:
vertex_audio_config = VertexTextToSpeechAudioConfig(
**optional_params["audioConfig"]
)
else:
vertex_audio_config = VertexTextToSpeechAudioConfig(
audioEncoding=audio_encoding,
speakingRate=speaking_rate,
)
request_body: Dict[str, Any] = {
"input": dict(vertex_input),
"voice": dict(vertex_voice),
"audioConfig": dict(vertex_audio_config),
}
return TextToSpeechRequestData(
dict_body=request_body,
headers=headers,
)
def transform_text_to_speech_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: "LiteLLMLoggingObj",
) -> "HttpxBinaryResponseContent":
"""
Transform Vertex AI TTS response to standard format
Vertex AI returns JSON with base64-encoded audio content.
We decode it and return as HttpxBinaryResponseContent.
"""
from litellm.types.llms.openai import HttpxBinaryResponseContent
# Parse JSON response
_json_response = raw_response.json()
# Get base64-encoded audio content
response_content = _json_response.get("audioContent")
if not response_content:
raise ValueError("No audioContent in Vertex AI TTS response")
# Decode base64 to get binary content
binary_data = base64.b64decode(response_content)
# Create an httpx.Response object with the binary data
response = httpx.Response(
status_code=200,
content=binary_data,
)
# Initialize the HttpxBinaryResponseContent instance
return HttpxBinaryResponseContent(response)

View File

@@ -0,0 +1,4 @@
from .rag_api.transformation import VertexVectorStoreConfig
from .search_api.transformation import VertexSearchAPIVectorStoreConfig
__all__ = ["VertexVectorStoreConfig", "VertexSearchAPIVectorStoreConfig"]

View File

@@ -0,0 +1,306 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
BaseVectorStoreAuthCredentials,
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreIndexEndpoints,
VectorStoreResultContent,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
"""
Configuration for Vertex AI Vector Store RAG API
This implementation uses the Vertex AI RAG Engine API for vector store operations.
"""
def __init__(self):
super().__init__()
def get_auth_credentials(
self, litellm_params: dict
) -> BaseVectorStoreAuthCredentials:
# Get credentials and project info
vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
vertex_project = self.get_vertex_ai_project(dict(litellm_params))
# Get access token using the base class method
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return {
"headers": {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
}
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
return {
"read": [("POST", ":retrieveContexts")],
"write": [("POST", "/ragCorpora")],
}
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""
Validate and set up authentication for Vertex AI RAG API
"""
litellm_params = litellm_params or GenericLiteLLMParams()
auth_headers = self.get_auth_credentials(litellm_params.model_dump())
headers.update(auth_headers.get("headers", {}))
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the Base endpoint for Vertex AI RAG API
"""
vertex_location = self.get_vertex_ai_location(litellm_params)
vertex_project = self.get_vertex_ai_project(litellm_params)
if api_base:
return api_base.rstrip("/")
# Vertex AI RAG API endpoint for retrieveContexts
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}"
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Vertex AI RAG API
"""
# Convert query to string if it's a list
if isinstance(query, list):
query = " ".join(query)
# Vertex AI RAG API endpoint for retrieving contexts
url = f"{api_base}:retrieveContexts"
# Use helper methods to get project and location, then construct full rag corpus path
vertex_project = self.get_vertex_ai_project(litellm_params)
vertex_location = self.get_vertex_ai_location(litellm_params)
# Handle both full corpus path and just corpus ID
if vector_store_id.startswith("projects/"):
# Already a full path
full_rag_corpus = vector_store_id
else:
# Just the corpus ID, construct full path
full_rag_corpus = f"projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{vector_store_id}"
# Build the request body for Vertex AI RAG API
request_body: Dict[str, Any] = {
"vertex_rag_store": {"rag_resources": [{"rag_corpus": full_rag_corpus}]},
"query": {"text": query},
}
#########################################################
# Update logging object with details of the request
#########################################################
litellm_logging_obj.model_call_details["query"] = query
# Add optional parameters
max_num_results = vector_store_search_optional_params.get("max_num_results")
if max_num_results is not None:
request_body["query"]["rag_retrieval_config"] = {"top_k": max_num_results}
# Add filters if provided
filters = vector_store_search_optional_params.get("filters")
if filters is not None:
if "rag_retrieval_config" not in request_body["query"]:
request_body["query"]["rag_retrieval_config"] = {}
request_body["query"]["rag_retrieval_config"]["filter"] = filters
# Add ranking options if provided
ranking_options = vector_store_search_optional_params.get("ranking_options")
if ranking_options is not None:
if "rag_retrieval_config" not in request_body["query"]:
request_body["query"]["rag_retrieval_config"] = {}
request_body["query"]["rag_retrieval_config"]["ranking"] = ranking_options
return url, request_body
def transform_search_vector_store_response(
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
) -> VectorStoreSearchResponse:
"""
Transform Vertex AI RAG API response to standard vector store search response
"""
try:
response_json = response.json()
# Extract contexts from Vertex AI response - handle nested structure
contexts = response_json.get("contexts", {}).get("contexts", [])
# Transform contexts to standard format
search_results = []
for context in contexts:
content = [
VectorStoreResultContent(
text=context.get("text", ""),
type="text",
)
]
# Extract file information
source_uri = context.get("sourceUri", "")
source_display_name = context.get("sourceDisplayName", "")
# Generate file_id from source URI or use display name as fallback
file_id = source_uri if source_uri else source_display_name
filename = (
source_display_name if source_display_name else "Unknown Document"
)
# Build attributes with available metadata
attributes = {}
if source_uri:
attributes["sourceUri"] = source_uri
if source_display_name:
attributes["sourceDisplayName"] = source_display_name
# Add page span information if available
page_span = context.get("pageSpan", {})
if page_span:
attributes["pageSpan"] = page_span
result = VectorStoreSearchResult(
score=context.get("score", 0.0),
content=content,
file_id=file_id,
filename=filename,
attributes=attributes,
)
search_results.append(result)
return VectorStoreSearchResponse(
object="vector_store.search_results.page",
search_query=litellm_logging_obj.model_call_details.get("query", ""),
data=search_results,
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform create request for Vertex AI RAG Corpus
"""
url = f"{api_base}/ragCorpora" # Base URL for creating RAG corpus
# Build the request body for Vertex AI RAG Corpus creation
request_body: Dict[str, Any] = {
"display_name": vector_store_create_optional_params.get(
"name", "litellm-vector-store"
),
"description": "Vector store created via LiteLLM",
}
# Add metadata if provided
metadata = vector_store_create_optional_params.get("metadata")
if metadata is not None:
request_body["labels"] = metadata
return url, request_body
def transform_create_vector_store_response(
self, response: httpx.Response
) -> VectorStoreCreateResponse:
"""
Transform Vertex AI RAG Corpus creation response to standard vector store response
"""
try:
response_json = response.json()
# Extract the corpus ID from the response name
corpus_name = response_json.get("name", "")
corpus_id = (
corpus_name.split("/")[-1] if "/" in corpus_name else corpus_name
)
# Handle createTime conversion
create_time = response_json.get("createTime", 0)
if isinstance(create_time, str):
# Convert ISO timestamp to Unix timestamp
from datetime import datetime
try:
dt = datetime.fromisoformat(create_time.replace("Z", "+00:00"))
create_time = int(dt.timestamp())
except ValueError:
create_time = 0
elif not isinstance(create_time, int):
create_time = 0
# Handle labels safely
labels = response_json.get("labels", {})
metadata = labels if isinstance(labels, dict) else {}
return VectorStoreCreateResponse(
id=corpus_id,
object="vector_store",
created_at=create_time,
name=response_json.get("display_name", ""),
bytes=0, # Vertex AI doesn't provide byte count in the same way
file_counts={
"in_progress": 0,
"completed": 0,
"failed": 0,
"cancelled": 0,
"total": 0,
},
status="completed", # Vertex AI corpus creation is typically synchronous
expires_after=None,
expires_at=None,
last_active_at=None,
metadata=metadata,
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)

View File

@@ -0,0 +1,267 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm import get_model_info
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
BaseVectorStoreAuthCredentials,
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreIndexEndpoints,
VectorStoreResultContent,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexSearchAPIVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
"""
Configuration for Vertex AI Search API Vector Store
This implementation uses the Vertex AI Search API for vector store operations.
"""
def __init__(self):
super().__init__()
def get_auth_credentials(
self, litellm_params: dict
) -> BaseVectorStoreAuthCredentials:
# Get credentials and project info
vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
vertex_project = self.get_vertex_ai_project(dict(litellm_params))
# Get access token using the base class method
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return {
"headers": {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
}
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
return {
"read": [("POST", ":search")],
"write": [],
}
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""
Validate and set up authentication for Vertex AI RAG API
"""
litellm_params = litellm_params or GenericLiteLLMParams()
auth_headers = self.get_auth_credentials(litellm_params.model_dump())
headers.update(auth_headers.get("headers", {}))
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the Base endpoint for Vertex AI Search API
"""
vertex_location = self.get_vertex_ai_location(litellm_params)
vertex_project = self.get_vertex_ai_project(litellm_params)
collection_id = (
litellm_params.get("vertex_collection_id") or "default_collection"
)
datastore_id = litellm_params.get("vector_store_id")
if not datastore_id:
raise ValueError("vector_store_id is required")
if api_base:
return api_base.rstrip("/")
# Vertex AI Search API endpoint for search
return (
f"https://discoveryengine.googleapis.com/v1/"
f"projects/{vertex_project}/locations/{vertex_location}/"
f"collections/{collection_id}/dataStores/{datastore_id}/servingConfigs/default_config"
)
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Vertex AI RAG API
"""
# Convert query to string if it's a list
if isinstance(query, list):
query = " ".join(query)
# Vertex AI RAG API endpoint for retrieving contexts
url = f"{api_base}:search"
# Construct full rag corpus path
# Build the request body for Vertex AI Search API
request_body = {"query": query, "pageSize": 10}
#########################################################
# Update logging object with details of the request
#########################################################
litellm_logging_obj.model_call_details["query"] = query
return url, request_body
def transform_search_vector_store_response(
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
) -> VectorStoreSearchResponse:
"""
Transform Vertex AI Search API response to standard vector store search response
Handles the format from Discovery Engine Search API which returns:
{
"results": [
{
"id": "...",
"document": {
"derivedStructData": {
"title": "...",
"link": "...",
"snippets": [...]
}
}
}
]
}
"""
try:
response_json = response.json()
# Extract results from Vertex AI Search API response
results = response_json.get("results", [])
# Transform results to standard format
search_results: List[VectorStoreSearchResult] = []
for result in results:
document = result.get("document", {})
derived_data = document.get("derivedStructData", {})
# Extract text content from snippets
snippets = derived_data.get("snippets", [])
text_content = ""
if snippets:
# Combine all snippets into one text
text_parts = [
snippet.get("snippet", snippet.get("htmlSnippet", ""))
for snippet in snippets
]
text_content = " ".join(text_parts)
# If no snippets, use title as fallback
if not text_content:
text_content = derived_data.get("title", "")
content = [
VectorStoreResultContent(
text=text_content,
type="text",
)
]
# Extract file/document information
document_link = derived_data.get("link", "")
document_title = derived_data.get("title", "")
document_id = result.get("id", "")
# Use link as file_id if available, otherwise use document ID
file_id = document_link if document_link else document_id
filename = document_title if document_title else "Unknown Document"
# Build attributes with available metadata
attributes = {
"document_id": document_id,
}
if document_link:
attributes["link"] = document_link
if document_title:
attributes["title"] = document_title
# Add display link if available
display_link = derived_data.get("displayLink", "")
if display_link:
attributes["displayLink"] = display_link
# Add formatted URL if available
formatted_url = derived_data.get("formattedUrl", "")
if formatted_url:
attributes["formattedUrl"] = formatted_url
# Note: Search API doesn't provide explicit scores in the response
# You can use the position/rank as an implicit score
score = 1.0 / (
float(search_results.__len__() + 1)
) # Decreasing score based on position
result_obj = VectorStoreSearchResult(
score=score,
content=content,
file_id=file_id,
filename=filename,
attributes=attributes,
)
search_results.append(result_obj)
return VectorStoreSearchResponse(
object="vector_store.search_results.page",
search_query=litellm_logging_obj.model_call_details.get("query", ""),
data=search_results,
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
raise NotImplementedError
def transform_create_vector_store_response(
self, response: httpx.Response
) -> VectorStoreCreateResponse:
raise NotImplementedError
def calculate_vector_store_cost(
self,
response: VectorStoreSearchResponse,
) -> Tuple[float, float]:
model_info = get_model_info(
model="vertex_ai/search_api",
)
input_cost_per_query = model_info.get("input_cost_per_query") or 0.0
return input_cost_per_query, 0.0

View File

@@ -0,0 +1,123 @@
"""
AWS Workload Identity Federation (WIF) auth for Vertex AI.
Handles explicit AWS credentials for GCP WIF token exchange,
bypassing the EC2 instance metadata service.
When aws_* keys are present in the WIF credential JSON, this module
uses BaseAWSLLM to obtain AWS credentials and wraps them in a custom
AwsSecurityCredentialsSupplier for google-auth.
"""
from typing import Dict
GOOGLE_IMPORT_ERROR_MESSAGE = (
"Google Cloud SDK not found. Install it with: pip install 'litellm[google]' "
"or pip install google-cloud-aiplatform"
)
# AWS params recognized in WIF credential JSON for explicit auth.
# These match the kwargs accepted by BaseAWSLLM.get_credentials().
_AWS_CREDENTIAL_KEYS = frozenset(
{
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_session_name",
"aws_profile_name",
"aws_role_name",
"aws_web_identity_token",
"aws_sts_endpoint",
"aws_external_id",
}
)
class VertexAIAwsWifAuth:
"""
Handles AWS-to-GCP Workload Identity Federation credential creation
for Vertex AI, using explicit AWS credentials rather than EC2 metadata.
"""
@staticmethod
def extract_aws_params(json_obj: dict) -> Dict[str, str]:
"""
Extract LiteLLM-specific aws_* keys from a WIF credential JSON dict.
Returns a dict of {param_name: value} for any recognized aws_* keys
found in the JSON. Returns empty dict if none are present.
"""
return {key: json_obj[key] for key in _AWS_CREDENTIAL_KEYS if key in json_obj}
@staticmethod
def credentials_from_explicit_aws(json_obj, aws_params, scopes):
"""
Create GCP credentials using explicit AWS credentials for WIF.
Uses BaseAWSLLM to obtain AWS credentials (via STS AssumeRole, profile,
static keys, etc.), then wraps them in a custom AwsSecurityCredentialsSupplier
so that google-auth bypasses the EC2 metadata service.
Args:
json_obj: The WIF credential JSON dict (contains audience, token_url, etc.)
aws_params: Dict of aws_* params extracted from json_obj
scopes: OAuth scopes for the GCP credentials
"""
try:
from google.auth import aws
except ImportError:
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.vertex_ai.aws_credentials_supplier import (
AwsCredentialsSupplier,
)
# Validate region first — required for the GCP token exchange.
# Check before get_credentials() to avoid unnecessary AWS API calls
# (e.g. STS AssumeRole) on misconfiguration.
aws_region = aws_params.get("aws_region_name")
if not aws_region:
raise ValueError(
"aws_region_name is required in the WIF credential JSON "
"when using explicit AWS authentication. Add "
'"aws_region_name": "<your-region>" to your credential file.'
)
# Build a credentials provider that re-resolves AWS creds on each call.
# This ensures rotated/refreshed STS tokens are picked up during
# long-running processes when google-auth refreshes the GCP token.
base_aws = BaseAWSLLM()
aws_params_copy = dict(aws_params) # avoid mutating caller's dict
def _get_aws_credentials():
return base_aws.get_credentials(**aws_params_copy)
# Create the custom supplier with a lazy credentials provider
supplier = AwsCredentialsSupplier(
credentials_provider=_get_aws_credentials,
aws_region=aws_region,
)
# Build kwargs for aws.Credentials — forward optional fields from JSON
creds_kwargs = dict(
audience=json_obj.get("audience"),
subject_token_type=json_obj.get("subject_token_type"),
token_url=json_obj.get("token_url"),
credential_source=None, # Not using metadata endpoints
aws_security_credentials_supplier=supplier,
service_account_impersonation_url=json_obj.get(
"service_account_impersonation_url"
),
)
# Forward universe_domain if present (defaults to googleapis.com)
if "universe_domain" in json_obj:
creds_kwargs["universe_domain"] = json_obj["universe_domain"]
creds = aws.Credentials(**creds_kwargs)
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
creds = creds.with_scopes(scopes)
return creds

View File

@@ -0,0 +1,788 @@
import json
import os
import time
from typing import Any, Callable, Optional, cast
import httpx
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.bedrock.common_utils import ModelResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TextStreamer:
"""
Fake streaming iterator for Vertex AI Model Garden calls
"""
def __init__(self, text):
self.text = text.split() # let's assume words as a streaming unit
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index < len(self.text):
result = self.text[self.index]
self.index += 1
return result
else:
raise StopIteration
def __aiter__(self):
return self
async def __anext__(self):
if self.index < len(self.text):
result = self.text[self.index]
self.index += 1
return result
else:
raise StopAsyncIteration # once we run out of data to stream, we raise this error
def _get_client_cache_key(
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
):
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
return _cache_key
def _get_client_from_cache(client_cache_key: str):
return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
litellm.in_memory_llm_clients_cache.set_cache(
key=client_cache_key,
value=vertex_llm_model,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
def completion( # noqa: PLR0915
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
):
"""
NON-GEMINI/ANTHROPIC CALLS.
This is the handler for OLDER PALM MODELS and VERTEX AI MODEL GARDEN
For Vertex AI Anthropic: `vertex_anthropic.py`
For Gemini: `vertex_httpx.py`
"""
try:
import vertexai
except Exception:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
import google.auth # type: ignore
from google.cloud import aiplatform # type: ignore
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types, # type: ignore
)
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Value # type: ignore
from vertexai.language_models import CodeGenerationModel, TextGenerationModel
from vertexai.preview.generative_models import GenerativeModel
from vertexai.preview.language_models import ChatModel, CodeChatModel
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
_cache_key = _get_client_cache_key(
model=model, vertex_project=vertex_project, vertex_location=vertex_location
)
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
# Load credentials - needed for both vertexai.init() and PredictionServiceClient
from google.auth.credentials import Credentials
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
if _vertex_llm_model_object is None:
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project,
location=vertex_location,
credentials=cast(Credentials, creds),
)
## Load Config
config = litellm.VertexAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
## Process safety settings into format expected by vertex AI
safety_settings = None
if "safety_settings" in optional_params:
safety_settings = optional_params.pop("safety_settings")
if not isinstance(safety_settings, list):
raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts")
safety_settings = [
gapic_content_types.SafetySetting(x) for x in safety_settings
]
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join(
[
message.get("content")
for message in messages
if isinstance(message.get("content", None), str)
]
)
mode = ""
request_str = ""
response_obj = None
instances = None
client_options = {
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
}
fake_stream = False
if (
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models:
llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
model
)
mode = "text"
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models:
llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
model
)
mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
fake_stream = True
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
elif model == "private":
mode = "private"
model = optional_params.pop("model_id", None)
# private endpoint requires a dict instead of JSON
instances = [optional_params.copy()]
instances[0]["prompt"] = prompt
llm_model = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
else: # assume vertex model garden on public endpoint
mode = "custom"
instances = [optional_params.copy()]
instances[0]["prompt"] = prompt
instances = [
json_format.ParseDict(instance_dict, Value()) # type: ignore[misc]
for instance_dict in instances
]
# Will determine the API used based on async parameter
llm_model = None
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
if acompletion is True:
data = {
"llm_model": llm_model,
"mode": mode,
"prompt": prompt,
"logging_obj": logging_obj,
"request_str": request_str,
"model": model,
"model_response": model_response,
"encoding": encoding,
"messages": messages,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
"vertex_location": vertex_location,
"vertex_project": vertex_project,
"vertex_credentials": creds,
"safety_settings": safety_settings,
**optional_params,
}
if optional_params.get("stream", False) is True:
# async streaming
return async_streaming(**data)
return async_completion(**data)
completion_response = None
stream = optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
if mode == "chat":
chat = llm_model.start_chat()
request_str += "chat = llm_model.start_chat()\n"
if fake_stream is not True and stream is True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = chat.send_message_streaming(prompt, **optional_params)
return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
if fake_stream is not True and stream is True:
request_str += (
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params)
return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = llm_model.predict(prompt, **optional_params).text
elif mode == "custom":
"""
Vertex AI Model Garden
"""
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
llm_model = aiplatform.gapic.PredictionServiceClient(
client_options=client_options,
credentials=creds, # type: ignore[arg-type]
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options}, credentials=...)\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response = llm_model.predict(
endpoint=endpoint_path, instances=instances
).predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream is True:
response = TextStreamer(completion_response)
return response
elif mode == "private":
"""
Vertex AI Model Garden deployed on private endpoint
"""
if instances is None:
raise ValueError("instances are required for private endpoint")
if llm_model is None:
raise ValueError("Unable to pick client for private endpoint")
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
request_str += f"llm_model.predict(instances={instances})\n"
response = llm_model.predict(instances=instances).predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream is True:
response = TextStreamer(completion_response)
return response
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if isinstance(completion_response, litellm.Message):
model_response.choices[0].message = completion_response # type: ignore
elif len(str(completion_response)) > 0:
model_response.choices[0].message.content = str(completion_response) # type: ignore
model_response.created = int(time.time())
model_response.model = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response.choices[0].finish_reason = map_finish_reason( # type: ignore[assignment]
response_obj.candidates[0].finish_reason.name
)
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
# init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None:
if hasattr(response_obj, "usage_metadata") and hasattr(
response_obj.usage_metadata, "prompt_token_count"
):
prompt_tokens = response_obj.usage_metadata.prompt_token_count
completion_tokens = (
response_obj.usage_metadata.candidates_token_count
)
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
if fake_stream is True and stream is True:
return ModelResponseIterator(model_response)
return model_response
except Exception as e:
if isinstance(e, VertexAIError):
raise e
raise litellm.APIConnectionError(
message=str(e), llm_provider="vertex_ai", model=model
)
async def async_completion( # noqa: PLR0915
llm_model,
mode: str,
prompt: str,
model: str,
messages: list,
model_response: ModelResponse,
request_str: str,
print_verbose: Callable,
logging_obj,
encoding,
client_options=None,
instances=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
safety_settings=None,
**optional_params,
):
"""
Add support for acompletion calls for gemini-pro
"""
try:
response_obj = None
completion_response = None
if mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text":
# gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "custom":
"""
Vertex AI Model Garden
"""
from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options,
credentials=vertex_credentials,
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options}, credentials=...)\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private":
request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if isinstance(completion_response, litellm.Message):
model_response.choices[0].message = completion_response # type: ignore
elif len(str(completion_response)) > 0:
model_response.choices[0].message.content = str( # type: ignore
completion_response
)
model_response.created = int(time.time())
model_response.model = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response.choices[0].finish_reason = map_finish_reason( # type: ignore[assignment]
response_obj.candidates[0].finish_reason.name
)
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
# init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None and (
hasattr(response_obj, "usage_metadata")
and hasattr(response_obj.usage_metadata, "prompt_token_count")
):
prompt_tokens = response_obj.usage_metadata.prompt_token_count
completion_tokens = response_obj.usage_metadata.candidates_token_count
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
# set usage
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming( # noqa: PLR0915
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
messages: list,
print_verbose: Callable,
logging_obj,
request_str: str,
encoding=None,
client_options=None,
instances=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
safety_settings=None,
**optional_params,
):
"""
Add support for async streaming calls for gemini-pro
"""
response: Any = None
if mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = chat.send_message_streaming_async(prompt, **optional_params)
elif mode == "text":
optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom":
from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
stream = optional_params.pop("stream", None)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options,
credentials=vertex_credentials,
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options}, credentials=...)\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream:
response = TextStreamer(completion_response)
elif mode == "private":
if instances is None:
raise ValueError("Instances are required for private endpoint")
stream = optional_params.pop("stream", None)
_ = instances[0].pop("stream", None)
request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream:
response = TextStreamer(completion_response)
if response is None:
raise ValueError("Unable to generate response")
logging_obj.post_call(input=prompt, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
return streamwrapper

View File

@@ -0,0 +1,24 @@
from litellm.llms.base_llm.chat.transformation import BaseConfig
def get_vertex_ai_partner_model_config(
model: str, vertex_publisher_or_api_spec: str
) -> BaseConfig:
"""Return config for handling response transformation for vertex ai partner models"""
if vertex_publisher_or_api_spec == "anthropic":
from .anthropic.transformation import VertexAIAnthropicConfig
return VertexAIAnthropicConfig()
elif vertex_publisher_or_api_spec == "ai21":
from .ai21.transformation import VertexAIAi21Config
return VertexAIAi21Config()
elif (
vertex_publisher_or_api_spec == "openapi"
or vertex_publisher_or_api_spec == "mistralai"
):
from .llama3.transformation import VertexAILlama3Config
return VertexAILlama3Config()
else:
raise ValueError(f"Unsupported model: {model}")

View File

@@ -0,0 +1,60 @@
import types
from typing import Optional
import litellm
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
class VertexAIAi21Config(OpenAIGPTConfig):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
-> Supports all OpenAI parameters
"""
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
if "max_completion_tokens" in non_default_params:
non_default_params["max_tokens"] = non_default_params.pop(
"max_completion_tokens"
)
return litellm.OpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)

View File

@@ -0,0 +1,167 @@
from typing import Any, Dict, List, Optional, Tuple
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
AnthropicMessagesConfig,
)
from litellm.types.llms.anthropic import (
ANTHROPIC_BETA_HEADER_VALUES,
ANTHROPIC_HOSTED_TOOLS,
)
from litellm.types.llms.anthropic_tool_search import get_tool_search_beta_header
from litellm.types.llms.vertex_ai import VertexPartnerProvider
from litellm.types.router import GenericLiteLLMParams
from ....vertex_llm_base import VertexBase
class VertexAIPartnerModelsAnthropicMessagesConfig(AnthropicMessagesConfig, VertexBase):
def validate_anthropic_messages_environment(
self,
headers: dict,
model: str,
messages: List[Any],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[dict, Optional[str]]:
"""
OPTIONAL
Validate the environment for the request
"""
vertex_ai_project = VertexBase.safe_get_vertex_ai_project(litellm_params)
vertex_ai_location = VertexBase.safe_get_vertex_ai_location(litellm_params)
project_id: Optional[str] = None
if "Authorization" not in headers:
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
litellm_params
)
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_ai_project,
custom_llm_provider="vertex_ai",
)
headers["Authorization"] = f"Bearer {access_token}"
else:
# Authorization already in headers, but we still need project_id
project_id = vertex_ai_project
# Always calculate api_base if not provided, regardless of Authorization header
if api_base is None:
api_base = self.get_complete_vertex_url(
custom_api_base=api_base,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
project_id=project_id or "",
partner=VertexPartnerProvider.claude,
stream=optional_params.get("stream", False),
model=model,
)
headers["content-type"] = "application/json"
# Add beta headers for Vertex AI
tools = optional_params.get("tools", [])
beta_values: set[str] = set()
# Get existing beta headers if any
existing_beta = headers.get("anthropic-beta")
if existing_beta:
beta_values.update(b.strip() for b in existing_beta.split(","))
# Check for context management
context_management_param = optional_params.get("context_management")
if context_management_param is not None:
# Check edits array for compact_20260112 type
edits = context_management_param.get("edits", [])
has_compact = False
has_other = False
for edit in edits:
edit_type = edit.get("type", "")
if edit_type == "compact_20260112":
has_compact = True
else:
has_other = True
# Add compact header if any compact edits exist
if has_compact:
beta_values.add(ANTHROPIC_BETA_HEADER_VALUES.COMPACT_2026_01_12.value)
# Add context management header if any other edits exist
if has_other:
beta_values.add(
ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value
)
# Check for web search tool
for tool in tools:
if isinstance(tool, dict) and tool.get("type", "").startswith(
ANTHROPIC_HOSTED_TOOLS.WEB_SEARCH.value
):
beta_values.add(
ANTHROPIC_BETA_HEADER_VALUES.WEB_SEARCH_2025_03_05.value
)
break
# Check for tool search tools - Vertex AI uses different beta header
anthropic_model_info = AnthropicModelInfo()
if anthropic_model_info.is_tool_search_used(tools):
beta_values.add(get_tool_search_beta_header("vertex_ai"))
if beta_values:
headers["anthropic-beta"] = ",".join(beta_values)
return headers, api_base
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
if api_base is None:
raise ValueError(
"api_base is required. Unable to determine the correct api_base for the request."
)
return api_base # no transformation is needed - handled in validate_environment
def transform_anthropic_messages_request(
self,
model: str,
messages: List[Dict],
anthropic_messages_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
anthropic_messages_request = super().transform_anthropic_messages_request(
model=model,
messages=messages,
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
litellm_params=litellm_params,
headers=headers,
)
anthropic_messages_request["anthropic_version"] = "vertex-2023-10-16"
anthropic_messages_request.pop(
"model", None
) # do not pass model in request body to vertex ai
anthropic_messages_request.pop(
"output_format", None
) # do not pass output_format in request body to vertex ai - vertex ai does not support output_format as yet
anthropic_messages_request.pop(
"output_config", None
) # do not pass output_config in request body to vertex ai - vertex ai does not support output_config
return anthropic_messages_request

View File

@@ -0,0 +1,230 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
from typing import Any, List, Optional
import httpx
import litellm
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ....anthropic.chat.transformation import AnthropicConfig
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig(AnthropicConfig):
"""
Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
@property
def custom_llm_provider(self) -> Optional[str]:
return "vertex_ai"
def _add_context_management_beta_headers(
self, beta_set: set, context_management: dict
) -> None:
"""
Add context_management beta headers to the beta_set.
- If any edit has type "compact_20260112", add compact-2026-01-12 header
- For all other edits, add context-management-2025-06-27 header
Args:
beta_set: Set of beta headers to modify in-place
context_management: The context_management dict from optional_params
"""
from litellm.types.llms.anthropic import ANTHROPIC_BETA_HEADER_VALUES
edits = context_management.get("edits", [])
has_compact = False
has_other = False
for edit in edits:
edit_type = edit.get("type", "")
if edit_type == "compact_20260112":
has_compact = True
else:
has_other = True
# Add compact header if any compact edits exist
if has_compact:
beta_set.add(ANTHROPIC_BETA_HEADER_VALUES.COMPACT_2026_01_12.value)
# Add context management header if any other edits exist
if has_other:
beta_set.add(
ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
data = super().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
data.pop("model", None) # vertex anthropic doesn't accept 'model' parameter
# VertexAI doesn't support output_format parameter, remove it if present
data.pop("output_format", None)
# VertexAI doesn't support output_config parameter, remove it if present
data.pop("output_config", None)
tools = optional_params.get("tools")
tool_search_used = self.is_tool_search_used(tools)
auto_betas = self.get_anthropic_beta_list(
model=model,
optional_params=optional_params,
computer_tool_used=self.is_computer_tool_used(tools),
prompt_caching_set=self.is_cache_control_set(messages),
file_id_used=self.is_file_id_used(messages),
mcp_server_used=self.is_mcp_server_used(optional_params.get("mcp_servers")),
)
beta_set = set(auto_betas)
if tool_search_used:
beta_set.add(
"tool-search-tool-2025-10-19"
) # Vertex requires this header for tool search
# Add context_management beta headers (compact and/or context-management)
context_management = optional_params.get("context_management")
if context_management:
self._add_context_management_beta_headers(beta_set, context_management)
extra_headers = optional_params.get("extra_headers") or {}
anthropic_beta_value = extra_headers.get("anthropic-beta", "")
if isinstance(anthropic_beta_value, str) and anthropic_beta_value:
for beta in anthropic_beta_value.split(","):
beta = beta.strip()
if beta:
beta_set.add(beta)
elif isinstance(anthropic_beta_value, list):
beta_set.update(anthropic_beta_value)
data.pop("extra_headers", None)
if beta_set:
data["anthropic_beta"] = list(beta_set)
headers["anthropic-beta"] = ",".join(beta_set)
return data
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Override parent method to ensure VertexAI always uses tool-based structured outputs.
VertexAI doesn't support the output_format parameter, so we force all models
to use the tool-based approach for structured outputs.
"""
# Temporarily override model name to force tool-based approach
# This ensures Claude Sonnet 4.5 uses tools instead of output_format
original_model = model
if "response_format" in non_default_params:
model = "claude-3-sonnet-20240229" # Use a model that will use tool-based approach
# Call parent method with potentially modified model name
optional_params = super().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
# Restore original model name for any other processing
model = original_model
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
response = super().transform_response(
model,
raw_response,
model_response,
logging_obj,
request_data,
messages,
optional_params,
litellm_params,
encoding,
api_key,
json_mode,
)
response.model = model
return response
@classmethod
def is_supported_model(cls, model: str, custom_llm_provider: str) -> bool:
"""
Check if the model is supported by the VertexAI Anthropic API.
"""
if (
custom_llm_provider != "vertex_ai"
and custom_llm_provider != "vertex_ai_beta"
):
return False
if "claude" in model.lower():
return True
elif model in litellm.vertex_anthropic_models:
return True
return False

View File

@@ -0,0 +1 @@
# Count tokens handler for Vertex AI Partner Models (Anthropic, Mistral, etc.)

View File

@@ -0,0 +1,162 @@
"""
Token counter for Vertex AI Partner Models (Anthropic Claude, Mistral, etc.)
This handler provides token counting for partner models hosted on Vertex AI.
Unlike Gemini models which use Google's token counting API, partner models use
their respective publisher-specific count-tokens endpoints.
"""
from typing import Any, Dict, Optional
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
class VertexAIPartnerModelsTokenCounter(VertexBase):
"""
Token counter for Vertex AI Partner Models.
Handles token counting for models like Claude (Anthropic), Mistral, etc.
that are available through Vertex AI's partner model program.
"""
def _get_publisher_for_model(self, model: str) -> str:
"""
Determine the publisher name for the given model.
Args:
model: The model name (e.g., "claude-3-5-sonnet-20241022")
Returns:
Publisher name to use in the Vertex AI endpoint URL
Raises:
ValueError: If the model is not a recognized partner model
"""
if "claude" in model:
return "anthropic"
elif "mistral" in model or "codestral" in model:
return "mistralai"
elif "llama" in model or "meta/" in model:
return "meta"
else:
raise ValueError(f"Unknown partner model: {model}")
def _build_count_tokens_endpoint(
self,
model: str,
project_id: str,
vertex_location: str,
api_base: Optional[str] = None,
) -> str:
"""
Build the count-tokens endpoint URL for a partner model.
Args:
model: The model name
project_id: Google Cloud project ID
vertex_location: Vertex AI location (e.g., "us-east5")
api_base: Optional custom API base URL
Returns:
Full endpoint URL for the count-tokens API
"""
publisher = self._get_publisher_for_model(model)
# Use custom api_base if provided, otherwise construct default
if api_base:
base_url = api_base
else:
base_url = get_vertex_base_url(vertex_location)
# Construct the count-tokens endpoint
# Format: /v1/projects/{project}/locations/{location}/publishers/{publisher}/models/count-tokens:rawPredict
endpoint = (
f"{base_url}/v1/projects/{project_id}/locations/{vertex_location}/"
f"publishers/{publisher}/models/count-tokens:rawPredict"
)
return endpoint
async def handle_count_tokens_request(
self,
model: str,
request_data: Dict[str, Any],
litellm_params: Dict[str, Any],
) -> Dict[str, Any]:
"""
Handle token counting request for a Vertex AI partner model.
Args:
model: The model name
request_data: Request payload (Anthropic Messages API format)
litellm_params: LiteLLM parameters containing credentials, project, location
Returns:
Dict containing token count information
Raises:
ValueError: If required parameters are missing or invalid
"""
# Validate request
if "messages" not in request_data:
raise ValueError("messages required for token counting")
# Extract Vertex AI credentials and settings
vertex_credentials = self.get_vertex_ai_credentials(litellm_params)
vertex_project = self.get_vertex_ai_project(litellm_params)
vertex_location = self.get_vertex_ai_location(litellm_params)
# Map empty location/cluade models to a supported region for count-tokens endpoint
# https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/count-tokens
if not vertex_location or "claude" in model.lower():
vertex_location = "us-central1"
# Get access token and resolved project ID
access_token, project_id = await self._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
# Build the endpoint URL
endpoint_url = self._build_count_tokens_endpoint(
model=model,
project_id=project_id,
vertex_location=vertex_location,
api_base=litellm_params.get("api_base"),
)
# Prepare headers
headers = {"Authorization": f"Bearer {access_token}"}
# Get async HTTP client
from litellm import LlmProviders
async_client = get_async_httpx_client(llm_provider=LlmProviders.VERTEX_AI)
# Make the request
# Note: Partner models (especially Claude) accept Anthropic Messages API format directly
response = await async_client.post(
endpoint_url,
headers=headers,
json=request_data,
timeout=30.0,
)
# Check for errors
if response.status_code != 200:
error_text = response.text
raise ValueError(
f"Token counting request failed with status {response.status_code}: {error_text}"
)
# Parse response
result = response.json()
# Return token count
# Vertex AI Anthropic returns: {"input_tokens": 123}
return {
"input_tokens": result.get("input_tokens", 0),
"tokenizer_used": "vertex_ai_partner_models",
}

View File

@@ -0,0 +1,36 @@
import litellm
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
class VertexAIGPTOSSTransformation(OpenAIGPTConfig):
"""
Transformation for GPT-OSS model from VertexAI
https://console.cloud.google.com/vertex-ai/publishers/openai/model-garden/gpt-oss-120b-maas?hl=id
"""
def __init__(self):
super().__init__()
def get_supported_openai_params(self, model: str) -> list:
base_gpt_series_params = super().get_supported_openai_params(model=model)
gpt_oss_only_params = ["reasoning_effort"]
base_gpt_series_params.extend(gpt_oss_only_params)
#########################################################
# VertexAI - GPT-OSS does not support tool calls
#########################################################
if litellm.supports_function_calling(model=model) is False:
TOOL_CALLING_PARAMS_TO_REMOVE = [
"tool",
"tool_choice",
"function_call",
"functions",
]
base_gpt_series_params = [
param
for param in base_gpt_series_params
if param not in TOOL_CALLING_PARAMS_TO_REMOVE
]
return base_gpt_series_params

View File

@@ -0,0 +1,226 @@
import types
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.openai.chat.gpt_transformation import (
OpenAIChatCompletionStreamingHandler,
OpenAIGPTConfig,
)
from litellm.types.llms.openai import AllMessageValues, OpenAIChatCompletionResponse
from litellm.types.utils import (
Delta,
ModelResponse,
ModelResponseStream,
StreamingChoices,
Usage,
)
from ...common_utils import VertexAIError
class VertexAILlama3Config(OpenAIGPTConfig):
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str):
supported_params = super().get_supported_openai_params(model=model)
try:
supported_params.remove("max_retries")
except KeyError:
pass
return supported_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
if "max_completion_tokens" in non_default_params:
non_default_params["max_tokens"] = non_default_params.pop(
"max_completion_tokens"
)
return super().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return VertexAILlama3StreamingHandler(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = OpenAIChatCompletionResponse(**raw_response.json()) # type: ignore
except Exception as e:
response_headers = getattr(raw_response, "headers", None)
raise VertexAIError(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), raw_response.text
),
status_code=raw_response.status_code,
headers=response_headers,
)
model_response.model = completion_response.get("model", model)
model_response.id = completion_response.get("id", "")
model_response.created = completion_response.get("created", 0)
setattr(model_response, "usage", Usage(**completion_response.get("usage", {})))
model_response.choices = self._transform_choices( # type: ignore
choices=completion_response["choices"],
json_mode=json_mode,
)
return model_response
class VertexAILlama3StreamingHandler(OpenAIChatCompletionStreamingHandler):
"""
Vertex AI Llama models may not include role in streaming chunk deltas.
This handler ensures the first chunk always has role="assistant".
When Vertex AI returns a single chunk with both role and finish_reason (empty response),
this handler splits it into two chunks:
1. First chunk: role="assistant", content="", finish_reason=None
2. Second chunk: role=None, content=None, finish_reason="stop"
This matches OpenAI's streaming format where the first chunk has role and
the final chunk has finish_reason but no role.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sent_role = False
self._pending_chunk: Optional[ModelResponseStream] = None
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
result = super().chunk_parser(chunk)
if not self.sent_role and result.choices:
delta = result.choices[0].delta
finish_reason = result.choices[0].finish_reason
# If this is both the first chunk AND the final chunk (has finish_reason),
# we need to split it into two chunks to match OpenAI format
if finish_reason is not None:
# Create a pending final chunk with finish_reason but no role
self._pending_chunk = ModelResponseStream(
id=result.id,
object="chat.completion.chunk",
created=result.created,
model=result.model,
choices=[
StreamingChoices(
index=0,
delta=Delta(content=None, role=None),
finish_reason=finish_reason,
)
],
)
# Modify current chunk to be the first chunk with role but no finish_reason
result.choices[0].finish_reason = None # type: ignore[assignment]
delta.role = "assistant"
# Ensure content is empty string for first chunk, not None
if delta.content is None:
delta.content = ""
# Prevent downstream stream wrapper from dropping this chunk
# (it drops empty-content chunks unless special fields are present)
if delta.provider_specific_fields is None:
delta.provider_specific_fields = {}
elif delta.role is None:
delta.role = "assistant"
# If the first chunk has empty content, ensure it's still emitted
if (
delta.content == "" or delta.content is None
) and delta.provider_specific_fields is None:
delta.provider_specific_fields = {}
self.sent_role = True
return result
def __next__(self):
# First return any pending chunk from a previous split
if self._pending_chunk is not None:
chunk = self._pending_chunk
self._pending_chunk = None
return chunk
return super().__next__()
async def __anext__(self):
# First return any pending chunk from a previous split
if self._pending_chunk is not None:
chunk = self._pending_chunk
self._pending_chunk = None
return chunk
return await super().__anext__()

View File

@@ -0,0 +1,344 @@
# What is this?
## API Handler for calling Vertex AI Partner Models
from enum import Enum
from typing import Callable, Optional, Union
import httpx # type: ignore
import litellm
from litellm import LlmProviders
from litellm.types.llms.vertex_ai import VertexPartnerProvider
from litellm.utils import ModelResponse
from ...custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from ..vertex_llm_base import VertexBase
base_llm_http_handler = BaseLLMHTTPHandler()
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class PartnerModelPrefixes(str, Enum):
META_PREFIX = "meta/"
DEEPSEEK_PREFIX = "deepseek-ai"
MISTRAL_PREFIX = "mistral"
CODERESTAL_PREFIX = "codestral"
JAMBA_PREFIX = "jamba"
CLAUDE_PREFIX = "claude"
QWEN_PREFIX = "qwen"
GPT_OSS_PREFIX = "openai/gpt-oss-"
MINIMAX_PREFIX = "minimaxai/"
MOONSHOT_PREFIX = "moonshotai/"
ZAI_PREFIX = "zai-org/"
class VertexAIPartnerModels(VertexBase):
def __init__(self) -> None:
pass
@staticmethod
def is_vertex_partner_model(model: str):
"""
Check if the model string is a Vertex AI Partner Model
Only use this once you have confirmed that custom_llm_provider is vertex_ai
Returns:
bool: True if the model string is a Vertex AI Partner Model, False otherwise
"""
if (
model.startswith(PartnerModelPrefixes.META_PREFIX)
or model.startswith(PartnerModelPrefixes.DEEPSEEK_PREFIX)
or model.startswith(PartnerModelPrefixes.MISTRAL_PREFIX)
or model.startswith(PartnerModelPrefixes.CODERESTAL_PREFIX)
or model.startswith(PartnerModelPrefixes.JAMBA_PREFIX)
or model.startswith(PartnerModelPrefixes.CLAUDE_PREFIX)
or model.startswith(PartnerModelPrefixes.QWEN_PREFIX)
or model.startswith(PartnerModelPrefixes.GPT_OSS_PREFIX)
or model.startswith(PartnerModelPrefixes.MINIMAX_PREFIX)
or model.startswith(PartnerModelPrefixes.MOONSHOT_PREFIX)
or model.startswith(PartnerModelPrefixes.ZAI_PREFIX)
):
return True
return False
@staticmethod
def should_use_openai_handler(model: str):
OPENAI_LIKE_VERTEX_PROVIDERS = [
"llama",
PartnerModelPrefixes.DEEPSEEK_PREFIX,
PartnerModelPrefixes.QWEN_PREFIX,
PartnerModelPrefixes.GPT_OSS_PREFIX,
PartnerModelPrefixes.MINIMAX_PREFIX,
PartnerModelPrefixes.MOONSHOT_PREFIX,
PartnerModelPrefixes.ZAI_PREFIX,
]
if any(provider in model for provider in OPENAI_LIKE_VERTEX_PROVIDERS):
return True
return False
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
api_base: Optional[str],
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.codestral.completion.handler import (
CodestralTextCompletion,
)
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
except Exception as e:
raise VertexAIError(
status_code=400,
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
openai_like_chat_completions = OpenAILikeChatHandler()
codestral_fim_completions = CodestralTextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
optional_params["stream"] = stream
if self.should_use_openai_handler(model):
partner = VertexPartnerProvider.llama
elif "mistral" in model or "codestral" in model:
partner = VertexPartnerProvider.mistralai
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
elif "claude" in model:
partner = VertexPartnerProvider.claude
else:
raise ValueError(f"Unknown partner model: {model}")
api_base = self.get_complete_vertex_url(
custom_api_base=api_base,
vertex_location=vertex_location,
vertex_project=vertex_project,
project_id=project_id,
partner=partner,
stream=stream,
model=model,
)
if "codestral" in model or "mistral" in model:
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
optional_params["model"] = model
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
return codestral_fim_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=text_completion_model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
)
elif "claude" in model:
if headers is None:
headers = {}
headers.update({"Authorization": "Bearer {}".format(access_token)})
optional_params.update(
{
"anthropic_version": "vertex-2023-10-16",
"is_vertex_request": True,
}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=access_token,
logging_obj=logging_obj,
headers=headers,
timeout=timeout,
client=client,
custom_llm_provider=LlmProviders.VERTEX_AI.value,
)
elif self.should_use_openai_handler(model):
return base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="vertex_ai",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=access_token,
logging_obj=logging_obj, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai",
custom_endpoint=True,
)
except Exception as e:
if hasattr(e, "status_code"):
raise e
raise VertexAIError(status_code=500, message=str(e))
async def count_tokens(
self,
model: str,
messages: list,
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
):
"""
Count tokens for Vertex AI partner models (Anthropic Claude, Mistral, etc.)
Args:
model: The model name (e.g., "claude-3-5-sonnet-20241022")
messages: List of messages in Anthropic Messages API format
litellm_params: LiteLLM parameters dict
vertex_project: Optional Google Cloud project ID
vertex_location: Optional Vertex AI location
vertex_credentials: Optional Vertex AI credentials
Returns:
Dict containing token count information
"""
try:
import vertexai
except Exception as e:
raise VertexAIError(
status_code=400,
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
from litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler import (
VertexAIPartnerModelsTokenCounter,
)
# Prepare request data in Anthropic Messages API format
request_data = {
"model": model,
"messages": messages,
}
# Prepare litellm_params with credentials
_litellm_params = litellm_params.copy()
if vertex_project:
_litellm_params["vertex_project"] = vertex_project
if vertex_location:
_litellm_params["vertex_location"] = vertex_location
if vertex_credentials:
_litellm_params["vertex_credentials"] = vertex_credentials
# Call the token counter
token_counter = VertexAIPartnerModelsTokenCounter()
result = await token_counter.handle_count_tokens_request(
model=model,
request_data=request_data,
litellm_params=_litellm_params,
)
return result
except Exception as e:
if hasattr(e, "status_code"):
raise e
raise VertexAIError(status_code=500, message=str(e))

View File

@@ -0,0 +1,183 @@
"""
Vertex AI BGE (BAAI General Embedding) Configuration
BGE models deployed on Vertex AI require different input/output format:
- Request: Use "prompt" instead of "content" as the input field
- Response: Embeddings are returned directly as arrays, not wrapped in objects
Model name handling:
- Model names like "bge/endpoint_id" are automatically transformed in common_utils._get_vertex_url()
- This module focuses on request/response transformation only
"""
from typing import List, Optional, Union
from litellm.types.utils import EmbeddingResponse, Usage
from .types import (
EmbeddingParameters,
TaskType,
TextEmbeddingBGEInput,
VertexEmbeddingRequest,
)
class VertexBGEConfig:
"""
Configuration and transformation logic for BGE models on Vertex AI.
BGE (BAAI General Embedding) models use a different request format
where the input field is named "prompt" instead of "content".
Supported model patterns (after provider split in main.py):
- "bge-small-en-v1.5" (model name)
- "bge/204379420394258432" (endpoint ID pattern)
Note: Model name transformation (bge/ -> numeric ID) is handled automatically
in common_utils._get_vertex_url(). This class focuses on request/response format only.
"""
@staticmethod
def is_bge_model(model: str) -> bool:
"""
Check if the model is a BGE (BAAI General Embedding) model.
After provider split in main.py, supports:
- "bge-small-en-v1.5" (model name)
- "bge/204379420394258432" (endpoint ID pattern)
Args:
model: The model name after provider split
Returns:
bool: True if the model is a BGE model
"""
model_lower = model.lower()
# Check for "bge/" prefix (endpoint pattern) or "bge" in model name
return model_lower.startswith("bge/") or "bge" in model_lower
@staticmethod
def transform_request(
input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an OpenAI request to a Vertex BGE embedding request.
BGE models use "prompt" instead of "content" as the input field.
Args:
input: The input text(s) to embed
optional_params: Optional parameters for the request
model: The model name
Returns:
VertexEmbeddingRequest: The transformed request
"""
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingBGEInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type")
title = optional_params.get("title")
if isinstance(input, str):
input = [input]
for text in input:
embedding_input = VertexBGEConfig._create_embedding_input(
prompt=text, task_type=task_type, title=title
)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = EmbeddingParameters(**optional_params)
return vertex_request
@staticmethod
def _create_embedding_input(
prompt: str,
task_type: Optional[TaskType] = None,
title: Optional[str] = None,
) -> TextEmbeddingBGEInput:
"""
Creates a TextEmbeddingBGEInput object for BGE models.
BGE models use "prompt" instead of "content" as the input field.
Args:
prompt: The prompt to be embedded
task_type: The type of task to be performed
title: The title of the document to be embedded
Returns:
TextEmbeddingBGEInput: A TextEmbeddingBGEInput object
"""
text_embedding_input = TextEmbeddingBGEInput(prompt=prompt)
if task_type is not None:
text_embedding_input["task_type"] = task_type
if title is not None:
text_embedding_input["title"] = title
return text_embedding_input
@staticmethod
def transform_response(
response: dict, model: str, model_response: EmbeddingResponse
) -> EmbeddingResponse:
"""
Transforms a Vertex BGE embedding response to OpenAI format.
BGE models return embeddings directly as arrays in predictions:
{
"predictions": [
[0.002, 0.021, ...],
[0.003, 0.022, ...]
]
}
Args:
response: The raw response from Vertex AI
model: The model name
model_response: The EmbeddingResponse object to populate
Returns:
EmbeddingResponse: The transformed response in OpenAI format
Raises:
KeyError: If response doesn't contain 'predictions'
ValueError: If predictions is not a list or contains invalid data
"""
if "predictions" not in response:
raise KeyError("Response missing 'predictions' field")
_predictions = response["predictions"]
if not isinstance(_predictions, list):
raise ValueError(
f"Expected 'predictions' to be a list, got {type(_predictions)}"
)
embedding_response = []
# BGE models don't return token counts, so we estimate or set to 0
input_tokens = 0
for idx, embedding_values in enumerate(_predictions):
if not isinstance(embedding_values, list):
raise ValueError(
f"Expected embedding at index {idx} to be a list, got {type(embedding_values)}"
)
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding_values,
}
)
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View File

@@ -0,0 +1,232 @@
from typing import Literal, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.llms.vertex_ai import *
from litellm.types.utils import EmbeddingResponse
from .types import *
class VertexEmbedding(VertexBase):
def __init__(self) -> None:
super().__init__()
def embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObject,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
timeout: Optional[Union[float, httpx.Timeout]],
api_key: Optional[str] = None,
encoding=None,
aembedding: Optional[bool] = False,
api_base: Optional[str] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> EmbeddingResponse:
if aembedding is True:
return self.async_embedding( # type: ignore
model=model,
input=input,
logging_obj=logging_obj,
model_response=model_response,
optional_params=optional_params,
encoding=encoding,
custom_llm_provider=custom_llm_provider,
timeout=timeout,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
gemini_api_key=gemini_api_key,
extra_headers=extra_headers,
)
should_use_v1beta1_features = self.is_using_v1beta1_features(
optional_params=optional_params
)
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
# Extract use_psc_endpoint_format from optional_params
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
mode="embedding",
use_psc_endpoint_format=use_psc_endpoint_format,
)
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params, model=model
)
_client_params = {}
if timeout:
_client_params["timeout"] = timeout
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client(params=_client_params)
else:
client = client # type: ignore
## LOGGING
logging_obj.pre_call(
input=vertex_request,
api_key="",
additional_args={
"complete_input_dict": vertex_request,
"api_base": api_base,
"headers": headers,
},
)
try:
response = client.post(url=api_base, headers=headers, json=vertex_request) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
## LOGGING POST-CALL
logging_obj.post_call(
input=input, api_key=None, original_response=_json_response
)
model_response = (
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response, model=model, model_response=model_response
)
)
return model_response
async def async_embedding(
self,
model: str,
input: Union[list, str],
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObject,
optional_params: dict,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
timeout: Optional[Union[float, httpx.Timeout]],
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
encoding=None,
) -> EmbeddingResponse:
"""
Async embedding implementation
"""
should_use_v1beta1_features = self.is_using_v1beta1_features(
optional_params=optional_params
)
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
# Extract use_psc_endpoint_format from optional_params
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
mode="embedding",
use_psc_endpoint_format=use_psc_endpoint_format,
)
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params, model=model
)
_async_client_params = {}
if timeout:
_async_client_params["timeout"] = timeout
if client is None or not isinstance(client, AsyncHTTPHandler):
client = get_async_httpx_client(
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
)
else:
client = client # type: ignore
## LOGGING
logging_obj.pre_call(
input=vertex_request,
api_key="",
additional_args={
"complete_input_dict": vertex_request,
"api_base": api_base,
"headers": headers,
},
)
try:
response = await client.post(api_base, headers=headers, json=vertex_request) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
## LOGGING POST-CALL
logging_obj.post_call(
input=input, api_key=None, original_response=_json_response
)
model_response = (
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response, model=model, model_response=model_response
)
)
return model_response

View File

@@ -0,0 +1,285 @@
import types
from typing import List, Literal, Optional, Union
from pydantic import BaseModel
from litellm.types.utils import EmbeddingResponse, Usage
from .types import *
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return ["dimensions"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict, kwargs: dict
):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["outputDimensionality"] = value
if "input_type" in kwargs:
optional_params["task_type"] = kwargs.pop("input_type")
return optional_params, kwargs
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def transform_openai_request_to_vertex_embedding_request(
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex embedding request.
"""
# Import here to avoid circular import issues with litellm.__init__
from litellm.llms.vertex_ai.vertex_embeddings.bge import VertexBGEConfig
if model.isdigit():
return self._transform_openai_request_to_fine_tuned_embedding_request(
input, optional_params, model
)
if VertexBGEConfig.is_bge_model(model):
return VertexBGEConfig.transform_request(
input=input, optional_params=optional_params, model=model
)
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type")
title = optional_params.get("title")
if isinstance(input, str):
input = [input] # Convert single string to list for uniform processing
for text in input:
embedding_input = self.create_embedding_input(
content=text, task_type=task_type, title=title
)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = EmbeddingParameters(**optional_params)
return vertex_request
def _transform_openai_request_to_fine_tuned_embedding_request(
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex fine-tuned embedding request.
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
Sample Request:
```json
{
"instances" : [
{
"inputs": "How would the Future of AI in 10 Years look?",
"parameters": {
"max_new_tokens": 128,
"temperature": 1.0,
"top_p": 0.9,
"top_k": 10
}
}
]
}
```
"""
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
if isinstance(input, str):
input = [input] # Convert single string to list for uniform processing
for text in input:
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
**optional_params
)
# Remove 'shared_session' from parameters if present
if (
vertex_request["parameters"] is not None
and "shared_session" in vertex_request["parameters"]
):
del vertex_request["parameters"]["shared_session"] # type: ignore[typeddict-item]
return vertex_request
def create_embedding_input(
self,
content: str,
task_type: Optional[TaskType] = None,
title: Optional[str] = None,
) -> TextEmbeddingInput:
"""
Creates a TextEmbeddingInput object.
Vertex requires a List of TextEmbeddingInput objects. This helper function creates a single TextEmbeddingInput object.
Args:
content (str): The content to be embedded.
task_type (Optional[TaskType]): The type of task to be performed.
title (Optional[str]): The title of the document to be embedded.
Returns:
TextEmbeddingInput: A TextEmbeddingInput object.
"""
text_embedding_input = TextEmbeddingInput(content=content)
if task_type is not None:
text_embedding_input["task_type"] = task_type
if title is not None:
text_embedding_input["title"] = title
return text_embedding_input
def transform_vertex_response_to_openai(
self, response: dict, model: str, model_response: EmbeddingResponse
) -> EmbeddingResponse:
"""
Transforms a vertex embedding response to an openai response.
"""
if model.isdigit():
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
response, model, model_response
)
# Import here to avoid circular import issues with litellm.__init__
from litellm.llms.vertex_ai.vertex_embeddings.bge import VertexBGEConfig
if VertexBGEConfig.is_bge_model(model):
return VertexBGEConfig.transform_response(
response=response, model=model, model_response=model_response
)
_predictions = response["predictions"]
embedding_response = []
input_tokens: int = 0
for idx, element in enumerate(_predictions):
embedding = element["embeddings"]
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding["values"],
}
)
input_tokens += embedding["statistics"]["token_count"]
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response
def _transform_vertex_response_to_openai_for_fine_tuned_models(
self, response: dict, model: str, model_response: EmbeddingResponse
) -> EmbeddingResponse:
"""
Transforms a vertex fine-tuned model embedding response to an openai response format.
"""
_predictions = response["predictions"]
embedding_response = []
# For fine-tuned models, we don't get token counts in the response
input_tokens = 0
for idx, embedding_values in enumerate(_predictions):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding_values[
0
], # The embedding values are nested one level deeper
}
)
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View File

@@ -0,0 +1,74 @@
"""
Types for Vertex Embeddings Requests
"""
from enum import Enum
from typing import List, Optional, Union
from typing_extensions import TypedDict
class TaskType(str, Enum):
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
CODE_RETRIEVAL_QUERY = "CODE_RETRIEVAL_QUERY"
class TextEmbeddingInput(TypedDict, total=False):
content: str
task_type: Optional[TaskType]
title: Optional[str]
class TextEmbeddingBGEInput(TypedDict, total=False):
prompt: str
task_type: Optional[TaskType]
title: Optional[str]
# Fine-tuned models require a different input format
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
class TextEmbeddingFineTunedInput(TypedDict, total=False):
inputs: str
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
max_new_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
class EmbeddingParameters(TypedDict, total=False):
auto_truncate: Optional[bool]
output_dimensionality: Optional[int]
class VertexEmbeddingRequest(TypedDict, total=False):
instances: Union[
List[TextEmbeddingInput],
List[TextEmbeddingBGEInput],
List[TextEmbeddingFineTunedInput],
]
parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
# Example usage:
# example_request: VertexEmbeddingRequest = {
# "instances": [
# {
# "content": "I would like embeddings for this text!",
# "task_type": "RETRIEVAL_DOCUMENT",
# "title": "document title"
# }
# ],
# "parameters": {
# "auto_truncate": True,
# "output_dimensionality": None
# }
# }

View File

@@ -0,0 +1 @@
"""Vertex AI Gemma-AI Models Handler"""

View File

@@ -0,0 +1,144 @@
"""
API Handler for calling Vertex AI Gemma Models
These models use a custom prediction endpoint format that wraps messages in 'instances'
with @requestFormat: "chatCompletions" and returns responses wrapped in 'predictions'.
Usage:
response = litellm.completion(
model="vertex_ai/gemma/gemma-3-12b-it-1222199011122",
messages=[{"role": "user", "content": "What is machine learning?"}],
vertex_project="your-project-id",
vertex_location="us-central1",
)
Sent to this route when `model` is in the format `vertex_ai/gemma/{MODEL_NAME}`
The API expects a custom endpoint URL format:
https://{ENDPOINT_NUMBER}.{location}-{REGION_NUMBER}.prediction.vertexai.goog/v1/projects/{PROJECT_ID}/locations/{location}/endpoints/{ENDPOINT_ID}:predict
"""
from typing import Callable, Optional, Union
import httpx # type: ignore
from litellm.utils import ModelResponse
from ..common_utils import VertexAIError, get_vertex_base_model_name
from ..vertex_llm_base import VertexBase
class VertexAIGemmaModels(VertexBase):
def __init__(self) -> None:
pass
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
api_base: Optional[str],
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
"""
Handles calling Vertex AI Gemma Models
Sent to this route when `model` is in the format `vertex_ai/gemma/{MODEL_NAME}`
"""
try:
import vertexai
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from litellm.llms.vertex_ai.vertex_gemma_models.transformation import (
VertexGemmaConfig,
)
except Exception as e:
raise VertexAIError(
status_code=400,
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
model = get_vertex_base_model_name(model=model)
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
gemma_transformation = VertexGemmaConfig()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
optional_params["stream"] = stream
# If api_base is not provided, it should be set as an environment variable
# or passed explicitly because the endpoint URL is unique per deployment
if api_base is None:
raise VertexAIError(
status_code=400,
message="api_base is required for Vertex AI Gemma models. Please provide the full endpoint URL.",
)
# Check if we need to append :predict
if not api_base.endswith(":predict"):
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint="predict",
stream=stream,
auth_header=None,
url=api_base,
)
# If api_base already ends with :predict, use it as-is
# Use the custom transformation handler for gemma models
return gemma_transformation.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai",
)
except Exception as e:
if hasattr(e, "status_code"):
raise e
raise VertexAIError(status_code=500, message=str(e))

View File

@@ -0,0 +1,360 @@
"""
Transformation logic for Vertex AI Gemma Models
Handles the custom request/response format:
- Request: Wraps messages in 'instances' with @requestFormat: "chatCompletions"
- Response: Extracts data from 'predictions' wrapper
The actual message transformation reuses OpenAIGPTConfig since Gemma uses OpenAI-compatible format.
"""
from typing import Any, Callable, Dict, List, Optional, Union, cast
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
class VertexGemmaConfig(OpenAIGPTConfig):
"""
Configuration and transformation class for Vertex AI Gemma models
Extends OpenAIGPTConfig to wrap/unwrap the instances/predictions format
used by Vertex AI's Gemma deployment endpoint.
"""
def __init__(self) -> None:
super().__init__()
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Vertex AI Gemma models do not support streaming.
Return True to enable fake streaming on the client side.
"""
return True
def _handle_fake_stream_response(
self,
model_response: ModelResponse,
stream: bool,
) -> Union[ModelResponse, Any]:
"""
Helper method to return fake stream iterator if streaming is requested.
Args:
model_response: The completed model response
stream: Whether streaming was requested
Returns:
MockResponseIterator if stream=True, otherwise the model_response
"""
if stream:
from litellm.llms.base_llm.base_model_iterator import MockResponseIterator
return MockResponseIterator(model_response=model_response)
return model_response
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform request to Vertex Gemma format.
Uses parent class to create OpenAI-compatible request, then wraps it
in the Vertex Gemma instances format.
"""
# Get the base OpenAI request from parent class
openai_request = super().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
# Remove params not needed/supported by Vertex Gemma
openai_request.pop("model", None)
openai_request.pop(
"stream", None
) # Streaming not supported, will be faked client-side
openai_request.pop("stream_options", None) # Stream options not supported
# Wrap in Vertex Gemma format
return {
"instances": [
{
"@requestFormat": "chatCompletions",
**openai_request,
}
]
}
def _unwrap_predictions_response(
self,
response_json: Dict[str, Any],
) -> Dict[str, Any]:
"""
Unwrap the Vertex Gemma predictions format to OpenAI format.
Vertex Gemma wraps the OpenAI-compatible response in a 'predictions' field.
This method extracts it so the parent class can process it normally.
"""
if "predictions" not in response_json:
raise BaseLLMException(
status_code=422,
message="Invalid response format: missing 'predictions' field",
)
return response_json["predictions"]
def completion(
self,
model: str,
messages: list,
api_base: str,
api_key: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj: Any,
optional_params: dict,
acompletion: bool,
litellm_params: dict,
logger_fn: Optional[Callable] = None,
client: Optional[httpx.Client] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
encoding=None,
custom_llm_provider: str = "vertex_ai",
):
"""
Make completion request to Vertex Gemma endpoint.
Supports both sync and async requests with fake streaming.
"""
if acompletion:
return self._async_completion(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
encoding=encoding,
)
else:
return self._sync_completion(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
encoding=encoding,
)
def _sync_completion(
self,
model: str,
messages: list,
api_base: str,
api_key: str,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj: Any,
optional_params: dict,
litellm_params: dict,
timeout: Optional[Union[float, httpx.Timeout]],
encoding: Any,
):
"""Synchronous completion request"""
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.utils import convert_to_model_response_object
# Check if streaming is requested (will be faked)
stream = optional_params.get("stream", False)
# Transform the request using parent class methods
request_data = self.transform_request(
model=model,
messages=messages,
optional_params=optional_params.copy(),
litellm_params=litellm_params,
headers={},
)
# Set up headers
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# Log the request
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": request_data,
"api_base": api_base,
},
)
# Make the HTTP request
http_handler = HTTPHandler(concurrent_limit=1)
response = http_handler.post(
url=api_base,
headers=headers,
json=request_data,
timeout=timeout,
)
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=f"Request failed: {response.text}",
)
response_json = response.json()
# Unwrap predictions to get OpenAI-compatible response
openai_response = self._unwrap_predictions_response(response_json)
# Use litellm's standard response converter
model_response = cast(
ModelResponse,
convert_to_model_response_object(
response_object=openai_response,
model_response_object=model_response,
_response_headers={},
),
)
# Ensure model is set correctly
model_response.model = model
# Log the response
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response_json,
additional_args={"complete_input_dict": request_data},
)
# Return fake stream iterator if streaming was requested
return self._handle_fake_stream_response(
model_response=model_response, stream=stream
)
async def _async_completion(
self,
model: str,
messages: list,
api_base: str,
api_key: str,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj: Any,
optional_params: dict,
litellm_params: dict,
timeout: Optional[Union[float, httpx.Timeout]],
encoding: Any,
):
"""Asynchronous completion request"""
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.utils import LlmProviders
from litellm.utils import convert_to_model_response_object
# Check if streaming is requested (will be faked)
stream = optional_params.get("stream", False)
# Transform the request using parent class async methods
request_data = await self.async_transform_request(
model=model,
messages=messages,
optional_params=optional_params.copy(),
litellm_params=litellm_params,
headers={},
)
# Set up headers
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# Log the request
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": request_data,
"api_base": api_base,
},
)
# Make the HTTP request
http_handler = get_async_httpx_client(
llm_provider=LlmProviders.VERTEX_AI,
)
response = await http_handler.post(
url=api_base,
headers=headers,
json=request_data,
timeout=timeout,
)
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=f"Request failed: {response.text}",
)
response_json = response.json()
# Unwrap predictions to get OpenAI-compatible response
openai_response = self._unwrap_predictions_response(response_json)
# Use litellm's standard response converter
model_response = cast(
ModelResponse,
convert_to_model_response_object(
response_object=openai_response,
model_response_object=model_response,
_response_headers={},
),
)
# Ensure model is set correctly
model_response.model = model
# Log the response
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response_json,
additional_args={"complete_input_dict": request_data},
)
# Return fake stream iterator if streaming was requested
return self._handle_fake_stream_response(
model_response=model_response, stream=stream
)

View File

@@ -0,0 +1,804 @@
"""
Base Vertex, Google AI Studio LLM Class
Handles Authentication and generating request urls for Vertex AI and Google AI Studio
"""
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES, VertexPartnerProvider
from .common_utils import (
_get_gemini_url,
_get_vertex_url,
all_gemini_url_modes,
get_vertex_base_model_name,
get_vertex_base_url,
)
GOOGLE_IMPORT_ERROR_MESSAGE = (
"Google Cloud SDK not found. Install it with: pip install 'litellm[google]' "
"or pip install google-cloud-aiplatform"
)
if TYPE_CHECKING:
from google.auth.credentials import Credentials as GoogleCredentialsObject
else:
GoogleCredentialsObject = Any
class VertexBase:
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[GoogleCredentialsObject] = None
self._credentials_project_mapping: Dict[
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
Tuple[GoogleCredentialsObject, str],
] = {}
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
def get_vertex_region(self, vertex_region: Optional[str], model: str) -> str:
import litellm
# Try to get supported_regions directly from model_cost
# Check both with and without vertex_ai/ prefix
model_key = (
f"vertex_ai/{model}" if not model.startswith("vertex_ai/") else model
)
model_info = litellm.model_cost.get(model_key, {})
supported_regions = model_info.get("supported_regions")
if supported_regions and len(supported_regions) > 0:
# If user didn't specify region, use the first supported region
if vertex_region is None:
return supported_regions[0]
# If user specified a region not supported by this model, override it
if vertex_region not in supported_regions:
verbose_logger.warning(
"Vertex AI model '%s' does not support region '%s' "
"(supported: %s). Routing to '%s'.",
model,
vertex_region,
supported_regions,
supported_regions[0],
)
return supported_regions[0]
return vertex_region
return vertex_region or "us-central1"
def load_auth(
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
) -> Tuple[Any, str]:
if credentials is not None:
if isinstance(credentials, str):
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
)
verbose_logger.debug(
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
credentials,
os.path.exists(credentials),
os.getcwd(),
)
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
)
)
elif isinstance(credentials, dict):
json_obj = credentials
else:
raise ValueError(
"Invalid credentials type: {}".format(type(credentials))
)
# Check if the JSON object contains Workload Identity Federation configuration
if "type" in json_obj and json_obj["type"] == "external_account":
# If environment_id key contains "aws" value it corresponds to an AWS config file
credential_source = json_obj.get("credential_source", {})
environment_id = (
credential_source.get("environment_id", "")
if isinstance(credential_source, dict)
else ""
)
if isinstance(environment_id, str) and "aws" in environment_id:
# Check if explicit AWS params are in the JSON (bypasses metadata)
from litellm.llms.vertex_ai.vertex_ai_aws_wif import (
VertexAIAwsWifAuth,
)
aws_params = VertexAIAwsWifAuth.extract_aws_params(json_obj)
if aws_params:
creds = VertexAIAwsWifAuth.credentials_from_explicit_aws(
json_obj,
aws_params=aws_params,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds = self._credentials_from_identity_pool_with_aws(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds = self._credentials_from_identity_pool(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
# Check if the JSON object contains Authorized User configuration (via gcloud auth application-default login)
elif "type" in json_obj and json_obj["type"] == "authorized_user":
creds = self._credentials_from_authorized_user(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if project_id is None:
project_id = (
creds.quota_project_id
) # authorized user credentials don't have a project_id, only quota_project_id
else:
creds = self._credentials_from_service_account(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if project_id is None:
project_id = getattr(creds, "project_id", None)
else:
creds, creds_project_id = self._credentials_from_default_auth(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
if project_id is None:
project_id = creds_project_id
self.refresh_auth(creds)
if not project_id:
raise ValueError("Could not resolve project_id")
if not isinstance(project_id, str):
raise TypeError(
f"Expected project_id to be a str but got {type(project_id)}"
)
return creds, project_id
# Google Auth Helpers -- extracted for mocking purposes in tests
def _credentials_from_identity_pool(self, json_obj, scopes):
try:
from google.auth import identity_pool
except ImportError:
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
creds = identity_pool.Credentials.from_info(json_obj)
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
creds = creds.with_scopes(scopes)
return creds
def _credentials_from_identity_pool_with_aws(self, json_obj, scopes):
try:
from google.auth import aws
except ImportError:
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
creds = aws.Credentials.from_info(json_obj)
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
creds = creds.with_scopes(scopes)
return creds
def _credentials_from_authorized_user(self, json_obj, scopes):
try:
import google.oauth2.credentials
except ImportError:
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
return google.oauth2.credentials.Credentials.from_authorized_user_info(
json_obj, scopes=scopes
)
def _credentials_from_service_account(self, json_obj, scopes):
try:
import google.oauth2.service_account
except ImportError:
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
return google.oauth2.service_account.Credentials.from_service_account_info(
json_obj, scopes=scopes
)
def _credentials_from_default_auth(self, scopes):
try:
import google.auth as google_auth
except ImportError:
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
return google_auth.default(scopes=scopes)
def get_default_vertex_location(self) -> str:
return "us-central1"
def get_api_base(
self, api_base: Optional[str], vertex_location: Optional[str]
) -> str:
if api_base:
return api_base
return get_vertex_base_url(
vertex_location or self.get_default_vertex_location()
)
@staticmethod
def create_vertex_url(
vertex_location: str,
vertex_project: str,
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
api_base: Optional[str] = None,
) -> str:
"""Return the base url for the vertex partner models"""
if api_base is None:
api_base = get_vertex_base_url(vertex_location)
if partner == VertexPartnerProvider.llama:
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi/chat/completions"
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"{api_base}/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"{api_base}/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.claude:
if stream:
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
def get_complete_vertex_url(
self,
custom_api_base: Optional[str],
vertex_location: Optional[str],
vertex_project: Optional[str],
project_id: str,
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
) -> str:
# Use get_vertex_region to handle global-only models
resolved_location = self.get_vertex_region(vertex_location, model)
api_base = self.get_api_base(
api_base=custom_api_base, vertex_location=resolved_location
)
default_api_base = VertexBase.create_vertex_url(
vertex_location=resolved_location,
vertex_project=vertex_project or project_id,
partner=partner,
stream=stream,
model=model,
api_base=api_base,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=custom_api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=stream,
auth_header=None,
url=default_api_base,
model=model,
vertex_project=vertex_project or project_id,
vertex_location=resolved_location,
vertex_api_version="v1", # Partner models typically use v1
)
return api_base
def refresh_auth(self, credentials: Any) -> None:
try:
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
except ImportError:
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
credentials.refresh(Request())
def _ensure_access_token(
self,
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
) -> Tuple[str, str]:
"""
Returns auth token and project id
"""
if custom_llm_provider == "gemini":
return "", ""
else:
return self.get_access_token(
credentials=credentials,
project_id=project_id,
)
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
"""
use this helper to decide if request should be sent to v1 or v1beta1
Returns true if any beta feature is enabled
Returns false in all other cases
"""
return False
def _check_custom_proxy(
self,
api_base: Optional[str],
custom_llm_provider: str,
gemini_api_key: Optional[str],
endpoint: str,
stream: Optional[bool],
auth_header: Optional[str],
url: str,
model: Optional[str] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_api_version: Optional[Literal["v1", "v1beta1"]] = None,
use_psc_endpoint_format: bool = False,
) -> Tuple[Optional[str], str]:
"""
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
Handles custom api_base for:
1. Gemini (Google AI Studio) - constructs /models/{model}:{endpoint}
2. Vertex AI with standard proxies - constructs {api_base}:{endpoint}
3. Vertex AI with PSC endpoints - constructs full path structure
{api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
(only when use_psc_endpoint_format=True)
Args:
use_psc_endpoint_format: If True, constructs PSC endpoint URL format.
If False (default), uses api_base as-is and appends :{endpoint}
## Returns
- (auth_header, url) - Tuple[Optional[str], str]
"""
if api_base:
if custom_llm_provider == "gemini":
# For Gemini (Google AI Studio), construct the full path like other providers
if model is None:
raise ValueError(
"Model parameter is required for Gemini custom API base URLs"
)
url = "{}/models/{}:{}".format(api_base, model, endpoint)
if gemini_api_key is None:
raise ValueError(
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
)
if gemini_api_key is not None:
auth_header = {"x-goog-api-key": gemini_api_key} # type: ignore[assignment]
else:
# For Vertex AI
if use_psc_endpoint_format:
# User explicitly specified PSC endpoint format
# Construct full PSC/custom endpoint URL
if not (vertex_project and vertex_location and model):
raise ValueError(
"vertex_project, vertex_location, and model are required when use_psc_endpoint_format=True"
)
# Strip routing prefixes (bge/, gemma/, etc.) for endpoint URL construction
model_for_url = get_vertex_base_model_name(model=model)
# Format: {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
version = vertex_api_version or "v1"
url = "{}/{}/projects/{}/locations/{}/endpoints/{}:{}".format(
api_base.rstrip("/"),
version,
vertex_project,
vertex_location,
model_for_url,
endpoint,
)
else:
# Fallback to simple format if we don't have all parameters
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
def _get_token_and_url(
self,
model: str,
auth_header: Optional[str],
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
use_psc_endpoint_format: bool = False,
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
version: Optional[Literal["v1beta1", "v1"]] = None
if custom_llm_provider == "gemini":
url, endpoint = _get_gemini_url(
mode=mode,
model=model,
stream=stream,
gemini_api_key=gemini_api_key,
)
auth_header = None # this field is not used for gemin
else:
vertex_location = self.get_vertex_region(
vertex_region=vertex_location,
model=model,
)
### SET RUNTIME ENDPOINT ###
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
url, endpoint = _get_vertex_url(
mode=mode,
model=model,
stream=stream,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version=version,
)
return self._check_custom_proxy(
api_base=api_base,
auth_header=auth_header,
custom_llm_provider=custom_llm_provider,
gemini_api_key=gemini_api_key,
endpoint=endpoint,
stream=stream,
url=url,
model=model,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version=version,
use_psc_endpoint_format=use_psc_endpoint_format,
)
def _handle_reauthentication(
self,
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
credential_cache_key: Tuple,
error: Exception,
) -> Tuple[str, str]:
"""
Handle reauthentication when credentials refresh fails.
This method clears the cached credentials and attempts to reload them once.
It should only be called when "Reauthentication is needed" error occurs.
Args:
credentials: The original credentials
project_id: The project ID
credential_cache_key: The cache key to clear
error: The original error that triggered reauthentication
Returns:
Tuple of (access_token, project_id)
Raises:
The original error if reauthentication fails
"""
verbose_logger.debug(
f"Handling reauthentication for project_id: {project_id}. "
f"Clearing cache and retrying once."
)
# Clear the cached credentials
if credential_cache_key in self._credentials_project_mapping:
del self._credentials_project_mapping[credential_cache_key]
# Retry once with _retry_reauth=True to prevent infinite recursion
try:
return self.get_access_token(
credentials=credentials,
project_id=project_id,
_retry_reauth=True,
)
except Exception as retry_error:
verbose_logger.error(
f"Reauthentication retry failed for project_id: {project_id}. "
f"Original error: {str(error)}. Retry error: {str(retry_error)}"
)
# Re-raise the original error for better context
raise error
def get_access_token(
self,
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
_retry_reauth: bool = False,
) -> Tuple[str, str]:
"""
Get access token and project id
1. Check if credentials are already in self._credentials_project_mapping
2. If not, load credentials and add to self._credentials_project_mapping
3. Check if loaded credentials have expired
4. If expired, refresh credentials
5. Return access token and project id
Args:
credentials: The credentials to use for authentication
project_id: The Google Cloud project ID
_retry_reauth: Internal flag to prevent infinite recursion during reauthentication
Returns:
Tuple of (access_token, project_id)
"""
# Convert dict credentials to string for caching
cache_credentials = (
json.dumps(credentials) if isinstance(credentials, dict) else credentials
)
credential_cache_key = (cache_credentials, project_id)
_credentials: Optional[GoogleCredentialsObject] = None
verbose_logger.debug(
f"Checking cached credentials for project_id: {project_id}"
)
if credential_cache_key in self._credentials_project_mapping:
verbose_logger.debug(
f"Cached credentials found for project_id: {project_id}."
)
# Retrieve both credentials and cached project_id
cached_entry = self._credentials_project_mapping[credential_cache_key]
verbose_logger.debug("cached_entry: %s", cached_entry)
if isinstance(cached_entry, tuple):
_credentials, credential_project_id = cached_entry
else:
# Backward compatibility with old cache format
_credentials = cached_entry
credential_project_id = _credentials.quota_project_id or getattr(
_credentials, "project_id", None
)
verbose_logger.debug(
"Using cached credentials for project_id: %s",
credential_project_id,
)
else:
verbose_logger.debug(
f"Credential cache key not found for project_id: {project_id}, loading new credentials"
)
try:
_credentials, credential_project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
except Exception as e:
verbose_logger.exception(
f"Failed to load vertex credentials. Check to see if credentials containing partial/invalid information. Error: {str(e)}"
)
raise e
if _credentials is None:
raise ValueError(
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
project_id
)
)
# Cache the project_id and credentials from load_auth result (resolved project_id)
self._credentials_project_mapping[credential_cache_key] = (
_credentials,
credential_project_id,
)
## VALIDATE CREDENTIALS
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
if (
project_id is None
and credential_project_id is not None
and isinstance(credential_project_id, str)
):
project_id = credential_project_id
# Update cache with resolved project_id for future lookups
resolved_cache_key = (cache_credentials, project_id)
if resolved_cache_key not in self._credentials_project_mapping:
self._credentials_project_mapping[resolved_cache_key] = (
_credentials,
credential_project_id,
)
# Check if credentials are None before accessing attributes
if _credentials is None:
raise ValueError("Credentials are None after loading")
if _credentials.expired:
try:
verbose_logger.debug(
f"Credentials expired, refreshing for project_id: {project_id}"
)
self.refresh_auth(_credentials)
self._credentials_project_mapping[credential_cache_key] = (
_credentials,
credential_project_id,
)
except Exception as e:
# if refresh fails, it's possible the user has re-authenticated via `gcloud auth application-default login`
# in this case, we should try to reload the credentials by clearing the cache and retrying
if "Reauthentication is needed" in str(e) and not _retry_reauth:
return self._handle_reauthentication(
credentials=credentials,
project_id=project_id,
credential_cache_key=credential_cache_key,
error=e,
)
raise e
## VALIDATION STEP
if _credentials.token is None or not isinstance(_credentials.token, str):
raise ValueError(
"Could not resolve credentials token. Got None or non-string token - {}".format(
_credentials.token
)
)
if project_id is None:
raise ValueError("Could not resolve project_id")
return _credentials.token, project_id
async def _ensure_access_token_async(
self,
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
) -> Tuple[str, str]:
"""
Async version of _ensure_access_token
"""
if custom_llm_provider == "gemini":
return "", ""
else:
try:
return await asyncify(self.get_access_token)(
credentials=credentials,
project_id=project_id,
)
except Exception as e:
raise e
def set_headers(
self, auth_header: Optional[str], extra_headers: Optional[dict]
) -> dict:
headers = {
"Content-Type": "application/json",
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
if extra_headers is not None:
headers.update(extra_headers)
return headers
@staticmethod
def get_vertex_ai_project(litellm_params: dict) -> Optional[str]:
return (
litellm_params.pop("vertex_project", None)
or litellm_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
@staticmethod
def get_vertex_ai_credentials(litellm_params: dict) -> Optional[str]:
return (
litellm_params.pop("vertex_credentials", None)
or litellm_params.pop("vertex_ai_credentials", None)
or get_secret_str("VERTEXAI_CREDENTIALS")
)
@staticmethod
def get_vertex_ai_location(litellm_params: dict) -> Optional[str]:
return (
litellm_params.pop("vertex_location", None)
or litellm_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
@staticmethod
def safe_get_vertex_ai_project(litellm_params: dict) -> Optional[str]:
"""
Safely get Vertex AI project without mutating the litellm_params dict.
Unlike get_vertex_ai_project(), this does NOT pop values from the dict,
making it safe to call multiple times with the same litellm_params.
Args:
litellm_params: Dictionary containing Vertex AI parameters
Returns:
Vertex AI project ID or None
"""
return (
litellm_params.get("vertex_project")
or litellm_params.get("vertex_ai_project")
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
@staticmethod
def safe_get_vertex_ai_credentials(litellm_params: dict) -> Optional[str]:
"""
Safely get Vertex AI credentials without mutating the litellm_params dict.
Unlike get_vertex_ai_credentials(), this does NOT pop values from the dict,
making it safe to call multiple times with the same litellm_params.
Args:
litellm_params: Dictionary containing Vertex AI parameters
Returns:
Vertex AI credentials or None
"""
return (
litellm_params.get("vertex_credentials")
or litellm_params.get("vertex_ai_credentials")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
@staticmethod
def safe_get_vertex_ai_location(litellm_params: dict) -> Optional[str]:
"""
Safely get Vertex AI location without mutating the litellm_params dict.
Unlike get_vertex_ai_location(), this does NOT pop values from the dict,
making it safe to call multiple times with the same litellm_params.
Args:
litellm_params: Dictionary containing Vertex AI parameters
Returns:
Vertex AI location/region or None
"""
return (
litellm_params.get("vertex_location")
or litellm_params.get("vertex_ai_location")
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)

View File

@@ -0,0 +1,153 @@
"""
API Handler for calling Vertex AI Model Garden Models
Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions`
Usage:
response = litellm.completion(
model="vertex_ai/openai/5464397967697903616",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb
"""
from typing import Callable, Optional, Union
import httpx # type: ignore
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.utils import ModelResponse
from ..common_utils import VertexAIError, get_vertex_base_model_name
from ..vertex_llm_base import VertexBase
def create_vertex_url(
vertex_location: str,
vertex_project: str,
stream: Optional[bool],
model: str,
api_base: Optional[str] = None,
) -> str:
"""Return the base url for the vertex garden models"""
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}"
class VertexAIModelGardenModels(VertexBase):
def __init__(self) -> None:
pass
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
api_base: Optional[str],
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
"""
Handles calling Vertex AI Model Garden Models in OpenAI compatible format
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
"""
try:
import vertexai
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
except Exception as e:
raise VertexAIError(
status_code=400,
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
model = get_vertex_base_model_name(model=model)
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
openai_like_chat_completions = OpenAILikeChatHandler()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
optional_params["stream"] = stream
default_api_base = create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
stream=stream,
model=model,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=stream,
auth_header=None,
url=default_api_base,
model=model,
vertex_project=vertex_project or project_id,
vertex_location=vertex_location or "us-central1",
vertex_api_version="v1beta1",
)
model = ""
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai",
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))

View File

@@ -0,0 +1,9 @@
"""
Vertex AI Video Generation Module
This module provides support for Vertex AI's Veo video generation API.
"""
from .transformation import VertexAIVideoConfig
__all__ = ["VertexAIVideoConfig"]

View File

@@ -0,0 +1,636 @@
"""
Vertex AI Video Generation Transformation
Handles transformation of requests/responses for Vertex AI's Veo video generation API.
Based on: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/veo-video-generation
"""
import base64
import time
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
import httpx
from httpx._types import RequestFiles
from litellm.constants import DEFAULT_GOOGLE_VIDEO_DURATION_SECONDS
from litellm.images.utils import ImageEditRequestUtils
from litellm.llms.base_llm.videos.transformation import BaseVideoConfig
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
get_vertex_base_url,
)
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.router import GenericLiteLLMParams
from litellm.types.videos.main import VideoCreateOptionalRequestParams, VideoObject
from litellm.types.videos.utils import (
encode_video_id_with_provider,
extract_original_video_id,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import (
BaseLLMException as _BaseLLMException,
)
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
def _convert_image_to_vertex_format(image_file) -> Dict[str, str]:
"""
Convert image file to Vertex AI format with base64 encoding and MIME type.
Args:
image_file: File-like object opened in binary mode (e.g., open("path", "rb"))
Returns:
Dict with bytesBase64Encoded and mimeType
"""
mime_type = ImageEditRequestUtils.get_image_content_type(image_file)
if hasattr(image_file, "seek"):
image_file.seek(0)
image_bytes = image_file.read()
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
return {"bytesBase64Encoded": base64_encoded, "mimeType": mime_type}
class VertexAIVideoConfig(BaseVideoConfig, VertexBase):
"""
Configuration class for Vertex AI (Veo) video generation.
Veo uses a long-running operation model:
1. POST to :predictLongRunning returns operation name
2. Poll operation using :fetchPredictOperation until done=true
3. Extract video data (base64) from response
"""
def __init__(self):
BaseVideoConfig.__init__(self)
VertexBase.__init__(self)
@staticmethod
def extract_model_from_operation_name(operation_name: str) -> Optional[str]:
"""
Extract the model name from a Vertex AI operation name.
Args:
operation_name: Operation name in format:
projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL/operations/OPERATION_ID
Returns:
Model name (e.g., "veo-2.0-generate-001") or None if extraction fails
"""
parts = operation_name.split("/")
# Model is at index 7 in the operation name format
if len(parts) >= 8:
return parts[7]
return None
def get_supported_openai_params(self, model: str) -> list:
"""
Get the list of supported OpenAI parameters for Veo video generation.
Veo supports minimal parameters compared to OpenAI.
"""
return ["model", "prompt", "input_reference", "seconds", "size"]
def map_openai_params(
self,
video_create_optional_params: VideoCreateOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict[str, Any]:
"""
Map OpenAI-style parameters to Veo format.
Mappings:
- prompt → prompt (in instances)
- input_reference → image (in instances)
- size → aspectRatio (e.g., "1280x720""16:9")
- seconds → durationSeconds (defaults to 4 seconds if not provided)
"""
mapped_params: Dict[str, Any] = {}
# Map input_reference to image (will be processed in transform_video_create_request)
if "input_reference" in video_create_optional_params:
mapped_params["image"] = video_create_optional_params["input_reference"]
elif "image" in video_create_optional_params:
mapped_params["image"] = video_create_optional_params["image"]
# Pass through a provider-specific parameters block if provided directly
if "parameters" in video_create_optional_params:
mapped_params["parameters"] = video_create_optional_params["parameters"]
# Map size to aspectRatio
if "size" in video_create_optional_params:
size = video_create_optional_params["size"]
if size is not None:
aspect_ratio = self._convert_size_to_aspect_ratio(size)
if aspect_ratio:
mapped_params["aspectRatio"] = aspect_ratio
# Map seconds to durationSeconds, default to 4 seconds (matching OpenAI)
if "seconds" in video_create_optional_params:
seconds = video_create_optional_params["seconds"]
try:
duration = int(seconds) if isinstance(seconds, str) else seconds
if duration is not None:
mapped_params["durationSeconds"] = duration
except (ValueError, TypeError):
# If conversion fails, use default
pass
return mapped_params
def _convert_size_to_aspect_ratio(self, size: str) -> Optional[str]:
"""
Convert OpenAI size format to Veo aspectRatio format.
Supported aspect ratios: 9:16 (portrait), 16:9 (landscape)
"""
if not size:
return None
aspect_ratio_map = {
"1280x720": "16:9",
"1920x1080": "16:9",
"720x1280": "9:16",
"1080x1920": "9:16",
}
return aspect_ratio_map.get(size, "16:9")
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
litellm_params: Optional[Union[GenericLiteLLMParams, dict]] = None,
) -> dict:
"""
Validate environment and return headers for Vertex AI OCR.
Vertex AI uses Bearer token authentication with access token from credentials.
"""
# Extract Vertex AI parameters using safe helpers from VertexBase
# Use safe_get_* methods that don't mutate litellm_params dict
# Ensure litellm_params is a dict for type checking
params_dict: Dict[str, Any] = (
cast(Dict[str, Any], litellm_params) if litellm_params is not None else {}
)
vertex_project = VertexBase.safe_get_vertex_ai_project(
litellm_params=params_dict
)
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
litellm_params=params_dict
)
# Get access token from Vertex credentials
access_token, project_id = self.get_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
**headers,
}
return headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for Veo video generation.
Returns URL for :predictLongRunning endpoint:
https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL:predictLongRunning
"""
vertex_project = VertexBase.safe_get_vertex_ai_project(litellm_params)
vertex_location = VertexBase.safe_get_vertex_ai_location(litellm_params)
if not vertex_project:
raise ValueError(
"vertex_project is required for Vertex AI video generation. "
"Set it via environment variable VERTEXAI_PROJECT or pass as parameter."
)
# Default to us-central1 if no location specified
vertex_location = vertex_location or "us-central1"
# Extract model name (remove vertex_ai/ prefix if present)
model_name = model.replace("vertex_ai/", "")
# Construct the URL
if api_base:
base_url = api_base.rstrip("/")
else:
base_url = get_vertex_base_url(vertex_location)
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}"
return url
def transform_video_create_request(
self,
model: str,
prompt: str,
api_base: str,
video_create_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles, str]:
"""
Transform the video creation request for Veo API.
Veo expects:
{
"instances": [
{
"prompt": "A cat playing with a ball of yarn",
"image": {
"bytesBase64Encoded": "...",
"mimeType": "image/jpeg"
}
}
],
"parameters": {
"aspectRatio": "16:9",
"durationSeconds": 8
}
}
"""
# Build instance with prompt
instance_dict: Dict[str, Any] = {"prompt": prompt}
params_copy = video_create_optional_request_params.copy()
# Check if user wants to provide full instance dict
if "instances" in params_copy and isinstance(params_copy["instances"], dict):
# Replace/merge with user-provided instance
instance_dict.update(params_copy["instances"])
params_copy.pop("instances")
elif "image" in params_copy and params_copy["image"] is not None:
image = params_copy["image"]
if isinstance(image, dict):
# Already in Vertex format e.g. {"gcsUri": "gs://..."} or
# {"bytesBase64Encoded": "...", "mimeType": "..."}
image_data = image
elif isinstance(image, str) and image.startswith("gs://"):
# Bare GCS URI — Vertex AI accepts gcsUri natively, no download needed
image_data = {"gcsUri": image}
elif isinstance(image, str):
raise ValueError(
f"Unsupported image value '{image}'. "
"Provide a GCS URI (gs://...), a dict with 'gcsUri' or "
"'bytesBase64Encoded'/'mimeType', or a binary file-like object."
)
else:
# File-like object — encode to base64
image_data = _convert_image_to_vertex_format(image)
instance_dict["image"] = image_data
params_copy.pop("image")
# Extract a nested "parameters" block that map_openai_params may have placed
# inside params_copy (e.g. from provider-specific pass-through). Merging it
# flat prevents the double-nesting bug:
# {"parameters": {"parameters": {...}}} ← wrong
# {"parameters": {...}} ← correct
nested_params = params_copy.pop("parameters", None)
vertex_params: Dict[str, Any] = {}
if isinstance(nested_params, dict):
vertex_params.update(nested_params)
vertex_params.update(params_copy)
# Build request data directly (TypedDict doesn't have model_dump)
request_data: Dict[str, Any] = {"instances": [instance_dict]}
# Only add parameters if there are any
if vertex_params:
request_data["parameters"] = vertex_params
# Append :predictLongRunning endpoint to api_base
url = f"{api_base}:predictLongRunning"
# No files needed - everything is in JSON
return request_data, [], url
def transform_video_create_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
request_data: Optional[Dict] = None,
) -> VideoObject:
"""
Transform the Veo video creation response.
Veo returns:
{
"name": "projects/PROJECT_ID/locations/LOCATION/publishers/google/models/MODEL/operations/OPERATION_ID"
}
We return this as a VideoObject with:
- id: operation name (used for polling)
- status: "processing"
- usage: includes duration_seconds for cost calculation
"""
response_data = raw_response.json()
operation_name = response_data.get("name")
if not operation_name:
raise ValueError(f"No operation name in Veo response: {response_data}")
if custom_llm_provider:
video_id = encode_video_id_with_provider(
operation_name, custom_llm_provider, model
)
else:
video_id = operation_name
video_obj = VideoObject(
id=video_id, object="video", status="processing", model=model
)
usage_data = {}
if request_data:
parameters = request_data.get("parameters", {})
duration = (
parameters.get("durationSeconds")
or DEFAULT_GOOGLE_VIDEO_DURATION_SECONDS
)
if duration is not None:
try:
usage_data["duration_seconds"] = float(duration)
except (ValueError, TypeError):
pass
video_obj.usage = usage_data
return video_obj
def transform_video_status_retrieve_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the video status retrieve request for Veo API.
Veo polls operations using :fetchPredictOperation endpoint with POST request.
"""
operation_name = extract_original_video_id(video_id)
model = self.extract_model_from_operation_name(operation_name)
if not model:
raise ValueError(
f"Invalid operation name format: {operation_name}. "
"Expected format: projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL/operations/OPERATION_ID"
)
# Construct the full URL including model ID
# URL format: https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL:fetchPredictOperation
# Strip trailing slashes from api_base and append model
url = f"{api_base.rstrip('/')}/{model}:fetchPredictOperation"
# Request body contains the operation name
params = {"operationName": operation_name}
return url, params
def transform_video_status_retrieve_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
"""
Transform the Veo operation status response.
Veo returns:
{
"name": "projects/.../operations/OPERATION_ID",
"done": false # or true when complete
}
When done=true:
{
"name": "projects/.../operations/OPERATION_ID",
"done": true,
"response": {
"@type": "type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse",
"raiMediaFilteredCount": 0,
"videos": [
{
"bytesBase64Encoded": "...",
"mimeType": "video/mp4"
}
]
}
}
"""
response_data = raw_response.json()
operation_name = response_data.get("name", "")
is_done = response_data.get("done", False)
error_data = response_data.get("error")
# Extract model from operation name
model = self.extract_model_from_operation_name(operation_name)
if custom_llm_provider:
video_id = encode_video_id_with_provider(
operation_name, custom_llm_provider, model
)
else:
video_id = operation_name
# Convert createTime to Unix timestamp
create_time_str = response_data.get("metadata", {}).get("createTime")
if create_time_str:
try:
created_at = _convert_vertex_datetime_to_openai_datetime(
create_time_str
)
except Exception:
created_at = int(time.time())
else:
created_at = int(time.time())
if error_data:
status = "failed"
elif is_done:
status = "completed"
else:
status = "processing"
video_obj = VideoObject(
id=video_id,
object="video",
status=status,
model=model,
created_at=created_at,
error=error_data,
)
return video_obj
def transform_video_content_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
variant: Optional[str] = None,
) -> Tuple[str, Dict]:
"""
Transform the video content request for Veo API.
For Veo, we need to:
1. Poll the operation status to ensure it's complete
2. Extract the base64 video data from the response
3. Return it for decoding
Since we need to make an HTTP call here, we'll use the same fetchPredictOperation
approach as status retrieval.
"""
return self.transform_video_status_retrieve_request(
video_id, api_base, litellm_params, headers
)
def transform_video_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> bytes:
"""
Transform the Veo video content download response.
Extracts the base64 encoded video from the response and decodes it to bytes.
"""
response_data = raw_response.json()
if not response_data.get("done", False):
raise ValueError(
"Video generation is not complete yet. "
"Please check status with video_status() before downloading."
)
try:
video_response = response_data.get("response", {})
videos = video_response.get("videos", [])
if not videos or len(videos) == 0:
raise ValueError("No video data found in completed operation")
# Get the first video
video_data = videos[0]
base64_encoded = video_data.get("bytesBase64Encoded")
if not base64_encoded:
raise ValueError("No base64 encoded video data found")
# Decode base64 to bytes
video_bytes = base64.b64decode(base64_encoded)
return video_bytes
except (KeyError, IndexError) as e:
raise ValueError(f"Failed to extract video data: {e}")
def transform_video_remix_request(
self,
video_id: str,
prompt: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Video remix is not supported by Veo API.
"""
raise NotImplementedError(
"Video remix is not supported by Vertex AI Veo. "
"Please use video_generation() to create new videos."
)
def transform_video_remix_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
"""Video remix is not supported."""
raise NotImplementedError("Video remix is not supported by Vertex AI Veo.")
def transform_video_list_request(
self,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
limit: Optional[int] = None,
order: Optional[str] = None,
extra_query: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Video list is not supported by Veo API.
"""
raise NotImplementedError(
"Video list is not supported by Vertex AI Veo. "
"Use the operations endpoint directly if you need to list operations."
)
def transform_video_list_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> Dict[str, str]:
"""Video list is not supported."""
raise NotImplementedError("Video list is not supported by Vertex AI Veo.")
def transform_video_delete_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Video delete is not supported by Veo API.
"""
raise NotImplementedError(
"Video delete is not supported by Vertex AI Veo. "
"Videos are automatically cleaned up by Google."
)
def transform_video_delete_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> VideoObject:
"""Video delete is not supported."""
raise NotImplementedError("Video delete is not supported by Vertex AI Veo.")
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
from litellm.llms.vertex_ai.common_utils import VertexAIError
return VertexAIError(
status_code=status_code,
message=error_message,
headers=headers,
)