chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Vertex AI Agent Engine (Reasoning Engines) Provider
|
||||
|
||||
Supports Vertex AI Reasoning Engines via the :query and :streamQuery endpoints.
|
||||
"""
|
||||
|
||||
from litellm.llms.vertex_ai.agent_engine.transformation import (
|
||||
VertexAgentEngineConfig,
|
||||
VertexAgentEngineError,
|
||||
)
|
||||
|
||||
__all__ = ["VertexAgentEngineConfig", "VertexAgentEngineError"]
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
SSE Stream Iterator for Vertex AI Agent Engine.
|
||||
|
||||
Handles Server-Sent Events (SSE) streaming responses from Vertex AI Reasoning Engines.
|
||||
"""
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.openai import ChatCompletionUsageBlock
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
GenericStreamingChunk,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
|
||||
class VertexAgentEngineResponseIterator(BaseModelResponseIterator):
|
||||
"""
|
||||
Iterator for Vertex Agent Engine SSE streaming responses.
|
||||
|
||||
Uses BaseModelResponseIterator which handles sync/async iteration.
|
||||
We just need to implement chunk_parser to parse Vertex Agent Engine response format.
|
||||
"""
|
||||
|
||||
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
|
||||
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
|
||||
|
||||
def chunk_parser(
|
||||
self, chunk: dict
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
"""
|
||||
Parse a Vertex Agent Engine response chunk into ModelResponseStream.
|
||||
|
||||
Vertex Agent Engine response format:
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "..."}],
|
||||
"role": "model"
|
||||
},
|
||||
"finish_reason": "STOP",
|
||||
"usage_metadata": {
|
||||
"prompt_token_count": 100,
|
||||
"candidates_token_count": 50,
|
||||
"total_token_count": 150
|
||||
}
|
||||
}
|
||||
"""
|
||||
# Extract text from content.parts
|
||||
text = None
|
||||
content = chunk.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
text = part["text"]
|
||||
break
|
||||
|
||||
# Extract finish_reason
|
||||
finish_reason = None
|
||||
raw_finish_reason = chunk.get("finish_reason")
|
||||
if raw_finish_reason == "STOP":
|
||||
finish_reason = "stop"
|
||||
elif raw_finish_reason:
|
||||
finish_reason = raw_finish_reason.lower()
|
||||
|
||||
# Extract usage from usage_metadata
|
||||
usage = None
|
||||
usage_metadata = chunk.get("usage_metadata", {})
|
||||
if usage_metadata:
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=usage_metadata.get("prompt_token_count", 0),
|
||||
completion_tokens=usage_metadata.get("candidates_token_count", 0),
|
||||
total_tokens=usage_metadata.get("total_token_count", 0),
|
||||
)
|
||||
|
||||
# Return ModelResponseStream (OpenAI-compatible chunk)
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=finish_reason,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content=text,
|
||||
role="assistant" if text else None,
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
@@ -0,0 +1,517 @@
|
||||
"""
|
||||
Transformation for Vertex AI Agent Engine (Reasoning Engines)
|
||||
|
||||
Handles the transformation between LiteLLM's OpenAI-compatible format and
|
||||
Vertex AI Reasoning Engine's API format.
|
||||
|
||||
API Reference:
|
||||
- :query endpoint - for session management (create, get, list, delete)
|
||||
- :streamQuery endpoint - for actual queries (stream_query method)
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.vertex_ai.agent_engine.sse_iterator import (
|
||||
VertexAgentEngineResponseIterator,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
CustomStreamWrapper = Any
|
||||
|
||||
|
||||
class VertexAgentEngineError(BaseLLMException):
|
||||
"""Exception for Vertex Agent Engine errors."""
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(message=message, status_code=status_code)
|
||||
|
||||
|
||||
class VertexAgentEngineConfig(BaseConfig, VertexBase):
|
||||
"""
|
||||
Configuration for Vertex AI Agent Engine (Reasoning Engines).
|
||||
|
||||
Model format: vertex_ai/agent_engine/<resource_id>
|
||||
Where resource_id is the numeric ID of the reasoning engine.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
BaseConfig.__init__(self, **kwargs)
|
||||
VertexBase.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""Vertex Agent Engine has limited OpenAI compatible params."""
|
||||
return ["user"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""Map OpenAI params to Agent Engine params."""
|
||||
# Map 'user' to 'user_id' for session management
|
||||
if "user" in non_default_params:
|
||||
optional_params["user_id"] = non_default_params["user"]
|
||||
return optional_params
|
||||
|
||||
def _parse_model_string(self, model: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse model string to extract resource ID.
|
||||
|
||||
Model format: agent_engine/<project_number>/<location>/<engine_id>
|
||||
Or: agent_engine/<engine_id> (uses default project/location)
|
||||
|
||||
Returns: (resource_path, engine_id)
|
||||
"""
|
||||
# Remove 'agent_engine/' prefix if present
|
||||
if model.startswith("agent_engine/"):
|
||||
model = model[len("agent_engine/") :]
|
||||
|
||||
# Check if it's a full resource path
|
||||
if model.startswith("projects/"):
|
||||
# Full path: projects/123/locations/us-central1/reasoningEngines/456
|
||||
return model, model.split("/")[-1]
|
||||
|
||||
# Just the engine ID
|
||||
return model, model
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the request.
|
||||
|
||||
For Vertex Agent Engine:
|
||||
- Non-streaming: :query endpoint (for session management)
|
||||
- Streaming: :streamQuery endpoint (for actual queries)
|
||||
"""
|
||||
resource_path, engine_id = self._parse_model_string(model)
|
||||
|
||||
# Get project and location from litellm_params or environment
|
||||
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
|
||||
vertex_location = (
|
||||
self.safe_get_vertex_ai_location(litellm_params) or "us-central1"
|
||||
)
|
||||
|
||||
# Build the full resource path if only engine_id was provided
|
||||
if not resource_path.startswith("projects/"):
|
||||
if not vertex_project:
|
||||
raise ValueError(
|
||||
"vertex_project is required for Vertex Agent Engine. "
|
||||
"Set via litellm_params['vertex_project'] or VERTEXAI_PROJECT env var."
|
||||
)
|
||||
resource_path = f"projects/{vertex_project}/locations/{vertex_location}/reasoningEngines/{engine_id}"
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
# Always use :streamQuery endpoint for actual queries
|
||||
# The :query endpoint only supports session management methods
|
||||
# (create_session, get_session, list_sessions, delete_session, etc.)
|
||||
endpoint = f"{base_url}/v1beta1/{resource_path}:streamQuery"
|
||||
|
||||
verbose_logger.debug(f"Vertex Agent Engine URL: {endpoint}")
|
||||
return endpoint
|
||||
|
||||
def _get_auth_headers(
|
||||
self,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Dict[str, str]:
|
||||
"""Get authentication headers using Google Cloud credentials."""
|
||||
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
|
||||
|
||||
# Get access token using VertexBase
|
||||
access_token, project_id = self.get_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Vertex Agent Engine: Authenticated for project {project_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _get_user_id(self, optional_params: dict) -> str:
|
||||
"""Get or generate user ID for session management."""
|
||||
user_id = optional_params.get("user_id") or optional_params.get("user")
|
||||
if user_id:
|
||||
return user_id
|
||||
# Generate a user ID
|
||||
return f"litellm-user-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
def _get_session_id(self, optional_params: dict) -> Optional[str]:
|
||||
"""Get session ID if provided."""
|
||||
return optional_params.get("session_id")
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request to Vertex Agent Engine format.
|
||||
|
||||
The API expects:
|
||||
{
|
||||
"class_method": "stream_query",
|
||||
"input": {
|
||||
"message": "...",
|
||||
"user_id": "...",
|
||||
"session_id": "..." (optional)
|
||||
}
|
||||
}
|
||||
"""
|
||||
# Use the last message content as the prompt
|
||||
prompt = convert_content_list_to_str(messages[-1])
|
||||
|
||||
# Get user_id and session_id
|
||||
user_id = self._get_user_id(optional_params)
|
||||
session_id = self._get_session_id(optional_params)
|
||||
|
||||
# Build the input
|
||||
input_data: Dict[str, Any] = {
|
||||
"message": prompt,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
if session_id:
|
||||
input_data["session_id"] = session_id
|
||||
|
||||
# Build the request payload
|
||||
# Note: stream_query is used for both streaming and non-streaming
|
||||
# The difference is the endpoint (:streamQuery vs :query)
|
||||
payload = {
|
||||
"class_method": "stream_query",
|
||||
"input": input_data,
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"Vertex Agent Engine payload: {payload}")
|
||||
return payload
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Validate environment and set up authentication headers."""
|
||||
auth_headers = self._get_auth_headers(optional_params, litellm_params)
|
||||
headers.update(auth_headers)
|
||||
return headers
|
||||
|
||||
def _extract_text_from_response(self, response_data: dict) -> str:
|
||||
"""Extract text content from the response."""
|
||||
# Try to get from content.parts
|
||||
content = response_data.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
return part["text"]
|
||||
|
||||
# Try actions.state_delta
|
||||
actions = response_data.get("actions", {})
|
||||
state_delta = actions.get("state_delta", {})
|
||||
for key, value in state_delta.items():
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
|
||||
return ""
|
||||
|
||||
def _calculate_usage(
|
||||
self, model: str, messages: List[AllMessageValues], content: str
|
||||
) -> Optional[Usage]:
|
||||
"""Calculate token usage using LiteLLM's token counter."""
|
||||
try:
|
||||
from litellm.utils import token_counter
|
||||
|
||||
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
|
||||
completion_tokens = token_counter(
|
||||
model="gpt-3.5-turbo", text=content, count_response_tokens=True
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
|
||||
return None
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform Vertex Agent Engine response to LiteLLM ModelResponse format.
|
||||
|
||||
The response is a streaming SSE format even for non-streaming requests.
|
||||
We need to collect all the chunks and extract the final response.
|
||||
"""
|
||||
try:
|
||||
content_type = raw_response.headers.get("content-type", "").lower()
|
||||
verbose_logger.debug(
|
||||
f"Vertex Agent Engine response Content-Type: {content_type}"
|
||||
)
|
||||
|
||||
# Parse the SSE response
|
||||
response_text = raw_response.text
|
||||
verbose_logger.debug(f"Response (first 500 chars): {response_text[:500]}")
|
||||
|
||||
# Extract content from SSE stream
|
||||
content = ""
|
||||
for line in response_text.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if isinstance(data, dict):
|
||||
text = self._extract_text_from_response(data)
|
||||
if text:
|
||||
content = text # Use the last non-empty text
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Create the message
|
||||
message = Message(content=content, role="assistant")
|
||||
|
||||
# Create choices
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
# Update model response
|
||||
model_response.choices = [choice]
|
||||
model_response.model = model
|
||||
|
||||
# Calculate usage
|
||||
calculated_usage = self._calculate_usage(model, messages, content)
|
||||
if calculated_usage:
|
||||
setattr(model_response, "usage", calculated_usage)
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error processing Vertex Agent Engine response: {str(e)}"
|
||||
)
|
||||
raise VertexAgentEngineError(
|
||||
message=f"Error processing response: {str(e)}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
def get_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
) -> VertexAgentEngineResponseIterator:
|
||||
"""Return a streaming iterator for SSE responses."""
|
||||
return VertexAgentEngineResponseIterator(
|
||||
streaming_response=raw_response.iter_lines(),
|
||||
sync_stream=True,
|
||||
)
|
||||
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, "AsyncHTTPHandler"]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> "CustomStreamWrapper":
|
||||
"""Get a CustomStreamWrapper for synchronous streaming."""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
|
||||
# Avoid logging sensitive api_base directly
|
||||
verbose_logger.debug("Making sync streaming request to Vertex AI endpoint.")
|
||||
|
||||
# Make streaming request
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise VertexAgentEngineError(
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
# Create iterator for SSE stream
|
||||
completion_stream = self.get_streaming_response(
|
||||
model=model, raw_response=response
|
||||
)
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
async def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional["AsyncHTTPHandler"] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> "CustomStreamWrapper":
|
||||
"""Get a CustomStreamWrapper for asynchronous streaming."""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=cast(Any, "vertex_ai"), params={}
|
||||
)
|
||||
|
||||
# Avoid logging sensitive api_base directly
|
||||
verbose_logger.debug("Making async streaming request to Vertex AI endpoint.")
|
||||
|
||||
# Make async streaming request
|
||||
response = await client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise VertexAgentEngineError(
|
||||
status_code=response.status_code, message=str(await response.aread())
|
||||
)
|
||||
|
||||
# Create iterator for SSE stream (async)
|
||||
completion_stream = VertexAgentEngineResponseIterator(
|
||||
streaming_response=response.aiter_lines(),
|
||||
sync_stream=False,
|
||||
)
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
"""Indicates that this config has custom streaming support."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""Agent Engine does not allow passing `stream` in the request body."""
|
||||
return False
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return VertexAgentEngineError(status_code=status_code, message=error_message)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Agent Engine always returns SSE streams, so we use real streaming."""
|
||||
return False
|
||||
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Custom AWS Security Credentials Supplier for Vertex AI WIF.
|
||||
|
||||
Wraps boto3/botocore credentials so that google-auth can use them
|
||||
for the AWS-to-GCP Workload Identity Federation token exchange
|
||||
without hitting the EC2 instance metadata service.
|
||||
|
||||
Requires google-auth >= 2.29.0.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from google.auth import aws
|
||||
|
||||
|
||||
class AwsCredentialsSupplier(aws.AwsSecurityCredentialsSupplier):
|
||||
"""
|
||||
Supplies AWS credentials to google-auth's aws.Credentials for WIF
|
||||
token exchange.
|
||||
|
||||
This bypasses the default metadata-based credential retrieval,
|
||||
allowing WIF to work in environments where EC2 metadata is blocked.
|
||||
|
||||
Accepts a credentials_provider callable that is invoked on every
|
||||
get_aws_security_credentials() call, so that refreshed/rotated
|
||||
credentials are picked up automatically (important for temporary
|
||||
STS tokens).
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_provider: Callable, aws_region: str):
|
||||
"""
|
||||
Args:
|
||||
credentials_provider: A zero-arg callable that returns a
|
||||
botocore.credentials.Credentials object (with access_key,
|
||||
secret_key, and token attributes).
|
||||
aws_region: The AWS region string (e.g. "us-east-1").
|
||||
"""
|
||||
self._credentials_provider = credentials_provider
|
||||
self._region = aws_region
|
||||
|
||||
def get_aws_security_credentials(self, context, request):
|
||||
"""Return current AWS credentials for the GCP token exchange."""
|
||||
current = self._credentials_provider()
|
||||
return aws.AwsSecurityCredentials(
|
||||
access_key_id=current.access_key,
|
||||
secret_access_key=current.secret_key,
|
||||
session_token=current.token,
|
||||
)
|
||||
|
||||
def get_aws_region(self, context, request):
|
||||
"""Return the AWS region for credential verification."""
|
||||
return self._region
|
||||
@@ -0,0 +1,6 @@
|
||||
# Vertex AI Batch Prediction Jobs
|
||||
|
||||
Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
|
||||
|
||||
Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||
|
||||
@@ -0,0 +1,378 @@
|
||||
import json
|
||||
from typing import Any, Coroutine, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.llms.openai import CreateBatchRequest
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VERTEX_CREDENTIALS_TYPES,
|
||||
VertexAIBatchPredictionJob,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from .transformation import VertexAIBatchTransformation
|
||||
|
||||
|
||||
class VertexAIBatchPrediction(VertexLLM):
|
||||
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.gcs_bucket_name = gcs_bucket_name
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_batch_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
model=None,
|
||||
vertex_project=vertex_project or project_id,
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_api_version="v1",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
vertex_batch_request: VertexAIBatchPredictionJob = VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||
request=create_batch_data
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_create_batch(
|
||||
vertex_batch_request=vertex_batch_request,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(vertex_batch_request),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_create_batch(
|
||||
self,
|
||||
vertex_batch_request: VertexAIBatchPredictionJob,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
) -> LiteLLMBatch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
try:
|
||||
response = await client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(vertex_batch_request),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_body = e.response.text
|
||||
litellm.verbose_logger.error(
|
||||
"Vertex AI batch create failed: status=%s, body=%s",
|
||||
e.response.status_code,
|
||||
error_body[:1000],
|
||||
)
|
||||
raise
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
def create_vertex_batch_url(
|
||||
self,
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
) -> str:
|
||||
"""Return the base url for the vertex garden models"""
|
||||
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
batch_id: str,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_batch_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
# Append batch_id to the URL
|
||||
default_api_base = f"{default_api_base}/{batch_id}"
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
model=None,
|
||||
vertex_project=vertex_project or project_id,
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_api_version="v1",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_retrieve_batch(
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# Log the request using logging_obj if available
|
||||
if logging_obj is not None:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
|
||||
if isinstance(logging_obj, Logging):
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
"request_str": (
|
||||
f"\nGET Request Sent from LiteLLM:\n"
|
||||
f"curl -X GET \\\n"
|
||||
f"{api_base} \\\n"
|
||||
f"-H 'Authorization: Bearer ***REDACTED***' \\\n"
|
||||
f"-H 'Content-Type: application/json; charset=utf-8'\n"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
response = sync_handler.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_retrieve_batch(
|
||||
self,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> LiteLLMBatch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
|
||||
# Log the request using logging_obj if available
|
||||
if logging_obj is not None:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
|
||||
if isinstance(logging_obj, Logging):
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
"request_str": (
|
||||
f"\nGET Request Sent from LiteLLM:\n"
|
||||
f"curl -X GET \\\n"
|
||||
f"{api_base} \\\n"
|
||||
f"-H 'Authorization: Bearer ***REDACTED***' \\\n"
|
||||
f"-H 'Content-Type: application/json; charset=utf-8'\n"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
def list_batches(
|
||||
self,
|
||||
_is_async: bool,
|
||||
after: Optional[str],
|
||||
limit: Optional[int],
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
):
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_batch_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
if limit is not None:
|
||||
params["pageSize"] = str(limit)
|
||||
if after is not None:
|
||||
params["pageToken"] = after
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_list_batches(
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
response = sync_handler.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_list_batches(
|
||||
self,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any],
|
||||
):
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
response = await client.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
@@ -0,0 +1,227 @@
|
||||
from litellm._uuid import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
|
||||
class VertexAIBatchTransformation:
|
||||
"""
|
||||
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||
|
||||
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||
cls,
|
||||
request: CreateBatchRequest,
|
||||
) -> VertexAIBatchPredictionJob:
|
||||
"""
|
||||
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||
"""
|
||||
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
|
||||
input_file_id = request.get("input_file_id")
|
||||
if input_file_id is None:
|
||||
raise ValueError("input_file_id is required, but not provided")
|
||||
input_config: InputConfig = InputConfig(
|
||||
gcsSource=GcsSource(uris=[input_file_id]), instancesFormat="jsonl"
|
||||
)
|
||||
model: str = cls._get_model_from_gcs_file(input_file_id)
|
||||
output_config: OutputConfig = OutputConfig(
|
||||
predictionsFormat="jsonl",
|
||||
gcsDestination=GcsDestination(
|
||||
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
|
||||
),
|
||||
)
|
||||
return VertexAIBatchPredictionJob(
|
||||
inputConfig=input_config,
|
||||
outputConfig=output_config,
|
||||
model=model,
|
||||
displayName=request_display_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> LiteLLMBatch:
|
||||
return LiteLLMBatch(
|
||||
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
|
||||
completion_window="24hrs",
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response.get("createTime", "")
|
||||
),
|
||||
endpoint="",
|
||||
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
|
||||
response
|
||||
),
|
||||
object="batch",
|
||||
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
|
||||
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
|
||||
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
|
||||
response
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_vertex_ai_batch_list_response_to_openai_list_response(
|
||||
cls, response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transforms Vertex AI batch list response into OpenAI-compatible list response.
|
||||
"""
|
||||
|
||||
batch_jobs = response.get("batchPredictionJobs", []) or []
|
||||
data = [
|
||||
cls.transform_vertex_ai_batch_response_to_openai_batch_response(job)
|
||||
for job in batch_jobs
|
||||
]
|
||||
|
||||
first_id = data[0].id if len(data) > 0 else None
|
||||
last_id = data[-1].id if len(data) > 0 else None
|
||||
next_page_token = response.get("nextPageToken")
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"first_id": first_id,
|
||||
"last_id": last_id,
|
||||
"has_more": bool(next_page_token),
|
||||
"next_page_token": next_page_token,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_batch_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the batch id from the Vertex AI Batch response safely
|
||||
|
||||
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
|
||||
returns: `3814889423749775360`
|
||||
"""
|
||||
_name = response.get("name", "")
|
||||
if not _name:
|
||||
return ""
|
||||
|
||||
# Split by '/' and get the last part if it exists
|
||||
parts = _name.split("/")
|
||||
return parts[-1] if parts else _name
|
||||
|
||||
@classmethod
|
||||
def _get_input_file_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the input file id from the Vertex AI Batch response
|
||||
"""
|
||||
input_file_id: str = ""
|
||||
input_config = response.get("inputConfig")
|
||||
if input_config is None:
|
||||
return input_file_id
|
||||
|
||||
gcs_source = input_config.get("gcsSource")
|
||||
if gcs_source is None:
|
||||
return input_file_id
|
||||
|
||||
uris = gcs_source.get("uris", "")
|
||||
if len(uris) == 0:
|
||||
return input_file_id
|
||||
|
||||
return uris[0]
|
||||
|
||||
@classmethod
|
||||
def _get_output_file_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the output file id from the Vertex AI Batch response
|
||||
"""
|
||||
|
||||
output_file_id: str = (
|
||||
response.get("outputInfo", OutputInfo()).get("gcsOutputDirectory", "")
|
||||
+ "/predictions.jsonl"
|
||||
)
|
||||
if output_file_id != "/predictions.jsonl":
|
||||
return output_file_id
|
||||
|
||||
output_config = response.get("outputConfig")
|
||||
if output_config is None:
|
||||
return output_file_id
|
||||
|
||||
gcs_destination = output_config.get("gcsDestination")
|
||||
if gcs_destination is None:
|
||||
return output_file_id
|
||||
|
||||
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
|
||||
return output_uri_prefix
|
||||
|
||||
@classmethod
|
||||
def _get_batch_job_status_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> BatchJobStatus:
|
||||
"""
|
||||
Gets the batch job status from the Vertex AI Batch response
|
||||
|
||||
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
|
||||
"""
|
||||
state_mapping: Dict[str, BatchJobStatus] = {
|
||||
"JOB_STATE_UNSPECIFIED": "failed",
|
||||
"JOB_STATE_QUEUED": "validating",
|
||||
"JOB_STATE_PENDING": "validating",
|
||||
"JOB_STATE_RUNNING": "in_progress",
|
||||
"JOB_STATE_SUCCEEDED": "completed",
|
||||
"JOB_STATE_FAILED": "failed",
|
||||
"JOB_STATE_CANCELLING": "cancelling",
|
||||
"JOB_STATE_CANCELLED": "cancelled",
|
||||
"JOB_STATE_PAUSED": "in_progress",
|
||||
"JOB_STATE_EXPIRED": "expired",
|
||||
"JOB_STATE_UPDATING": "in_progress",
|
||||
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
|
||||
}
|
||||
|
||||
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
|
||||
return state_mapping[vertex_state]
|
||||
|
||||
@classmethod
|
||||
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
|
||||
"""
|
||||
Gets the gcs uri prefix from the input file id
|
||||
|
||||
Example:
|
||||
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
|
||||
returns: "gs://litellm-testing-bucket"
|
||||
|
||||
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
|
||||
returns: "gs://litellm-testing-bucket/batches"
|
||||
"""
|
||||
# Split the path and remove the filename
|
||||
path_parts = input_file_id.rsplit("/", 1)
|
||||
return path_parts[0]
|
||||
|
||||
@classmethod
|
||||
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
|
||||
"""
|
||||
Extracts the model from the gcs file uri
|
||||
|
||||
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
|
||||
|
||||
Why?
|
||||
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
|
||||
|
||||
|
||||
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
|
||||
returns: "publishers/google/models/gemini-1.5-flash-001"
|
||||
"""
|
||||
from urllib.parse import unquote
|
||||
|
||||
decoded_uri = unquote(gcs_file_uri)
|
||||
|
||||
model_path = decoded_uri.split("publishers/")[1]
|
||||
parts = model_path.split("/")
|
||||
model = f"publishers/{'/'.join(parts[:3])}"
|
||||
return model
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Transformation logic for context caching.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Literal
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.vertex_ai import CachedContentRequestBody
|
||||
from litellm.utils import is_cached_message
|
||||
|
||||
from ..common_utils import get_supports_system_message
|
||||
from ..gemini.transformation import (
|
||||
_gemini_convert_messages_with_history,
|
||||
_transform_system_message,
|
||||
)
|
||||
|
||||
|
||||
def get_first_continuous_block_idx(
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
|
||||
) -> int:
|
||||
"""
|
||||
Find the array index that ends the first continuous sequence of message blocks.
|
||||
|
||||
Args:
|
||||
filtered_messages: List of tuples containing (index, message) pairs
|
||||
|
||||
Returns:
|
||||
int: The array index where the first continuous sequence ends
|
||||
"""
|
||||
if not filtered_messages:
|
||||
return -1
|
||||
|
||||
if len(filtered_messages) == 1:
|
||||
return 0
|
||||
|
||||
current_value = filtered_messages[0][0]
|
||||
|
||||
# Search forward through the array indices
|
||||
for i in range(1, len(filtered_messages)):
|
||||
if filtered_messages[i][0] != current_value + 1:
|
||||
return i - 1
|
||||
current_value = filtered_messages[i][0]
|
||||
|
||||
# If we made it through the whole list, return the last index
|
||||
return len(filtered_messages) - 1
|
||||
|
||||
|
||||
def extract_ttl_from_cached_messages(messages: List[AllMessageValues]) -> Optional[str]:
|
||||
"""
|
||||
Extract TTL from cached messages. Returns the first valid TTL found.
|
||||
|
||||
Args:
|
||||
messages: List of messages to extract TTL from
|
||||
|
||||
Returns:
|
||||
Optional[str]: TTL string in format "3600s" or None if not found/invalid
|
||||
"""
|
||||
for message in messages:
|
||||
if not is_cached_message(message):
|
||||
continue
|
||||
|
||||
content = message.get("content")
|
||||
if not content or isinstance(content, str):
|
||||
continue
|
||||
|
||||
for content_item in content:
|
||||
# Type check to ensure content_item is a dictionary before calling .get()
|
||||
if not isinstance(content_item, dict):
|
||||
continue
|
||||
|
||||
cache_control = content_item.get("cache_control")
|
||||
if not cache_control or not isinstance(cache_control, dict):
|
||||
continue
|
||||
|
||||
if cache_control.get("type") != "ephemeral":
|
||||
continue
|
||||
|
||||
ttl = cache_control.get("ttl")
|
||||
if ttl and _is_valid_ttl_format(ttl):
|
||||
return str(ttl)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_valid_ttl_format(ttl: str) -> bool:
|
||||
"""
|
||||
Validate TTL format. Should be a string ending with 's' for seconds.
|
||||
Examples: "3600s", "7200s", "1.5s"
|
||||
|
||||
Args:
|
||||
ttl: TTL string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid format, False otherwise
|
||||
"""
|
||||
if not isinstance(ttl, str):
|
||||
return False
|
||||
|
||||
# TTL should end with 's' and contain a valid number before it
|
||||
pattern = r"^([0-9]*\.?[0-9]+)s$"
|
||||
match = re.match(pattern, ttl)
|
||||
|
||||
if not match:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Ensure the numeric part is valid and positive
|
||||
numeric_part = float(match.group(1))
|
||||
return numeric_part > 0
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def separate_cached_messages(
|
||||
messages: List[AllMessageValues],
|
||||
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
|
||||
"""
|
||||
Returns separated cached and non-cached messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages to be separated.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- cached_messages: List of cached messages.
|
||||
- non_cached_messages: List of non-cached messages.
|
||||
"""
|
||||
cached_messages: List[AllMessageValues] = []
|
||||
non_cached_messages: List[AllMessageValues] = []
|
||||
|
||||
# Extract cached messages and their indices
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if is_cached_message(message=message):
|
||||
filtered_messages.append((idx, message))
|
||||
|
||||
# Validate only one block of continuous cached messages
|
||||
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
|
||||
# Separate messages based on the block of cached messages
|
||||
if filtered_messages and last_continuous_block_idx is not None:
|
||||
first_cached_idx = filtered_messages[0][0]
|
||||
last_cached_idx = filtered_messages[last_continuous_block_idx][0]
|
||||
|
||||
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
|
||||
non_cached_messages = (
|
||||
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
|
||||
)
|
||||
else:
|
||||
non_cached_messages = messages
|
||||
|
||||
return cached_messages, non_cached_messages
|
||||
|
||||
|
||||
def transform_openai_messages_to_gemini_context_caching(
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
cache_key: str,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
) -> CachedContentRequestBody:
|
||||
# Extract TTL from cached messages BEFORE system message transformation
|
||||
ttl = extract_ttl_from_cached_messages(messages)
|
||||
|
||||
supports_system_message = get_supports_system_message(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
transformed_system_messages, new_messages = _transform_system_message(
|
||||
supports_system_message=supports_system_message, messages=messages
|
||||
)
|
||||
|
||||
transformed_messages = _gemini_convert_messages_with_history(
|
||||
messages=new_messages, model=model
|
||||
)
|
||||
|
||||
model_name = "models/{}".format(model)
|
||||
|
||||
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
|
||||
model_name = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/{model_name}"
|
||||
|
||||
data = CachedContentRequestBody(
|
||||
contents=transformed_messages,
|
||||
model=model_name,
|
||||
displayName=cache_key,
|
||||
)
|
||||
|
||||
# Add TTL if present and valid
|
||||
if ttl:
|
||||
data["ttl"] = ttl
|
||||
|
||||
if transformed_system_messages is not None:
|
||||
data["system_instruction"] = transformed_system_messages
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,578 @@
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import Cache, LiteLLMCacheType
|
||||
from litellm.constants import MINIMUM_PROMPT_CACHE_TOKEN_COUNT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.openai.openai import AllMessageValues
|
||||
from litellm.utils import is_prompt_caching_valid_prompt
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
CachedContentListAllResponseBody,
|
||||
VertexAICachedContentResponseObject,
|
||||
)
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
from .transformation import (
|
||||
separate_cached_messages,
|
||||
transform_openai_messages_to_gemini_context_caching,
|
||||
)
|
||||
|
||||
local_cache_obj = Cache(
|
||||
type=LiteLLMCacheType.LOCAL
|
||||
) # only used for calling 'get_cache_key' function
|
||||
|
||||
MAX_PAGINATION_PAGES = 100 # Reasonable upper bound for pagination
|
||||
|
||||
|
||||
class ContextCachingEndpoints(VertexBase):
|
||||
"""
|
||||
Covers context caching endpoints for Vertex AI + Google AI Studio
|
||||
|
||||
v0: covers Google AI Studio
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_token_and_url_context_caching(
|
||||
self,
|
||||
gemini_api_key: Optional[str],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
||||
Handles logic if it's google ai studio vs. vertex ai.
|
||||
|
||||
Returns
|
||||
token, url
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
auth_header = None
|
||||
endpoint = "cachedContents"
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
|
||||
endpoint, gemini_api_key
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
auth_header = vertex_auth_header
|
||||
endpoint = "cachedContents"
|
||||
if vertex_location == "global":
|
||||
url = f"https://aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
else:
|
||||
auth_header = vertex_auth_header
|
||||
endpoint = "cachedContents"
|
||||
if vertex_location == "global":
|
||||
url = f"https://aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
|
||||
return self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
gemini_api_key=gemini_api_key,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=auth_header,
|
||||
url=url,
|
||||
model=None,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version="v1beta1"
|
||||
if custom_llm_provider == "vertex_ai_beta"
|
||||
else "v1",
|
||||
)
|
||||
|
||||
def check_cache(
|
||||
self,
|
||||
cache_key: str,
|
||||
client: HTTPHandler,
|
||||
headers: dict,
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Checks if content already cached.
|
||||
|
||||
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
|
||||
|
||||
Returns
|
||||
- cached_content_name - str - cached content name stored on google. (if found.)
|
||||
OR
|
||||
- None
|
||||
"""
|
||||
|
||||
_, base_url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
page_token: Optional[str] = None
|
||||
|
||||
# Iterate through all pages
|
||||
for _ in range(MAX_PAGINATION_PAGES):
|
||||
# Build URL with pagination token if present
|
||||
if page_token:
|
||||
separator = "&" if "?" in base_url else "?"
|
||||
url = f"{base_url}{separator}pageToken={page_token}"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
try:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
resp = client.get(url=url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 403:
|
||||
return None
|
||||
raise VertexAIError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
raw_response = resp.json()
|
||||
logging_obj.post_call(original_response=raw_response)
|
||||
|
||||
if "cachedContents" not in raw_response:
|
||||
return None
|
||||
|
||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||
|
||||
if "cachedContents" not in all_cached_items:
|
||||
return None
|
||||
|
||||
# Check current page for matching cache_key
|
||||
for cached_item in all_cached_items["cachedContents"]:
|
||||
display_name = cached_item.get("displayName")
|
||||
if display_name is not None and display_name == cache_key:
|
||||
return cached_item.get("name")
|
||||
|
||||
# Check if there are more pages
|
||||
page_token = all_cached_items.get("nextPageToken")
|
||||
if not page_token:
|
||||
# No more pages, cache not found
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
async def async_check_cache(
|
||||
self,
|
||||
cache_key: str,
|
||||
client: AsyncHTTPHandler,
|
||||
headers: dict,
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Checks if content already cached.
|
||||
|
||||
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
|
||||
|
||||
Returns
|
||||
- cached_content_name - str - cached content name stored on google. (if found.)
|
||||
OR
|
||||
- None
|
||||
"""
|
||||
|
||||
_, base_url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
page_token: Optional[str] = None
|
||||
|
||||
# Iterate through all pages
|
||||
for _ in range(MAX_PAGINATION_PAGES):
|
||||
# Build URL with pagination token if present
|
||||
if page_token:
|
||||
separator = "&" if "?" in base_url else "?"
|
||||
url = f"{base_url}{separator}pageToken={page_token}"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
try:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(url=url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 403:
|
||||
return None
|
||||
raise VertexAIError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
raw_response = resp.json()
|
||||
logging_obj.post_call(original_response=raw_response)
|
||||
|
||||
if "cachedContents" not in raw_response:
|
||||
return None
|
||||
|
||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||
|
||||
if "cachedContents" not in all_cached_items:
|
||||
return None
|
||||
|
||||
# Check current page for matching cache_key
|
||||
for cached_item in all_cached_items["cachedContents"]:
|
||||
display_name = cached_item.get("displayName")
|
||||
if display_name is not None and display_name == cache_key:
|
||||
return cached_item.get("name")
|
||||
|
||||
# Check if there are more pages
|
||||
page_token = all_cached_items.get("nextPageToken")
|
||||
if not page_token:
|
||||
# No more pages, cache not found
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
def check_and_create_cache(
|
||||
self,
|
||||
messages: List[AllMessageValues], # receives openai format messages
|
||||
optional_params: dict, # cache the tools if present, in case cache content exists in messages
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
extra_headers: Optional[dict] = None,
|
||||
cached_content: Optional[str] = None,
|
||||
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
|
||||
"""
|
||||
Receives
|
||||
- messages: List of dict - messages in the openai format
|
||||
|
||||
Returns
|
||||
- messages - List[dict] - filtered list of messages in the openai format.
|
||||
- cached_content - str - the cache content id, to be passed in the gemini request body
|
||||
|
||||
Follows - https://ai.google.dev/api/caching#request-body
|
||||
"""
|
||||
if cached_content is not None:
|
||||
return messages, optional_params, cached_content
|
||||
|
||||
cached_messages, non_cached_messages = separate_cached_messages(
|
||||
messages=messages
|
||||
)
|
||||
|
||||
if len(cached_messages) == 0:
|
||||
return messages, optional_params, None
|
||||
|
||||
# Gemini requires a minimum of 1024 tokens for context caching.
|
||||
# Skip caching if the cached content is too small to avoid API errors.
|
||||
if not is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI context caching: cached content is below minimum token "
|
||||
"count (%d). Skipping context caching.",
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
|
||||
)
|
||||
return messages, optional_params, None
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if token is not None:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
## CHECK IF CACHED ALREADY
|
||||
generated_cache_key = local_cache_obj.get_cache_key(
|
||||
messages=cached_messages, tools=tools, model=model
|
||||
)
|
||||
google_cache_name = self.check_cache(
|
||||
cache_key=generated_cache_key,
|
||||
client=client,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
if google_cache_name:
|
||||
return non_cached_messages, optional_params, google_cache_name
|
||||
|
||||
## TRANSFORM REQUEST
|
||||
cached_content_request_body = (
|
||||
transform_openai_messages_to_gemini_context_caching(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
cache_key=generated_cache_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
)
|
||||
)
|
||||
|
||||
cached_content_request_body["tools"] = tools
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": cached_content_request_body,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
raw_response_cached = response.json()
|
||||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||
)
|
||||
return (
|
||||
non_cached_messages,
|
||||
optional_params,
|
||||
cached_content_response_obj["name"],
|
||||
)
|
||||
|
||||
async def async_check_and_create_cache(
|
||||
self,
|
||||
messages: List[AllMessageValues], # receives openai format messages
|
||||
optional_params: dict, # cache the tools if present, in case cache content exists in messages
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
extra_headers: Optional[dict] = None,
|
||||
cached_content: Optional[str] = None,
|
||||
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
|
||||
"""
|
||||
Receives
|
||||
- messages: List of dict - messages in the openai format
|
||||
|
||||
Returns
|
||||
- messages - List[dict] - filtered list of messages in the openai format.
|
||||
- cached_content - str - the cache content id, to be passed in the gemini request body
|
||||
|
||||
Follows - https://ai.google.dev/api/caching#request-body
|
||||
"""
|
||||
if cached_content is not None:
|
||||
return messages, optional_params, cached_content
|
||||
|
||||
cached_messages, non_cached_messages = separate_cached_messages(
|
||||
messages=messages
|
||||
)
|
||||
|
||||
if len(cached_messages) == 0:
|
||||
return messages, optional_params, None
|
||||
|
||||
# Gemini requires a minimum of 1024 tokens for context caching.
|
||||
# Skip caching if the cached content is too small to avoid API errors.
|
||||
if not is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI context caching: cached content is below minimum token "
|
||||
"count (%d). Skipping context caching.",
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
|
||||
)
|
||||
return messages, optional_params, None
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if token is not None:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
params={"timeout": timeout}, llm_provider=litellm.LlmProviders.VERTEX_AI
|
||||
)
|
||||
else:
|
||||
client = client
|
||||
|
||||
## CHECK IF CACHED ALREADY
|
||||
generated_cache_key = local_cache_obj.get_cache_key(
|
||||
messages=cached_messages, tools=tools, model=model
|
||||
)
|
||||
google_cache_name = await self.async_check_cache(
|
||||
cache_key=generated_cache_key,
|
||||
client=client,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
if google_cache_name:
|
||||
return non_cached_messages, optional_params, google_cache_name
|
||||
|
||||
## TRANSFORM REQUEST
|
||||
cached_content_request_body = (
|
||||
transform_openai_messages_to_gemini_context_caching(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
cache_key=generated_cache_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
)
|
||||
)
|
||||
|
||||
cached_content_request_body["tools"] = tools
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": cached_content_request_body,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
raw_response_cached = response.json()
|
||||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||
)
|
||||
return (
|
||||
non_cached_messages,
|
||||
optional_params,
|
||||
cached_content_response_obj["name"],
|
||||
)
|
||||
|
||||
def get_cache(self):
|
||||
pass
|
||||
|
||||
async def async_get_cache(self):
|
||||
pass
|
||||
@@ -0,0 +1,273 @@
|
||||
# What is this?
|
||||
## Cost calculation for Google AI Studio / Vertex AI models
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||
_is_above_128k,
|
||||
generic_cost_per_token,
|
||||
)
|
||||
from litellm.types.utils import ModelInfo, Usage
|
||||
|
||||
"""
|
||||
Gemini pricing covers:
|
||||
- token
|
||||
- image
|
||||
- audio
|
||||
- video
|
||||
"""
|
||||
|
||||
"""
|
||||
Vertex AI -> character based pricing
|
||||
|
||||
Google AI Studio -> token based pricing
|
||||
"""
|
||||
|
||||
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro", "gemini-2"]
|
||||
|
||||
|
||||
def cost_router(
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
call_type: Union[Literal["embedding", "aembedding"], str],
|
||||
) -> Literal["cost_per_character", "cost_per_token"]:
|
||||
"""
|
||||
Route the cost calc to the right place, based on model/call_type/etc.
|
||||
|
||||
Returns
|
||||
- str, the specific google cost calc function it should route to.
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai" and (
|
||||
"claude" in model
|
||||
or "llama" in model
|
||||
or "mistral" in model
|
||||
or "jamba" in model
|
||||
or "codestral" in model
|
||||
or "gemma" in model
|
||||
):
|
||||
return "cost_per_token"
|
||||
elif custom_llm_provider == "vertex_ai" and (
|
||||
call_type == "embedding" or call_type == "aembedding"
|
||||
):
|
||||
return "cost_per_token"
|
||||
elif custom_llm_provider == "vertex_ai" and ("gemini-2" in model):
|
||||
return "cost_per_token"
|
||||
return "cost_per_character"
|
||||
|
||||
|
||||
def cost_per_character(
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
usage: Usage,
|
||||
prompt_characters: Optional[float] = None,
|
||||
completion_characters: Optional[float] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per character for a given VertexAI model, input messages, and response object.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- custom_llm_provider: str, "vertex_ai-*"
|
||||
- prompt_characters: float, the number of input characters
|
||||
- completion_characters: float, the number of output characters
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
|
||||
Raises:
|
||||
Exception if model requires >128k pricing, but model cost not mapped
|
||||
"""
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
## GET MODEL INFO
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
## CALCULATE INPUT COST
|
||||
if prompt_characters is None:
|
||||
prompt_cost, _ = cost_per_token(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
usage=usage,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
if (
|
||||
_is_above_128k(tokens=prompt_characters * 4) # 1 token = 4 char
|
||||
and model not in models_without_dynamic_pricing
|
||||
):
|
||||
## check if character pricing, else default to token pricing
|
||||
assert (
|
||||
"input_cost_per_character_above_128k_tokens" in model_info
|
||||
and model_info["input_cost_per_character_above_128k_tokens"]
|
||||
is not None
|
||||
), "model info for model={} does not have 'input_cost_per_character_above_128k_tokens'-pricing for > 128k tokens\nmodel_info={}".format(
|
||||
model, model_info
|
||||
)
|
||||
prompt_cost = (
|
||||
prompt_characters
|
||||
* model_info["input_cost_per_character_above_128k_tokens"]
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
"input_cost_per_character" in model_info
|
||||
and model_info["input_cost_per_character"] is not None
|
||||
), "model info for model={} does not have 'input_cost_per_character'-pricing\nmodel_info={}".format(
|
||||
model, model_info
|
||||
)
|
||||
prompt_cost = prompt_characters * model_info["input_cost_per_character"]
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
"litellm.litellm_core_utils.llm_cost_calc.google.py::cost_per_character(): Exception occured - {}\nDefaulting to None".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
prompt_cost, _ = cost_per_token(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
if completion_characters is None:
|
||||
_, completion_cost = cost_per_token(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
usage=usage,
|
||||
)
|
||||
else:
|
||||
completion_tokens = usage.completion_tokens
|
||||
try:
|
||||
if (
|
||||
_is_above_128k(tokens=completion_characters * 4) # 1 token = 4 char
|
||||
and model not in models_without_dynamic_pricing
|
||||
):
|
||||
assert (
|
||||
"output_cost_per_character_above_128k_tokens" in model_info
|
||||
and model_info["output_cost_per_character_above_128k_tokens"]
|
||||
is not None
|
||||
), "model info for model={} does not have 'output_cost_per_character_above_128k_tokens' pricing\nmodel_info={}".format(
|
||||
model, model_info
|
||||
)
|
||||
completion_cost = (
|
||||
completion_tokens
|
||||
* model_info["output_cost_per_character_above_128k_tokens"]
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
"output_cost_per_character" in model_info
|
||||
and model_info["output_cost_per_character"] is not None
|
||||
), "model info for model={} does not have 'output_cost_per_character'-pricing\nmodel_info={}".format(
|
||||
model, model_info
|
||||
)
|
||||
completion_cost = (
|
||||
completion_characters * model_info["output_cost_per_character"]
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
"litellm.litellm_core_utils.llm_cost_calc.google.py::cost_per_character(): Exception occured - {}\nDefaulting to None".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
_, completion_cost = cost_per_token(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def _handle_128k_pricing(
|
||||
model_info: ModelInfo,
|
||||
usage: Usage,
|
||||
) -> Tuple[float, float]:
|
||||
## CALCULATE INPUT COST
|
||||
input_cost_per_token_above_128k_tokens = model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens"
|
||||
)
|
||||
output_cost_per_token_above_128k_tokens = model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens"
|
||||
)
|
||||
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
completion_tokens = usage.completion_tokens
|
||||
|
||||
if (
|
||||
_is_above_128k(tokens=prompt_tokens)
|
||||
and input_cost_per_token_above_128k_tokens is not None
|
||||
):
|
||||
prompt_cost = prompt_tokens * input_cost_per_token_above_128k_tokens
|
||||
else:
|
||||
prompt_cost = prompt_tokens * model_info["input_cost_per_token"]
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
output_cost_per_token_above_128k_tokens = model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens"
|
||||
)
|
||||
if (
|
||||
_is_above_128k(tokens=completion_tokens)
|
||||
and output_cost_per_token_above_128k_tokens is not None
|
||||
):
|
||||
completion_cost = completion_tokens * output_cost_per_token_above_128k_tokens
|
||||
else:
|
||||
completion_cost = completion_tokens * model_info["output_cost_per_token"]
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def cost_per_token(
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
usage: Usage,
|
||||
service_tier: Optional[str] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- custom_llm_provider: str, either "vertex_ai-*" or "gemini"
|
||||
- prompt_tokens: float, the number of input tokens
|
||||
- completion_tokens: float, the number of output tokens
|
||||
- service_tier: optional tier derived from Gemini trafficType
|
||||
("priority" for ON_DEMAND_PRIORITY, "flex" for FLEX/batch).
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
|
||||
Raises:
|
||||
Exception if model requires >128k pricing, but model cost not mapped
|
||||
"""
|
||||
|
||||
## GET MODEL INFO
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
## HANDLE 128k+ PRICING
|
||||
input_cost_per_token_above_128k_tokens = model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens"
|
||||
)
|
||||
output_cost_per_token_above_128k_tokens = model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens"
|
||||
)
|
||||
if (
|
||||
input_cost_per_token_above_128k_tokens is not None
|
||||
or output_cost_per_token_above_128k_tokens is not None
|
||||
):
|
||||
return _handle_128k_pricing(
|
||||
model_info=model_info,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return generic_cost_per_token(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
usage=usage,
|
||||
service_tier=service_tier,
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from litellm.llms.gemini.count_tokens.handler import GoogleAIStudioTokenCounter
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAITokenCounter(GoogleAIStudioTokenCounter, VertexBase):
|
||||
async def validate_environment(
|
||||
self,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
model: str = "",
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], str]:
|
||||
"""
|
||||
Returns a Tuple of headers and url for the Vertex AI countTokens endpoint.
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
vertex_credentials = self.get_vertex_ai_credentials(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params=litellm_params)
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params=litellm_params)
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(litellm_params)
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=None,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
mode="count_tokens",
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
return headers, api_base
|
||||
@@ -0,0 +1,246 @@
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
from typing import Any, Coroutine, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import LlmProviders
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
|
||||
GCSBucketBase,
|
||||
GCSLoggingConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.openai import (
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
from .transformation import VertexAIJsonlFilesTransformation
|
||||
|
||||
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
|
||||
|
||||
|
||||
class VertexAIFilesHandler(GCSBucketBase):
|
||||
"""
|
||||
Handles Calling VertexAI in OpenAI Files API format v1/files/*
|
||||
|
||||
This implementation uploads files on GCS Buckets
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.VERTEX_AI,
|
||||
)
|
||||
|
||||
async def async_create_file(
|
||||
self,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> OpenAIFileObject:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs={}
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
(
|
||||
logging_payload,
|
||||
object_name,
|
||||
) = vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content(
|
||||
openai_file_content=create_file_data.get("file")
|
||||
)
|
||||
gcs_upload_response = await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
)
|
||||
|
||||
return vertex_ai_files_transformation.transform_gcs_bucket_response_to_openai_file_object(
|
||||
create_file_data=create_file_data,
|
||||
gcs_upload_response=gcs_upload_response,
|
||||
)
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
|
||||
"""
|
||||
Creates a file on VertexAI GCS Bucket
|
||||
|
||||
Only supported for Async litellm.acreate_file
|
||||
"""
|
||||
|
||||
if _is_async:
|
||||
return self.async_create_file(
|
||||
create_file_data=create_file_data,
|
||||
api_base=api_base,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(
|
||||
self.async_create_file(
|
||||
create_file_data=create_file_data,
|
||||
api_base=api_base,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_bucket_and_object_from_file_id(self, file_id: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Extract bucket name and object path from URL-encoded file_id.
|
||||
|
||||
Expected format: gs%3A%2F%2Fbucket-name%2Fpath%2Fto%2Ffile
|
||||
Which decodes to: gs://bucket-name/path/to/file
|
||||
|
||||
Returns:
|
||||
tuple: (bucket_name, url_encoded_object_path)
|
||||
- bucket_name: "bucket-name"
|
||||
- url_encoded_object_path: "path%2Fto%2Ffile"
|
||||
"""
|
||||
decoded_path = urllib.parse.unquote(file_id)
|
||||
|
||||
if decoded_path.startswith("gs://"):
|
||||
full_path = decoded_path[5:] # Remove 'gs://' prefix
|
||||
else:
|
||||
full_path = decoded_path
|
||||
|
||||
if "/" in full_path:
|
||||
bucket_name, object_path = full_path.split("/", 1)
|
||||
else:
|
||||
bucket_name = full_path
|
||||
object_path = ""
|
||||
|
||||
encoded_object_path = urllib.parse.quote(object_path, safe="")
|
||||
|
||||
return bucket_name, encoded_object_path
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_content_request: FileContentRequest,
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Download file content from GCS bucket for VertexAI files.
|
||||
|
||||
Args:
|
||||
file_content_request: Contains file_id (URL-encoded GCS path)
|
||||
vertex_credentials: VertexAI credentials
|
||||
vertex_project: VertexAI project ID
|
||||
vertex_location: VertexAI location
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts
|
||||
|
||||
Returns:
|
||||
HttpxBinaryResponseContent: Binary content wrapped in compatible response format
|
||||
"""
|
||||
file_id = file_content_request.get("file_id")
|
||||
if not file_id:
|
||||
raise ValueError("file_id is required in file_content_request")
|
||||
|
||||
bucket_name, encoded_object_path = self._extract_bucket_and_object_from_file_id(
|
||||
file_id
|
||||
)
|
||||
|
||||
download_kwargs = {
|
||||
"standard_callback_dynamic_params": {"gcs_bucket_name": bucket_name}
|
||||
}
|
||||
|
||||
file_content = await self.download_gcs_object(
|
||||
object_name=encoded_object_path, **download_kwargs
|
||||
)
|
||||
|
||||
if file_content is None:
|
||||
decoded_path = urllib.parse.unquote(file_id)
|
||||
raise ValueError(f"Failed to download file from GCS: {decoded_path}")
|
||||
|
||||
decoded_path = urllib.parse.unquote(file_id)
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=file_content,
|
||||
headers={"content-type": "application/octet-stream"},
|
||||
request=httpx.Request(method="GET", url=decoded_path),
|
||||
)
|
||||
|
||||
return HttpxBinaryResponseContent(response=mock_response)
|
||||
|
||||
def file_content(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
"""
|
||||
Download file content from GCS bucket for VertexAI files.
|
||||
Supports both sync and async operations.
|
||||
|
||||
Args:
|
||||
_is_async: Whether to run asynchronously
|
||||
file_content_request: Contains file_id (URL-encoded GCS path)
|
||||
api_base: API base (unused for GCS operations)
|
||||
vertex_credentials: VertexAI credentials
|
||||
vertex_project: VertexAI project ID
|
||||
vertex_location: VertexAI location
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts
|
||||
|
||||
Returns:
|
||||
HttpxBinaryResponseContent or Coroutine: Binary content wrapped in compatible response format
|
||||
"""
|
||||
if _is_async:
|
||||
return self.afile_content(
|
||||
file_content_request=file_content_request,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(
|
||||
self.afile_content(
|
||||
file_content_request=file_content_request,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,607 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm._uuid import uuid
|
||||
from litellm.files.utils import FilesAPIUtils
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.transformation import _transform_request_body
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateFileRequest,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
PathLike,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import GcsBucketResponse
|
||||
from litellm.types.utils import ExtractedFileData, LlmProviders
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
|
||||
"""
|
||||
Config for VertexAI Files
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.VERTEX_AI
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if not api_key:
|
||||
api_key, _ = self.get_access_token(
|
||||
credentials=litellm_params.get("vertex_credentials"),
|
||||
project_id=litellm_params.get("vertex_project"),
|
||||
)
|
||||
if not api_key:
|
||||
raise ValueError("api_key is required")
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||
"""
|
||||
Helper to extract content from various OpenAI file types and return as string.
|
||||
|
||||
Handles:
|
||||
- Direct content (str, bytes, IO[bytes])
|
||||
- Tuple formats: (filename, content, [content_type], [headers])
|
||||
- PathLike objects
|
||||
"""
|
||||
content: Union[str, bytes] = b""
|
||||
# Extract file content from tuple if necessary
|
||||
if isinstance(openai_file_content, tuple):
|
||||
# Take the second element which is always the file content
|
||||
file_content = openai_file_content[1]
|
||||
else:
|
||||
file_content = openai_file_content
|
||||
|
||||
# Handle different file content types
|
||||
if isinstance(file_content, str):
|
||||
# String content can be used directly
|
||||
content = file_content
|
||||
elif isinstance(file_content, bytes):
|
||||
# Bytes content can be decoded
|
||||
content = file_content
|
||||
elif isinstance(file_content, PathLike): # PathLike
|
||||
with open(str(file_content), "rb") as f:
|
||||
content = f.read()
|
||||
elif hasattr(file_content, "read"): # IO[bytes]
|
||||
# File-like objects need to be read
|
||||
content = file_content.read()
|
||||
|
||||
# Ensure content is string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return content
|
||||
|
||||
def _get_gcs_object_name_from_batch_jsonl(
|
||||
self,
|
||||
openai_jsonl_content: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||
|
||||
named as: litellm-vertex-{model}-{uuid}
|
||||
"""
|
||||
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||
if "publishers/google/models" not in _model:
|
||||
_model = f"publishers/google/models/{_model}"
|
||||
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||
return object_name
|
||||
|
||||
def get_object_name(
|
||||
self, extracted_file_data: ExtractedFileData, purpose: str
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name for the request
|
||||
"""
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
|
||||
if extracted_file_data_content is None:
|
||||
raise ValueError("file content is required")
|
||||
|
||||
if purpose == "batch":
|
||||
## 1. If jsonl, check if there's a model name
|
||||
file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
if len(openai_jsonl_content) > 0:
|
||||
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
|
||||
|
||||
## 2. If not jsonl, return the filename
|
||||
filename = extracted_file_data.get("filename")
|
||||
if filename:
|
||||
return filename
|
||||
## 3. If no file name, return timestamp
|
||||
return str(int(time.time()))
|
||||
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateFileRequest,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
bucket_name = (
|
||||
litellm_params.get("bucket_name")
|
||||
or litellm_params.get("litellm_metadata", {}).pop("gcs_bucket_name", None)
|
||||
or os.getenv("GCS_BUCKET_NAME")
|
||||
)
|
||||
if not bucket_name:
|
||||
raise ValueError("GCS bucket_name is required")
|
||||
file_data = data.get("file")
|
||||
purpose = data.get("purpose")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
if purpose is None:
|
||||
raise ValueError("purpose is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
object_name = self.get_object_name(extracted_file_data, purpose)
|
||||
endpoint = (
|
||||
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
)
|
||||
api_base = api_base or "https://storage.googleapis.com"
|
||||
if not api_base:
|
||||
raise ValueError("api_base is required")
|
||||
|
||||
return f"{api_base}/{endpoint}"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def _map_openai_to_vertex_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
wrapper to call VertexGeminiConfig.map_openai_params
|
||||
"""
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
|
||||
config = VertexGeminiConfig()
|
||||
_model = openai_request_body.get("model", "")
|
||||
vertex_params = config.map_openai_params(
|
||||
model=_model,
|
||||
non_default_params=openai_request_body,
|
||||
optional_params={},
|
||||
drop_params=False,
|
||||
)
|
||||
return vertex_params
|
||||
|
||||
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||
|
||||
jsonl body for vertex is {"request": <request_body>}
|
||||
Example Vertex jsonl
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||
"""
|
||||
|
||||
vertex_jsonl_content = []
|
||||
for _openai_jsonl_content in openai_jsonl_content:
|
||||
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||
vertex_request_body = _transform_request_body(
|
||||
messages=openai_request_body.get("messages", []),
|
||||
model=openai_request_body.get("model", ""),
|
||||
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
cached_content=None,
|
||||
)
|
||||
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||
return vertex_jsonl_content
|
||||
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, dict]:
|
||||
"""
|
||||
2 Cases:
|
||||
1. Handle basic file upload
|
||||
2. Handle batch file upload (.jsonl)
|
||||
"""
|
||||
file_data = create_file_data.get("file")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
|
||||
if extracted_file_data_content is None:
|
||||
raise ValueError("file content is required")
|
||||
|
||||
if FilesAPIUtils.is_batch_jsonl_file(
|
||||
create_file_data=create_file_data,
|
||||
extracted_file_data=extracted_file_data,
|
||||
):
|
||||
## 1. If jsonl, check if there's a model name
|
||||
file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
vertex_jsonl_content = (
|
||||
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
)
|
||||
return "\n".join(json.dumps(item) for item in vertex_jsonl_content)
|
||||
elif isinstance(extracted_file_data_content, bytes):
|
||||
return extracted_file_data_content
|
||||
else:
|
||||
raise ValueError("Unsupported file content type")
|
||||
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transform VertexAI File upload response into OpenAI-style FileObject
|
||||
"""
|
||||
response_json = raw_response.json()
|
||||
|
||||
try:
|
||||
response_object = GcsBucketResponse(**response_json) # type: ignore
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Error reading GCS bucket response: {e}",
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
gcs_id = response_object.get("id", "")
|
||||
# Remove the last numeric ID from the path
|
||||
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||
|
||||
return OpenAIFileObject(
|
||||
purpose=response_object.get("purpose", "batch"),
|
||||
id=f"gs://{gcs_id}",
|
||||
filename=response_object.get("name", ""),
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response_object.get("timeCreated", "")
|
||||
),
|
||||
status="uploaded",
|
||||
bytes=int(response_object.get("size", 0)),
|
||||
object="file",
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return VertexAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def _parse_gcs_uri(self, file_id: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse a GCS URI (gs://bucket/path/to/object) into (bucket, url-encoded-object-path).
|
||||
Handles both raw and URL-encoded input.
|
||||
"""
|
||||
import urllib.parse
|
||||
|
||||
decoded = urllib.parse.unquote(file_id)
|
||||
if decoded.startswith("gs://"):
|
||||
full_path = decoded[5:]
|
||||
else:
|
||||
full_path = decoded
|
||||
|
||||
if "/" in full_path:
|
||||
bucket_name, object_path = full_path.split("/", 1)
|
||||
else:
|
||||
bucket_name = full_path
|
||||
object_path = ""
|
||||
|
||||
encoded_object = urllib.parse.quote(object_path, safe="")
|
||||
return bucket_name, encoded_object
|
||||
|
||||
def transform_retrieve_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
bucket, encoded_object = self._parse_gcs_uri(file_id)
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded_object}"
|
||||
return url, {}
|
||||
|
||||
def transform_retrieve_file_response(
|
||||
self,
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
response_json = raw_response.json()
|
||||
gcs_id = response_json.get("id", "")
|
||||
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||
return OpenAIFileObject(
|
||||
id=f"gs://{gcs_id}",
|
||||
bytes=int(response_json.get("size", 0)),
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response_json.get("timeCreated", "")
|
||||
),
|
||||
filename=response_json.get("name", ""),
|
||||
object="file",
|
||||
purpose=response_json.get("metadata", {}).get("purpose", "batch"),
|
||||
status="processed",
|
||||
status_details=None,
|
||||
)
|
||||
|
||||
def transform_delete_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
bucket, encoded_object = self._parse_gcs_uri(file_id)
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded_object}"
|
||||
return url, {}
|
||||
|
||||
def transform_delete_file_response(
|
||||
self,
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> FileDeleted:
|
||||
file_id = "deleted"
|
||||
if hasattr(raw_response, "request") and raw_response.request:
|
||||
url = str(raw_response.request.url)
|
||||
if "/b/" in url and "/o/" in url:
|
||||
import urllib.parse
|
||||
|
||||
bucket_part = url.split("/b/")[-1].split("/o/")[0]
|
||||
encoded_name = url.split("/o/")[-1].split("?")[0]
|
||||
file_id = f"gs://{bucket_part}/{urllib.parse.unquote(encoded_name)}"
|
||||
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||
|
||||
def transform_list_files_request(
|
||||
self,
|
||||
purpose: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError("VertexAIFilesConfig does not support file listing")
|
||||
|
||||
def transform_list_files_response(
|
||||
self,
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
raise NotImplementedError("VertexAIFilesConfig does not support file listing")
|
||||
|
||||
def transform_file_content_request(
|
||||
self,
|
||||
file_content_request,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
file_id = file_content_request.get("file_id", "")
|
||||
bucket, encoded_object = self._parse_gcs_uri(file_id)
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded_object}?alt=media"
|
||||
return url, {}
|
||||
|
||||
def transform_file_content_response(
|
||||
self,
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
return HttpxBinaryResponseContent(response=raw_response)
|
||||
|
||||
|
||||
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
|
||||
"""
|
||||
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
||||
"""
|
||||
|
||||
def transform_openai_file_content_to_vertex_ai_file_content(
|
||||
self, openai_file_content: Optional[FileTypes] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Transforms OpenAI FileContentRequest to VertexAI FileContentRequest
|
||||
"""
|
||||
|
||||
if openai_file_content is None:
|
||||
raise ValueError("contents of file are None")
|
||||
# Read the content of the file
|
||||
file_content = self._get_content_from_openai_file(openai_file_content)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
vertex_jsonl_content = (
|
||||
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
)
|
||||
vertex_jsonl_string = "\n".join(
|
||||
json.dumps(item) for item in vertex_jsonl_content
|
||||
)
|
||||
object_name = self._get_gcs_object_name(
|
||||
openai_jsonl_content=openai_jsonl_content
|
||||
)
|
||||
return vertex_jsonl_string, object_name
|
||||
|
||||
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||
):
|
||||
"""
|
||||
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||
|
||||
jsonl body for vertex is {"request": <request_body>}
|
||||
Example Vertex jsonl
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||
"""
|
||||
|
||||
vertex_jsonl_content = []
|
||||
for _openai_jsonl_content in openai_jsonl_content:
|
||||
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||
vertex_request_body = _transform_request_body(
|
||||
messages=openai_request_body.get("messages", []),
|
||||
model=openai_request_body.get("model", ""),
|
||||
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
cached_content=None,
|
||||
)
|
||||
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||
return vertex_jsonl_content
|
||||
|
||||
def _get_gcs_object_name(
|
||||
self,
|
||||
openai_jsonl_content: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||
|
||||
named as: litellm-vertex-{model}-{uuid}
|
||||
"""
|
||||
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||
if "publishers/google/models" not in _model:
|
||||
_model = f"publishers/google/models/{_model}"
|
||||
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||
return object_name
|
||||
|
||||
def _map_openai_to_vertex_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
wrapper to call VertexGeminiConfig.map_openai_params
|
||||
"""
|
||||
_model = openai_request_body.get("model", "")
|
||||
vertex_params = self.map_openai_params(
|
||||
model=_model,
|
||||
non_default_params=openai_request_body,
|
||||
optional_params={},
|
||||
drop_params=False,
|
||||
)
|
||||
return vertex_params
|
||||
|
||||
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||
"""
|
||||
Helper to extract content from various OpenAI file types and return as string.
|
||||
|
||||
Handles:
|
||||
- Direct content (str, bytes, IO[bytes])
|
||||
- Tuple formats: (filename, content, [content_type], [headers])
|
||||
- PathLike objects
|
||||
"""
|
||||
content: Union[str, bytes] = b""
|
||||
# Extract file content from tuple if necessary
|
||||
if isinstance(openai_file_content, tuple):
|
||||
# Take the second element which is always the file content
|
||||
file_content = openai_file_content[1]
|
||||
else:
|
||||
file_content = openai_file_content
|
||||
|
||||
# Handle different file content types
|
||||
if isinstance(file_content, str):
|
||||
# String content can be used directly
|
||||
content = file_content
|
||||
elif isinstance(file_content, bytes):
|
||||
# Bytes content can be decoded
|
||||
content = file_content
|
||||
elif isinstance(file_content, PathLike): # PathLike
|
||||
with open(str(file_content), "rb") as f:
|
||||
content = f.read()
|
||||
elif hasattr(file_content, "read"): # IO[bytes]
|
||||
# File-like objects need to be read
|
||||
content = file_content.read()
|
||||
|
||||
# Ensure content is string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return content
|
||||
|
||||
def transform_gcs_bucket_response_to_openai_file_object(
|
||||
self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any]
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transforms GCS Bucket upload file response to OpenAI FileObject
|
||||
"""
|
||||
gcs_id = gcs_upload_response.get("id", "")
|
||||
# Remove the last numeric ID from the path
|
||||
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||
|
||||
return OpenAIFileObject(
|
||||
purpose=create_file_data.get("purpose", "batch"),
|
||||
id=f"gs://{gcs_id}",
|
||||
filename=gcs_upload_response.get("name", ""),
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=gcs_upload_response.get("timeCreated", "")
|
||||
),
|
||||
status="uploaded",
|
||||
bytes=gcs_upload_response.get("size", 0),
|
||||
object="file",
|
||||
)
|
||||
@@ -0,0 +1,374 @@
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Coroutine, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters
|
||||
from litellm.types.llms.openai import FineTuningJobCreate
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VERTEX_CREDENTIALS_TYPES,
|
||||
FineTuneHyperparameters,
|
||||
FineTuneJobCreate,
|
||||
FineTunesupervisedTuningSpec,
|
||||
ResponseSupervisedTuningSpec,
|
||||
ResponseTuningJob,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMFineTuningJob
|
||||
|
||||
|
||||
class VertexFineTuningAPI(VertexLLM):
|
||||
"""
|
||||
Vertex methods to support for batches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
params={"timeout": 600.0},
|
||||
)
|
||||
|
||||
def convert_response_created_at(self, response: ResponseTuningJob):
|
||||
try:
|
||||
create_time_str = response.get("createTime", "") or ""
|
||||
create_time_datetime = datetime.fromisoformat(
|
||||
create_time_str.replace("Z", "+00:00")
|
||||
)
|
||||
# Convert to Unix timestamp (seconds since epoch)
|
||||
created_at = int(create_time_datetime.timestamp())
|
||||
|
||||
return created_at
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def convert_openai_request_to_vertex(
|
||||
self,
|
||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||
original_hyperparameters: dict = {},
|
||||
kwargs: Optional[dict] = None,
|
||||
) -> FineTuneJobCreate:
|
||||
"""
|
||||
convert request from OpenAI format to Vertex format
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
|
||||
supervised_tuning_spec = FineTunesupervisedTuningSpec(
|
||||
"""
|
||||
|
||||
supervised_tuning_spec = FineTunesupervisedTuningSpec(
|
||||
training_dataset_uri=create_fine_tuning_job_data.training_file,
|
||||
)
|
||||
|
||||
if create_fine_tuning_job_data.validation_file:
|
||||
supervised_tuning_spec[
|
||||
"validation_dataset"
|
||||
] = create_fine_tuning_job_data.validation_file
|
||||
|
||||
_vertex_hyperparameters = (
|
||||
self._transform_openai_hyperparameters_to_vertex_hyperparameters(
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
kwargs=kwargs,
|
||||
original_hyperparameters=original_hyperparameters,
|
||||
)
|
||||
)
|
||||
|
||||
if _vertex_hyperparameters and len(_vertex_hyperparameters) > 0:
|
||||
supervised_tuning_spec["hyperParameters"] = _vertex_hyperparameters
|
||||
|
||||
fine_tune_job = FineTuneJobCreate(
|
||||
baseModel=create_fine_tuning_job_data.model,
|
||||
supervisedTuningSpec=supervised_tuning_spec,
|
||||
tunedModelDisplayName=create_fine_tuning_job_data.suffix,
|
||||
)
|
||||
|
||||
return fine_tune_job
|
||||
|
||||
def _transform_openai_hyperparameters_to_vertex_hyperparameters(
|
||||
self,
|
||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||
original_hyperparameters: dict = {},
|
||||
kwargs: Optional[dict] = None,
|
||||
) -> FineTuneHyperparameters:
|
||||
_oai_hyperparameters = create_fine_tuning_job_data.hyperparameters
|
||||
_vertex_hyperparameters = FineTuneHyperparameters()
|
||||
if _oai_hyperparameters:
|
||||
if _oai_hyperparameters.n_epochs:
|
||||
_vertex_hyperparameters["epoch_count"] = int(
|
||||
_oai_hyperparameters.n_epochs
|
||||
)
|
||||
if _oai_hyperparameters.learning_rate_multiplier:
|
||||
_vertex_hyperparameters["learning_rate_multiplier"] = float(
|
||||
_oai_hyperparameters.learning_rate_multiplier
|
||||
)
|
||||
|
||||
_adapter_size = original_hyperparameters.get("adapter_size", None)
|
||||
if _adapter_size:
|
||||
_vertex_hyperparameters["adapter_size"] = _adapter_size
|
||||
|
||||
return _vertex_hyperparameters
|
||||
|
||||
def convert_vertex_response_to_open_ai_response(
|
||||
self, response: ResponseTuningJob
|
||||
) -> LiteLLMFineTuningJob:
|
||||
status: Literal[
|
||||
"validating_files", "queued", "running", "succeeded", "failed", "cancelled"
|
||||
] = "queued"
|
||||
if response["state"] == "JOB_STATE_PENDING":
|
||||
status = "queued"
|
||||
if response["state"] == "JOB_STATE_SUCCEEDED":
|
||||
status = "succeeded"
|
||||
if response["state"] == "JOB_STATE_FAILED":
|
||||
status = "failed"
|
||||
if response["state"] == "JOB_STATE_CANCELLED":
|
||||
status = "cancelled"
|
||||
if response["state"] == "JOB_STATE_RUNNING":
|
||||
status = "running"
|
||||
|
||||
created_at = self.convert_response_created_at(response)
|
||||
|
||||
_supervisedTuningSpec: ResponseSupervisedTuningSpec = (
|
||||
response.get("supervisedTuningSpec", None) or {}
|
||||
)
|
||||
training_uri: str = _supervisedTuningSpec.get("trainingDatasetUri", "") or ""
|
||||
return LiteLLMFineTuningJob(
|
||||
id=response.get("name", "") or "",
|
||||
created_at=created_at,
|
||||
fine_tuned_model=response.get("tunedModelDisplayName", ""),
|
||||
finished_at=None,
|
||||
hyperparameters=self._translate_vertex_response_hyperparameters(
|
||||
vertex_hyper_parameters=_supervisedTuningSpec.get(
|
||||
"hyperParameters", FineTuneHyperparameters()
|
||||
)
|
||||
or {}
|
||||
),
|
||||
model=response.get("baseModel", "") or "",
|
||||
object="fine_tuning.job",
|
||||
organization_id="",
|
||||
result_files=[],
|
||||
seed=0,
|
||||
status=status,
|
||||
trained_tokens=None,
|
||||
training_file=training_uri,
|
||||
validation_file=None,
|
||||
estimated_finish=None,
|
||||
integrations=[],
|
||||
)
|
||||
|
||||
def _translate_vertex_response_hyperparameters(
|
||||
self, vertex_hyper_parameters: FineTuneHyperparameters
|
||||
) -> OpenAIFineTuningHyperparameters:
|
||||
"""
|
||||
translate vertex responsehyperparameters to openai hyperparameters
|
||||
"""
|
||||
_dict_remaining_hyperparameters: dict = dict(vertex_hyper_parameters)
|
||||
return OpenAIFineTuningHyperparameters(
|
||||
n_epochs=_dict_remaining_hyperparameters.pop("epoch_count", 0),
|
||||
**_dict_remaining_hyperparameters,
|
||||
)
|
||||
|
||||
async def acreate_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_url: str,
|
||||
headers: dict,
|
||||
request_data: FineTuneJobCreate,
|
||||
):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"about to create fine tuning job: %s, request_data: %s",
|
||||
fine_tuning_url,
|
||||
json.dumps(request_data, indent=4),
|
||||
)
|
||||
if self.async_handler is None:
|
||||
raise ValueError(
|
||||
"VertexAI Fine Tuning - async_handler is not initialized"
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
headers=headers,
|
||||
url=fine_tuning_url,
|
||||
json=request_data, # type: ignore
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"got response from creating fine tuning job: %s", response.json()
|
||||
)
|
||||
|
||||
vertex_response = ResponseTuningJob( # type: ignore
|
||||
**response.json(),
|
||||
)
|
||||
|
||||
verbose_logger.debug("vertex_response %s", vertex_response)
|
||||
open_ai_response = self.convert_vertex_response_to_open_ai_response(
|
||||
vertex_response
|
||||
)
|
||||
return open_ai_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("asyncerror creating fine tuning job %s", e)
|
||||
trace_back_str = traceback.format_exc()
|
||||
verbose_logger.error(trace_back_str)
|
||||
raise e
|
||||
|
||||
def create_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
kwargs: Optional[dict] = None,
|
||||
original_hyperparameters: Optional[dict] = {},
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
fine_tune_job = self.convert_openai_request_to_vertex(
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
kwargs=kwargs,
|
||||
original_hyperparameters=original_hyperparameters or {},
|
||||
)
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
fine_tuning_url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
|
||||
if _is_async is True:
|
||||
return self.acreate_fine_tuning_job( # type: ignore
|
||||
fine_tuning_url=fine_tuning_url,
|
||||
headers=headers,
|
||||
request_data=fine_tune_job,
|
||||
)
|
||||
sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
verbose_logger.debug(
|
||||
"about to create fine tuning job: %s, request_data: %s",
|
||||
fine_tuning_url,
|
||||
fine_tune_job,
|
||||
)
|
||||
response = sync_handler.post(
|
||||
headers=headers,
|
||||
url=fine_tuning_url,
|
||||
json=fine_tune_job, # type: ignore
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"got response from creating fine tuning job: %s", response.json()
|
||||
)
|
||||
vertex_response = ResponseTuningJob( # type: ignore
|
||||
**response.json(),
|
||||
)
|
||||
|
||||
verbose_logger.debug("vertex_response %s", vertex_response)
|
||||
open_ai_response = self.convert_vertex_response_to_open_ai_response(
|
||||
vertex_response
|
||||
)
|
||||
return open_ai_response
|
||||
|
||||
async def pass_through_vertex_ai_POST_request(
|
||||
self,
|
||||
request_data: dict,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
vertex_credentials: str,
|
||||
request_route: str,
|
||||
):
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base="",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
url = None
|
||||
if request_route == "/tuningJobs":
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
|
||||
elif "/tuningJobs/" in request_route and "cancel" in request_route:
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs{request_route}"
|
||||
elif "generateContent" in request_route:
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
elif "predict" in request_route:
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
elif "/batchPredictionJobs" in request_route:
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
elif "countTokens" in request_route:
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
elif "cachedContents" in request_route:
|
||||
_model = request_data.get("model")
|
||||
if _model is not None and "/publishers/google/models/" not in _model:
|
||||
request_data[
|
||||
"model"
|
||||
] = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}"
|
||||
|
||||
url = f"{base_url}/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported Vertex AI request route: {request_route}")
|
||||
if self.async_handler is None:
|
||||
raise ValueError("VertexAI Fine Tuning - async_handler is not initialized")
|
||||
|
||||
response = await self.async_handler.post(
|
||||
headers=headers,
|
||||
url=url,
|
||||
json=request_data, # type: ignore
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
return response_json
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Cost calculator for Vertex AI Gemini.
|
||||
|
||||
Used because there are differences in how Google AI Studio and Vertex AI Gemini handle web search requests.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ModelInfo, Usage
|
||||
|
||||
|
||||
def cost_per_web_search_request(usage: "Usage", model_info: "ModelInfo") -> float:
|
||||
"""
|
||||
Calculate the cost of a web search request for Vertex AI Gemini.
|
||||
|
||||
Vertex AI charges $35/1000 prompts, independent of the number of web search requests.
|
||||
|
||||
For a single call, this is $35e-3 USD.
|
||||
|
||||
Args:
|
||||
usage: The usage object for the web search request.
|
||||
model_info: The model info for the web search request.
|
||||
|
||||
Returns:
|
||||
The cost of the web search request.
|
||||
"""
|
||||
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||
|
||||
# check if usage object has web search requests
|
||||
cost_per_llm_call_with_web_search = 35e-3
|
||||
|
||||
makes_web_search_request = False
|
||||
if (
|
||||
usage is not None
|
||||
and usage.prompt_tokens_details is not None
|
||||
and isinstance(usage.prompt_tokens_details, PromptTokensDetailsWrapper)
|
||||
):
|
||||
makes_web_search_request = True
|
||||
|
||||
# Calculate total cost
|
||||
if makes_web_search_request:
|
||||
return cost_per_llm_call_with_web_search
|
||||
else:
|
||||
return 0.0
|
||||
@@ -0,0 +1,888 @@
|
||||
"""
|
||||
Transformation logic from OpenAI format to Gemini format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
_get_image_mime_type_from_url,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
convert_to_gemini_tool_call_result,
|
||||
response_schema_prompt,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.files import (
|
||||
get_file_mime_type_for_file_type,
|
||||
get_file_type_from_extension,
|
||||
is_gemini_1_5_accepted_file_type,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAudioObject,
|
||||
ChatCompletionFileObject,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionUserMessage,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
GenerationConfig,
|
||||
PartType,
|
||||
RequestBody,
|
||||
SafetSettingsConfig,
|
||||
SystemInstructions,
|
||||
ToolConfig,
|
||||
Tools,
|
||||
)
|
||||
from litellm.types.utils import GenericImageParsingChunk, LlmProviders
|
||||
|
||||
from ..common_utils import (
|
||||
_check_text_in_content,
|
||||
get_supports_response_schema,
|
||||
get_supports_system_message,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
def _convert_detail_to_media_resolution_enum(
|
||||
detail: Optional[str],
|
||||
) -> Optional[Dict[str, str]]:
|
||||
if detail == "low":
|
||||
return {"level": "MEDIA_RESOLUTION_LOW"}
|
||||
elif detail == "medium":
|
||||
return {"level": "MEDIA_RESOLUTION_MEDIUM"}
|
||||
elif detail == "high":
|
||||
return {"level": "MEDIA_RESOLUTION_HIGH"}
|
||||
elif detail == "ultra_high":
|
||||
return {"level": "MEDIA_RESOLUTION_ULTRA_HIGH"}
|
||||
return None
|
||||
|
||||
|
||||
def _get_highest_media_resolution(
|
||||
current: Optional[str], new_detail: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compare two media resolution values and return the highest one.
|
||||
Resolution hierarchy: ultra_high > high > medium > low > None
|
||||
"""
|
||||
resolution_priority = {"ultra_high": 4, "high": 3, "medium": 2, "low": 1}
|
||||
current_priority = resolution_priority.get(current, 0) if current else 0
|
||||
new_priority = resolution_priority.get(new_detail, 0) if new_detail else 0
|
||||
|
||||
if new_priority > current_priority:
|
||||
return new_detail
|
||||
return current
|
||||
|
||||
|
||||
def _extract_max_media_resolution_from_messages(
|
||||
messages: List[AllMessageValues],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract the highest media resolution (detail) from image content in messages.
|
||||
|
||||
This is used to set the global media_resolution in generation_config for
|
||||
Gemini 2.x models which don't support per-part media resolution.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
The highest detail level found ("high", "low", or None)
|
||||
"""
|
||||
max_resolution: Optional[str] = None
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
detail: Optional[str] = None
|
||||
if item.get("type") == "image_url":
|
||||
image_url = item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
detail = image_url.get("detail")
|
||||
elif item.get("type") == "file":
|
||||
file_obj = item.get("file")
|
||||
if isinstance(file_obj, dict):
|
||||
detail = file_obj.get("detail")
|
||||
if detail:
|
||||
max_resolution = _get_highest_media_resolution(
|
||||
max_resolution, detail
|
||||
)
|
||||
return max_resolution
|
||||
|
||||
|
||||
def _apply_gemini_3_metadata(
|
||||
part: PartType,
|
||||
model: Optional[str],
|
||||
media_resolution_enum: Optional[Dict[str, str]],
|
||||
video_metadata: Optional[Dict[str, Any]],
|
||||
) -> PartType:
|
||||
"""
|
||||
Apply the unique media_resolution and video_metadata parameters of Gemini 3+
|
||||
"""
|
||||
if model is None:
|
||||
return part
|
||||
|
||||
from .vertex_and_google_ai_studio_gemini import VertexGeminiConfig
|
||||
|
||||
if not VertexGeminiConfig._is_gemini_3_or_newer(model):
|
||||
return part
|
||||
|
||||
part_dict = dict(part)
|
||||
|
||||
if media_resolution_enum is not None:
|
||||
part_dict["media_resolution"] = media_resolution_enum
|
||||
|
||||
if video_metadata is not None:
|
||||
gemini_video_metadata = {}
|
||||
if "fps" in video_metadata:
|
||||
gemini_video_metadata["fps"] = video_metadata["fps"]
|
||||
if "start_offset" in video_metadata:
|
||||
gemini_video_metadata["startOffset"] = video_metadata["start_offset"]
|
||||
if "end_offset" in video_metadata:
|
||||
gemini_video_metadata["endOffset"] = video_metadata["end_offset"]
|
||||
if gemini_video_metadata:
|
||||
part_dict["video_metadata"] = gemini_video_metadata
|
||||
|
||||
return cast(PartType, part_dict)
|
||||
|
||||
|
||||
def _process_gemini_media(
|
||||
image_url: str,
|
||||
format: Optional[str] = None,
|
||||
media_resolution_enum: Optional[Dict[str, str]] = None,
|
||||
model: Optional[str] = None,
|
||||
video_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> PartType:
|
||||
"""
|
||||
Given a media URL (image, audio, or video), return the appropriate PartType for Gemini
|
||||
By the way, actually video_metadata can only be used with videos; it cannot be used with images, audio, or files. However, I haven't made any special handling because vertex returns a parameter error.
|
||||
|
||||
Args:
|
||||
image_url: The URL or base64 string of the media (image, audio, or video)
|
||||
format: The MIME type of the media
|
||||
media_resolution_enum: Media resolution level (for Gemini 3+)
|
||||
model: The model name (to check version compatibility)
|
||||
video_metadata: Video-specific metadata (fps, start_offset, end_offset)
|
||||
"""
|
||||
|
||||
try:
|
||||
# GCS URIs
|
||||
if "gs://" in image_url:
|
||||
# Figure out file type
|
||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||
extension = extension_with_dot[1:] # Ex: "png"
|
||||
|
||||
if not format:
|
||||
file_type = get_file_type_from_extension(extension)
|
||||
|
||||
# Validate the file type is supported by Gemini
|
||||
if not is_gemini_1_5_accepted_file_type(file_type):
|
||||
raise Exception(f"File type not supported by gemini - {file_type}")
|
||||
|
||||
mime_type = get_file_mime_type_for_file_type(file_type)
|
||||
else:
|
||||
mime_type = format
|
||||
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
||||
part: PartType = {"file_data": file_data}
|
||||
return _apply_gemini_3_metadata(
|
||||
part, model, media_resolution_enum, video_metadata
|
||||
)
|
||||
elif (
|
||||
"https://" in image_url
|
||||
and (image_type := format or _get_image_mime_type_from_url(image_url))
|
||||
is not None
|
||||
):
|
||||
file_data = FileDataType(mime_type=image_type, file_uri=image_url)
|
||||
part = {"file_data": file_data}
|
||||
return _apply_gemini_3_metadata(
|
||||
part, model, media_resolution_enum, video_metadata
|
||||
)
|
||||
elif "http://" in image_url or "https://" in image_url or "base64" in image_url:
|
||||
image = convert_to_anthropic_image_obj(image_url, format=format)
|
||||
_blob: BlobType = {"data": image["data"], "mime_type": image["media_type"]}
|
||||
part = {"inline_data": cast(BlobType, _blob)}
|
||||
return _apply_gemini_3_metadata(
|
||||
part, model, media_resolution_enum, video_metadata
|
||||
)
|
||||
raise Exception("Invalid image received - {}".format(image_url))
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _snake_to_camel(snake_str: str) -> str:
|
||||
"""Convert snake_case to camelCase"""
|
||||
components = snake_str.split("_")
|
||||
return components[0] + "".join(x.capitalize() for x in components[1:])
|
||||
|
||||
|
||||
def _camel_to_snake(camel_str: str) -> str:
|
||||
"""Convert camelCase to snake_case"""
|
||||
import re
|
||||
|
||||
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
|
||||
|
||||
|
||||
def _get_equivalent_key(key: str, available_keys: set) -> Optional[str]:
|
||||
"""
|
||||
Get the equivalent key from available keys, checking both camelCase and snake_case variants
|
||||
"""
|
||||
if key in available_keys:
|
||||
return key
|
||||
|
||||
# Try camelCase version
|
||||
camel_key = _snake_to_camel(key)
|
||||
if camel_key in available_keys:
|
||||
return camel_key
|
||||
|
||||
# Try snake_case version
|
||||
snake_key = _camel_to_snake(key)
|
||||
if snake_key in available_keys:
|
||||
return snake_key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_if_part_exists_in_parts(
|
||||
parts: List[PartType], part: PartType, excluded_keys: List[str] = []
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a part exists in a list of parts
|
||||
Handles both camelCase and snake_case key variations (e.g., function_call vs functionCall)
|
||||
"""
|
||||
keys_to_compare = set(part.keys()) - set(excluded_keys)
|
||||
for p in parts:
|
||||
p_keys = set(p.keys())
|
||||
# Check if all keys in part have equivalent values in p
|
||||
match_found = True
|
||||
for key in keys_to_compare:
|
||||
equivalent_key = _get_equivalent_key(key, p_keys)
|
||||
if equivalent_key is None or p.get(equivalent_key, None) != part.get(
|
||||
key, None
|
||||
):
|
||||
match_found = False
|
||||
break
|
||||
|
||||
if match_found:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||
messages: List[AllMessageValues],
|
||||
model: Optional[str] = None,
|
||||
) -> List[ContentType]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Gemini format
|
||||
|
||||
- Parts must be iterable
|
||||
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||
- Please ensure that function response turn comes immediately after a function call turn
|
||||
"""
|
||||
user_message_types = {"user", "system"}
|
||||
contents: List[ContentType] = []
|
||||
|
||||
last_message_with_tool_calls = None
|
||||
|
||||
msg_i = 0
|
||||
tool_call_responses = []
|
||||
try:
|
||||
while msg_i < len(messages):
|
||||
user_content: List[PartType] = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while (
|
||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||
):
|
||||
_message_content = messages[msg_i].get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts: List[PartType] = []
|
||||
for element_idx, element in enumerate(_message_content):
|
||||
if (
|
||||
element["type"] == "text"
|
||||
and "text" in element
|
||||
and len(element["text"]) > 0
|
||||
):
|
||||
element = cast(ChatCompletionTextObject, element)
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
element = cast(ChatCompletionImageObject, element)
|
||||
img_element = element
|
||||
format: Optional[str] = None
|
||||
media_resolution_enum: Optional[Dict[str, str]] = None
|
||||
if isinstance(img_element["image_url"], dict):
|
||||
image_url = img_element["image_url"]["url"]
|
||||
format = img_element["image_url"].get("format")
|
||||
detail = img_element["image_url"].get("detail")
|
||||
media_resolution_enum = (
|
||||
_convert_detail_to_media_resolution_enum(detail)
|
||||
)
|
||||
else:
|
||||
image_url = img_element["image_url"]
|
||||
_part = _process_gemini_media(
|
||||
image_url=image_url,
|
||||
format=format,
|
||||
media_resolution_enum=media_resolution_enum,
|
||||
model=model,
|
||||
)
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "input_audio":
|
||||
audio_element = cast(ChatCompletionAudioObject, element)
|
||||
audio_data = audio_element["input_audio"].get("data")
|
||||
audio_format = audio_element["input_audio"].get("format")
|
||||
if audio_data is not None and audio_format is not None:
|
||||
audio_format_modified = (
|
||||
"audio/" + audio_format
|
||||
if audio_format.startswith("audio/") is False
|
||||
else audio_format
|
||||
) # Gemini expects audio/wav, audio/mp3, etc.
|
||||
openai_image_str = (
|
||||
convert_generic_image_chunk_to_openai_image_obj(
|
||||
image_chunk=GenericImageParsingChunk(
|
||||
type="base64",
|
||||
media_type=audio_format_modified,
|
||||
data=audio_data,
|
||||
)
|
||||
)
|
||||
)
|
||||
_part = _process_gemini_media(
|
||||
image_url=openai_image_str,
|
||||
format=audio_format_modified,
|
||||
model=model,
|
||||
)
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "file":
|
||||
file_element = cast(ChatCompletionFileObject, element)
|
||||
file_id = file_element["file"].get("file_id")
|
||||
format = file_element["file"].get("format")
|
||||
file_data = file_element["file"].get("file_data")
|
||||
detail = file_element["file"].get("detail")
|
||||
video_metadata = file_element["file"].get("video_metadata")
|
||||
passed_file = file_id or file_data
|
||||
if passed_file is None:
|
||||
raise Exception(
|
||||
"Unknown file type. Please pass in a file_id or file_data"
|
||||
)
|
||||
|
||||
# Convert detail to media_resolution_enum
|
||||
media_resolution_enum = (
|
||||
_convert_detail_to_media_resolution_enum(detail)
|
||||
)
|
||||
|
||||
try:
|
||||
_part = _process_gemini_media(
|
||||
image_url=passed_file,
|
||||
format=format,
|
||||
model=model,
|
||||
media_resolution_enum=media_resolution_enum,
|
||||
video_metadata=video_metadata,
|
||||
)
|
||||
_parts.append(_part)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Unable to determine mime type for file_id: {}, set this explicitly using message[{}].content[{}].file.format".format(
|
||||
file_id, msg_i, element_idx
|
||||
)
|
||||
)
|
||||
user_content.extend(_parts)
|
||||
elif _message_content is not None and isinstance(_message_content, str):
|
||||
_part = PartType(text=_message_content)
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content:
|
||||
"""
|
||||
check that user_content has 'text' parameter.
|
||||
- Known Vertex Error: Unable to submit request because it must have a text parameter.
|
||||
- Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
|
||||
"""
|
||||
has_text_in_content = _check_text_in_content(user_content)
|
||||
if has_text_in_content is False:
|
||||
verbose_logger.warning(
|
||||
"No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515"
|
||||
)
|
||||
user_content.append(
|
||||
PartType(text=" ")
|
||||
) # add a blank text, to ensure Gemini doesn't fail the request.
|
||||
contents.append(ContentType(role="user", parts=user_content))
|
||||
assistant_content = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
if isinstance(messages[msg_i], BaseModel):
|
||||
msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
|
||||
else:
|
||||
msg_dict = messages[msg_i] # type: ignore
|
||||
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
|
||||
_message_content = assistant_msg.get("content", None)
|
||||
reasoning_content = assistant_msg.get("reasoning_content", None)
|
||||
thinking_blocks = assistant_msg.get("thinking_blocks")
|
||||
if reasoning_content is not None:
|
||||
assistant_content.append(
|
||||
PartType(thought=True, text=reasoning_content)
|
||||
)
|
||||
if thinking_blocks is not None:
|
||||
for block in thinking_blocks:
|
||||
if block["type"] == "thinking":
|
||||
block_thinking_str = block.get("thinking")
|
||||
block_signature = block.get("signature")
|
||||
if (
|
||||
block_thinking_str is not None
|
||||
and block_signature is not None
|
||||
):
|
||||
try:
|
||||
assistant_content.append(
|
||||
PartType(
|
||||
thoughtSignature=block_signature,
|
||||
**json.loads(block_thinking_str),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
assistant_content.append(
|
||||
PartType(
|
||||
thoughtSignature=block_signature,
|
||||
text=block_thinking_str,
|
||||
)
|
||||
)
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts = []
|
||||
for element in _message_content:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
|
||||
assistant_content.extend(_parts)
|
||||
elif _message_content is not None and isinstance(_message_content, str):
|
||||
assistant_text = _message_content
|
||||
# Check if message has thought_signatures in provider_specific_fields
|
||||
provider_specific_fields = assistant_msg.get(
|
||||
"provider_specific_fields"
|
||||
)
|
||||
thought_signatures = None
|
||||
if provider_specific_fields and isinstance(
|
||||
provider_specific_fields, dict
|
||||
):
|
||||
thought_signatures = provider_specific_fields.get(
|
||||
"thought_signatures"
|
||||
)
|
||||
|
||||
# If we have thought signatures, add them to the part
|
||||
if (
|
||||
thought_signatures
|
||||
and isinstance(thought_signatures, list)
|
||||
and len(thought_signatures) > 0
|
||||
):
|
||||
# Use the first signature for the text part (Gemini expects one signature per part)
|
||||
assistant_content.append(PartType(text=assistant_text, thoughtSignature=thought_signatures[0])) # type: ignore
|
||||
else:
|
||||
assistant_content.append(PartType(text=assistant_text)) # type: ignore
|
||||
|
||||
## HANDLE ASSISTANT IMAGES FIELD
|
||||
# Process images field if present (for generated images from assistant)
|
||||
assistant_images = assistant_msg.get("images")
|
||||
if assistant_images is not None and isinstance(assistant_images, list):
|
||||
for image_item in assistant_images:
|
||||
if isinstance(image_item, dict):
|
||||
image_url_obj = image_item.get("image_url")
|
||||
if isinstance(image_url_obj, dict):
|
||||
assistant_image_url = image_url_obj.get("url")
|
||||
format = image_url_obj.get("format")
|
||||
detail = image_url_obj.get("detail")
|
||||
media_resolution_enum = (
|
||||
_convert_detail_to_media_resolution_enum(detail)
|
||||
)
|
||||
if assistant_image_url:
|
||||
_part = _process_gemini_media(
|
||||
image_url=assistant_image_url,
|
||||
format=format,
|
||||
media_resolution_enum=media_resolution_enum,
|
||||
model=model,
|
||||
)
|
||||
assistant_content.append(_part)
|
||||
|
||||
## HANDLE ASSISTANT FUNCTION CALL
|
||||
if (
|
||||
assistant_msg.get("tool_calls", []) is not None
|
||||
or assistant_msg.get("function_call") is not None
|
||||
): # support assistant tool invoke conversion
|
||||
gemini_tool_call_parts = convert_to_gemini_tool_call_invoke(
|
||||
assistant_msg, model=model
|
||||
)
|
||||
## check if gemini_tool_call already exists in assistant_content
|
||||
for gemini_tool_call_part in gemini_tool_call_parts:
|
||||
if not check_if_part_exists_in_parts(
|
||||
assistant_content,
|
||||
gemini_tool_call_part,
|
||||
excluded_keys=["thoughtSignature"],
|
||||
):
|
||||
assistant_content.append(gemini_tool_call_part)
|
||||
last_message_with_tool_calls = assistant_msg
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
contents.append(ContentType(role="model", parts=assistant_content))
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
tool_call_message_roles = ["tool", "function"]
|
||||
if (
|
||||
msg_i < len(messages)
|
||||
and messages[msg_i]["role"] in tool_call_message_roles
|
||||
):
|
||||
_part = convert_to_gemini_tool_call_result(
|
||||
messages[msg_i], last_message_with_tool_calls # type: ignore
|
||||
)
|
||||
msg_i += 1
|
||||
# Handle both single part and list of parts (for Computer Use with images)
|
||||
if isinstance(_part, list):
|
||||
tool_call_responses.extend(_part)
|
||||
else:
|
||||
tool_call_responses.append(_part)
|
||||
if msg_i < len(messages) and (
|
||||
messages[msg_i]["role"] not in tool_call_message_roles
|
||||
):
|
||||
if len(tool_call_responses) > 0:
|
||||
contents.append(ContentType(role="user", parts=tool_call_responses))
|
||||
tool_call_responses = []
|
||||
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
if len(tool_call_responses) > 0:
|
||||
contents.append(ContentType(role="user", parts=tool_call_responses))
|
||||
|
||||
if len(contents) == 0:
|
||||
verbose_logger.warning(
|
||||
"""
|
||||
No contents in messages. Contents are required. See
|
||||
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body.
|
||||
If the original request did not comply to OpenAI API requirements it should have failed by now,
|
||||
but LiteLLM does not check for missing messages.
|
||||
Setting an empty content to prevent an 400 error.
|
||||
Relevant Issue - https://github.com/BerriAI/litellm/issues/9733
|
||||
"""
|
||||
)
|
||||
contents.append(ContentType(role="user", parts=[PartType(text=" ")]))
|
||||
return contents
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Keys that LiteLLM consumes internally and must never be forwarded to the
|
||||
_LITELLM_INTERNAL_EXTRA_BODY_KEYS: frozenset = frozenset({"cache", "tags"})
|
||||
|
||||
|
||||
def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
|
||||
"""Pop extra_body from optional_params and shallow-merge into data, deep-merging dict values."""
|
||||
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||
if extra_body is not None:
|
||||
data_dict: dict = data # type: ignore[assignment]
|
||||
for k, v in extra_body.items():
|
||||
if k in _LITELLM_INTERNAL_EXTRA_BODY_KEYS:
|
||||
continue
|
||||
if (
|
||||
k in data_dict
|
||||
and isinstance(data_dict[k], dict)
|
||||
and isinstance(v, dict)
|
||||
):
|
||||
data_dict[k].update(v)
|
||||
else:
|
||||
data_dict[k] = v
|
||||
|
||||
|
||||
def _transform_request_body( # noqa: PLR0915
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
litellm_params: dict,
|
||||
cached_content: Optional[str],
|
||||
) -> RequestBody:
|
||||
"""
|
||||
Common transformation logic across sync + async Gemini /generateContent calls.
|
||||
"""
|
||||
# Separate system prompt from rest of message
|
||||
supports_system_message = get_supports_system_message(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
system_instructions, messages = _transform_system_message(
|
||||
supports_system_message=supports_system_message, messages=messages
|
||||
)
|
||||
# Checks for 'response_schema' support - if passed in
|
||||
if "response_schema" in optional_params:
|
||||
supports_response_schema = get_supports_response_schema(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if supports_response_schema is False:
|
||||
user_response_schema_message = response_schema_prompt(
|
||||
model=model, response_schema=optional_params.get("response_schema") # type: ignore
|
||||
)
|
||||
messages.append({"role": "user", "content": user_response_schema_message})
|
||||
optional_params.pop("response_schema")
|
||||
|
||||
# Check for any 'litellm_param_*' set during optional param mapping
|
||||
|
||||
remove_keys = []
|
||||
for k, v in optional_params.items():
|
||||
if k.startswith("litellm_param_"):
|
||||
litellm_params.update({k: v})
|
||||
remove_keys.append(k)
|
||||
|
||||
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
|
||||
|
||||
try:
|
||||
if custom_llm_provider == "gemini":
|
||||
content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
else:
|
||||
content = litellm.VertexGeminiConfig()._transform_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||
"safety_settings", None
|
||||
) # type: ignore
|
||||
# Drop output_config as it's not supported by Vertex AI
|
||||
optional_params.pop("output_config", None)
|
||||
config_fields = GenerationConfig.__annotations__.keys()
|
||||
|
||||
# If the LiteLLM client sends Gemini-supported parameter "labels", add it
|
||||
# as "labels" field to the request sent to the Gemini backend.
|
||||
labels: Optional[dict[str, str]] = optional_params.pop("labels", None)
|
||||
# If the LiteLLM client sends OpenAI-supported parameter "metadata", add it
|
||||
# as "labels" field to the request sent to the Gemini backend.
|
||||
if labels is None and "metadata" in litellm_params:
|
||||
metadata = litellm_params["metadata"]
|
||||
if metadata is not None and "requester_metadata" in metadata:
|
||||
rm = metadata["requester_metadata"]
|
||||
labels = {k: v for k, v in rm.items() if isinstance(v, str)}
|
||||
|
||||
filtered_params = {
|
||||
k: v
|
||||
for k, v in optional_params.items()
|
||||
if _get_equivalent_key(k, set(config_fields))
|
||||
}
|
||||
|
||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||
**filtered_params
|
||||
)
|
||||
|
||||
# For Gemini 2.x models, add media_resolution to generation_config (global)
|
||||
# Gemini 3+ supports per-part media_resolution, but 2.x only supports global
|
||||
# Gemini 1.x does not support mediaResolution at all
|
||||
if "gemini-2" in model:
|
||||
max_media_resolution = _extract_max_media_resolution_from_messages(messages)
|
||||
if max_media_resolution:
|
||||
media_resolution_value = _convert_detail_to_media_resolution_enum(
|
||||
max_media_resolution
|
||||
)
|
||||
if media_resolution_value and generation_config is not None:
|
||||
generation_config["mediaResolution"] = media_resolution_value[
|
||||
"level"
|
||||
]
|
||||
|
||||
data = RequestBody(contents=content)
|
||||
if system_instructions is not None:
|
||||
data["system_instruction"] = system_instructions
|
||||
if tools is not None:
|
||||
data["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
data["toolConfig"] = tool_choice
|
||||
if safety_settings is not None:
|
||||
data["safetySettings"] = safety_settings
|
||||
if generation_config is not None and len(generation_config) > 0:
|
||||
data["generationConfig"] = generation_config
|
||||
if cached_content is not None:
|
||||
data["cachedContent"] = cached_content
|
||||
# Only add labels for Vertex AI endpoints (not Google GenAI/AI Studio) and only if non-empty
|
||||
if labels and custom_llm_provider != LlmProviders.GEMINI:
|
||||
data["labels"] = labels
|
||||
_pop_and_merge_extra_body(data, optional_params)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def sync_transform_request_body(
|
||||
gemini_api_key: Optional[str],
|
||||
messages: List[AllMessageValues],
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[dict],
|
||||
optional_params: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
litellm_params: dict,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> RequestBody:
|
||||
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
|
||||
|
||||
context_caching_endpoints = ContextCachingEndpoints()
|
||||
|
||||
(
|
||||
messages,
|
||||
optional_params,
|
||||
cached_content,
|
||||
) = context_caching_endpoints.check_and_create_cache(
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=gemini_api_key or "dummy",
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
cached_content=optional_params.pop("cached_content", None),
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
return _transform_request_body(
|
||||
messages=messages,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
cached_content=cached_content,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
|
||||
async def async_transform_request_body(
|
||||
gemini_api_key: Optional[str],
|
||||
messages: List[AllMessageValues],
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[dict],
|
||||
optional_params: dict,
|
||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
litellm_params: dict,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> RequestBody:
|
||||
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
|
||||
|
||||
context_caching_endpoints = ContextCachingEndpoints()
|
||||
|
||||
(
|
||||
messages,
|
||||
optional_params,
|
||||
cached_content,
|
||||
) = await context_caching_endpoints.async_check_and_create_cache(
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=gemini_api_key or "dummy",
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
cached_content=optional_params.pop("cached_content", None),
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
return _transform_request_body(
|
||||
messages=messages,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
cached_content=cached_content,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
|
||||
def _default_user_message_when_system_message_passed() -> ChatCompletionUserMessage:
|
||||
"""
|
||||
Returns a default user message when a "system" message is passed in gemini fails.
|
||||
|
||||
This adds a blank user message to the messages list, to ensure that gemini doesn't fail the request.
|
||||
"""
|
||||
return ChatCompletionUserMessage(content=".", role="user")
|
||||
|
||||
|
||||
def _transform_system_message(
|
||||
supports_system_message: bool, messages: List[AllMessageValues]
|
||||
) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:
|
||||
"""
|
||||
Extracts the system message from the openai message list.
|
||||
|
||||
Converts the system message to Gemini format
|
||||
|
||||
Returns
|
||||
- system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format.
|
||||
- messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately)
|
||||
"""
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[PartType] = []
|
||||
if supports_system_message is True:
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block: Optional[PartType] = None
|
||||
if isinstance(message["content"], str):
|
||||
_system_content_block = PartType(text=message["content"])
|
||||
elif isinstance(message["content"], list):
|
||||
system_text = ""
|
||||
for content in message["content"]:
|
||||
system_text += content.get("text") or ""
|
||||
_system_content_block = PartType(text=system_text)
|
||||
if _system_content_block is not None:
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
|
||||
if len(system_content_blocks) > 0:
|
||||
#########################################################
|
||||
# If no messages are passed in, add a blank user message
|
||||
# Relevant Issue - https://github.com/BerriAI/litellm/issues/13769
|
||||
#########################################################
|
||||
if len(messages) == 0:
|
||||
messages.append(_default_user_message_when_system_message_passed())
|
||||
#########################################################
|
||||
return SystemInstructions(parts=system_content_blocks), messages
|
||||
|
||||
return None, messages
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Google AI Studio /batchEmbedContents Embeddings Endpoint
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VertexAIBatchEmbeddingsRequestBody,
|
||||
VertexAIBatchEmbeddingsResponseObject,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from .batch_embed_content_transformation import (
|
||||
_is_file_reference,
|
||||
_is_multimodal_input,
|
||||
process_embed_content_response,
|
||||
process_response,
|
||||
transform_openai_input_gemini_content,
|
||||
transform_openai_input_gemini_embed_content,
|
||||
)
|
||||
|
||||
|
||||
class GoogleBatchEmbeddings(VertexLLM):
|
||||
def _resolve_file_references(
|
||||
self,
|
||||
input: EmbeddingInput,
|
||||
api_key: str,
|
||||
sync_handler: HTTPHandler,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Resolve Gemini file references (files/...) to get mime_type and uri.
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput that may contain file references
|
||||
api_key: Gemini API key
|
||||
sync_handler: HTTP client
|
||||
|
||||
Returns:
|
||||
Dict mapping file name to {mime_type, uri}
|
||||
"""
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
resolved_files: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for element in input_list:
|
||||
if isinstance(element, str) and _is_file_reference(element):
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
|
||||
headers = {"x-goog-api-key": api_key}
|
||||
response = sync_handler.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error fetching file {element}: {response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
file_data = response.json()
|
||||
resolved_files[element] = {
|
||||
"mime_type": file_data.get("mimeType", ""),
|
||||
"uri": file_data.get("uri", element),
|
||||
}
|
||||
|
||||
return resolved_files
|
||||
|
||||
async def _async_resolve_file_references(
|
||||
self,
|
||||
input: EmbeddingInput,
|
||||
api_key: str,
|
||||
async_handler: AsyncHTTPHandler,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Async version of _resolve_file_references.
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput that may contain file references
|
||||
api_key: Gemini API key
|
||||
async_handler: Async HTTP client
|
||||
|
||||
Returns:
|
||||
Dict mapping file name to {mime_type, uri}
|
||||
"""
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
resolved_files: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for element in input_list:
|
||||
if isinstance(element, str) and _is_file_reference(element):
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
|
||||
headers = {"x-goog-api-key": api_key}
|
||||
response = await async_handler.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error fetching file {element}: {response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
file_data = response.json()
|
||||
resolved_files[element] = {
|
||||
"mime_type": file_data.get("mimeType", ""),
|
||||
"uri": file_data.get("uri", element),
|
||||
}
|
||||
|
||||
return resolved_files
|
||||
|
||||
def batch_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: EmbeddingInput,
|
||||
print_verbose,
|
||||
model_response: EmbeddingResponse,
|
||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding: Optional[bool] = False,
|
||||
timeout=300,
|
||||
client=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
optional_params = optional_params or {}
|
||||
|
||||
is_multimodal = _is_multimodal_input(input)
|
||||
use_embed_content = is_multimodal or (custom_llm_provider == "vertex_ai")
|
||||
mode: Literal["embedding", "batch_embedding"]
|
||||
if use_embed_content:
|
||||
mode = "embedding"
|
||||
else:
|
||||
mode = "batch_embedding"
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
if auth_header is not None:
|
||||
if isinstance(auth_header, dict):
|
||||
headers.update(auth_header)
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_batch_embeddings( # type: ignore
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
url=url,
|
||||
data=None,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
input=input,
|
||||
use_embed_content=use_embed_content,
|
||||
api_key=api_key,
|
||||
optional_params=optional_params,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
### TRANSFORMATION (sync path) ###
|
||||
request_data: Any
|
||||
if use_embed_content:
|
||||
resolved_files = {}
|
||||
if api_key:
|
||||
resolved_files = self._resolve_file_references(
|
||||
input=input, api_key=api_key, sync_handler=sync_handler
|
||||
)
|
||||
request_data = transform_openai_input_gemini_embed_content(
|
||||
input=input,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
resolved_files=resolved_files,
|
||||
)
|
||||
else:
|
||||
request_data = transform_openai_input_gemini_content(
|
||||
input=input, model=model, optional_params=optional_params
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
if use_embed_content:
|
||||
return process_embed_content_response(
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
response_json=_json_response,
|
||||
)
|
||||
else:
|
||||
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
||||
async def async_batch_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
url: str,
|
||||
data: Optional[Union[VertexAIBatchEmbeddingsRequestBody, dict]],
|
||||
model_response: EmbeddingResponse,
|
||||
input: EmbeddingInput,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
use_embed_content: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
async_handler: AsyncHTTPHandler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
else:
|
||||
async_handler = client # type: ignore
|
||||
|
||||
### TRANSFORMATION (async path) ###
|
||||
if use_embed_content:
|
||||
resolved_files = {}
|
||||
if api_key:
|
||||
resolved_files = await self._async_resolve_file_references(
|
||||
input=input, api_key=api_key, async_handler=async_handler
|
||||
)
|
||||
data = transform_openai_input_gemini_embed_content(
|
||||
input=input,
|
||||
model=model,
|
||||
optional_params=optional_params or {},
|
||||
resolved_files=resolved_files,
|
||||
)
|
||||
else:
|
||||
data = transform_openai_input_gemini_content(
|
||||
input=input, model=model, optional_params=optional_params or {}
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
if use_embed_content:
|
||||
return process_embed_content_response(
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
response_json=_json_response,
|
||||
)
|
||||
else:
|
||||
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
BlobType,
|
||||
ContentType,
|
||||
EmbedContentRequest,
|
||||
FileDataType,
|
||||
PartType,
|
||||
VertexAIBatchEmbeddingsRequestBody,
|
||||
VertexAIBatchEmbeddingsResponseObject,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
from litellm.utils import get_formatted_prompt, token_counter
|
||||
|
||||
SUPPORTED_EMBEDDING_MIME_TYPES = {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
"application/pdf",
|
||||
}
|
||||
|
||||
|
||||
def _is_file_reference(s: str) -> bool:
|
||||
"""Check if string is a Gemini file reference (files/...)."""
|
||||
return isinstance(s, str) and s.startswith("files/")
|
||||
|
||||
|
||||
def _is_gcs_url(s: str) -> bool:
|
||||
"""Check if string is a GCS URL (gs://...)."""
|
||||
return isinstance(s, str) and s.startswith("gs://")
|
||||
|
||||
|
||||
def _infer_mime_type_from_gcs_url(gcs_url: str) -> str:
|
||||
"""
|
||||
Infer MIME type from GCS URL file extension.
|
||||
|
||||
Args:
|
||||
gcs_url: GCS URL like gs://bucket/path/to/file.png
|
||||
|
||||
Returns:
|
||||
str: Inferred MIME type
|
||||
|
||||
Raises:
|
||||
ValueError: If file extension is not supported
|
||||
"""
|
||||
extension_to_mime = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".mp3": "audio/mpeg",
|
||||
".wav": "audio/wav",
|
||||
".mp4": "video/mp4",
|
||||
".mov": "video/quicktime",
|
||||
".pdf": "application/pdf",
|
||||
}
|
||||
|
||||
gcs_url_lower = gcs_url.lower()
|
||||
for ext, mime_type in extension_to_mime.items():
|
||||
if gcs_url_lower.endswith(ext):
|
||||
return mime_type
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to infer MIME type from GCS URL: {gcs_url}. "
|
||||
f"Supported extensions: {', '.join(extension_to_mime.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def _parse_data_url(data_url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse a data URL to extract the media type and base64 data.
|
||||
|
||||
Args:
|
||||
data_url: Data URL in format: data:image/jpeg;base64,/9j/4AAQ...
|
||||
|
||||
Returns:
|
||||
tuple: (media_type, base64_data)
|
||||
media_type: e.g., "image/jpeg", "video/mp4", "audio/mpeg"
|
||||
base64_data: The base64-encoded data without the prefix
|
||||
|
||||
Raises:
|
||||
ValueError: If data URL format is invalid or MIME type is unsupported
|
||||
"""
|
||||
if not data_url.startswith("data:"):
|
||||
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
|
||||
|
||||
if "," not in data_url:
|
||||
raise ValueError(f"Invalid data URL format (missing comma): {data_url[:50]}...")
|
||||
|
||||
metadata, base64_data = data_url.split(",", 1)
|
||||
|
||||
metadata = metadata[5:]
|
||||
|
||||
if ";" in metadata:
|
||||
media_type = metadata.split(";")[0]
|
||||
else:
|
||||
media_type = metadata
|
||||
|
||||
if media_type not in SUPPORTED_EMBEDDING_MIME_TYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported MIME type for embedding: {media_type}. "
|
||||
f"Supported types: {', '.join(sorted(SUPPORTED_EMBEDDING_MIME_TYPES))}"
|
||||
)
|
||||
|
||||
return media_type, base64_data
|
||||
|
||||
|
||||
def _is_multimodal_input(input: EmbeddingInput) -> bool:
|
||||
"""
|
||||
Check if the input contains multimodal data (data URIs, file references, or GCS URLs).
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput (str or List[str])
|
||||
|
||||
Returns:
|
||||
bool: True if any element is a data URI, file reference, or GCS URL
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input_list = [input]
|
||||
else:
|
||||
input_list = input
|
||||
|
||||
for element in input_list:
|
||||
if isinstance(element, str):
|
||||
if element.startswith("data:") and ";base64," in element:
|
||||
return True
|
||||
if _is_file_reference(element):
|
||||
return True
|
||||
if _is_gcs_url(element):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def transform_openai_input_gemini_content(
|
||||
input: EmbeddingInput, model: str, optional_params: dict
|
||||
) -> VertexAIBatchEmbeddingsRequestBody:
|
||||
"""
|
||||
The content to embed. Only the parts.text fields will be counted.
|
||||
"""
|
||||
gemini_model_name = "models/{}".format(model)
|
||||
|
||||
gemini_params = optional_params.copy()
|
||||
if "dimensions" in gemini_params:
|
||||
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
|
||||
|
||||
requests: List[EmbedContentRequest] = []
|
||||
if isinstance(input, str):
|
||||
request = EmbedContentRequest(
|
||||
model=gemini_model_name,
|
||||
content=ContentType(parts=[PartType(text=input)]),
|
||||
**gemini_params,
|
||||
)
|
||||
requests.append(request)
|
||||
else:
|
||||
for i in input:
|
||||
request = EmbedContentRequest(
|
||||
model=gemini_model_name,
|
||||
content=ContentType(parts=[PartType(text=i)]),
|
||||
**gemini_params,
|
||||
)
|
||||
requests.append(request)
|
||||
|
||||
return VertexAIBatchEmbeddingsRequestBody(requests=requests)
|
||||
|
||||
|
||||
def transform_openai_input_gemini_embed_content(
|
||||
input: EmbeddingInput,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
resolved_files: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI embedding input to Gemini embedContent format (multimodal).
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput (str or List[str]) with text, data URIs, or file references
|
||||
model: Model name
|
||||
optional_params: Additional parameters (taskType, outputDimensionality, etc.)
|
||||
resolved_files: Dict mapping file names (files/abc) to {mime_type, uri}
|
||||
|
||||
Returns:
|
||||
dict: Gemini embedContent request body with content.parts
|
||||
"""
|
||||
resolved_files = resolved_files or {}
|
||||
|
||||
gemini_params = optional_params.copy()
|
||||
if "dimensions" in gemini_params:
|
||||
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
|
||||
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
parts: List[PartType] = []
|
||||
|
||||
for element in input_list:
|
||||
if not isinstance(element, str):
|
||||
raise ValueError(f"Unsupported input type: {type(element)}")
|
||||
|
||||
if element.startswith("data:") and ";base64," in element:
|
||||
mime_type, base64_data = _parse_data_url(element)
|
||||
blob: BlobType = {"mime_type": mime_type, "data": base64_data}
|
||||
parts.append(PartType(inline_data=blob))
|
||||
elif _is_gcs_url(element):
|
||||
mime_type = _infer_mime_type_from_gcs_url(element)
|
||||
file_data: FileDataType = {
|
||||
"mime_type": mime_type,
|
||||
"file_uri": element,
|
||||
}
|
||||
parts.append(PartType(file_data=file_data))
|
||||
elif _is_file_reference(element):
|
||||
if element not in resolved_files:
|
||||
raise ValueError(f"File reference {element} not resolved")
|
||||
file_info = resolved_files[element]
|
||||
file_data_ref: FileDataType = {
|
||||
"mime_type": file_info["mime_type"],
|
||||
"file_uri": file_info["uri"],
|
||||
}
|
||||
parts.append(PartType(file_data=file_data_ref))
|
||||
else:
|
||||
parts.append(PartType(text=element))
|
||||
|
||||
request_body: dict = {
|
||||
"content": ContentType(parts=parts),
|
||||
**gemini_params,
|
||||
}
|
||||
|
||||
return request_body
|
||||
|
||||
|
||||
def process_embed_content_response(
|
||||
input: EmbeddingInput,
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
response_json: dict,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Process Gemini embedContent response (single embedding for multimodal input).
|
||||
|
||||
Args:
|
||||
input: Original input
|
||||
model_response: EmbeddingResponse to populate
|
||||
model: Model name
|
||||
response_json: Raw JSON response from embedContent endpoint
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse with single embedding
|
||||
"""
|
||||
if "embedding" not in response_json:
|
||||
raise ValueError(
|
||||
f"embedContent response missing 'embedding' field: {response_json}"
|
||||
)
|
||||
|
||||
embedding_data = response_json["embedding"]
|
||||
|
||||
openai_embedding = Embedding(
|
||||
embedding=embedding_data["values"],
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
|
||||
model_response.data = [openai_embedding]
|
||||
model_response.model = model
|
||||
|
||||
if _is_multimodal_input(input):
|
||||
prompt_tokens = 0
|
||||
else:
|
||||
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||
prompt_tokens = token_counter(model=model, text=input_text)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
def process_response(
|
||||
input: EmbeddingInput,
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
_predictions: VertexAIBatchEmbeddingsResponseObject,
|
||||
) -> EmbeddingResponse:
|
||||
openai_embeddings: List[Embedding] = []
|
||||
for embedding in _predictions["embeddings"]:
|
||||
openai_embedding = Embedding(
|
||||
embedding=embedding["values"],
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
openai_embeddings.append(openai_embedding)
|
||||
|
||||
model_response.data = openai_embeddings
|
||||
model_response.model = model
|
||||
|
||||
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||
prompt_tokens = token_counter(model=model, text=input_text)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Transformation for Calling Google models in their native format.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from litellm.llms.gemini.google_genai.transformation import GoogleGenAIConfig
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
|
||||
class VertexAIGoogleGenAIConfig(GoogleGenAIConfig):
|
||||
"""
|
||||
Configuration for calling Google models in their native format.
|
||||
"""
|
||||
|
||||
HEADER_NAME = "Authorization"
|
||||
BEARER_PREFIX = "Bearer"
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Literal["gemini", "vertex_ai"]:
|
||||
return "vertex_ai"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
headers: Optional[dict],
|
||||
model: str,
|
||||
litellm_params: Optional[Union[GenericLiteLLMParams, dict]],
|
||||
) -> dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if api_key is not None:
|
||||
default_headers[self.HEADER_NAME] = f"{self.BEARER_PREFIX} {api_key}"
|
||||
if headers is not None:
|
||||
default_headers.update(headers)
|
||||
|
||||
return default_headers
|
||||
|
||||
def _camel_to_snake(self, camel_str: str) -> str:
|
||||
"""Convert camelCase to snake_case"""
|
||||
import re
|
||||
|
||||
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
|
||||
|
||||
def map_generate_content_optional_params(
|
||||
self,
|
||||
generate_content_config_dict,
|
||||
model: str,
|
||||
):
|
||||
"""
|
||||
Map Google GenAI parameters to provider-specific format.
|
||||
|
||||
Args:
|
||||
generate_content_optional_params: Optional parameters for generate content
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
Mapped parameters for the provider
|
||||
"""
|
||||
|
||||
_generate_content_config_dict: Dict = {}
|
||||
|
||||
for param, value in generate_content_config_dict.items():
|
||||
camel_case_key = self._camel_to_snake(param)
|
||||
_generate_content_config_dict[camel_case_key] = value
|
||||
return _generate_content_config_dict
|
||||
|
||||
def transform_generate_content_request(
|
||||
self,
|
||||
model: str,
|
||||
contents: Any,
|
||||
tools: Optional[Any],
|
||||
generate_content_config_dict: Dict,
|
||||
system_instruction: Optional[Any] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the generate content request for Vertex AI.
|
||||
Since Vertex AI natively supports Google GenAI format, we can pass most fields directly.
|
||||
"""
|
||||
# Build the request in Google GenAI format that Vertex AI expects
|
||||
result = {
|
||||
"model": model,
|
||||
"contents": contents,
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
result["tools"] = tools
|
||||
|
||||
# Add systemInstruction if provided
|
||||
if system_instruction:
|
||||
result["systemInstruction"] = system_instruction
|
||||
|
||||
# Handle generationConfig - Vertex AI expects it in the same format
|
||||
if generate_content_config_dict:
|
||||
result["generationConfig"] = generate_content_config_dict
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,42 @@
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
VertexAIModelRoute,
|
||||
get_vertex_ai_model_route,
|
||||
)
|
||||
|
||||
from .cost_calculator import cost_calculator
|
||||
from .vertex_gemini_transformation import VertexAIGeminiImageEditConfig
|
||||
from .vertex_imagen_transformation import VertexAIImagenImageEditConfig
|
||||
|
||||
__all__ = [
|
||||
"VertexAIGeminiImageEditConfig",
|
||||
"VertexAIImagenImageEditConfig",
|
||||
"get_vertex_ai_image_edit_config",
|
||||
"cost_calculator",
|
||||
]
|
||||
|
||||
|
||||
def get_vertex_ai_image_edit_config(model: str) -> BaseImageEditConfig:
|
||||
"""
|
||||
Get the appropriate image edit config for a Vertex AI model.
|
||||
|
||||
Routes to the correct transformation class based on the model type:
|
||||
- Gemini models use generateContent API (VertexAIGeminiImageEditConfig)
|
||||
- Imagen models use predict API (VertexAIImagenImageEditConfig)
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gemini-2.5-flash", "imagegeneration@006")
|
||||
|
||||
Returns:
|
||||
BaseImageEditConfig: The appropriate configuration class
|
||||
"""
|
||||
# Determine the model route
|
||||
model_route = get_vertex_ai_model_route(model)
|
||||
|
||||
if model_route == VertexAIModelRoute.GEMINI:
|
||||
# Gemini models use generateContent API
|
||||
return VertexAIGeminiImageEditConfig()
|
||||
else:
|
||||
# Default to Imagen for other models (imagegeneration, etc.)
|
||||
# This includes NON_GEMINI models like imagegeneration@006
|
||||
return VertexAIImagenImageEditConfig()
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Vertex AI Image Edit Cost Calculator
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Vertex AI image edit cost calculator.
|
||||
|
||||
Mirrors image generation pricing: charge per returned image based on
|
||||
model metadata (`output_cost_per_image`).
|
||||
"""
|
||||
model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
|
||||
|
||||
if not isinstance(image_response, ImageResponse):
|
||||
raise ValueError(
|
||||
f"image_response must be of type ImageResponse got type={type(image_response)}"
|
||||
)
|
||||
|
||||
num_images = len(image_response.data or [])
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,298 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BufferedReader, BytesIO
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIGeminiImageEditConfig(BaseImageEditConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Gemini Image Edit Configuration
|
||||
|
||||
Uses generateContent API for Gemini models on Vertex AI
|
||||
"""
|
||||
|
||||
SUPPORTED_PARAMS: List[str] = ["size"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageEditConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return list(self.SUPPORTED_PARAMS)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
filtered_params = {
|
||||
key: value
|
||||
for key, value in image_edit_optional_params.items()
|
||||
if key in supported_params
|
||||
}
|
||||
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
if "size" in filtered_params:
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
|
||||
filtered_params["size"] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
litellm_params = litellm_params or {}
|
||||
|
||||
# If a custom api_base is provided, skip credential validation
|
||||
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
|
||||
_api_base = litellm_params.get("api_base") or api_base
|
||||
if _api_base is not None:
|
||||
return headers
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_credentials = (
|
||||
self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
or self._resolve_vertex_credentials()
|
||||
)
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Gemini generateContent API
|
||||
"""
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
# If a custom api_base is provided, use it directly
|
||||
# This allows users to use proxies or mock endpoints
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_location = (
|
||||
self.safe_get_vertex_ai_location(litellm_params)
|
||||
or self._resolve_vertex_location()
|
||||
)
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
|
||||
|
||||
def transform_image_edit_request( # type: ignore[override]
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
|
||||
inline_parts = self._prepare_inline_image_parts(image) if image else []
|
||||
if not inline_parts:
|
||||
raise ValueError("Vertex AI Gemini image edit requires at least one image.")
|
||||
|
||||
# Build parts list with image and prompt (if provided)
|
||||
parts = inline_parts.copy()
|
||||
if prompt is not None and prompt != "":
|
||||
parts.append({"text": prompt})
|
||||
|
||||
# Correct format for Vertex AI Gemini image editing
|
||||
contents = {"role": "USER", "parts": parts}
|
||||
|
||||
request_body: Dict[str, Any] = {"contents": contents}
|
||||
|
||||
# Generation config with proper structure for image editing
|
||||
generation_config: Dict[str, Any] = {"response_modalities": ["IMAGE"]}
|
||||
|
||||
# Add image-specific configuration
|
||||
image_config: Dict[str, Any] = {}
|
||||
if "aspectRatio" in image_edit_optional_request_params:
|
||||
image_config["aspect_ratio"] = image_edit_optional_request_params[
|
||||
"aspectRatio"
|
||||
]
|
||||
|
||||
if image_config:
|
||||
generation_config["image_config"] = image_config
|
||||
|
||||
request_body["generationConfig"] = generation_config
|
||||
|
||||
payload: Any = json.dumps(request_body)
|
||||
empty_files = cast(RequestFiles, [])
|
||||
return cast(
|
||||
Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files)
|
||||
)
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
) -> ImageResponse:
|
||||
model_response = ImageResponse()
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as exc:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image edit response: {exc}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
candidates = response_json.get("candidates", [])
|
||||
data_list: List[ImageObject] = []
|
||||
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
inline_data = part.get("inlineData")
|
||||
if inline_data and inline_data.get("data"):
|
||||
data_list.append(
|
||||
ImageObject(
|
||||
b64_json=inline_data["data"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.data = cast(List[OpenAIImage], data_list)
|
||||
return model_response
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""Map OpenAI size format to Gemini aspect ratio format"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _prepare_inline_image_parts(
|
||||
self, image: Union[FileTypes, List[FileTypes]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
images: List[FileTypes]
|
||||
if isinstance(image, list):
|
||||
images = image
|
||||
else:
|
||||
images = [image]
|
||||
|
||||
inline_parts: List[Dict[str, Any]] = []
|
||||
for img in images:
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
mime_type = ImageEditRequestUtils.get_image_content_type(img)
|
||||
image_bytes = self._read_all_bytes(img)
|
||||
inline_parts.append(
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": mime_type,
|
||||
"data": base64.b64encode(image_bytes).decode("utf-8"),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return inline_parts
|
||||
|
||||
def _read_all_bytes(self, image: FileTypes) -> bytes:
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
if isinstance(image, BytesIO):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
if isinstance(image, BufferedReader):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
raise ValueError("Unsupported image type for Vertex AI Gemini image edit.")
|
||||
@@ -0,0 +1,365 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BufferedRandom, BufferedReader, BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIImagenImageEditConfig(BaseImageEditConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Imagen Image Edit Configuration
|
||||
|
||||
Uses predict API for Imagen models on Vertex AI
|
||||
"""
|
||||
|
||||
SUPPORTED_PARAMS: List[str] = ["n", "size", "mask"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageEditConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return list(self.SUPPORTED_PARAMS)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
filtered_params = {
|
||||
key: value
|
||||
for key, value in image_edit_optional_params.items()
|
||||
if key in supported_params
|
||||
}
|
||||
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
# Map OpenAI parameters to Imagen format
|
||||
if "n" in filtered_params:
|
||||
mapped_params["sampleCount"] = filtered_params["n"]
|
||||
|
||||
if "size" in filtered_params:
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
|
||||
filtered_params["size"] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if "mask" in filtered_params:
|
||||
mapped_params["mask"] = filtered_params["mask"]
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_credentials = self._resolve_vertex_credentials()
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Imagen predict API
|
||||
"""
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_location = self._resolve_vertex_location()
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
if api_base:
|
||||
base_url = api_base.rstrip("/")
|
||||
else:
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
|
||||
|
||||
def transform_image_edit_request( # type: ignore[override]
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
|
||||
# Prepare reference images in the correct Imagen format
|
||||
if image is None:
|
||||
raise ValueError(
|
||||
"Vertex AI Imagen image edit requires at least one reference image."
|
||||
)
|
||||
reference_images = self._prepare_reference_images(
|
||||
image, image_edit_optional_request_params
|
||||
)
|
||||
if not reference_images:
|
||||
raise ValueError(
|
||||
"Vertex AI Imagen image edit requires at least one reference image."
|
||||
)
|
||||
|
||||
if prompt is None:
|
||||
raise ValueError("Vertex AI Imagen image edit requires a prompt.")
|
||||
|
||||
# Correct Imagen instances format
|
||||
instances = [{"prompt": prompt, "referenceImages": reference_images}]
|
||||
|
||||
# Extract OpenAI parameters and set sensible defaults for Vertex AI-specific parameters
|
||||
sample_count = image_edit_optional_request_params.get("sampleCount", 1)
|
||||
# Use sensible defaults for Vertex AI-specific parameters (not exposed to users)
|
||||
edit_mode = "EDIT_MODE_INPAINT_INSERTION" # Default edit mode
|
||||
base_steps = 50 # Default number of steps
|
||||
|
||||
# Imagen parameters with correct structure
|
||||
parameters = {
|
||||
"sampleCount": sample_count,
|
||||
"editMode": edit_mode,
|
||||
"editConfig": {"baseSteps": base_steps},
|
||||
}
|
||||
|
||||
# Set default values for Vertex AI-specific parameters (not configurable by users via OpenAI API)
|
||||
parameters["guidanceScale"] = 7.5 # Default guidance scale
|
||||
parameters["seed"] = None # Let Vertex AI choose random seed
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"instances": instances,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
payload: Any = json.dumps(request_body)
|
||||
empty_files = cast(RequestFiles, [])
|
||||
return cast(
|
||||
Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files)
|
||||
)
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
) -> ImageResponse:
|
||||
model_response = ImageResponse()
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as exc:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image edit response: {exc}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
predictions = response_json.get("predictions", [])
|
||||
data_list: List[ImageObject] = []
|
||||
|
||||
for prediction in predictions:
|
||||
# Imagen returns images as bytesBase64Encoded
|
||||
if "bytesBase64Encoded" in prediction:
|
||||
data_list.append(
|
||||
ImageObject(
|
||||
b64_json=prediction["bytesBase64Encoded"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.data = cast(List[OpenAIImage], data_list)
|
||||
return model_response
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""Map OpenAI size format to Imagen aspect ratio format"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _prepare_reference_images(
|
||||
self,
|
||||
image: Union[FileTypes, List[FileTypes]],
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare reference images in the correct Imagen API format
|
||||
"""
|
||||
images: List[FileTypes]
|
||||
if isinstance(image, list):
|
||||
images = image
|
||||
else:
|
||||
images = [image]
|
||||
|
||||
reference_images: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
image_bytes = self._read_all_bytes(img)
|
||||
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
# Create reference image structure
|
||||
reference_image = {
|
||||
"referenceType": "REFERENCE_TYPE_RAW",
|
||||
"referenceId": idx + 1,
|
||||
"referenceImage": {"bytesBase64Encoded": base64_data},
|
||||
}
|
||||
|
||||
reference_images.append(reference_image)
|
||||
|
||||
# Handle mask image if provided (for inpainting)
|
||||
mask_image = image_edit_optional_request_params.get("mask")
|
||||
if mask_image is not None:
|
||||
mask_bytes = self._read_all_bytes(mask_image)
|
||||
mask_base64 = base64.b64encode(mask_bytes).decode("utf-8")
|
||||
|
||||
mask_reference = {
|
||||
"referenceType": "REFERENCE_TYPE_MASK",
|
||||
"referenceId": len(reference_images) + 1,
|
||||
"referenceImage": {"bytesBase64Encoded": mask_base64},
|
||||
"maskImageConfig": {
|
||||
"maskMode": "MASK_MODE_USER_PROVIDED",
|
||||
"dilation": 0.03, # Default dilation value (not configurable via OpenAI API)
|
||||
},
|
||||
}
|
||||
reference_images.append(mask_reference)
|
||||
|
||||
return reference_images
|
||||
|
||||
def _read_all_bytes(
|
||||
self, image: Any, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||
) -> bytes:
|
||||
if depth > max_depth:
|
||||
raise ValueError(
|
||||
f"Max recursion depth {max_depth} reached while reading image bytes for Vertex AI Imagen image edit."
|
||||
)
|
||||
|
||||
if isinstance(image, (list, tuple)):
|
||||
for item in image:
|
||||
if item is not None:
|
||||
return self._read_all_bytes(
|
||||
item, depth=depth + 1, max_depth=max_depth
|
||||
)
|
||||
raise ValueError("Unsupported image type for Vertex AI Imagen image edit.")
|
||||
|
||||
if isinstance(image, dict):
|
||||
for key in ("data", "bytes", "content"):
|
||||
if key in image and image[key] is not None:
|
||||
value = image[key]
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return base64.b64decode(value)
|
||||
except Exception:
|
||||
continue
|
||||
return self._read_all_bytes(
|
||||
value, depth=depth + 1, max_depth=max_depth
|
||||
)
|
||||
if "path" in image:
|
||||
return self._read_all_bytes(
|
||||
image["path"], depth=depth + 1, max_depth=max_depth
|
||||
)
|
||||
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
if isinstance(image, bytearray):
|
||||
return bytes(image)
|
||||
if isinstance(image, BytesIO):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
if isinstance(image, (BufferedReader, BufferedRandom)):
|
||||
stream_pos: Optional[int] = None
|
||||
try:
|
||||
stream_pos = image.tell()
|
||||
except Exception:
|
||||
stream_pos = None
|
||||
if stream_pos is not None:
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
if stream_pos is not None:
|
||||
image.seek(stream_pos)
|
||||
return data
|
||||
if isinstance(image, (str, Path)):
|
||||
path_obj = Path(image)
|
||||
if not path_obj.exists():
|
||||
raise ValueError(
|
||||
f"Mask/image path does not exist for Vertex AI Imagen image edit: {path_obj}"
|
||||
)
|
||||
return path_obj.read_bytes()
|
||||
if hasattr(image, "read"):
|
||||
data = image.read()
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
return data
|
||||
raise ValueError(
|
||||
f"Unsupported image type for Vertex AI Imagen image edit. Got type={type(image)}"
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
VertexAIModelRoute,
|
||||
get_vertex_ai_model_route,
|
||||
)
|
||||
|
||||
from .vertex_gemini_transformation import VertexAIGeminiImageGenerationConfig
|
||||
from .vertex_imagen_transformation import VertexAIImagenImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"VertexAIGeminiImageGenerationConfig",
|
||||
"VertexAIImagenImageGenerationConfig",
|
||||
"get_vertex_ai_image_generation_config",
|
||||
]
|
||||
|
||||
|
||||
def get_vertex_ai_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
"""
|
||||
Get the appropriate image generation config for a Vertex AI model.
|
||||
|
||||
Routes to the correct transformation class based on the model type:
|
||||
- Gemini image generation models use generateContent API (VertexAIGeminiImageGenerationConfig)
|
||||
- Imagen models use predict API (VertexAIImagenImageGenerationConfig)
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gemini-2.5-flash-image", "imagegeneration@006")
|
||||
|
||||
Returns:
|
||||
BaseImageGenerationConfig: The appropriate configuration class
|
||||
"""
|
||||
# Determine the model route
|
||||
model_route = get_vertex_ai_model_route(model)
|
||||
|
||||
if model_route == VertexAIModelRoute.GEMINI:
|
||||
# Gemini models use generateContent API
|
||||
return VertexAIGeminiImageGenerationConfig()
|
||||
else:
|
||||
# Default to Imagen for other models (imagegeneration, etc.)
|
||||
# This includes NON_GEMINI models like imagegeneration@006
|
||||
return VertexAIImagenImageGenerationConfig()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Vertex AI Image Generation Cost Calculator
|
||||
"""
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||
calculate_image_response_cost_from_usage,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
) -> float:
|
||||
"""
|
||||
Vertex AI Image Generation Cost Calculator
|
||||
"""
|
||||
_model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
token_based_cost = calculate_image_response_cost_from_usage(
|
||||
model=model,
|
||||
image_response=image_response,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
if token_based_cost is not None:
|
||||
return token_based_cost
|
||||
|
||||
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,282 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from openai.types.image import Image
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class VertexImageGeneration(VertexLLM):
|
||||
def process_image_generation_response(
|
||||
self,
|
||||
json_response: Dict[str, Any],
|
||||
model_response: ImageResponse,
|
||||
model: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
if "predictions" not in json_response:
|
||||
raise litellm.InternalServerError(
|
||||
message=f"image generation response does not contain 'predictions', got {json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
predictions = json_response["predictions"]
|
||||
response_data: List[Image] = []
|
||||
|
||||
for prediction in predictions:
|
||||
bytes_base64_encoded = prediction["bytesBase64Encoded"]
|
||||
image_object = Image(b64_json=bytes_base64_encoded)
|
||||
response_data.append(image_object)
|
||||
|
||||
model_response.data = response_data
|
||||
return model_response
|
||||
|
||||
def transform_optional_params(self, optional_params: Optional[dict]) -> dict:
|
||||
"""
|
||||
Transform the optional params to the format expected by the Vertex AI API.
|
||||
For example, "aspect_ratio" is transformed to "aspectRatio".
|
||||
"""
|
||||
default_params = {
|
||||
"sampleCount": 1,
|
||||
}
|
||||
if optional_params is None:
|
||||
return default_params
|
||||
|
||||
def snake_to_camel(snake_str: str) -> str:
|
||||
"""Convert snake_case to camelCase"""
|
||||
components = snake_str.split("_")
|
||||
return components[0] + "".join(word.capitalize() for word in components[1:])
|
||||
|
||||
transformed_params = default_params.copy()
|
||||
for key, value in optional_params.items():
|
||||
if "_" in key:
|
||||
camel_case_key = snake_to_camel(key)
|
||||
transformed_params[camel_case_key] = value
|
||||
else:
|
||||
transformed_params[key] = value
|
||||
|
||||
return transformed_params
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
api_base: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
model_response: ImageResponse,
|
||||
logging_obj: Any,
|
||||
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[Any] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
aimg_generation=False,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
if aimg_generation is True:
|
||||
return self.aimage_generation( # type: ignore
|
||||
prompt=prompt,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
model=model,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
auth_header: Optional[str] = None
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=None,
|
||||
auth_header=auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="image_generation",
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
||||
# Transform optional params to camelCase format
|
||||
optional_params = self.transform_optional_params(optional_params)
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return self.process_image_generation_response(
|
||||
json_response, model_response, model
|
||||
)
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
api_base: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
model_response: ImageResponse,
|
||||
logging_obj: Any,
|
||||
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
else:
|
||||
self.async_handler = client # type: ignore
|
||||
|
||||
# make POST request to
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||
|
||||
"""
|
||||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-d {
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "a cat"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"sampleCount": 1
|
||||
}
|
||||
} \
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
auth_header: Optional[str] = None
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=None,
|
||||
auth_header=auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="image_generation",
|
||||
)
|
||||
|
||||
# Transform optional params to camelCase format
|
||||
optional_params = self.transform_optional_params(optional_params)
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return self.process_image_generation_response(
|
||||
json_response, model_response, model
|
||||
)
|
||||
|
||||
def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
|
||||
if "predictions" in json_response:
|
||||
if "bytesBase64Encoded" in json_response["predictions"][0]:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,327 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ImageObject,
|
||||
ImageResponse,
|
||||
ImageUsage,
|
||||
ImageUsageInputTokensDetails,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIGeminiImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Gemini Image Generation Configuration
|
||||
|
||||
Uses generateContent API for Gemini image generation models on Vertex AI
|
||||
Supports models like gemini-2.5-flash-image, gemini-3-pro-image-preview, etc.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageGenerationConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Gemini image generation supported parameters
|
||||
|
||||
Includes native Gemini imageConfig params (aspectRatio, imageSize)
|
||||
in both camelCase and snake_case variants.
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"size",
|
||||
"aspectRatio",
|
||||
"aspect_ratio",
|
||||
"imageSize",
|
||||
"image_size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
mapped_params = {}
|
||||
|
||||
for k, v in non_default_params.items():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Map OpenAI parameters to Gemini format
|
||||
if k == "n":
|
||||
mapped_params["candidate_count"] = v
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Gemini aspectRatio
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
|
||||
elif k in ("aspectRatio", "aspect_ratio"):
|
||||
mapped_params["aspectRatio"] = v
|
||||
elif k in ("imageSize", "image_size"):
|
||||
mapped_params["imageSize"] = v
|
||||
else:
|
||||
mapped_params[k] = v
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
Map OpenAI size format to Gemini aspect ratio format
|
||||
"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Gemini generateContent API
|
||||
"""
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
# If a custom api_base is provided, use it directly
|
||||
# This allows users to use proxies or mock endpoints
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_location = (
|
||||
self.safe_get_vertex_ai_location(litellm_params)
|
||||
or self._resolve_vertex_location()
|
||||
)
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
|
||||
# If a custom api_base is provided, skip credential validation
|
||||
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
|
||||
_api_base = litellm_params.get("api_base") or api_base
|
||||
if _api_base is not None:
|
||||
return headers
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_credentials = (
|
||||
self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
or self._resolve_vertex_credentials()
|
||||
)
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Gemini format
|
||||
|
||||
Uses generateContent API with responseModalities: ["IMAGE"]
|
||||
"""
|
||||
# Prepare messages with the prompt
|
||||
contents = [{"role": "user", "parts": [{"text": prompt}]}]
|
||||
|
||||
# Prepare generation config
|
||||
generation_config: Dict[str, Any] = {"responseModalities": ["IMAGE"]}
|
||||
|
||||
# Handle image-specific config parameters
|
||||
image_config: Dict[str, Any] = {}
|
||||
|
||||
# Map aspectRatio
|
||||
if "aspectRatio" in optional_params:
|
||||
image_config["aspectRatio"] = optional_params["aspectRatio"]
|
||||
elif "aspect_ratio" in optional_params:
|
||||
image_config["aspectRatio"] = optional_params["aspect_ratio"]
|
||||
|
||||
# Map imageSize (for Gemini 3 Pro)
|
||||
if "imageSize" in optional_params:
|
||||
image_config["imageSize"] = optional_params["imageSize"]
|
||||
elif "image_size" in optional_params:
|
||||
image_config["imageSize"] = optional_params["image_size"]
|
||||
|
||||
if image_config:
|
||||
generation_config["imageConfig"] = image_config
|
||||
|
||||
# Handle candidate_count (n parameter)
|
||||
if "candidate_count" in optional_params:
|
||||
generation_config["candidateCount"] = optional_params["candidate_count"]
|
||||
elif "n" in optional_params:
|
||||
generation_config["candidateCount"] = optional_params["n"]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"generationConfig": generation_config,
|
||||
}
|
||||
|
||||
return request_body
|
||||
|
||||
def _transform_image_usage(self, usage: dict) -> ImageUsage:
|
||||
input_tokens_details = ImageUsageInputTokensDetails(
|
||||
image_tokens=0,
|
||||
text_tokens=0,
|
||||
)
|
||||
tokens_details = usage.get("promptTokensDetails", [])
|
||||
for details in tokens_details:
|
||||
if isinstance(details, dict) and (modality := details.get("modality")):
|
||||
token_count = details.get("tokenCount", 0)
|
||||
if modality == "TEXT":
|
||||
input_tokens_details.text_tokens += token_count
|
||||
elif modality == "IMAGE":
|
||||
input_tokens_details.image_tokens += token_count
|
||||
|
||||
return ImageUsage(
|
||||
input_tokens=usage.get("promptTokenCount", 0),
|
||||
input_tokens_details=input_tokens_details,
|
||||
output_tokens=usage.get("candidatesTokenCount", 0),
|
||||
total_tokens=usage.get("totalTokenCount", 0),
|
||||
)
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Gemini image generation response to litellm ImageResponse format
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Gemini image generation models return in candidates format
|
||||
candidates = response_data.get("candidates", [])
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
# Look for inlineData with image
|
||||
if "inlineData" in part:
|
||||
inline_data = part["inlineData"]
|
||||
if "data" in inline_data:
|
||||
thought_sig = part.get("thoughtSignature")
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=inline_data["data"],
|
||||
url=None,
|
||||
provider_specific_fields={
|
||||
"thought_signature": thought_sig
|
||||
}
|
||||
if thought_sig
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
if usage_metadata := response_data.get("usageMetadata", None):
|
||||
model_response.usage = self._transform_image_usage(usage_metadata)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,256 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIImagenImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Imagen Image Generation Configuration
|
||||
|
||||
Uses predict API for Imagen models on Vertex AI
|
||||
Supports models like imagegeneration@006
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageGenerationConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Imagen API supported parameters
|
||||
"""
|
||||
return ["n", "size"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
mapped_params = {}
|
||||
|
||||
for k, v in non_default_params.items():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Map OpenAI parameters to Imagen format
|
||||
if k == "n":
|
||||
mapped_params["sampleCount"] = v
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Imagen aspectRatio
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
|
||||
else:
|
||||
mapped_params[k] = v
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
Map OpenAI size format to Imagen aspect ratio format
|
||||
"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Imagen predict API
|
||||
"""
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
# If a custom api_base is provided, use it directly
|
||||
# This allows users to use proxies or mock endpoints
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_location = (
|
||||
self.safe_get_vertex_ai_location(litellm_params)
|
||||
or self._resolve_vertex_location()
|
||||
)
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
|
||||
# If a custom api_base is provided, skip credential validation
|
||||
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
|
||||
_api_base = litellm_params.get("api_base") or api_base
|
||||
if _api_base is not None:
|
||||
return headers
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_credentials = (
|
||||
self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
or self._resolve_vertex_credentials()
|
||||
)
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Imagen format
|
||||
|
||||
Uses predict API with instances and parameters
|
||||
"""
|
||||
# Default parameters
|
||||
default_params = {
|
||||
"sampleCount": 1,
|
||||
}
|
||||
|
||||
# Merge with optional params
|
||||
parameters = {**default_params, **optional_params}
|
||||
|
||||
request_body = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
return request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Imagen image generation response to litellm ImageResponse format
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Imagen format - predictions with generated images
|
||||
predictions = response_data.get("predictions", [])
|
||||
for prediction in predictions:
|
||||
# Imagen returns images as bytesBase64Encoded
|
||||
if "bytesBase64Encoded" in prediction:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=prediction["bytesBase64Encoded"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,188 @@
|
||||
import json
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexAIError,
|
||||
VertexLLM,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .transformation import VertexAIMultimodalEmbeddingConfig
|
||||
|
||||
vertex_multimodal_embedding_handler = VertexAIMultimodalEmbeddingConfig()
|
||||
|
||||
|
||||
class VertexMultimodalEmbedding(VertexLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
|
||||
"multimodalembedding",
|
||||
"multimodalembedding@001",
|
||||
]
|
||||
|
||||
def multimodal_embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
print_verbose,
|
||||
model_response: EmbeddingResponse,
|
||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
headers: dict = {},
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding: Optional[bool] = False,
|
||||
timeout=300,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="embedding",
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
request_data = vertex_multimodal_embedding_handler.transform_embedding_request(
|
||||
model, input, optional_params, headers
|
||||
)
|
||||
|
||||
headers = vertex_multimodal_embedding_handler.validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
api_key=auth_header,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_multimodal_embedding( # type: ignore
|
||||
model=model,
|
||||
api_base=url,
|
||||
data=request_data,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
return vertex_multimodal_embedding_handler.transform_embedding_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=request_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
async def async_multimodal_embedding(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
data: dict,
|
||||
model_response: EmbeddingResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return vertex_multimodal_embedding_handler.transform_embedding_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
@@ -0,0 +1,325 @@
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.exceptions import InternalServerError
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
Instance,
|
||||
InstanceImage,
|
||||
InstanceVideo,
|
||||
MultimodalPredictions,
|
||||
VertexMultimodalEmbeddingRequest,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
Embedding,
|
||||
EmbeddingResponse,
|
||||
PromptTokensDetailsWrapper,
|
||||
Usage,
|
||||
)
|
||||
from litellm.utils import _count_characters, is_base64_encoded
|
||||
|
||||
from ...base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from ..common_utils import VertexAIError
|
||||
|
||||
|
||||
class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return ["dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "dimensions":
|
||||
optional_params["outputDimensionality"] = value
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
headers.update(default_headers)
|
||||
return headers
|
||||
|
||||
def _is_gcs_uri(self, input_str: str) -> bool:
|
||||
"""Check if the input string is a GCS URI."""
|
||||
return "gs://" in input_str
|
||||
|
||||
def _is_video(self, input_str: str) -> bool:
|
||||
"""Check if the input string represents a video (mp4)."""
|
||||
return "mp4" in input_str
|
||||
|
||||
def _is_media_input(self, input_str: str) -> bool:
|
||||
"""Check if the input string is a media element (GCS URI or base64 image)."""
|
||||
return self._is_gcs_uri(input_str) or is_base64_encoded(s=input_str)
|
||||
|
||||
def _create_image_instance(self, input_str: str) -> InstanceImage:
|
||||
"""Create an InstanceImage from a GCS URI or base64 string."""
|
||||
if self._is_gcs_uri(input_str):
|
||||
return InstanceImage(gcsUri=input_str)
|
||||
else:
|
||||
return InstanceImage(
|
||||
bytesBase64Encoded=(
|
||||
input_str.split(",")[1] if "," in input_str else input_str
|
||||
)
|
||||
)
|
||||
|
||||
def _create_video_instance(self, input_str: str) -> InstanceVideo:
|
||||
"""Create an InstanceVideo from a GCS URI."""
|
||||
return InstanceVideo(gcsUri=input_str)
|
||||
|
||||
def _process_input_element(self, input_element: str) -> Instance:
|
||||
"""
|
||||
Process a single input element for multimodal embedding requests.
|
||||
Detects if the input is a GCS URI, base64 encoded image, or plain text.
|
||||
|
||||
Args:
|
||||
input_element (str): The input element to process.
|
||||
|
||||
Returns:
|
||||
Instance: A dictionary representing the processed input element.
|
||||
"""
|
||||
if len(input_element) == 0:
|
||||
return Instance(text=input_element)
|
||||
elif self._is_gcs_uri(input_element):
|
||||
if self._is_video(input_element):
|
||||
return Instance(video=self._create_video_instance(input_element))
|
||||
else:
|
||||
return Instance(image=self._create_image_instance(input_element))
|
||||
elif is_base64_encoded(s=input_element):
|
||||
return Instance(image=self._create_image_instance(input_element))
|
||||
else:
|
||||
return Instance(text=input_element)
|
||||
|
||||
def _try_merge_text_with_media(
|
||||
self, text_str: str, next_elem: Optional[str]
|
||||
) -> tuple[Instance, bool]:
|
||||
"""
|
||||
Try to merge a text element with a following media element into a single instance.
|
||||
|
||||
Args:
|
||||
text_str: The text string to potentially merge.
|
||||
next_elem: The next element in the input list (may be media).
|
||||
|
||||
Returns:
|
||||
A tuple of (Instance, consumed_next) where consumed_next indicates
|
||||
if the next element was merged into this instance.
|
||||
"""
|
||||
instance_args: Instance = {"text": text_str}
|
||||
|
||||
if next_elem and isinstance(next_elem, str) and self._is_media_input(next_elem):
|
||||
if self._is_gcs_uri(next_elem) and self._is_video(next_elem):
|
||||
instance_args["video"] = self._create_video_instance(next_elem)
|
||||
else:
|
||||
instance_args["image"] = self._create_image_instance(next_elem)
|
||||
return instance_args, True
|
||||
|
||||
return instance_args, False
|
||||
|
||||
def process_openai_embedding_input(
|
||||
self, _input: Union[list, str]
|
||||
) -> List[Instance]:
|
||||
"""
|
||||
Process the input for multimodal embedding requests.
|
||||
|
||||
Args:
|
||||
_input (Union[list, str]): The input data to process.
|
||||
|
||||
Returns:
|
||||
List[Instance]: List of Instance objects for the embedding request.
|
||||
"""
|
||||
_input_list = [_input] if not isinstance(_input, list) else _input
|
||||
processed_instances: List[Instance] = []
|
||||
|
||||
i = 0
|
||||
while i < len(_input_list):
|
||||
current = _input_list[i]
|
||||
next_elem = _input_list[i + 1] if i + 1 < len(_input_list) else None
|
||||
|
||||
if isinstance(current, str):
|
||||
if self._is_media_input(current):
|
||||
# Current element is media - process it standalone
|
||||
processed_instances.append(self._process_input_element(current))
|
||||
i += 1
|
||||
else:
|
||||
# Current element is text - try to merge with next media element
|
||||
instance, consumed_next = self._try_merge_text_with_media(
|
||||
text_str=current, next_elem=next_elem
|
||||
)
|
||||
processed_instances.append(instance)
|
||||
i += 2 if consumed_next else 1
|
||||
elif isinstance(current, dict):
|
||||
processed_instances.append(Instance(**current))
|
||||
i += 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(current)}")
|
||||
|
||||
return processed_instances
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
optional_params = optional_params or {}
|
||||
|
||||
request_data = VertexMultimodalEmbeddingRequest(instances=[])
|
||||
|
||||
if "instances" in optional_params:
|
||||
request_data["instances"] = optional_params["instances"]
|
||||
elif isinstance(input, list):
|
||||
vertex_instances: List[Instance] = self.process_openai_embedding_input(
|
||||
_input=input
|
||||
)
|
||||
request_data["instances"] = vertex_instances
|
||||
|
||||
else:
|
||||
# construct instances
|
||||
vertex_request_instance = Instance(**optional_params)
|
||||
|
||||
if isinstance(input, str):
|
||||
vertex_request_instance = self._process_input_element(input)
|
||||
|
||||
request_data["instances"] = [vertex_request_instance]
|
||||
|
||||
return cast(dict, request_data)
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
if raw_response.status_code != 200:
|
||||
raise Exception(f"Error: {raw_response.status_code} {raw_response.text}")
|
||||
|
||||
_json_response = raw_response.json()
|
||||
if "predictions" not in _json_response:
|
||||
raise InternalServerError(
|
||||
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
_predictions = _json_response["predictions"]
|
||||
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
||||
model_response.data = self.transform_embedding_response_to_openai(
|
||||
predictions=vertex_predictions
|
||||
)
|
||||
model_response.model = model
|
||||
|
||||
model_response.usage = self.calculate_usage(
|
||||
request_data=cast(VertexMultimodalEmbeddingRequest, request_data),
|
||||
vertex_predictions=vertex_predictions,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
def calculate_usage(
|
||||
self,
|
||||
request_data: VertexMultimodalEmbeddingRequest,
|
||||
vertex_predictions: MultimodalPredictions,
|
||||
) -> Usage:
|
||||
## Calculate text embeddings usage
|
||||
prompt: Optional[str] = None
|
||||
character_count: Optional[int] = None
|
||||
|
||||
for instance in request_data["instances"]:
|
||||
text = instance.get("text")
|
||||
if text:
|
||||
if prompt is None:
|
||||
prompt = text
|
||||
else:
|
||||
prompt += text
|
||||
|
||||
if prompt is not None:
|
||||
character_count = _count_characters(prompt)
|
||||
|
||||
## Calculate image embeddings usage
|
||||
image_count = 0
|
||||
for instance in request_data["instances"]:
|
||||
if instance.get("image"):
|
||||
image_count += 1
|
||||
|
||||
## Calculate video embeddings usage
|
||||
video_length_seconds = 0.0
|
||||
for prediction in vertex_predictions["predictions"]:
|
||||
video_embeddings = prediction.get("videoEmbeddings")
|
||||
if video_embeddings:
|
||||
for embedding in video_embeddings:
|
||||
duration = embedding["endOffsetSec"] - embedding["startOffsetSec"]
|
||||
video_length_seconds += duration
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
character_count=character_count,
|
||||
image_count=image_count,
|
||||
video_length_seconds=video_length_seconds,
|
||||
)
|
||||
|
||||
return Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
def transform_embedding_response_to_openai(
|
||||
self, predictions: MultimodalPredictions
|
||||
) -> List[Embedding]:
|
||||
openai_embeddings: List[Embedding] = []
|
||||
if "predictions" in predictions:
|
||||
for idx, _prediction in enumerate(predictions["predictions"]):
|
||||
if _prediction:
|
||||
if "textEmbedding" in _prediction:
|
||||
openai_embedding_object = Embedding(
|
||||
embedding=_prediction["textEmbedding"],
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
openai_embeddings.append(openai_embedding_object)
|
||||
elif "imageEmbedding" in _prediction:
|
||||
openai_embedding_object = Embedding(
|
||||
embedding=_prediction["imageEmbedding"],
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
openai_embeddings.append(openai_embedding_object)
|
||||
elif "videoEmbeddings" in _prediction:
|
||||
for video_embedding in _prediction["videoEmbeddings"]:
|
||||
openai_embedding_object = Embedding(
|
||||
embedding=video_embedding["embedding"],
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
openai_embeddings.append(openai_embedding_object)
|
||||
return openai_embeddings
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return VertexAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
"""Vertex AI OCR module."""
|
||||
from .transformation import VertexAIOCRConfig
|
||||
|
||||
__all__ = ["VertexAIOCRConfig"]
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Common utilities for Vertex AI OCR providers.
|
||||
|
||||
This module provides routing logic to determine which OCR configuration to use
|
||||
based on the model name.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
|
||||
|
||||
|
||||
def get_vertex_ai_ocr_config(model: str) -> Optional["BaseOCRConfig"]:
|
||||
"""
|
||||
Determine which Vertex AI OCR configuration to use based on the model name.
|
||||
|
||||
Vertex AI supports multiple OCR services:
|
||||
- Vertex AI OCR: vertex_ai/<model>
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "vertex_ai/ocr/<model>")
|
||||
|
||||
Returns:
|
||||
OCR configuration instance for the specified model
|
||||
|
||||
Examples:
|
||||
>>> get_vertex_ai_ocr_config("vertex_ai/deepseek-ai/deepseek-ocr-maas")
|
||||
<VertexAIDeepSeekOCRConfig object>
|
||||
|
||||
>>> get_vertex_ai_ocr_config("vertex_ai/ocr/mistral-ocr-maas")
|
||||
<VertexAIOCRConfig object>
|
||||
"""
|
||||
from litellm.llms.vertex_ai.ocr.deepseek_transformation import (
|
||||
VertexAIDeepSeekOCRConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.ocr.transformation import VertexAIOCRConfig
|
||||
|
||||
if "deepseek" in model:
|
||||
return VertexAIDeepSeekOCRConfig()
|
||||
return VertexAIOCRConfig()
|
||||
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
Vertex AI DeepSeek OCR transformation implementation.
|
||||
"""
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.ocr.transformation import (
|
||||
BaseOCRConfig,
|
||||
DocumentType,
|
||||
OCRPage,
|
||||
OCRRequestData,
|
||||
OCRResponse,
|
||||
OCRUsageInfo,
|
||||
)
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIDeepSeekOCRConfig(BaseOCRConfig):
|
||||
"""
|
||||
Vertex AI DeepSeek OCR transformation configuration.
|
||||
|
||||
Vertex AI DeepSeek OCR uses the chat completion API format through the openapi endpoint.
|
||||
This transformation converts OCR requests to chat completion format and vice versa.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.vertex_base = VertexBase()
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers for Vertex AI OCR.
|
||||
|
||||
Vertex AI uses Bearer token authentication with access token from credentials.
|
||||
"""
|
||||
# Extract Vertex AI parameters using safe helpers from VertexBase
|
||||
# Use safe_get_* methods that don't mutate litellm_params dict
|
||||
litellm_params = litellm_params or {}
|
||||
|
||||
vertex_project = VertexBase.safe_get_vertex_ai_project(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
# Get access token from Vertex credentials
|
||||
access_token, project_id = self.vertex_base.get_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
**headers,
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for Vertex AI DeepSeek OCR endpoint.
|
||||
|
||||
Vertex AI endpoint format:
|
||||
https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/endpoints/openapi/chat/completions
|
||||
|
||||
Args:
|
||||
api_base: Vertex AI API base URL (optional)
|
||||
model: Model name (e.g., "deepseek-ai/deepseek-ocr-maas")
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters containing vertex_project, vertex_location
|
||||
|
||||
Returns: Complete URL for Vertex AI OCR endpoint
|
||||
"""
|
||||
# Extract Vertex AI parameters using safe helpers from VertexBase
|
||||
# Use safe_get_* methods that don't mutate litellm_params dict
|
||||
litellm_params = litellm_params or {}
|
||||
|
||||
vertex_project = VertexBase.safe_get_vertex_ai_project(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
vertex_location = VertexBase.safe_get_vertex_ai_location(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if vertex_project is None:
|
||||
raise ValueError(
|
||||
"Missing vertex_project - Set VERTEXAI_PROJECT environment variable or pass vertex_project parameter"
|
||||
)
|
||||
|
||||
if vertex_location is None:
|
||||
vertex_location = "us-central1"
|
||||
|
||||
# Get API base URL
|
||||
if api_base is None:
|
||||
api_base = "https://aiplatform.googleapis.com"
|
||||
|
||||
# Ensure no trailing slash
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Vertex AI DeepSeek OCR endpoint format
|
||||
# Format: https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/endpoints/openapi/chat/completions
|
||||
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi/chat/completions"
|
||||
|
||||
def transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request to chat completion format for Vertex AI DeepSeek OCR.
|
||||
|
||||
Converts OCR document format to chat completion messages format:
|
||||
- Input: {"type": "image_url", "image_url": "gs://..."}
|
||||
- Output: {"model": "deepseek-ai/deepseek-ocr-maas", "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": "gs://..."}]}]}
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "deepseek-ai/deepseek-ocr-maas")
|
||||
document: Document dict from user (Mistral OCR format)
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data in chat completion format
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"Vertex AI DeepSeek OCR transform_ocr_request (sync) called"
|
||||
)
|
||||
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError(f"Expected document dict, got {type(document)}")
|
||||
|
||||
# Extract document type and URL
|
||||
doc_type = document.get("type")
|
||||
image_url = None
|
||||
document_url = None
|
||||
|
||||
if doc_type == "image_url":
|
||||
image_url = document.get("image_url", "")
|
||||
elif doc_type == "document_url":
|
||||
document_url = document.get("document_url", "")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported document type: {doc_type}. Expected 'image_url' or 'document_url'"
|
||||
)
|
||||
|
||||
# Build chat completion message content
|
||||
content_item = {}
|
||||
if image_url:
|
||||
content_item = {"type": "image_url", "image_url": image_url}
|
||||
elif document_url:
|
||||
# For document URLs, we use image_url type as well (Vertex AI supports both)
|
||||
content_item = {"type": "image_url", "image_url": document_url}
|
||||
|
||||
# Build chat completion request
|
||||
data = {
|
||||
"model": "deepseek-ai/" + model,
|
||||
"messages": [{"role": "user", "content": [content_item]}],
|
||||
}
|
||||
|
||||
# Add optional parameters (stream, temperature, etc.)
|
||||
# Filter out OCR-specific params that don't apply to chat completion
|
||||
chat_completion_params = {}
|
||||
for key, value in optional_params.items():
|
||||
# Include common chat completion params
|
||||
if key in ["stream", "temperature", "max_tokens", "top_p", "n", "stop"]:
|
||||
chat_completion_params[key] = value
|
||||
|
||||
data.update(chat_completion_params)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Vertex AI DeepSeek OCR: Transformed request to chat completion format"
|
||||
)
|
||||
|
||||
return OCRRequestData(data=data, files=None)
|
||||
|
||||
async def async_transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request to chat completion format for Vertex AI DeepSeek OCR (async).
|
||||
|
||||
Same as sync version - no async-specific logic needed.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document dict from user
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data in chat completion format
|
||||
"""
|
||||
return self.transform_ocr_request(
|
||||
model=model,
|
||||
document=document,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def transform_ocr_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> OCRResponse:
|
||||
"""
|
||||
Transform chat completion response to OCR format.
|
||||
|
||||
Vertex AI DeepSeek OCR returns chat completion format:
|
||||
{
|
||||
"id": "...",
|
||||
"object": "chat.completion",
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "<OCR result as JSON string or markdown>"
|
||||
}
|
||||
}],
|
||||
"usage": {...}
|
||||
}
|
||||
|
||||
We need to extract the content and convert it to OCRResponse format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: Raw HTTP response from Vertex AI
|
||||
logging_obj: Logging object
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRResponse in standard format
|
||||
"""
|
||||
verbose_logger.debug("Vertex AI DeepSeek OCR transform_ocr_response called")
|
||||
verbose_logger.debug(f"Raw response: {raw_response.text}")
|
||||
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
|
||||
# Extract content from chat completion response
|
||||
choices = response_json.get("choices", [])
|
||||
if not choices:
|
||||
raise ValueError("No choices in chat completion response")
|
||||
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content", "")
|
||||
|
||||
if not content:
|
||||
raise ValueError("No content in chat completion response")
|
||||
|
||||
# Try to parse content as JSON (OCR result might be JSON string)
|
||||
ocr_data = None
|
||||
try:
|
||||
# If content is a JSON string, parse it
|
||||
if isinstance(content, str) and content.strip().startswith("{"):
|
||||
ocr_data = json.loads(content)
|
||||
elif isinstance(content, dict):
|
||||
ocr_data = content
|
||||
else:
|
||||
# If content is markdown text, create a single page with the markdown
|
||||
ocr_data = {
|
||||
"pages": [{"index": 0, "markdown": content}],
|
||||
"model": model,
|
||||
"usage_info": response_json.get("usage", {}),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# If JSON parsing fails, treat content as markdown
|
||||
ocr_data = {
|
||||
"pages": [{"index": 0, "markdown": content}],
|
||||
"model": model,
|
||||
"usage_info": response_json.get("usage", {}),
|
||||
}
|
||||
|
||||
# Ensure we have the expected structure
|
||||
if "pages" not in ocr_data:
|
||||
# If OCR data doesn't have pages, wrap the content in a page
|
||||
ocr_data = {
|
||||
"pages": [
|
||||
{
|
||||
"index": 0,
|
||||
"markdown": content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content),
|
||||
}
|
||||
],
|
||||
"model": ocr_data.get("model", model),
|
||||
"usage_info": ocr_data.get(
|
||||
"usage_info", response_json.get("usage", {})
|
||||
),
|
||||
}
|
||||
|
||||
# Convert usage info if present
|
||||
usage_info = None
|
||||
if "usage_info" in ocr_data:
|
||||
usage_dict = ocr_data["usage_info"]
|
||||
if isinstance(usage_dict, dict):
|
||||
usage_info = OCRUsageInfo(**usage_dict)
|
||||
|
||||
# Build OCRResponse
|
||||
pages = []
|
||||
for page_data in ocr_data.get("pages", []):
|
||||
# Ensure page has required fields
|
||||
if isinstance(page_data, dict):
|
||||
page = OCRPage(
|
||||
index=page_data.get("index", 0),
|
||||
markdown=page_data.get("markdown", ""),
|
||||
images=page_data.get("images"),
|
||||
dimensions=page_data.get("dimensions"),
|
||||
)
|
||||
pages.append(page)
|
||||
|
||||
if not pages:
|
||||
# Create a default page if none exist
|
||||
pages = [
|
||||
OCRPage(
|
||||
index=0, markdown=content if isinstance(content, str) else ""
|
||||
)
|
||||
]
|
||||
|
||||
return OCRResponse(
|
||||
pages=pages,
|
||||
model=ocr_data.get("model", model),
|
||||
document_annotation=ocr_data.get("document_annotation"),
|
||||
usage_info=usage_info,
|
||||
object="ocr",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error parsing Vertex AI DeepSeek OCR response: {e}")
|
||||
raise e
|
||||
|
||||
async def async_transform_ocr_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> OCRResponse:
|
||||
"""
|
||||
Async transform chat completion response to OCR format.
|
||||
|
||||
Same as sync version - no async-specific logic needed.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRResponse in standard format
|
||||
"""
|
||||
return self.transform_ocr_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
logging_obj=logging_obj,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Vertex AI Mistral OCR transformation implementation.
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.image_handling import (
|
||||
async_convert_url_to_base64,
|
||||
convert_url_to_base64,
|
||||
)
|
||||
from litellm.llms.base_llm.ocr.transformation import DocumentType, OCRRequestData
|
||||
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAIOCRConfig(MistralOCRConfig):
|
||||
"""
|
||||
Vertex AI Mistral OCR transformation configuration.
|
||||
|
||||
Vertex AI uses Mistral's OCR API format through the Mistral publisher endpoint.
|
||||
Inherits transformation logic from MistralOCRConfig since they use the same format.
|
||||
|
||||
Reference: Vertex AI Mistral OCR documentation
|
||||
|
||||
Important: Vertex AI OCR only supports base64 data URIs (data:image/..., data:application/pdf;base64,...).
|
||||
Regular URLs are not supported.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.vertex_base = VertexBase()
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers for Vertex AI OCR.
|
||||
|
||||
Vertex AI uses Bearer token authentication with access token from credentials.
|
||||
"""
|
||||
# Extract Vertex AI parameters using safe helpers from VertexBase
|
||||
# Use safe_get_* methods that don't mutate litellm_params dict
|
||||
litellm_params = litellm_params or {}
|
||||
|
||||
vertex_project = VertexBase.safe_get_vertex_ai_project(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
# Get access token from Vertex credentials
|
||||
access_token, project_id = self.vertex_base.get_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
**headers,
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for Vertex AI OCR endpoint.
|
||||
|
||||
Vertex AI endpoint format:
|
||||
https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/mistralai/ocr
|
||||
|
||||
Args:
|
||||
api_base: Vertex AI API base URL (optional)
|
||||
model: Model name (not used in URL construction)
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters containing vertex_project, vertex_location
|
||||
|
||||
Returns: Complete URL for Vertex AI OCR endpoint
|
||||
"""
|
||||
# Extract Vertex AI parameters using safe helpers from VertexBase
|
||||
# Use safe_get_* methods that don't mutate litellm_params dict
|
||||
litellm_params = litellm_params or {}
|
||||
|
||||
vertex_project = VertexBase.safe_get_vertex_ai_project(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
vertex_location = VertexBase.safe_get_vertex_ai_location(
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if vertex_project is None:
|
||||
raise ValueError(
|
||||
"Missing vertex_project - Set VERTEXAI_PROJECT environment variable or pass vertex_project parameter"
|
||||
)
|
||||
|
||||
if vertex_location is None:
|
||||
vertex_location = "us-central1"
|
||||
|
||||
# Get API base URL
|
||||
if api_base is None:
|
||||
api_base = get_vertex_base_url(vertex_location)
|
||||
|
||||
# Ensure no trailing slash
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Vertex AI OCR endpoint format for Mistral publisher
|
||||
# Format: https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/mistralai/models/{model}:rawPredict
|
||||
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||
|
||||
def _convert_url_to_data_uri_sync(self, url: str) -> str:
|
||||
"""
|
||||
Synchronously convert a URL to a base64 data URI.
|
||||
|
||||
Vertex AI OCR doesn't have internet access, so we need to fetch URLs
|
||||
and convert them to base64 data URIs.
|
||||
|
||||
Args:
|
||||
url: The URL to convert
|
||||
|
||||
Returns:
|
||||
Base64 data URI string
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Vertex AI OCR: Converting URL to base64 data URI (sync): {url}"
|
||||
)
|
||||
|
||||
# Fetch and convert to base64 data URI
|
||||
# convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
|
||||
data_uri = convert_url_to_base64(url=url)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Vertex AI OCR: Converted URL to data URI (length: {len(data_uri)})"
|
||||
)
|
||||
|
||||
return data_uri
|
||||
|
||||
async def _convert_url_to_data_uri_async(self, url: str) -> str:
|
||||
"""
|
||||
Asynchronously convert a URL to a base64 data URI.
|
||||
|
||||
Vertex AI OCR doesn't have internet access, so we need to fetch URLs
|
||||
and convert them to base64 data URIs.
|
||||
|
||||
Args:
|
||||
url: The URL to convert
|
||||
|
||||
Returns:
|
||||
Base64 data URI string
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Vertex AI OCR: Converting URL to base64 data URI (async): {url}"
|
||||
)
|
||||
|
||||
# Fetch and convert to base64 data URI asynchronously
|
||||
# async_convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
|
||||
data_uri = await async_convert_url_to_base64(url=url)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Vertex AI OCR: Converted URL to data URI (length: {len(data_uri)})"
|
||||
)
|
||||
|
||||
return data_uri
|
||||
|
||||
def transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request for Vertex AI, converting URLs to base64 data URIs (sync).
|
||||
|
||||
Vertex AI OCR doesn't have internet access, so we automatically fetch
|
||||
any URLs and convert them to base64 data URIs synchronously.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document dict from user
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data
|
||||
"""
|
||||
verbose_logger.debug("Vertex AI OCR transform_ocr_request (sync) called")
|
||||
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError(f"Expected document dict, got {type(document)}")
|
||||
|
||||
# Check if we need to convert URL to base64
|
||||
doc_type = document.get("type")
|
||||
transformed_document = document.copy()
|
||||
|
||||
if doc_type == "document_url":
|
||||
document_url = document.get("document_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if document_url and not document_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI OCR: Converting document URL to base64 data URI (sync)"
|
||||
)
|
||||
data_uri = self._convert_url_to_data_uri_sync(url=document_url)
|
||||
transformed_document["document_url"] = data_uri
|
||||
elif doc_type == "image_url":
|
||||
image_url = document.get("image_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if image_url and not image_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI OCR: Converting image URL to base64 data URI (sync)"
|
||||
)
|
||||
data_uri = self._convert_url_to_data_uri_sync(url=image_url)
|
||||
transformed_document["image_url"] = data_uri
|
||||
|
||||
# Call parent's transform to build the request
|
||||
return super().transform_ocr_request(
|
||||
model=model,
|
||||
document=transformed_document,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def async_transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request for Vertex AI, converting URLs to base64 data URIs (async).
|
||||
|
||||
Vertex AI OCR doesn't have internet access, so we automatically fetch
|
||||
any URLs and convert them to base64 data URIs asynchronously.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document dict from user
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Vertex AI OCR async_transform_ocr_request - model: {model}"
|
||||
)
|
||||
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError(f"Expected document dict, got {type(document)}")
|
||||
|
||||
# Check if we need to convert URL to base64
|
||||
doc_type = document.get("type")
|
||||
transformed_document = document.copy()
|
||||
|
||||
if doc_type == "document_url":
|
||||
document_url = document.get("document_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if document_url and not document_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI OCR: Converting document URL to base64 data URI (async)"
|
||||
)
|
||||
data_uri = await self._convert_url_to_data_uri_async(url=document_url)
|
||||
transformed_document["document_url"] = data_uri
|
||||
elif doc_type == "image_url":
|
||||
image_url = document.get("image_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if image_url and not image_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI OCR: Converting image URL to base64 data URI (async)"
|
||||
)
|
||||
data_uri = await self._convert_url_to_data_uri_async(url=image_url)
|
||||
transformed_document["image_url"] = data_uri
|
||||
|
||||
# Call parent's transform to build the request
|
||||
return super().transform_ocr_request(
|
||||
model=model,
|
||||
document=transformed_document,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Vertex AI RAG Engine module.
|
||||
|
||||
Handles RAG ingestion via Vertex AI RAG Engine API.
|
||||
"""
|
||||
|
||||
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
|
||||
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
|
||||
|
||||
__all__ = [
|
||||
"VertexAIRAGIngestion",
|
||||
"VertexAIRAGTransformation",
|
||||
]
|
||||
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Vertex AI RAG Engine Ingestion implementation.
|
||||
|
||||
Uses:
|
||||
- litellm.files.acreate_file for uploading files to GCS
|
||||
- Vertex AI RAG Engine REST API for importing files into corpus (via httpx)
|
||||
|
||||
Key differences from OpenAI:
|
||||
- Files must be uploaded to GCS first (via litellm.files.acreate_file)
|
||||
- Embedding is handled internally using text-embedding-005 by default
|
||||
- Chunking configured via unified chunking_strategy in ingest_options
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
from litellm import get_secret_str
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
|
||||
|
||||
def _get_str_or_none(value: Any) -> Optional[str]:
|
||||
"""Cast config value to Optional[str]."""
|
||||
return str(value) if value is not None else None
|
||||
|
||||
|
||||
def _get_int(value: Any, default: int) -> int:
|
||||
"""Cast config value to int with default."""
|
||||
if value is None:
|
||||
return default
|
||||
return int(value)
|
||||
|
||||
|
||||
class VertexAIRAGIngestion(BaseRAGIngestion):
|
||||
"""
|
||||
Vertex AI RAG Engine ingestion.
|
||||
|
||||
Uses litellm.files.acreate_file for GCS upload, then imports into RAG corpus.
|
||||
|
||||
Required config in vector_store:
|
||||
- vector_store_id: RAG corpus ID (required)
|
||||
|
||||
Optional config in vector_store:
|
||||
- vertex_project: GCP project ID (uses env VERTEXAI_PROJECT if not set)
|
||||
- vertex_location: GCP region (default: us-central1)
|
||||
- vertex_credentials: Path to credentials JSON (uses ADC if not set)
|
||||
- wait_for_import: Wait for import to complete (default: True)
|
||||
- import_timeout: Timeout in seconds (default: 600)
|
||||
|
||||
Chunking is configured via ingest_options["chunking_strategy"]:
|
||||
- chunk_size: Maximum size of chunks (default: 1000)
|
||||
- chunk_overlap: Overlap between chunks (default: 200)
|
||||
|
||||
Authentication:
|
||||
- Uses Application Default Credentials (ADC)
|
||||
- Run: gcloud auth application-default login
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: "RAGIngestOptions",
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
super().__init__(ingest_options=ingest_options, router=router)
|
||||
|
||||
# Get corpus ID (required for Vertex AI)
|
||||
self.corpus_id = self.vector_store_config.get("vector_store_id")
|
||||
if not self.corpus_id:
|
||||
raise ValueError(
|
||||
"vector_store_id (corpus ID) is required for Vertex AI RAG ingestion. "
|
||||
"Please provide an existing RAG corpus ID."
|
||||
)
|
||||
|
||||
# GCP config
|
||||
self.vertex_project = self.vector_store_config.get(
|
||||
"vertex_project"
|
||||
) or get_secret_str("VERTEXAI_PROJECT")
|
||||
self.vertex_location = (
|
||||
self.vector_store_config.get("vertex_location")
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or "us-central1"
|
||||
)
|
||||
self.vertex_credentials = self.vector_store_config.get("vertex_credentials")
|
||||
|
||||
# GCS bucket for file uploads
|
||||
self.gcs_bucket = self.vector_store_config.get("gcs_bucket") or os.environ.get(
|
||||
"GCS_BUCKET_NAME"
|
||||
)
|
||||
if not self.gcs_bucket:
|
||||
raise ValueError(
|
||||
"gcs_bucket is required for Vertex AI RAG ingestion. "
|
||||
"Set via vector_store config or GCS_BUCKET_NAME env var."
|
||||
)
|
||||
|
||||
# Import settings
|
||||
self.wait_for_import = self.vector_store_config.get("wait_for_import", True)
|
||||
self.import_timeout = _get_int(
|
||||
self.vector_store_config.get("import_timeout"), 600
|
||||
)
|
||||
|
||||
# Validate required config
|
||||
if not self.vertex_project:
|
||||
raise ValueError(
|
||||
"vertex_project is required for Vertex AI RAG ingestion. "
|
||||
"Set via vector_store config or VERTEXAI_PROJECT env var."
|
||||
)
|
||||
|
||||
def _get_corpus_name(self) -> str:
|
||||
"""Get full corpus resource name."""
|
||||
return f"projects/{self.vertex_project}/locations/{self.vertex_location}/ragCorpora/{self.corpus_id}"
|
||||
|
||||
async def _upload_file_to_gcs(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
"""
|
||||
Upload file to GCS using litellm.files.acreate_file.
|
||||
|
||||
Returns:
|
||||
GCS URI of the uploaded file (gs://bucket/path/file)
|
||||
"""
|
||||
import litellm
|
||||
|
||||
# Set GCS_BUCKET_NAME env var for litellm.files.create_file
|
||||
# The handler uses this to determine where to upload
|
||||
original_bucket = os.environ.get("GCS_BUCKET_NAME")
|
||||
if self.gcs_bucket:
|
||||
os.environ["GCS_BUCKET_NAME"] = self.gcs_bucket
|
||||
|
||||
try:
|
||||
# Create file tuple for litellm.files.acreate_file
|
||||
file_tuple = (filename, file_content, content_type)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Uploading file to GCS via litellm.files.acreate_file: {filename} "
|
||||
f"(bucket: {self.gcs_bucket})"
|
||||
)
|
||||
|
||||
# Upload to GCS using LiteLLM's file upload
|
||||
response = await litellm.acreate_file(
|
||||
file=file_tuple,
|
||||
purpose="assistants", # Purpose for file storage
|
||||
custom_llm_provider="vertex_ai",
|
||||
vertex_project=self.vertex_project,
|
||||
vertex_location=self.vertex_location,
|
||||
vertex_credentials=self.vertex_credentials,
|
||||
)
|
||||
|
||||
# The response.id should be the GCS URI
|
||||
gcs_uri = response.id
|
||||
verbose_logger.info(f"Uploaded file to GCS: {gcs_uri}")
|
||||
|
||||
return gcs_uri
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_bucket is not None:
|
||||
os.environ["GCS_BUCKET_NAME"] = original_bucket
|
||||
elif "GCS_BUCKET_NAME" in os.environ:
|
||||
del os.environ["GCS_BUCKET_NAME"]
|
||||
|
||||
async def _import_file_to_corpus_via_sdk(
|
||||
self,
|
||||
gcs_uri: str,
|
||||
) -> None:
|
||||
"""
|
||||
Import file into RAG corpus using the Vertex AI SDK.
|
||||
|
||||
The REST API endpoint for importRagFiles is not publicly available,
|
||||
so we use the Python SDK.
|
||||
"""
|
||||
try:
|
||||
from vertexai import init as vertexai_init
|
||||
from vertexai import rag # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"vertexai.rag module not found. Vertex AI RAG requires "
|
||||
"google-cloud-aiplatform>=1.60.0. Install with: "
|
||||
"pip install 'google-cloud-aiplatform>=1.60.0'"
|
||||
)
|
||||
|
||||
# Initialize Vertex AI
|
||||
vertexai_init(project=self.vertex_project, location=self.vertex_location)
|
||||
|
||||
# Get chunking config from ingest_options (unified interface)
|
||||
transformation_config = self._build_transformation_config()
|
||||
|
||||
corpus_name = self._get_corpus_name()
|
||||
verbose_logger.debug(f"Importing {gcs_uri} into corpus {self.corpus_id}")
|
||||
|
||||
if self.wait_for_import:
|
||||
# Synchronous import - wait for completion
|
||||
response = rag.import_files(
|
||||
corpus_name=corpus_name,
|
||||
paths=[gcs_uri],
|
||||
transformation_config=transformation_config,
|
||||
timeout=self.import_timeout,
|
||||
)
|
||||
verbose_logger.info(
|
||||
f"Import complete: {response.imported_rag_files_count} files imported"
|
||||
)
|
||||
else:
|
||||
# Async import - don't wait
|
||||
_ = rag.import_files_async(
|
||||
corpus_name=corpus_name,
|
||||
paths=[gcs_uri],
|
||||
transformation_config=transformation_config,
|
||||
)
|
||||
verbose_logger.info("Import started asynchronously")
|
||||
|
||||
def _build_transformation_config(self) -> Any:
|
||||
"""
|
||||
Build Vertex AI TransformationConfig from unified chunking_strategy.
|
||||
|
||||
Uses chunking_strategy from ingest_options (not vector_store).
|
||||
"""
|
||||
try:
|
||||
from vertexai import rag # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"vertexai.rag module not found. Vertex AI RAG requires "
|
||||
"google-cloud-aiplatform>=1.60.0. Install with: "
|
||||
"pip install 'google-cloud-aiplatform>=1.60.0'"
|
||||
)
|
||||
|
||||
# Get chunking config from ingest_options using transformation class
|
||||
from typing import cast
|
||||
|
||||
from litellm.types.rag import RAGChunkingStrategy
|
||||
|
||||
transformation = VertexAIRAGTransformation()
|
||||
chunking_config = transformation.transform_chunking_strategy_to_vertex_format(
|
||||
cast(Optional[RAGChunkingStrategy], self.chunking_strategy)
|
||||
)
|
||||
|
||||
chunk_size = chunking_config["chunking_config"]["chunk_size"]
|
||||
chunk_overlap = chunking_config["chunking_config"]["chunk_overlap"]
|
||||
|
||||
return rag.TransformationConfig(
|
||||
chunking_config=rag.ChunkingConfig(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
),
|
||||
)
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Vertex AI handles embedding internally - skip this step.
|
||||
|
||||
Returns:
|
||||
None (Vertex AI embeds when files are imported)
|
||||
"""
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store content in Vertex AI RAG corpus.
|
||||
|
||||
Vertex AI workflow:
|
||||
1. Upload file to GCS via litellm.files.acreate_file
|
||||
2. Import file into RAG corpus via SDK
|
||||
3. (Optional) Wait for import to complete
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
filename: Name of the file
|
||||
content_type: MIME type
|
||||
chunks: Ignored - Vertex AI handles chunking
|
||||
embeddings: Ignored - Vertex AI handles embedding
|
||||
|
||||
Returns:
|
||||
Tuple of (corpus_id, gcs_uri)
|
||||
"""
|
||||
if not file_content or not filename:
|
||||
verbose_logger.warning(
|
||||
"No file content or filename provided for Vertex AI ingestion"
|
||||
)
|
||||
return _get_str_or_none(self.corpus_id), None
|
||||
|
||||
# Step 1: Upload file to GCS
|
||||
gcs_uri = await self._upload_file_to_gcs(
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
content_type=content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
# Step 2: Import file into RAG corpus
|
||||
try:
|
||||
await self._import_file_to_corpus_via_sdk(gcs_uri=gcs_uri)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to import file into RAG corpus: {e}")
|
||||
raise RuntimeError(f"Failed to import file into RAG corpus: {e}") from e
|
||||
|
||||
return str(self.corpus_id), gcs_uri
|
||||
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Transformation utilities for Vertex AI RAG Engine.
|
||||
|
||||
Handles transforming LiteLLM's unified formats to Vertex AI RAG Engine API format.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.rag import RAGChunkingStrategy
|
||||
|
||||
|
||||
class VertexAIRAGTransformation(VertexBase):
|
||||
"""
|
||||
Transformation class for Vertex AI RAG Engine API.
|
||||
|
||||
Handles:
|
||||
- Converting unified chunking_strategy to Vertex AI format
|
||||
- Building import request payloads
|
||||
- Transforming responses
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_import_rag_files_url(
|
||||
self,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
corpus_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Get the URL for importing RAG files.
|
||||
|
||||
Note: The REST endpoint for importRagFiles may not be publicly available.
|
||||
Vertex AI RAG Engine primarily uses gRPC-based SDK.
|
||||
"""
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{corpus_id}:importRagFiles"
|
||||
|
||||
def get_retrieve_contexts_url(
|
||||
self,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
) -> str:
|
||||
"""Get the URL for retrieving contexts (search)."""
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}:retrieveContexts"
|
||||
|
||||
def transform_chunking_strategy_to_vertex_format(
|
||||
self,
|
||||
chunking_strategy: Optional[RAGChunkingStrategy],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform LiteLLM's unified chunking_strategy to Vertex AI RAG format.
|
||||
|
||||
LiteLLM format (RAGChunkingStrategy):
|
||||
{
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 200,
|
||||
"separators": ["\n\n", "\n", " ", ""]
|
||||
}
|
||||
|
||||
Vertex AI RAG format (TransformationConfig):
|
||||
{
|
||||
"chunking_config": {
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 200
|
||||
}
|
||||
}
|
||||
|
||||
Note: Vertex AI doesn't support custom separators in the same way,
|
||||
so we only transform chunk_size and chunk_overlap.
|
||||
"""
|
||||
if not chunking_strategy:
|
||||
return {
|
||||
"chunking_config": {
|
||||
"chunk_size": DEFAULT_CHUNK_SIZE,
|
||||
"chunk_overlap": DEFAULT_CHUNK_OVERLAP,
|
||||
}
|
||||
}
|
||||
|
||||
chunk_size = chunking_strategy.get("chunk_size", DEFAULT_CHUNK_SIZE)
|
||||
chunk_overlap = chunking_strategy.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
|
||||
|
||||
# Log if separators are provided (not supported by Vertex AI)
|
||||
if chunking_strategy.get("separators"):
|
||||
verbose_logger.warning(
|
||||
"Vertex AI RAG Engine does not support custom separators. "
|
||||
"The 'separators' parameter will be ignored."
|
||||
)
|
||||
|
||||
return {
|
||||
"chunking_config": {
|
||||
"chunk_size": chunk_size,
|
||||
"chunk_overlap": chunk_overlap,
|
||||
}
|
||||
}
|
||||
|
||||
def build_import_rag_files_request(
|
||||
self,
|
||||
gcs_uri: str,
|
||||
chunking_strategy: Optional[RAGChunkingStrategy] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build the request payload for importing RAG files.
|
||||
|
||||
Args:
|
||||
gcs_uri: GCS URI of the file to import (e.g., gs://bucket/path/file.txt)
|
||||
chunking_strategy: LiteLLM unified chunking config
|
||||
|
||||
Returns:
|
||||
Request payload dict for importRagFiles API
|
||||
"""
|
||||
transformation_config = self.transform_chunking_strategy_to_vertex_format(
|
||||
chunking_strategy
|
||||
)
|
||||
|
||||
return {
|
||||
"import_rag_files_config": {
|
||||
"gcs_source": {"uris": [gcs_uri]},
|
||||
"rag_file_transformation_config": transformation_config,
|
||||
}
|
||||
}
|
||||
|
||||
def get_auth_headers(
|
||||
self,
|
||||
vertex_credentials: Optional[str] = None,
|
||||
vertex_project: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get authentication headers for Vertex AI API calls.
|
||||
|
||||
Uses the base class method to get credentials.
|
||||
"""
|
||||
credentials = self.get_vertex_ai_credentials(
|
||||
{"vertex_credentials": vertex_credentials}
|
||||
)
|
||||
project = vertex_project or self.get_vertex_ai_project({})
|
||||
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=credentials,
|
||||
project_id=project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Vertex AI Realtime (BidiGenerateContent) config.
|
||||
|
||||
Extends GeminiRealtimeConfig but adapts the WSS URL and auth header for the
|
||||
Vertex AI endpoint instead of Google AI Studio.
|
||||
|
||||
URL pattern:
|
||||
wss://{location}-aiplatform.googleapis.com/ws/
|
||||
google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent
|
||||
|
||||
Auth: OAuth2 Bearer token (not an API key).
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.gemini.realtime.transformation import GeminiRealtimeConfig
|
||||
|
||||
|
||||
class VertexAIRealtimeConfig(GeminiRealtimeConfig):
|
||||
"""
|
||||
Realtime config for Vertex AI (BidiGenerateContent).
|
||||
|
||||
``access_token`` and ``project`` must be pre-resolved by the caller
|
||||
(they require async I/O) and injected at construction time.
|
||||
"""
|
||||
|
||||
def __init__(self, access_token: str, project: str, location: str) -> None:
|
||||
self._access_token = access_token
|
||||
self._project = project
|
||||
self._location = location
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# URL
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
api_key: Optional[str] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
Build the Vertex AI Live WSS endpoint URL.
|
||||
|
||||
If *api_base* is provided it overrides the default aiplatform host,
|
||||
allowing enterprise / VPC-SC deployments to point at a custom gateway.
|
||||
"""
|
||||
if api_base:
|
||||
# Allow callers to supply a fully-qualified wss:// base URL.
|
||||
base = api_base.rstrip("/")
|
||||
base = base.replace("https://", "wss://").replace("http://", "ws://")
|
||||
return f"{base}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
|
||||
|
||||
location = self._location
|
||||
if location == "global":
|
||||
host = "aiplatform.googleapis.com"
|
||||
else:
|
||||
host = f"{location}-aiplatform.googleapis.com"
|
||||
|
||||
return f"wss://{host}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auth headers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str, # noqa: ARG002
|
||||
api_key: Optional[str] = None, # noqa: ARG002
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers with a Bearer token for Vertex AI.
|
||||
|
||||
``api_key`` is intentionally ignored — Vertex AI uses OAuth2 tokens,
|
||||
not API keys. The token was resolved at config-construction time.
|
||||
"""
|
||||
headers = dict(headers)
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
if self._project:
|
||||
headers["x-goog-user-project"] = self._project
|
||||
return headers
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Audio MIME type — Vertex AI needs the sample rate in the MIME string
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_audio_mime_type(self, input_audio_format: str = "pcm16") -> str:
|
||||
mime_types = {
|
||||
"pcm16": "audio/pcm;rate=16000",
|
||||
"g711_ulaw": "audio/pcmu",
|
||||
"g711_alaw": "audio/pcma",
|
||||
}
|
||||
return mime_types.get(input_audio_format, "application/octet-stream")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session setup message
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def session_configuration_request(self, model: str) -> str:
|
||||
"""
|
||||
Return the JSON setup message for Vertex AI Live.
|
||||
|
||||
Vertex AI requires the fully-qualified model path:
|
||||
``projects/{project}/locations/{location}/publishers/google/models/{model}``
|
||||
|
||||
Also enables automatic activity detection (server VAD) and output
|
||||
audio transcription so the proxy forwards transcript events.
|
||||
"""
|
||||
from litellm.types.llms.gemini import BidiGenerateContentSetup
|
||||
from litellm.types.llms.vertex_ai import GeminiResponseModalities
|
||||
|
||||
response_modalities: list[GeminiResponseModalities] = ["AUDIO"]
|
||||
full_model_path = (
|
||||
f"projects/{self._project}"
|
||||
f"/locations/{self._location}"
|
||||
f"/publishers/google/models/{model}"
|
||||
)
|
||||
setup_config: BidiGenerateContentSetup = {
|
||||
"model": full_model_path,
|
||||
"generationConfig": {"responseModalities": response_modalities},
|
||||
# Enable server-side VAD with sensible defaults for voice sessions.
|
||||
"realtimeInputConfig": {
|
||||
"automaticActivityDetection": {
|
||||
"disabled": False,
|
||||
"silenceDurationMs": 800,
|
||||
}
|
||||
},
|
||||
# Return input transcript so guardrails can inspect user speech.
|
||||
"inputAudioTranscription": {},
|
||||
# Return output transcript so clients can read what the model said.
|
||||
"outputAudioTranscription": {},
|
||||
}
|
||||
return json.dumps({"setup": setup_config})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request translation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def transform_realtime_request(
|
||||
self,
|
||||
message: str,
|
||||
model: str,
|
||||
session_configuration_request: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Translate OpenAI realtime client messages to Vertex AI format.
|
||||
|
||||
``session.update`` is intentionally ignored (returns []) because
|
||||
Vertex AI only accepts a single ``setup`` message at the start of
|
||||
the connection — sending a second one causes a 1007 close error.
|
||||
The initial setup (sent automatically before bidirectional_forward)
|
||||
already includes AUDIO modality and server VAD, so there is nothing
|
||||
more to configure.
|
||||
"""
|
||||
json_message = json.loads(message)
|
||||
if json_message.get("type") == "session.update":
|
||||
# Do not forward as a second setup — Vertex AI rejects it.
|
||||
return []
|
||||
|
||||
return super().transform_realtime_request(
|
||||
message, model, session_configuration_request
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Vertex AI Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Translates from Cohere's `/v1/rerank` input format to Vertex AI Discovery Engine's `/rank` input format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.rerank import (
|
||||
RerankResponse,
|
||||
RerankResponseMeta,
|
||||
RerankBilledUnits,
|
||||
RerankResponseResult,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIRerankConfig(BaseRerankConfig, VertexBase):
|
||||
"""
|
||||
Configuration for Vertex AI Discovery Engine Rerank API
|
||||
|
||||
Reference: https://cloud.google.com/generative-ai-app-builder/docs/ranking#rank_or_rerank_a_set_of_records_according_to_a_query
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: Optional[Dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Vertex AI Discovery Engine ranking API
|
||||
"""
|
||||
# Try to get project ID from optional_params first (e.g., vertex_project parameter)
|
||||
params = optional_params or {}
|
||||
|
||||
# Get credentials to extract project ID if needed
|
||||
vertex_credentials = self.safe_get_vertex_ai_credentials(params.copy())
|
||||
vertex_project = self.safe_get_vertex_ai_project(params.copy())
|
||||
|
||||
# Use _ensure_access_token to extract project_id from credentials
|
||||
# This is the same method used in vertex embeddings
|
||||
_, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
# Fallback to environment or litellm config
|
||||
project_id = (
|
||||
vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
or litellm.vertex_project
|
||||
)
|
||||
|
||||
if not project_id:
|
||||
raise ValueError(
|
||||
"Vertex AI project ID is required. Please set 'VERTEXAI_PROJECT', 'litellm.vertex_project', or pass 'vertex_project' parameter"
|
||||
)
|
||||
|
||||
return f"https://discoveryengine.googleapis.com/v1/projects/{project_id}/locations/global/rankingConfigs/default_ranking_config:rank"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
optional_params: Optional[Dict] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up authentication for Vertex AI Discovery Engine API
|
||||
"""
|
||||
# Get credentials and project info from optional_params (which contains vertex_credentials, etc.)
|
||||
litellm_params = optional_params.copy() if optional_params else {}
|
||||
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
|
||||
|
||||
# Get access token using the base class method
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"X-Goog-User-Project": project_id,
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: Dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request from Cohere format to Vertex AI Discovery Engine format
|
||||
"""
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Vertex AI rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Vertex AI rerank")
|
||||
|
||||
query = optional_rerank_params["query"]
|
||||
documents = optional_rerank_params["documents"]
|
||||
top_n = optional_rerank_params.get("top_n", None)
|
||||
return_documents = optional_rerank_params.get("return_documents", True)
|
||||
|
||||
# Convert documents to records format
|
||||
records = []
|
||||
for idx, document in enumerate(documents):
|
||||
if isinstance(document, str):
|
||||
content = document
|
||||
title = " ".join(document.split()[:3]) # First 3 words as title
|
||||
else:
|
||||
# Handle dict format
|
||||
content = document.get("text", str(document))
|
||||
title = document.get("title", " ".join(content.split()[:3]))
|
||||
|
||||
records.append({"id": str(idx), "title": title, "content": content})
|
||||
|
||||
request_data = {"model": model, "query": query, "records": records}
|
||||
|
||||
if top_n is not None:
|
||||
request_data["topN"] = top_n
|
||||
|
||||
# Map return_documents to ignoreRecordDetailsInResponse
|
||||
# When return_documents is False, we want to ignore record details (return only IDs)
|
||||
request_data["ignoreRecordDetailsInResponse"] = not return_documents
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Transform Vertex AI Discovery Engine response to Cohere format
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse response: {e}")
|
||||
|
||||
# Extract records from response
|
||||
records = raw_response_json.get("records", [])
|
||||
|
||||
# Convert to Cohere format
|
||||
results = []
|
||||
for record in records:
|
||||
# Handle both cases: with full details and with only IDs
|
||||
if "score" in record:
|
||||
# Full response with score and details
|
||||
results.append(
|
||||
{
|
||||
"index": int(record["id"]),
|
||||
"relevance_score": record.get("score", 0.0),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Response with only IDs (when ignoreRecordDetailsInResponse=true)
|
||||
# We can't provide a relevance score, so we'll use a default
|
||||
results.append(
|
||||
{
|
||||
"index": int(record["id"]),
|
||||
"relevance_score": 1.0, # Default score when details are ignored
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by relevance score (descending)
|
||||
results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||
|
||||
# Create response in Cohere format
|
||||
# Convert results to proper RerankResponseResult objects
|
||||
rerank_results = []
|
||||
for result in results:
|
||||
rerank_results.append(
|
||||
RerankResponseResult(
|
||||
index=result["index"], relevance_score=result["relevance_score"]
|
||||
)
|
||||
)
|
||||
|
||||
# Create meta object
|
||||
meta = RerankResponseMeta(
|
||||
billed_units=RerankBilledUnits(search_units=len(records))
|
||||
)
|
||||
|
||||
return RerankResponse(
|
||||
id=f"vertex_ai_rerank_{model}", results=rerank_results, meta=meta
|
||||
)
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map Cohere rerank params to Vertex AI format
|
||||
"""
|
||||
result = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": top_n,
|
||||
"return_documents": return_documents,
|
||||
}
|
||||
result.update(non_default_params)
|
||||
return result
|
||||
@@ -0,0 +1,244 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.openai.openai import HttpxBinaryResponseContent
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
|
||||
class VertexInput(TypedDict, total=False):
|
||||
text: Optional[str]
|
||||
ssml: Optional[str]
|
||||
|
||||
|
||||
class VertexVoice(TypedDict, total=False):
|
||||
languageCode: str
|
||||
name: str
|
||||
|
||||
|
||||
class VertexAudioConfig(TypedDict, total=False):
|
||||
audioEncoding: str
|
||||
speakingRate: str
|
||||
|
||||
|
||||
class VertexTextToSpeechRequest(TypedDict, total=False):
|
||||
input: VertexInput
|
||||
voice: VertexVoice
|
||||
audioConfig: Optional[VertexAudioConfig]
|
||||
|
||||
|
||||
class VertexTextToSpeechAPI(VertexLLM):
|
||||
"""
|
||||
Vertex methods to support for batches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def audio_speech(
|
||||
self,
|
||||
logging_obj,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[dict] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
optional_params: Optional[dict] = None,
|
||||
kwargs: Optional[dict] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
import base64
|
||||
|
||||
####### Authenticate with Vertex AI ########
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"x-goog-user-project": vertex_project,
|
||||
"Content-Type": "application/json",
|
||||
"charset": "UTF-8",
|
||||
}
|
||||
|
||||
######### End of Authentication ###########
|
||||
|
||||
####### Build the request ################
|
||||
# API Ref: https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize
|
||||
kwargs = kwargs or {}
|
||||
optional_params = optional_params or {}
|
||||
|
||||
vertex_input = VertexInput(text=input)
|
||||
validate_vertex_input(vertex_input, kwargs, optional_params)
|
||||
|
||||
# required param
|
||||
if voice is not None:
|
||||
vertex_voice = VertexVoice(**voice)
|
||||
elif "voice" in kwargs:
|
||||
vertex_voice = VertexVoice(**kwargs["voice"])
|
||||
else:
|
||||
# use defaults to not fail the request
|
||||
vertex_voice = VertexVoice(
|
||||
languageCode="en-US",
|
||||
name="en-US-Studio-O",
|
||||
)
|
||||
|
||||
if "audioConfig" in kwargs:
|
||||
vertex_audio_config = VertexAudioConfig(**kwargs["audioConfig"])
|
||||
else:
|
||||
# use defaults to not fail the request
|
||||
vertex_audio_config = VertexAudioConfig(
|
||||
audioEncoding="LINEAR16",
|
||||
speakingRate="1",
|
||||
)
|
||||
|
||||
request = VertexTextToSpeechRequest(
|
||||
input=vertex_input,
|
||||
voice=vertex_voice,
|
||||
audioConfig=vertex_audio_config,
|
||||
)
|
||||
|
||||
url = "https://texttospeech.googleapis.com/v1/text:synthesize"
|
||||
########## End of building request ############
|
||||
|
||||
########## Log the request for debugging / logging ############
|
||||
logging_obj.pre_call(
|
||||
input=[],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
########## End of logging ############
|
||||
####### Send the request ###################
|
||||
if _is_async is True:
|
||||
return self.async_audio_speech( # type:ignore
|
||||
logging_obj=logging_obj, url=url, headers=headers, request=request
|
||||
)
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=request, # type: ignore
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}, {response.text}"
|
||||
)
|
||||
############ Process the response ############
|
||||
_json_response = response.json()
|
||||
|
||||
response_content = _json_response["audioContent"]
|
||||
|
||||
# Decode base64 to get binary content
|
||||
binary_data = base64.b64decode(response_content)
|
||||
|
||||
# Create an httpx.Response object
|
||||
response = httpx.Response(
|
||||
status_code=200,
|
||||
content=binary_data,
|
||||
)
|
||||
|
||||
# Initialize the HttpxBinaryResponseContent instance
|
||||
http_binary_response = HttpxBinaryResponseContent(response)
|
||||
return http_binary_response
|
||||
|
||||
async def async_audio_speech(
|
||||
self,
|
||||
logging_obj,
|
||||
url: str,
|
||||
headers: dict,
|
||||
request: VertexTextToSpeechRequest,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
import base64
|
||||
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI
|
||||
)
|
||||
|
||||
response = await async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=request, # type: ignore
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request did not return a 200 status code: {response.status_code}, {response.text}"
|
||||
)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
response_content = _json_response["audioContent"]
|
||||
|
||||
# Decode base64 to get binary content
|
||||
binary_data = base64.b64decode(response_content)
|
||||
|
||||
# Create an httpx.Response object
|
||||
response = httpx.Response(
|
||||
status_code=200,
|
||||
content=binary_data,
|
||||
)
|
||||
|
||||
# Initialize the HttpxBinaryResponseContent instance
|
||||
http_binary_response = HttpxBinaryResponseContent(response)
|
||||
return http_binary_response
|
||||
|
||||
|
||||
def validate_vertex_input(
|
||||
input_data: VertexInput, kwargs: dict, optional_params: dict
|
||||
) -> None:
|
||||
# Remove None values
|
||||
if input_data.get("text") is None:
|
||||
input_data.pop("text", None)
|
||||
if input_data.get("ssml") is None:
|
||||
input_data.pop("ssml", None)
|
||||
|
||||
# Check if use_ssml is set
|
||||
use_ssml = kwargs.get("use_ssml", optional_params.get("use_ssml", False))
|
||||
|
||||
if use_ssml:
|
||||
if "text" in input_data:
|
||||
input_data["ssml"] = input_data.pop("text")
|
||||
elif "ssml" not in input_data:
|
||||
raise ValueError("SSML input is required when use_ssml is True.")
|
||||
else:
|
||||
# LiteLLM will auto-detect if text is in ssml format
|
||||
# check if "text" is an ssml - in this case we should pass it as ssml instead of text
|
||||
if input_data:
|
||||
_text = input_data.get("text", None) or ""
|
||||
if "<speak>" in _text:
|
||||
input_data["ssml"] = input_data.pop("text")
|
||||
|
||||
if not input_data:
|
||||
raise ValueError("Either 'text' or 'ssml' must be provided.")
|
||||
if "text" in input_data and "ssml" in input_data:
|
||||
raise ValueError("Only one of 'text' or 'ssml' should be provided, not both.")
|
||||
@@ -0,0 +1,479 @@
|
||||
"""
|
||||
Vertex AI Text-to-Speech transformation
|
||||
|
||||
Maps OpenAI TTS spec to Google Cloud Text-to-Speech API
|
||||
Reference: https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.text_to_speech.transformation import (
|
||||
BaseTextToSpeechConfig,
|
||||
TextToSpeechRequestData,
|
||||
)
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
from litellm.types.llms.vertex_ai_text_to_speech import (
|
||||
VertexTextToSpeechAudioConfig,
|
||||
VertexTextToSpeechInput,
|
||||
VertexTextToSpeechVoice,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HttpxBinaryResponseContent = Any
|
||||
|
||||
|
||||
class VertexAITextToSpeechConfig(BaseTextToSpeechConfig, VertexBase):
|
||||
"""
|
||||
Configuration for Google Cloud/Vertex AI Text-to-Speech
|
||||
|
||||
Reference: https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize
|
||||
"""
|
||||
|
||||
# Default values
|
||||
DEFAULT_LANGUAGE_CODE = "en-US"
|
||||
DEFAULT_VOICE_NAME = "en-US-Studio-O"
|
||||
DEFAULT_AUDIO_ENCODING = "LINEAR16"
|
||||
DEFAULT_SPEAKING_RATE = "1"
|
||||
|
||||
# API endpoint
|
||||
TTS_API_URL = "https://texttospeech.googleapis.com/v1/text:synthesize"
|
||||
|
||||
# Voice name mappings from OpenAI voices to Google Cloud voices
|
||||
# Users can pass either:
|
||||
# 1. OpenAI voice names (alloy, echo, fable, onyx, nova, shimmer) - will be mapped
|
||||
# 2. Google Cloud/Vertex AI voice names (en-US-Studio-O, en-US-Wavenet-D, etc.) - used directly
|
||||
VOICE_MAPPINGS = {
|
||||
"alloy": "en-US-Studio-O",
|
||||
"echo": "en-US-Studio-M",
|
||||
"fable": "en-GB-Studio-B",
|
||||
"onyx": "en-US-Wavenet-D",
|
||||
"nova": "en-US-Studio-O",
|
||||
"shimmer": "en-US-Wavenet-F",
|
||||
}
|
||||
|
||||
# Response format mappings from OpenAI to Google Cloud audio encoding
|
||||
FORMAT_MAPPINGS = {
|
||||
"mp3": "MP3",
|
||||
"opus": "OGG_OPUS",
|
||||
"aac": "MP3", # Google doesn't have AAC, use MP3
|
||||
"flac": "FLAC",
|
||||
"wav": "LINEAR16",
|
||||
"pcm": "LINEAR16",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseTextToSpeechConfig.__init__(self)
|
||||
VertexBase.__init__(self)
|
||||
|
||||
def _map_voice_to_vertex_format(
|
||||
self,
|
||||
voice: Optional[Union[str, Dict]],
|
||||
) -> Tuple[Optional[str], Optional[Dict]]:
|
||||
"""
|
||||
Map voice to Vertex AI format.
|
||||
|
||||
Supports both:
|
||||
1. OpenAI voice names (alloy, echo, fable, onyx, nova, shimmer) - will be mapped
|
||||
2. Vertex AI voice names (en-US-Studio-O, en-US-Wavenet-D, etc.) - used directly
|
||||
3. Dict with languageCode and name - used as-is
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_str, voice_dict) where:
|
||||
- voice_str: Original string voice (for interface compatibility)
|
||||
- voice_dict: Vertex AI format dict with languageCode and name
|
||||
"""
|
||||
if voice is None:
|
||||
return None, None
|
||||
|
||||
if isinstance(voice, dict):
|
||||
# Already in Vertex AI format
|
||||
return None, voice
|
||||
|
||||
# voice is a string
|
||||
voice_str = voice
|
||||
|
||||
# Map OpenAI voice if it's a known OpenAI voice, otherwise use directly
|
||||
if voice in self.VOICE_MAPPINGS:
|
||||
mapped_voice_name = self.VOICE_MAPPINGS[voice]
|
||||
else:
|
||||
# Assume it's already a Vertex AI voice name
|
||||
mapped_voice_name = voice
|
||||
|
||||
# Extract language code from voice name (e.g., "en-US-Studio-O" -> "en-US")
|
||||
parts = mapped_voice_name.split("-")
|
||||
if len(parts) >= 2:
|
||||
language_code = f"{parts[0]}-{parts[1]}"
|
||||
else:
|
||||
language_code = self.DEFAULT_LANGUAGE_CODE
|
||||
|
||||
voice_dict = {
|
||||
"languageCode": language_code,
|
||||
"name": mapped_voice_name,
|
||||
}
|
||||
|
||||
return voice_str, voice_dict
|
||||
|
||||
def dispatch_text_to_speech(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, Dict]],
|
||||
optional_params: Dict,
|
||||
litellm_params_dict: Dict,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
extra_headers: Optional[Dict[str, Any]],
|
||||
base_llm_http_handler: Any,
|
||||
aspeech: bool,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
"HttpxBinaryResponseContent",
|
||||
Coroutine[Any, Any, "HttpxBinaryResponseContent"],
|
||||
]:
|
||||
"""
|
||||
Dispatch method to handle Vertex AI TTS requests
|
||||
|
||||
This method encapsulates Vertex AI-specific credential resolution and parameter handling.
|
||||
Voice mapping is handled in map_openai_params (similar to Azure AVA pattern).
|
||||
|
||||
Args:
|
||||
base_llm_http_handler: The BaseLLMHTTPHandler instance from main.py
|
||||
"""
|
||||
# Resolve Vertex AI credentials using VertexBase helpers
|
||||
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params_dict)
|
||||
vertex_project = self.safe_get_vertex_ai_project(litellm_params_dict)
|
||||
vertex_location = self.safe_get_vertex_ai_location(litellm_params_dict)
|
||||
|
||||
# Convert voice to string if it's a dict (extract name)
|
||||
# Actual voice mapping happens in map_openai_params
|
||||
voice_str: Optional[str] = None
|
||||
if isinstance(voice, str):
|
||||
voice_str = voice
|
||||
elif isinstance(voice, dict):
|
||||
# Extract voice name from dict if needed
|
||||
voice_str = voice.get("name") if voice else None
|
||||
|
||||
# Store credentials in litellm_params for use in transform methods
|
||||
litellm_params_dict.update(
|
||||
{
|
||||
"vertex_credentials": vertex_credentials,
|
||||
"vertex_project": vertex_project,
|
||||
"vertex_location": vertex_location,
|
||||
"api_base": api_base,
|
||||
}
|
||||
)
|
||||
|
||||
# Call the text_to_speech_handler
|
||||
response = base_llm_http_handler.text_to_speech_handler(
|
||||
model=model,
|
||||
input=input,
|
||||
voice=voice_str,
|
||||
text_to_speech_provider_config=self,
|
||||
text_to_speech_optional_params=optional_params,
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params=litellm_params_dict,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=None,
|
||||
_is_async=aspeech,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Vertex AI TTS supports these OpenAI parameters
|
||||
|
||||
Note: Vertex AI also supports additional parameters like audioConfig
|
||||
which can be passed but are not part of the OpenAI spec
|
||||
"""
|
||||
return ["voice", "response_format", "speed"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
voice: Optional[Union[str, Dict]] = None,
|
||||
drop_params: bool = False,
|
||||
kwargs: Dict = {},
|
||||
) -> Tuple[Optional[str], Dict]:
|
||||
"""
|
||||
Map OpenAI parameters to Vertex AI TTS parameters
|
||||
|
||||
Voice handling (similar to Azure AVA):
|
||||
- If voice is an OpenAI voice name (alloy, echo, etc.), it maps to a Vertex AI voice
|
||||
- If voice is already a Vertex AI voice name (en-US-Studio-O, etc.), it's used directly
|
||||
- If voice is a dict with languageCode and name, it's used as-is
|
||||
|
||||
Note: For Vertex AI, voice dict is stored in mapped_params["vertex_voice_dict"]
|
||||
because the base class interface expects voice to be a string.
|
||||
|
||||
Returns:
|
||||
Tuple of (mapped_voice_str, mapped_params)
|
||||
"""
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
##########################################################
|
||||
# Map voice using helper
|
||||
##########################################################
|
||||
mapped_voice_str, voice_dict = self._map_voice_to_vertex_format(voice)
|
||||
if voice_dict is not None:
|
||||
mapped_params["vertex_voice_dict"] = voice_dict
|
||||
|
||||
# Map response format
|
||||
if "response_format" in optional_params:
|
||||
format_name = optional_params["response_format"]
|
||||
if format_name in self.FORMAT_MAPPINGS:
|
||||
mapped_params["audioEncoding"] = self.FORMAT_MAPPINGS[format_name]
|
||||
else:
|
||||
# Try to use it directly as Google Cloud format
|
||||
mapped_params["audioEncoding"] = format_name
|
||||
else:
|
||||
# Default to LINEAR16
|
||||
mapped_params["audioEncoding"] = self.DEFAULT_AUDIO_ENCODING
|
||||
|
||||
# Map speed (OpenAI: 0.25-4.0, Vertex AI: speakingRate 0.25-4.0)
|
||||
if "speed" in optional_params:
|
||||
speed = optional_params["speed"]
|
||||
if speed is not None:
|
||||
mapped_params["speakingRate"] = str(speed)
|
||||
|
||||
# Pass through Vertex AI-specific parameters from kwargs
|
||||
if "audioConfig" in kwargs:
|
||||
mapped_params["audioConfig"] = kwargs["audioConfig"]
|
||||
|
||||
if "use_ssml" in kwargs:
|
||||
mapped_params["use_ssml"] = kwargs["use_ssml"]
|
||||
|
||||
return mapped_voice_str, mapped_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate Vertex AI environment and set up authentication headers
|
||||
|
||||
Note: Actual authentication is handled in transform_text_to_speech_request
|
||||
because Vertex AI requires OAuth2 token refresh
|
||||
"""
|
||||
validated_headers = headers.copy()
|
||||
|
||||
# Content-Type for JSON
|
||||
validated_headers["Content-Type"] = "application/json"
|
||||
validated_headers["charset"] = "UTF-8"
|
||||
|
||||
return validated_headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI TTS request
|
||||
|
||||
Google Cloud TTS endpoint: https://texttospeech.googleapis.com/v1/text:synthesize
|
||||
"""
|
||||
if api_base:
|
||||
return api_base
|
||||
|
||||
return self.TTS_API_URL
|
||||
|
||||
def _validate_vertex_input(
|
||||
self,
|
||||
input_data: VertexTextToSpeechInput,
|
||||
optional_params: Dict,
|
||||
) -> VertexTextToSpeechInput:
|
||||
"""
|
||||
Validate and transform input for Vertex AI TTS
|
||||
|
||||
Handles text vs SSML input detection and validation
|
||||
"""
|
||||
# Remove None values
|
||||
if input_data.get("text") is None:
|
||||
input_data.pop("text", None)
|
||||
if input_data.get("ssml") is None:
|
||||
input_data.pop("ssml", None)
|
||||
|
||||
# Check if use_ssml is set
|
||||
use_ssml = optional_params.get("use_ssml", False)
|
||||
|
||||
if use_ssml:
|
||||
if "text" in input_data:
|
||||
input_data["ssml"] = input_data.pop("text")
|
||||
elif "ssml" not in input_data:
|
||||
raise ValueError("SSML input is required when use_ssml is True.")
|
||||
else:
|
||||
# LiteLLM will auto-detect if text is in ssml format
|
||||
# check if "text" is an ssml - in this case we should pass it as ssml instead of text
|
||||
if input_data:
|
||||
_text = input_data.get("text", None) or ""
|
||||
if "<speak>" in _text:
|
||||
input_data["ssml"] = input_data.pop("text")
|
||||
|
||||
if not input_data:
|
||||
raise ValueError("Either 'text' or 'ssml' must be provided.")
|
||||
if "text" in input_data and "ssml" in input_data:
|
||||
raise ValueError(
|
||||
"Only one of 'text' or 'ssml' should be provided, not both."
|
||||
)
|
||||
|
||||
return input_data
|
||||
|
||||
def transform_text_to_speech_request(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[str],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
headers: dict,
|
||||
) -> TextToSpeechRequestData:
|
||||
"""
|
||||
Transform OpenAI TTS request to Vertex AI TTS format
|
||||
|
||||
This method handles:
|
||||
1. Authentication with Vertex AI
|
||||
2. Building the request body
|
||||
3. Setting up headers
|
||||
|
||||
Returns:
|
||||
TextToSpeechRequestData: Contains dict_body and headers
|
||||
"""
|
||||
# Get Vertex AI credentials from litellm_params
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = litellm_params.get(
|
||||
"vertex_credentials"
|
||||
)
|
||||
vertex_project: Optional[str] = litellm_params.get("vertex_project")
|
||||
|
||||
####### Authenticate with Vertex AI ########
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=litellm_params.get("vertex_location"),
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base=litellm_params.get("api_base"),
|
||||
)
|
||||
|
||||
# Set authentication headers
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
headers["x-goog-user-project"] = vertex_project
|
||||
|
||||
####### Build the request ################
|
||||
vertex_input = VertexTextToSpeechInput(text=input)
|
||||
vertex_input = self._validate_vertex_input(vertex_input, optional_params)
|
||||
|
||||
# Build voice configuration
|
||||
# Check for voice dict stored in:
|
||||
# 1. litellm_params by dispatch method
|
||||
# 2. optional_params by map_openai_params
|
||||
voice_dict = litellm_params.get("vertex_voice_dict") or optional_params.get(
|
||||
"vertex_voice_dict"
|
||||
)
|
||||
if voice_dict is not None and isinstance(voice_dict, dict):
|
||||
vertex_voice = VertexTextToSpeechVoice(**voice_dict)
|
||||
elif voice is not None and isinstance(voice, str):
|
||||
# Handle string voice (shouldn't normally happen if dispatch was called)
|
||||
parts = voice.split("-")
|
||||
if len(parts) >= 2:
|
||||
language_code = f"{parts[0]}-{parts[1]}"
|
||||
else:
|
||||
language_code = self.DEFAULT_LANGUAGE_CODE
|
||||
vertex_voice = VertexTextToSpeechVoice(
|
||||
languageCode=language_code,
|
||||
name=voice,
|
||||
)
|
||||
else:
|
||||
# Use defaults
|
||||
vertex_voice = VertexTextToSpeechVoice(
|
||||
languageCode=self.DEFAULT_LANGUAGE_CODE,
|
||||
name=self.DEFAULT_VOICE_NAME,
|
||||
)
|
||||
|
||||
# Build audio configuration
|
||||
audio_encoding = optional_params.get(
|
||||
"audioEncoding", self.DEFAULT_AUDIO_ENCODING
|
||||
)
|
||||
speaking_rate = optional_params.get("speakingRate", self.DEFAULT_SPEAKING_RATE)
|
||||
|
||||
# Check for full audioConfig in optional_params
|
||||
if "audioConfig" in optional_params:
|
||||
vertex_audio_config = VertexTextToSpeechAudioConfig(
|
||||
**optional_params["audioConfig"]
|
||||
)
|
||||
else:
|
||||
vertex_audio_config = VertexTextToSpeechAudioConfig(
|
||||
audioEncoding=audio_encoding,
|
||||
speakingRate=speaking_rate,
|
||||
)
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"input": dict(vertex_input),
|
||||
"voice": dict(vertex_voice),
|
||||
"audioConfig": dict(vertex_audio_config),
|
||||
}
|
||||
|
||||
return TextToSpeechRequestData(
|
||||
dict_body=request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_text_to_speech_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
"""
|
||||
Transform Vertex AI TTS response to standard format
|
||||
|
||||
Vertex AI returns JSON with base64-encoded audio content.
|
||||
We decode it and return as HttpxBinaryResponseContent.
|
||||
"""
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
# Parse JSON response
|
||||
_json_response = raw_response.json()
|
||||
|
||||
# Get base64-encoded audio content
|
||||
response_content = _json_response.get("audioContent")
|
||||
if not response_content:
|
||||
raise ValueError("No audioContent in Vertex AI TTS response")
|
||||
|
||||
# Decode base64 to get binary content
|
||||
binary_data = base64.b64decode(response_content)
|
||||
|
||||
# Create an httpx.Response object with the binary data
|
||||
response = httpx.Response(
|
||||
status_code=200,
|
||||
content=binary_data,
|
||||
)
|
||||
|
||||
# Initialize the HttpxBinaryResponseContent instance
|
||||
return HttpxBinaryResponseContent(response)
|
||||
@@ -0,0 +1,4 @@
|
||||
from .rag_api.transformation import VertexVectorStoreConfig
|
||||
from .search_api.transformation import VertexSearchAPIVectorStoreConfig
|
||||
|
||||
__all__ = ["VertexVectorStoreConfig", "VertexSearchAPIVectorStoreConfig"]
|
||||
@@ -0,0 +1,306 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreIndexEndpoints,
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
|
||||
"""
|
||||
Configuration for Vertex AI Vector Store RAG API
|
||||
|
||||
This implementation uses the Vertex AI RAG Engine API for vector store operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
# Get credentials and project info
|
||||
vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
|
||||
vertex_project = self.get_vertex_ai_project(dict(litellm_params))
|
||||
|
||||
# Get access token using the base class method
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
return {
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
return {
|
||||
"read": [("POST", ":retrieveContexts")],
|
||||
"write": [("POST", "/ragCorpora")],
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up authentication for Vertex AI RAG API
|
||||
"""
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
|
||||
auth_headers = self.get_auth_credentials(litellm_params.model_dump())
|
||||
headers.update(auth_headers.get("headers", {}))
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the Base endpoint for Vertex AI RAG API
|
||||
"""
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# Vertex AI RAG API endpoint for retrieveContexts
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}"
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform search request for Vertex AI RAG API
|
||||
"""
|
||||
# Convert query to string if it's a list
|
||||
if isinstance(query, list):
|
||||
query = " ".join(query)
|
||||
|
||||
# Vertex AI RAG API endpoint for retrieving contexts
|
||||
url = f"{api_base}:retrieveContexts"
|
||||
|
||||
# Use helper methods to get project and location, then construct full rag corpus path
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||
|
||||
# Handle both full corpus path and just corpus ID
|
||||
if vector_store_id.startswith("projects/"):
|
||||
# Already a full path
|
||||
full_rag_corpus = vector_store_id
|
||||
else:
|
||||
# Just the corpus ID, construct full path
|
||||
full_rag_corpus = f"projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{vector_store_id}"
|
||||
|
||||
# Build the request body for Vertex AI RAG API
|
||||
request_body: Dict[str, Any] = {
|
||||
"vertex_rag_store": {"rag_resources": [{"rag_corpus": full_rag_corpus}]},
|
||||
"query": {"text": query},
|
||||
}
|
||||
|
||||
#########################################################
|
||||
# Update logging object with details of the request
|
||||
#########################################################
|
||||
litellm_logging_obj.model_call_details["query"] = query
|
||||
|
||||
# Add optional parameters
|
||||
max_num_results = vector_store_search_optional_params.get("max_num_results")
|
||||
if max_num_results is not None:
|
||||
request_body["query"]["rag_retrieval_config"] = {"top_k": max_num_results}
|
||||
|
||||
# Add filters if provided
|
||||
filters = vector_store_search_optional_params.get("filters")
|
||||
if filters is not None:
|
||||
if "rag_retrieval_config" not in request_body["query"]:
|
||||
request_body["query"]["rag_retrieval_config"] = {}
|
||||
request_body["query"]["rag_retrieval_config"]["filter"] = filters
|
||||
|
||||
# Add ranking options if provided
|
||||
ranking_options = vector_store_search_optional_params.get("ranking_options")
|
||||
if ranking_options is not None:
|
||||
if "rag_retrieval_config" not in request_body["query"]:
|
||||
request_body["query"]["rag_retrieval_config"] = {}
|
||||
request_body["query"]["rag_retrieval_config"]["ranking"] = ranking_options
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""
|
||||
Transform Vertex AI RAG API response to standard vector store search response
|
||||
"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
# Extract contexts from Vertex AI response - handle nested structure
|
||||
contexts = response_json.get("contexts", {}).get("contexts", [])
|
||||
|
||||
# Transform contexts to standard format
|
||||
search_results = []
|
||||
for context in contexts:
|
||||
content = [
|
||||
VectorStoreResultContent(
|
||||
text=context.get("text", ""),
|
||||
type="text",
|
||||
)
|
||||
]
|
||||
|
||||
# Extract file information
|
||||
source_uri = context.get("sourceUri", "")
|
||||
source_display_name = context.get("sourceDisplayName", "")
|
||||
|
||||
# Generate file_id from source URI or use display name as fallback
|
||||
file_id = source_uri if source_uri else source_display_name
|
||||
filename = (
|
||||
source_display_name if source_display_name else "Unknown Document"
|
||||
)
|
||||
|
||||
# Build attributes with available metadata
|
||||
attributes = {}
|
||||
if source_uri:
|
||||
attributes["sourceUri"] = source_uri
|
||||
if source_display_name:
|
||||
attributes["sourceDisplayName"] = source_display_name
|
||||
|
||||
# Add page span information if available
|
||||
page_span = context.get("pageSpan", {})
|
||||
if page_span:
|
||||
attributes["pageSpan"] = page_span
|
||||
|
||||
result = VectorStoreSearchResult(
|
||||
score=context.get("score", 0.0),
|
||||
content=content,
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
attributes=attributes,
|
||||
)
|
||||
search_results.append(result)
|
||||
|
||||
return VectorStoreSearchResponse(
|
||||
object="vector_store.search_results.page",
|
||||
search_query=litellm_logging_obj.model_call_details.get("query", ""),
|
||||
data=search_results,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform create request for Vertex AI RAG Corpus
|
||||
"""
|
||||
url = f"{api_base}/ragCorpora" # Base URL for creating RAG corpus
|
||||
|
||||
# Build the request body for Vertex AI RAG Corpus creation
|
||||
request_body: Dict[str, Any] = {
|
||||
"display_name": vector_store_create_optional_params.get(
|
||||
"name", "litellm-vector-store"
|
||||
),
|
||||
"description": "Vector store created via LiteLLM",
|
||||
}
|
||||
|
||||
# Add metadata if provided
|
||||
metadata = vector_store_create_optional_params.get("metadata")
|
||||
if metadata is not None:
|
||||
request_body["labels"] = metadata
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
"""
|
||||
Transform Vertex AI RAG Corpus creation response to standard vector store response
|
||||
"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
|
||||
# Extract the corpus ID from the response name
|
||||
corpus_name = response_json.get("name", "")
|
||||
corpus_id = (
|
||||
corpus_name.split("/")[-1] if "/" in corpus_name else corpus_name
|
||||
)
|
||||
|
||||
# Handle createTime conversion
|
||||
create_time = response_json.get("createTime", 0)
|
||||
if isinstance(create_time, str):
|
||||
# Convert ISO timestamp to Unix timestamp
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
dt = datetime.fromisoformat(create_time.replace("Z", "+00:00"))
|
||||
create_time = int(dt.timestamp())
|
||||
except ValueError:
|
||||
create_time = 0
|
||||
elif not isinstance(create_time, int):
|
||||
create_time = 0
|
||||
|
||||
# Handle labels safely
|
||||
labels = response_json.get("labels", {})
|
||||
metadata = labels if isinstance(labels, dict) else {}
|
||||
|
||||
return VectorStoreCreateResponse(
|
||||
id=corpus_id,
|
||||
object="vector_store",
|
||||
created_at=create_time,
|
||||
name=response_json.get("display_name", ""),
|
||||
bytes=0, # Vertex AI doesn't provide byte count in the same way
|
||||
file_counts={
|
||||
"in_progress": 0,
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"cancelled": 0,
|
||||
"total": 0,
|
||||
},
|
||||
status="completed", # Vertex AI corpus creation is typically synchronous
|
||||
expires_after=None,
|
||||
expires_at=None,
|
||||
last_active_at=None,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
@@ -0,0 +1,267 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import get_model_info
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreIndexEndpoints,
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexSearchAPIVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
|
||||
"""
|
||||
Configuration for Vertex AI Search API Vector Store
|
||||
|
||||
This implementation uses the Vertex AI Search API for vector store operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
# Get credentials and project info
|
||||
vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
|
||||
vertex_project = self.get_vertex_ai_project(dict(litellm_params))
|
||||
|
||||
# Get access token using the base class method
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
return {
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
return {
|
||||
"read": [("POST", ":search")],
|
||||
"write": [],
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up authentication for Vertex AI RAG API
|
||||
"""
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
auth_headers = self.get_auth_credentials(litellm_params.model_dump())
|
||||
headers.update(auth_headers.get("headers", {}))
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the Base endpoint for Vertex AI Search API
|
||||
"""
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||
collection_id = (
|
||||
litellm_params.get("vertex_collection_id") or "default_collection"
|
||||
)
|
||||
datastore_id = litellm_params.get("vector_store_id")
|
||||
if not datastore_id:
|
||||
raise ValueError("vector_store_id is required")
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# Vertex AI Search API endpoint for search
|
||||
return (
|
||||
f"https://discoveryengine.googleapis.com/v1/"
|
||||
f"projects/{vertex_project}/locations/{vertex_location}/"
|
||||
f"collections/{collection_id}/dataStores/{datastore_id}/servingConfigs/default_config"
|
||||
)
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform search request for Vertex AI RAG API
|
||||
"""
|
||||
# Convert query to string if it's a list
|
||||
if isinstance(query, list):
|
||||
query = " ".join(query)
|
||||
|
||||
# Vertex AI RAG API endpoint for retrieving contexts
|
||||
url = f"{api_base}:search"
|
||||
|
||||
# Construct full rag corpus path
|
||||
# Build the request body for Vertex AI Search API
|
||||
request_body = {"query": query, "pageSize": 10}
|
||||
|
||||
#########################################################
|
||||
# Update logging object with details of the request
|
||||
#########################################################
|
||||
litellm_logging_obj.model_call_details["query"] = query
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""
|
||||
Transform Vertex AI Search API response to standard vector store search response
|
||||
|
||||
Handles the format from Discovery Engine Search API which returns:
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"id": "...",
|
||||
"document": {
|
||||
"derivedStructData": {
|
||||
"title": "...",
|
||||
"link": "...",
|
||||
"snippets": [...]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
|
||||
# Extract results from Vertex AI Search API response
|
||||
results = response_json.get("results", [])
|
||||
|
||||
# Transform results to standard format
|
||||
search_results: List[VectorStoreSearchResult] = []
|
||||
for result in results:
|
||||
document = result.get("document", {})
|
||||
derived_data = document.get("derivedStructData", {})
|
||||
|
||||
# Extract text content from snippets
|
||||
snippets = derived_data.get("snippets", [])
|
||||
text_content = ""
|
||||
|
||||
if snippets:
|
||||
# Combine all snippets into one text
|
||||
text_parts = [
|
||||
snippet.get("snippet", snippet.get("htmlSnippet", ""))
|
||||
for snippet in snippets
|
||||
]
|
||||
text_content = " ".join(text_parts)
|
||||
|
||||
# If no snippets, use title as fallback
|
||||
if not text_content:
|
||||
text_content = derived_data.get("title", "")
|
||||
|
||||
content = [
|
||||
VectorStoreResultContent(
|
||||
text=text_content,
|
||||
type="text",
|
||||
)
|
||||
]
|
||||
|
||||
# Extract file/document information
|
||||
document_link = derived_data.get("link", "")
|
||||
document_title = derived_data.get("title", "")
|
||||
document_id = result.get("id", "")
|
||||
|
||||
# Use link as file_id if available, otherwise use document ID
|
||||
file_id = document_link if document_link else document_id
|
||||
filename = document_title if document_title else "Unknown Document"
|
||||
|
||||
# Build attributes with available metadata
|
||||
attributes = {
|
||||
"document_id": document_id,
|
||||
}
|
||||
|
||||
if document_link:
|
||||
attributes["link"] = document_link
|
||||
if document_title:
|
||||
attributes["title"] = document_title
|
||||
|
||||
# Add display link if available
|
||||
display_link = derived_data.get("displayLink", "")
|
||||
if display_link:
|
||||
attributes["displayLink"] = display_link
|
||||
|
||||
# Add formatted URL if available
|
||||
formatted_url = derived_data.get("formattedUrl", "")
|
||||
if formatted_url:
|
||||
attributes["formattedUrl"] = formatted_url
|
||||
|
||||
# Note: Search API doesn't provide explicit scores in the response
|
||||
# You can use the position/rank as an implicit score
|
||||
score = 1.0 / (
|
||||
float(search_results.__len__() + 1)
|
||||
) # Decreasing score based on position
|
||||
|
||||
result_obj = VectorStoreSearchResult(
|
||||
score=score,
|
||||
content=content,
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
attributes=attributes,
|
||||
)
|
||||
search_results.append(result_obj)
|
||||
|
||||
return VectorStoreSearchResponse(
|
||||
object="vector_store.search_results.page",
|
||||
search_query=litellm_logging_obj.model_call_details.get("query", ""),
|
||||
data=search_results,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_vector_store_cost(
|
||||
self,
|
||||
response: VectorStoreSearchResponse,
|
||||
) -> Tuple[float, float]:
|
||||
model_info = get_model_info(
|
||||
model="vertex_ai/search_api",
|
||||
)
|
||||
|
||||
input_cost_per_query = model_info.get("input_cost_per_query") or 0.0
|
||||
return input_cost_per_query, 0.0
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
AWS Workload Identity Federation (WIF) auth for Vertex AI.
|
||||
|
||||
Handles explicit AWS credentials for GCP WIF token exchange,
|
||||
bypassing the EC2 instance metadata service.
|
||||
|
||||
When aws_* keys are present in the WIF credential JSON, this module
|
||||
uses BaseAWSLLM to obtain AWS credentials and wraps them in a custom
|
||||
AwsSecurityCredentialsSupplier for google-auth.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
GOOGLE_IMPORT_ERROR_MESSAGE = (
|
||||
"Google Cloud SDK not found. Install it with: pip install 'litellm[google]' "
|
||||
"or pip install google-cloud-aiplatform"
|
||||
)
|
||||
|
||||
# AWS params recognized in WIF credential JSON for explicit auth.
|
||||
# These match the kwargs accepted by BaseAWSLLM.get_credentials().
|
||||
_AWS_CREDENTIAL_KEYS = frozenset(
|
||||
{
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
"aws_region_name",
|
||||
"aws_session_name",
|
||||
"aws_profile_name",
|
||||
"aws_role_name",
|
||||
"aws_web_identity_token",
|
||||
"aws_sts_endpoint",
|
||||
"aws_external_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class VertexAIAwsWifAuth:
|
||||
"""
|
||||
Handles AWS-to-GCP Workload Identity Federation credential creation
|
||||
for Vertex AI, using explicit AWS credentials rather than EC2 metadata.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def extract_aws_params(json_obj: dict) -> Dict[str, str]:
|
||||
"""
|
||||
Extract LiteLLM-specific aws_* keys from a WIF credential JSON dict.
|
||||
|
||||
Returns a dict of {param_name: value} for any recognized aws_* keys
|
||||
found in the JSON. Returns empty dict if none are present.
|
||||
"""
|
||||
return {key: json_obj[key] for key in _AWS_CREDENTIAL_KEYS if key in json_obj}
|
||||
|
||||
@staticmethod
|
||||
def credentials_from_explicit_aws(json_obj, aws_params, scopes):
|
||||
"""
|
||||
Create GCP credentials using explicit AWS credentials for WIF.
|
||||
|
||||
Uses BaseAWSLLM to obtain AWS credentials (via STS AssumeRole, profile,
|
||||
static keys, etc.), then wraps them in a custom AwsSecurityCredentialsSupplier
|
||||
so that google-auth bypasses the EC2 metadata service.
|
||||
|
||||
Args:
|
||||
json_obj: The WIF credential JSON dict (contains audience, token_url, etc.)
|
||||
aws_params: Dict of aws_* params extracted from json_obj
|
||||
scopes: OAuth scopes for the GCP credentials
|
||||
"""
|
||||
try:
|
||||
from google.auth import aws
|
||||
except ImportError:
|
||||
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.vertex_ai.aws_credentials_supplier import (
|
||||
AwsCredentialsSupplier,
|
||||
)
|
||||
|
||||
# Validate region first — required for the GCP token exchange.
|
||||
# Check before get_credentials() to avoid unnecessary AWS API calls
|
||||
# (e.g. STS AssumeRole) on misconfiguration.
|
||||
aws_region = aws_params.get("aws_region_name")
|
||||
if not aws_region:
|
||||
raise ValueError(
|
||||
"aws_region_name is required in the WIF credential JSON "
|
||||
"when using explicit AWS authentication. Add "
|
||||
'"aws_region_name": "<your-region>" to your credential file.'
|
||||
)
|
||||
|
||||
# Build a credentials provider that re-resolves AWS creds on each call.
|
||||
# This ensures rotated/refreshed STS tokens are picked up during
|
||||
# long-running processes when google-auth refreshes the GCP token.
|
||||
base_aws = BaseAWSLLM()
|
||||
aws_params_copy = dict(aws_params) # avoid mutating caller's dict
|
||||
|
||||
def _get_aws_credentials():
|
||||
return base_aws.get_credentials(**aws_params_copy)
|
||||
|
||||
# Create the custom supplier with a lazy credentials provider
|
||||
supplier = AwsCredentialsSupplier(
|
||||
credentials_provider=_get_aws_credentials,
|
||||
aws_region=aws_region,
|
||||
)
|
||||
|
||||
# Build kwargs for aws.Credentials — forward optional fields from JSON
|
||||
creds_kwargs = dict(
|
||||
audience=json_obj.get("audience"),
|
||||
subject_token_type=json_obj.get("subject_token_type"),
|
||||
token_url=json_obj.get("token_url"),
|
||||
credential_source=None, # Not using metadata endpoints
|
||||
aws_security_credentials_supplier=supplier,
|
||||
service_account_impersonation_url=json_obj.get(
|
||||
"service_account_impersonation_url"
|
||||
),
|
||||
)
|
||||
# Forward universe_domain if present (defaults to googleapis.com)
|
||||
if "universe_domain" in json_obj:
|
||||
creds_kwargs["universe_domain"] = json_obj["universe_domain"]
|
||||
|
||||
creds = aws.Credentials(**creds_kwargs)
|
||||
|
||||
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
|
||||
creds = creds.with_scopes(scopes)
|
||||
|
||||
return creds
|
||||
@@ -0,0 +1,788 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.bedrock.common_utils import ModelResponseIterator
|
||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class TextStreamer:
|
||||
"""
|
||||
Fake streaming iterator for Vertex AI Model Garden calls
|
||||
"""
|
||||
|
||||
def __init__(self, text):
|
||||
self.text = text.split() # let's assume words as a streaming unit
|
||||
self.index = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index < len(self.text):
|
||||
result = self.text[self.index]
|
||||
self.index += 1
|
||||
return result
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.index < len(self.text):
|
||||
result = self.text[self.index]
|
||||
self.index += 1
|
||||
return result
|
||||
else:
|
||||
raise StopAsyncIteration # once we run out of data to stream, we raise this error
|
||||
|
||||
|
||||
def _get_client_cache_key(
|
||||
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
|
||||
):
|
||||
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
||||
return _cache_key
|
||||
|
||||
|
||||
def _get_client_from_cache(client_cache_key: str):
|
||||
return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
|
||||
|
||||
|
||||
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
|
||||
litellm.in_memory_llm_clients_cache.set_cache(
|
||||
key=client_cache_key,
|
||||
value=vertex_llm_model,
|
||||
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||
)
|
||||
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
):
|
||||
"""
|
||||
NON-GEMINI/ANTHROPIC CALLS.
|
||||
|
||||
This is the handler for OLDER PALM MODELS and VERTEX AI MODEL GARDEN
|
||||
|
||||
For Vertex AI Anthropic: `vertex_anthropic.py`
|
||||
For Gemini: `vertex_httpx.py`
|
||||
"""
|
||||
try:
|
||||
import vertexai
|
||||
except Exception:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
import google.auth # type: ignore
|
||||
from google.cloud import aiplatform # type: ignore
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
content as gapic_content_types, # type: ignore
|
||||
)
|
||||
from google.protobuf import json_format # type: ignore
|
||||
from google.protobuf.struct_pb2 import Value # type: ignore
|
||||
from vertexai.language_models import CodeGenerationModel, TextGenerationModel
|
||||
from vertexai.preview.generative_models import GenerativeModel
|
||||
from vertexai.preview.language_models import ChatModel, CodeChatModel
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
)
|
||||
|
||||
_cache_key = _get_client_cache_key(
|
||||
model=model, vertex_project=vertex_project, vertex_location=vertex_location
|
||||
)
|
||||
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
|
||||
|
||||
# Load credentials - needed for both vertexai.init() and PredictionServiceClient
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
|
||||
if _vertex_llm_model_object is None:
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
vertexai.init(
|
||||
project=vertex_project,
|
||||
location=vertex_location,
|
||||
credentials=cast(Credentials, creds),
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.VertexAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
## Process safety settings into format expected by vertex AI
|
||||
safety_settings = None
|
||||
if "safety_settings" in optional_params:
|
||||
safety_settings = optional_params.pop("safety_settings")
|
||||
if not isinstance(safety_settings, list):
|
||||
raise ValueError("safety_settings must be a list")
|
||||
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
|
||||
raise ValueError("safety_settings must be a list of dicts")
|
||||
safety_settings = [
|
||||
gapic_content_types.SafetySetting(x) for x in safety_settings
|
||||
]
|
||||
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join(
|
||||
[
|
||||
message.get("content")
|
||||
for message in messages
|
||||
if isinstance(message.get("content", None), str)
|
||||
]
|
||||
)
|
||||
|
||||
mode = ""
|
||||
|
||||
request_str = ""
|
||||
response_obj = None
|
||||
instances = None
|
||||
client_options = {
|
||||
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
|
||||
}
|
||||
fake_stream = False
|
||||
if (
|
||||
model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_vision_models
|
||||
):
|
||||
llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
|
||||
mode = "vision"
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
elif model in litellm.vertex_chat_models:
|
||||
llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_text_models:
|
||||
llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
|
||||
model
|
||||
)
|
||||
mode = "text"
|
||||
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_code_text_models:
|
||||
llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
|
||||
model
|
||||
)
|
||||
mode = "text"
|
||||
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||
fake_stream = True
|
||||
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
|
||||
llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||
elif model == "private":
|
||||
mode = "private"
|
||||
model = optional_params.pop("model_id", None)
|
||||
# private endpoint requires a dict instead of JSON
|
||||
instances = [optional_params.copy()]
|
||||
instances[0]["prompt"] = prompt
|
||||
llm_model = aiplatform.PrivateEndpoint(
|
||||
endpoint_name=model,
|
||||
project=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
|
||||
else: # assume vertex model garden on public endpoint
|
||||
mode = "custom"
|
||||
|
||||
instances = [optional_params.copy()]
|
||||
instances[0]["prompt"] = prompt
|
||||
instances = [
|
||||
json_format.ParseDict(instance_dict, Value()) # type: ignore[misc]
|
||||
for instance_dict in instances
|
||||
]
|
||||
# Will determine the API used based on async parameter
|
||||
llm_model = None
|
||||
|
||||
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||
if acompletion is True:
|
||||
data = {
|
||||
"llm_model": llm_model,
|
||||
"mode": mode,
|
||||
"prompt": prompt,
|
||||
"logging_obj": logging_obj,
|
||||
"request_str": request_str,
|
||||
"model": model,
|
||||
"model_response": model_response,
|
||||
"encoding": encoding,
|
||||
"messages": messages,
|
||||
"print_verbose": print_verbose,
|
||||
"client_options": client_options,
|
||||
"instances": instances,
|
||||
"vertex_location": vertex_location,
|
||||
"vertex_project": vertex_project,
|
||||
"vertex_credentials": creds,
|
||||
"safety_settings": safety_settings,
|
||||
**optional_params,
|
||||
}
|
||||
if optional_params.get("stream", False) is True:
|
||||
# async streaming
|
||||
return async_streaming(**data)
|
||||
|
||||
return async_completion(**data)
|
||||
|
||||
completion_response = None
|
||||
|
||||
stream = optional_params.pop(
|
||||
"stream", None
|
||||
) # See note above on handling streaming for vertex ai
|
||||
if mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
request_str += "chat = llm_model.start_chat()\n"
|
||||
|
||||
if fake_stream is not True and stream is True:
|
||||
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
||||
# we handle this by removing 'stream' from optional params and sending the request
|
||||
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # vertex ai raises an error when passing stream in optional params
|
||||
|
||||
request_str += (
|
||||
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
model_response = chat.send_message_streaming(prompt, **optional_params)
|
||||
|
||||
return model_response
|
||||
|
||||
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
completion_response = chat.send_message(prompt, **optional_params).text
|
||||
elif mode == "text":
|
||||
if fake_stream is not True and stream is True:
|
||||
request_str += (
|
||||
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
model_response = llm_model.predict_streaming(prompt, **optional_params)
|
||||
|
||||
return model_response
|
||||
|
||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
completion_response = llm_model.predict(prompt, **optional_params).text
|
||||
elif mode == "custom":
|
||||
"""
|
||||
Vertex AI Model Garden
|
||||
"""
|
||||
|
||||
if vertex_project is None or vertex_location is None:
|
||||
raise ValueError(
|
||||
"Vertex project and location are required for custom endpoint"
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
llm_model = aiplatform.gapic.PredictionServiceClient(
|
||||
client_options=client_options,
|
||||
credentials=creds, # type: ignore[arg-type]
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options}, credentials=...)\n"
|
||||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += (
|
||||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response = llm_model.predict(
|
||||
endpoint=endpoint_path, instances=instances
|
||||
).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if stream is True:
|
||||
response = TextStreamer(completion_response)
|
||||
return response
|
||||
elif mode == "private":
|
||||
"""
|
||||
Vertex AI Model Garden deployed on private endpoint
|
||||
"""
|
||||
if instances is None:
|
||||
raise ValueError("instances are required for private endpoint")
|
||||
if llm_model is None:
|
||||
raise ValueError("Unable to pick client for private endpoint")
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
request_str += f"llm_model.predict(instances={instances})\n"
|
||||
response = llm_model.predict(instances=instances).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if stream is True:
|
||||
response = TextStreamer(completion_response)
|
||||
return response
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt, api_key=None, original_response=completion_response
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
if isinstance(completion_response, litellm.Message):
|
||||
model_response.choices[0].message = completion_response # type: ignore
|
||||
elif len(str(completion_response)) > 0:
|
||||
model_response.choices[0].message.content = str(completion_response) # type: ignore
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
## CALCULATING USAGE
|
||||
if model in litellm.vertex_language_models and response_obj is not None:
|
||||
model_response.choices[0].finish_reason = map_finish_reason( # type: ignore[assignment]
|
||||
response_obj.candidates[0].finish_reason.name
|
||||
)
|
||||
usage = Usage(
|
||||
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count,
|
||||
)
|
||||
else:
|
||||
# init prompt tokens
|
||||
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
||||
prompt_tokens, completion_tokens, _ = 0, 0, 0
|
||||
if response_obj is not None:
|
||||
if hasattr(response_obj, "usage_metadata") and hasattr(
|
||||
response_obj.usage_metadata, "prompt_token_count"
|
||||
):
|
||||
prompt_tokens = response_obj.usage_metadata.prompt_token_count
|
||||
completion_tokens = (
|
||||
response_obj.usage_metadata.candidates_token_count
|
||||
)
|
||||
else:
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
)
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
if fake_stream is True and stream is True:
|
||||
return ModelResponseIterator(model_response)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
if isinstance(e, VertexAIError):
|
||||
raise e
|
||||
raise litellm.APIConnectionError(
|
||||
message=str(e), llm_provider="vertex_ai", model=model
|
||||
)
|
||||
|
||||
|
||||
async def async_completion( # noqa: PLR0915
|
||||
llm_model,
|
||||
mode: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
request_str: str,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
encoding,
|
||||
client_options=None,
|
||||
instances=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
safety_settings=None,
|
||||
**optional_params,
|
||||
):
|
||||
"""
|
||||
Add support for acompletion calls for gemini-pro
|
||||
"""
|
||||
try:
|
||||
response_obj = None
|
||||
completion_response = None
|
||||
if mode == "chat":
|
||||
# chat-bison etc.
|
||||
chat = llm_model.start_chat()
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = await chat.send_message_async(prompt, **optional_params)
|
||||
completion_response = response_obj.text
|
||||
elif mode == "text":
|
||||
# gecko etc.
|
||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
||||
completion_response = response_obj.text
|
||||
elif mode == "custom":
|
||||
"""
|
||||
Vertex AI Model Garden
|
||||
"""
|
||||
from google.cloud import aiplatform # type: ignore
|
||||
|
||||
if vertex_project is None or vertex_location is None:
|
||||
raise ValueError(
|
||||
"Vertex project and location are required for custom endpoint"
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options,
|
||||
credentials=vertex_credentials,
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options}, credentials=...)\n"
|
||||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += (
|
||||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response_obj = await llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances,
|
||||
)
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
|
||||
elif mode == "private":
|
||||
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||
response_obj = await llm_model.predict_async(
|
||||
instances=instances,
|
||||
)
|
||||
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt, api_key=None, original_response=completion_response
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
if isinstance(completion_response, litellm.Message):
|
||||
model_response.choices[0].message = completion_response # type: ignore
|
||||
elif len(str(completion_response)) > 0:
|
||||
model_response.choices[0].message.content = str( # type: ignore
|
||||
completion_response
|
||||
)
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
## CALCULATING USAGE
|
||||
if model in litellm.vertex_language_models and response_obj is not None:
|
||||
model_response.choices[0].finish_reason = map_finish_reason( # type: ignore[assignment]
|
||||
response_obj.candidates[0].finish_reason.name
|
||||
)
|
||||
usage = Usage(
|
||||
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count,
|
||||
)
|
||||
else:
|
||||
# init prompt tokens
|
||||
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
||||
prompt_tokens, completion_tokens, _ = 0, 0, 0
|
||||
if response_obj is not None and (
|
||||
hasattr(response_obj, "usage_metadata")
|
||||
and hasattr(response_obj.usage_metadata, "prompt_token_count")
|
||||
):
|
||||
prompt_tokens = response_obj.usage_metadata.prompt_token_count
|
||||
completion_tokens = response_obj.usage_metadata.candidates_token_count
|
||||
else:
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
)
|
||||
|
||||
# set usage
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
||||
async def async_streaming( # noqa: PLR0915
|
||||
llm_model,
|
||||
mode: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
messages: list,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
request_str: str,
|
||||
encoding=None,
|
||||
client_options=None,
|
||||
instances=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
safety_settings=None,
|
||||
**optional_params,
|
||||
):
|
||||
"""
|
||||
Add support for async streaming calls for gemini-pro
|
||||
"""
|
||||
response: Any = None
|
||||
if mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += (
|
||||
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = chat.send_message_streaming_async(prompt, **optional_params)
|
||||
|
||||
elif mode == "text":
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # See note above on handling streaming for vertex ai
|
||||
request_str += (
|
||||
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||
elif mode == "custom":
|
||||
from google.cloud import aiplatform # type: ignore
|
||||
|
||||
if vertex_project is None or vertex_location is None:
|
||||
raise ValueError(
|
||||
"Vertex project and location are required for custom endpoint"
|
||||
)
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options,
|
||||
credentials=vertex_credentials,
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options}, credentials=...)\n"
|
||||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += (
|
||||
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response_obj = await llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances,
|
||||
)
|
||||
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if stream:
|
||||
response = TextStreamer(completion_response)
|
||||
|
||||
elif mode == "private":
|
||||
if instances is None:
|
||||
raise ValueError("Instances are required for private endpoint")
|
||||
stream = optional_params.pop("stream", None)
|
||||
_ = instances[0].pop("stream", None)
|
||||
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||
response_obj = await llm_model.predict_async(
|
||||
instances=instances,
|
||||
)
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if stream:
|
||||
response = TextStreamer(completion_response)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("Unable to generate response")
|
||||
|
||||
logging_obj.post_call(input=prompt, api_key=None, original_response=response)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streamwrapper
|
||||
@@ -0,0 +1,24 @@
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
|
||||
|
||||
def get_vertex_ai_partner_model_config(
|
||||
model: str, vertex_publisher_or_api_spec: str
|
||||
) -> BaseConfig:
|
||||
"""Return config for handling response transformation for vertex ai partner models"""
|
||||
if vertex_publisher_or_api_spec == "anthropic":
|
||||
from .anthropic.transformation import VertexAIAnthropicConfig
|
||||
|
||||
return VertexAIAnthropicConfig()
|
||||
elif vertex_publisher_or_api_spec == "ai21":
|
||||
from .ai21.transformation import VertexAIAi21Config
|
||||
|
||||
return VertexAIAi21Config()
|
||||
elif (
|
||||
vertex_publisher_or_api_spec == "openapi"
|
||||
or vertex_publisher_or_api_spec == "mistralai"
|
||||
):
|
||||
from .llama3.transformation import VertexAILlama3Config
|
||||
|
||||
return VertexAILlama3Config()
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
@@ -0,0 +1,60 @@
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class VertexAIAi21Config(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
|
||||
|
||||
The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
|
||||
|
||||
-> Supports all OpenAI parameters
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
if "max_completion_tokens" in non_default_params:
|
||||
non_default_params["max_tokens"] = non_default_params.pop(
|
||||
"max_completion_tokens"
|
||||
)
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,167 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
|
||||
AnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_BETA_HEADER_VALUES,
|
||||
ANTHROPIC_HOSTED_TOOLS,
|
||||
)
|
||||
from litellm.types.llms.anthropic_tool_search import get_tool_search_beta_header
|
||||
from litellm.types.llms.vertex_ai import VertexPartnerProvider
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
from ....vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAIPartnerModelsAnthropicMessagesConfig(AnthropicMessagesConfig, VertexBase):
|
||||
def validate_anthropic_messages_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[str]]:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Validate the environment for the request
|
||||
"""
|
||||
vertex_ai_project = VertexBase.safe_get_vertex_ai_project(litellm_params)
|
||||
vertex_ai_location = VertexBase.safe_get_vertex_ai_location(litellm_params)
|
||||
|
||||
project_id: Optional[str] = None
|
||||
if "Authorization" not in headers:
|
||||
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
|
||||
litellm_params
|
||||
)
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_ai_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
else:
|
||||
# Authorization already in headers, but we still need project_id
|
||||
project_id = vertex_ai_project
|
||||
|
||||
# Always calculate api_base if not provided, regardless of Authorization header
|
||||
if api_base is None:
|
||||
api_base = self.get_complete_vertex_url(
|
||||
custom_api_base=api_base,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
project_id=project_id or "",
|
||||
partner=VertexPartnerProvider.claude,
|
||||
stream=optional_params.get("stream", False),
|
||||
model=model,
|
||||
)
|
||||
|
||||
headers["content-type"] = "application/json"
|
||||
|
||||
# Add beta headers for Vertex AI
|
||||
tools = optional_params.get("tools", [])
|
||||
beta_values: set[str] = set()
|
||||
|
||||
# Get existing beta headers if any
|
||||
existing_beta = headers.get("anthropic-beta")
|
||||
if existing_beta:
|
||||
beta_values.update(b.strip() for b in existing_beta.split(","))
|
||||
|
||||
# Check for context management
|
||||
context_management_param = optional_params.get("context_management")
|
||||
if context_management_param is not None:
|
||||
# Check edits array for compact_20260112 type
|
||||
edits = context_management_param.get("edits", [])
|
||||
has_compact = False
|
||||
has_other = False
|
||||
|
||||
for edit in edits:
|
||||
edit_type = edit.get("type", "")
|
||||
if edit_type == "compact_20260112":
|
||||
has_compact = True
|
||||
else:
|
||||
has_other = True
|
||||
|
||||
# Add compact header if any compact edits exist
|
||||
if has_compact:
|
||||
beta_values.add(ANTHROPIC_BETA_HEADER_VALUES.COMPACT_2026_01_12.value)
|
||||
|
||||
# Add context management header if any other edits exist
|
||||
if has_other:
|
||||
beta_values.add(
|
||||
ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value
|
||||
)
|
||||
|
||||
# Check for web search tool
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and tool.get("type", "").startswith(
|
||||
ANTHROPIC_HOSTED_TOOLS.WEB_SEARCH.value
|
||||
):
|
||||
beta_values.add(
|
||||
ANTHROPIC_BETA_HEADER_VALUES.WEB_SEARCH_2025_03_05.value
|
||||
)
|
||||
break
|
||||
|
||||
# Check for tool search tools - Vertex AI uses different beta header
|
||||
anthropic_model_info = AnthropicModelInfo()
|
||||
if anthropic_model_info.is_tool_search_used(tools):
|
||||
beta_values.add(get_tool_search_beta_header("vertex_ai"))
|
||||
|
||||
if beta_values:
|
||||
headers["anthropic-beta"] = ",".join(beta_values)
|
||||
|
||||
return headers, api_base
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required. Unable to determine the correct api_base for the request."
|
||||
)
|
||||
return api_base # no transformation is needed - handled in validate_environment
|
||||
|
||||
def transform_anthropic_messages_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
anthropic_messages_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
anthropic_messages_request = super().transform_anthropic_messages_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
anthropic_messages_request["anthropic_version"] = "vertex-2023-10-16"
|
||||
|
||||
anthropic_messages_request.pop(
|
||||
"model", None
|
||||
) # do not pass model in request body to vertex ai
|
||||
|
||||
anthropic_messages_request.pop(
|
||||
"output_format", None
|
||||
) # do not pass output_format in request body to vertex ai - vertex ai does not support output_format as yet
|
||||
|
||||
anthropic_messages_request.pop(
|
||||
"output_config", None
|
||||
) # do not pass output_config in request body to vertex ai - vertex ai does not support output_config
|
||||
|
||||
return anthropic_messages_request
|
||||
@@ -0,0 +1,230 @@
|
||||
# What is this?
|
||||
## Handler file for calling claude-3 on vertex ai
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ....anthropic.chat.transformation import AnthropicConfig
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexAIAnthropicConfig(AnthropicConfig):
|
||||
"""
|
||||
Reference:https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
|
||||
|
||||
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
|
||||
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
|
||||
|
||||
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "vertex_ai"
|
||||
|
||||
def _add_context_management_beta_headers(
|
||||
self, beta_set: set, context_management: dict
|
||||
) -> None:
|
||||
"""
|
||||
Add context_management beta headers to the beta_set.
|
||||
|
||||
- If any edit has type "compact_20260112", add compact-2026-01-12 header
|
||||
- For all other edits, add context-management-2025-06-27 header
|
||||
|
||||
Args:
|
||||
beta_set: Set of beta headers to modify in-place
|
||||
context_management: The context_management dict from optional_params
|
||||
"""
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_BETA_HEADER_VALUES
|
||||
|
||||
edits = context_management.get("edits", [])
|
||||
has_compact = False
|
||||
has_other = False
|
||||
|
||||
for edit in edits:
|
||||
edit_type = edit.get("type", "")
|
||||
if edit_type == "compact_20260112":
|
||||
has_compact = True
|
||||
else:
|
||||
has_other = True
|
||||
|
||||
# Add compact header if any compact edits exist
|
||||
if has_compact:
|
||||
beta_set.add(ANTHROPIC_BETA_HEADER_VALUES.COMPACT_2026_01_12.value)
|
||||
|
||||
# Add context management header if any other edits exist
|
||||
if has_other:
|
||||
beta_set.add(
|
||||
ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
data = super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data.pop("model", None) # vertex anthropic doesn't accept 'model' parameter
|
||||
|
||||
# VertexAI doesn't support output_format parameter, remove it if present
|
||||
data.pop("output_format", None)
|
||||
|
||||
# VertexAI doesn't support output_config parameter, remove it if present
|
||||
data.pop("output_config", None)
|
||||
|
||||
tools = optional_params.get("tools")
|
||||
tool_search_used = self.is_tool_search_used(tools)
|
||||
auto_betas = self.get_anthropic_beta_list(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
computer_tool_used=self.is_computer_tool_used(tools),
|
||||
prompt_caching_set=self.is_cache_control_set(messages),
|
||||
file_id_used=self.is_file_id_used(messages),
|
||||
mcp_server_used=self.is_mcp_server_used(optional_params.get("mcp_servers")),
|
||||
)
|
||||
|
||||
beta_set = set(auto_betas)
|
||||
if tool_search_used:
|
||||
beta_set.add(
|
||||
"tool-search-tool-2025-10-19"
|
||||
) # Vertex requires this header for tool search
|
||||
|
||||
# Add context_management beta headers (compact and/or context-management)
|
||||
context_management = optional_params.get("context_management")
|
||||
if context_management:
|
||||
self._add_context_management_beta_headers(beta_set, context_management)
|
||||
|
||||
extra_headers = optional_params.get("extra_headers") or {}
|
||||
anthropic_beta_value = extra_headers.get("anthropic-beta", "")
|
||||
if isinstance(anthropic_beta_value, str) and anthropic_beta_value:
|
||||
for beta in anthropic_beta_value.split(","):
|
||||
beta = beta.strip()
|
||||
if beta:
|
||||
beta_set.add(beta)
|
||||
elif isinstance(anthropic_beta_value, list):
|
||||
beta_set.update(anthropic_beta_value)
|
||||
|
||||
data.pop("extra_headers", None)
|
||||
|
||||
if beta_set:
|
||||
data["anthropic_beta"] = list(beta_set)
|
||||
headers["anthropic-beta"] = ",".join(beta_set)
|
||||
|
||||
return data
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Override parent method to ensure VertexAI always uses tool-based structured outputs.
|
||||
VertexAI doesn't support the output_format parameter, so we force all models
|
||||
to use the tool-based approach for structured outputs.
|
||||
"""
|
||||
# Temporarily override model name to force tool-based approach
|
||||
# This ensures Claude Sonnet 4.5 uses tools instead of output_format
|
||||
original_model = model
|
||||
if "response_format" in non_default_params:
|
||||
model = "claude-3-sonnet-20240229" # Use a model that will use tool-based approach
|
||||
|
||||
# Call parent method with potentially modified model name
|
||||
optional_params = super().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
# Restore original model name for any other processing
|
||||
model = original_model
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
response = super().transform_response(
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
response.model = model
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def is_supported_model(cls, model: str, custom_llm_provider: str) -> bool:
|
||||
"""
|
||||
Check if the model is supported by the VertexAI Anthropic API.
|
||||
"""
|
||||
if (
|
||||
custom_llm_provider != "vertex_ai"
|
||||
and custom_llm_provider != "vertex_ai_beta"
|
||||
):
|
||||
return False
|
||||
if "claude" in model.lower():
|
||||
return True
|
||||
elif model in litellm.vertex_anthropic_models:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1 @@
|
||||
# Count tokens handler for Vertex AI Partner Models (Anthropic, Mistral, etc.)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Token counter for Vertex AI Partner Models (Anthropic Claude, Mistral, etc.)
|
||||
|
||||
This handler provides token counting for partner models hosted on Vertex AI.
|
||||
Unlike Gemini models which use Google's token counting API, partner models use
|
||||
their respective publisher-specific count-tokens endpoints.
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAIPartnerModelsTokenCounter(VertexBase):
|
||||
"""
|
||||
Token counter for Vertex AI Partner Models.
|
||||
|
||||
Handles token counting for models like Claude (Anthropic), Mistral, etc.
|
||||
that are available through Vertex AI's partner model program.
|
||||
"""
|
||||
|
||||
def _get_publisher_for_model(self, model: str) -> str:
|
||||
"""
|
||||
Determine the publisher name for the given model.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "claude-3-5-sonnet-20241022")
|
||||
|
||||
Returns:
|
||||
Publisher name to use in the Vertex AI endpoint URL
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not a recognized partner model
|
||||
"""
|
||||
if "claude" in model:
|
||||
return "anthropic"
|
||||
elif "mistral" in model or "codestral" in model:
|
||||
return "mistralai"
|
||||
elif "llama" in model or "meta/" in model:
|
||||
return "meta"
|
||||
else:
|
||||
raise ValueError(f"Unknown partner model: {model}")
|
||||
|
||||
def _build_count_tokens_endpoint(
|
||||
self,
|
||||
model: str,
|
||||
project_id: str,
|
||||
vertex_location: str,
|
||||
api_base: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the count-tokens endpoint URL for a partner model.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
project_id: Google Cloud project ID
|
||||
vertex_location: Vertex AI location (e.g., "us-east5")
|
||||
api_base: Optional custom API base URL
|
||||
|
||||
Returns:
|
||||
Full endpoint URL for the count-tokens API
|
||||
"""
|
||||
publisher = self._get_publisher_for_model(model)
|
||||
|
||||
# Use custom api_base if provided, otherwise construct default
|
||||
if api_base:
|
||||
base_url = api_base
|
||||
else:
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
# Construct the count-tokens endpoint
|
||||
# Format: /v1/projects/{project}/locations/{location}/publishers/{publisher}/models/count-tokens:rawPredict
|
||||
endpoint = (
|
||||
f"{base_url}/v1/projects/{project_id}/locations/{vertex_location}/"
|
||||
f"publishers/{publisher}/models/count-tokens:rawPredict"
|
||||
)
|
||||
|
||||
return endpoint
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
model: str,
|
||||
request_data: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle token counting request for a Vertex AI partner model.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
request_data: Request payload (Anthropic Messages API format)
|
||||
litellm_params: LiteLLM parameters containing credentials, project, location
|
||||
|
||||
Returns:
|
||||
Dict containing token count information
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing or invalid
|
||||
"""
|
||||
# Validate request
|
||||
if "messages" not in request_data:
|
||||
raise ValueError("messages required for token counting")
|
||||
|
||||
# Extract Vertex AI credentials and settings
|
||||
vertex_credentials = self.get_vertex_ai_credentials(litellm_params)
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||
|
||||
# Map empty location/cluade models to a supported region for count-tokens endpoint
|
||||
# https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/count-tokens
|
||||
if not vertex_location or "claude" in model.lower():
|
||||
vertex_location = "us-central1"
|
||||
|
||||
# Get access token and resolved project ID
|
||||
access_token, project_id = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
# Build the endpoint URL
|
||||
endpoint_url = self._build_count_tokens_endpoint(
|
||||
model=model,
|
||||
project_id=project_id,
|
||||
vertex_location=vertex_location,
|
||||
api_base=litellm_params.get("api_base"),
|
||||
)
|
||||
|
||||
# Prepare headers
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
# Get async HTTP client
|
||||
from litellm import LlmProviders
|
||||
|
||||
async_client = get_async_httpx_client(llm_provider=LlmProviders.VERTEX_AI)
|
||||
|
||||
# Make the request
|
||||
# Note: Partner models (especially Claude) accept Anthropic Messages API format directly
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
# Check for errors
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
raise ValueError(
|
||||
f"Token counting request failed with status {response.status_code}: {error_text}"
|
||||
)
|
||||
|
||||
# Parse response
|
||||
result = response.json()
|
||||
|
||||
# Return token count
|
||||
# Vertex AI Anthropic returns: {"input_tokens": 123}
|
||||
return {
|
||||
"input_tokens": result.get("input_tokens", 0),
|
||||
"tokenizer_used": "vertex_ai_partner_models",
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
import litellm
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class VertexAIGPTOSSTransformation(OpenAIGPTConfig):
|
||||
"""
|
||||
Transformation for GPT-OSS model from VertexAI
|
||||
|
||||
https://console.cloud.google.com/vertex-ai/publishers/openai/model-garden/gpt-oss-120b-maas?hl=id
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_gpt_series_params = super().get_supported_openai_params(model=model)
|
||||
gpt_oss_only_params = ["reasoning_effort"]
|
||||
base_gpt_series_params.extend(gpt_oss_only_params)
|
||||
|
||||
#########################################################
|
||||
# VertexAI - GPT-OSS does not support tool calls
|
||||
#########################################################
|
||||
if litellm.supports_function_calling(model=model) is False:
|
||||
TOOL_CALLING_PARAMS_TO_REMOVE = [
|
||||
"tool",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
]
|
||||
base_gpt_series_params = [
|
||||
param
|
||||
for param in base_gpt_series_params
|
||||
if param not in TOOL_CALLING_PARAMS_TO_REMOVE
|
||||
]
|
||||
|
||||
return base_gpt_series_params
|
||||
@@ -0,0 +1,226 @@
|
||||
import types
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.openai.chat.gpt_transformation import (
|
||||
OpenAIChatCompletionStreamingHandler,
|
||||
OpenAIGPTConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAIChatCompletionResponse
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from ...common_utils import VertexAIError
|
||||
|
||||
|
||||
class VertexAILlama3Config(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
||||
|
||||
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
supported_params = super().get_supported_openai_params(model=model)
|
||||
try:
|
||||
supported_params.remove("max_retries")
|
||||
except KeyError:
|
||||
pass
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
if "max_completion_tokens" in non_default_params:
|
||||
non_default_params["max_tokens"] = non_default_params.pop(
|
||||
"max_completion_tokens"
|
||||
)
|
||||
return super().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return VertexAILlama3StreamingHandler(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = OpenAIChatCompletionResponse(**raw_response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise VertexAIError(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
model_response.model = completion_response.get("model", model)
|
||||
model_response.id = completion_response.get("id", "")
|
||||
model_response.created = completion_response.get("created", 0)
|
||||
setattr(model_response, "usage", Usage(**completion_response.get("usage", {})))
|
||||
|
||||
model_response.choices = self._transform_choices( # type: ignore
|
||||
choices=completion_response["choices"],
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class VertexAILlama3StreamingHandler(OpenAIChatCompletionStreamingHandler):
|
||||
"""
|
||||
Vertex AI Llama models may not include role in streaming chunk deltas.
|
||||
This handler ensures the first chunk always has role="assistant".
|
||||
|
||||
When Vertex AI returns a single chunk with both role and finish_reason (empty response),
|
||||
this handler splits it into two chunks:
|
||||
1. First chunk: role="assistant", content="", finish_reason=None
|
||||
2. Second chunk: role=None, content=None, finish_reason="stop"
|
||||
|
||||
This matches OpenAI's streaming format where the first chunk has role and
|
||||
the final chunk has finish_reason but no role.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.sent_role = False
|
||||
self._pending_chunk: Optional[ModelResponseStream] = None
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
result = super().chunk_parser(chunk)
|
||||
if not self.sent_role and result.choices:
|
||||
delta = result.choices[0].delta
|
||||
finish_reason = result.choices[0].finish_reason
|
||||
|
||||
# If this is both the first chunk AND the final chunk (has finish_reason),
|
||||
# we need to split it into two chunks to match OpenAI format
|
||||
if finish_reason is not None:
|
||||
# Create a pending final chunk with finish_reason but no role
|
||||
self._pending_chunk = ModelResponseStream(
|
||||
id=result.id,
|
||||
object="chat.completion.chunk",
|
||||
created=result.created,
|
||||
model=result.model,
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(content=None, role=None),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
# Modify current chunk to be the first chunk with role but no finish_reason
|
||||
result.choices[0].finish_reason = None # type: ignore[assignment]
|
||||
delta.role = "assistant"
|
||||
# Ensure content is empty string for first chunk, not None
|
||||
if delta.content is None:
|
||||
delta.content = ""
|
||||
# Prevent downstream stream wrapper from dropping this chunk
|
||||
# (it drops empty-content chunks unless special fields are present)
|
||||
if delta.provider_specific_fields is None:
|
||||
delta.provider_specific_fields = {}
|
||||
elif delta.role is None:
|
||||
delta.role = "assistant"
|
||||
# If the first chunk has empty content, ensure it's still emitted
|
||||
if (
|
||||
delta.content == "" or delta.content is None
|
||||
) and delta.provider_specific_fields is None:
|
||||
delta.provider_specific_fields = {}
|
||||
self.sent_role = True
|
||||
return result
|
||||
|
||||
def __next__(self):
|
||||
# First return any pending chunk from a previous split
|
||||
if self._pending_chunk is not None:
|
||||
chunk = self._pending_chunk
|
||||
self._pending_chunk = None
|
||||
return chunk
|
||||
return super().__next__()
|
||||
|
||||
async def __anext__(self):
|
||||
# First return any pending chunk from a previous split
|
||||
if self._pending_chunk is not None:
|
||||
chunk = self._pending_chunk
|
||||
self._pending_chunk = None
|
||||
return chunk
|
||||
return await super().__anext__()
|
||||
@@ -0,0 +1,344 @@
|
||||
# What is this?
|
||||
## API Handler for calling Vertex AI Partner Models
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm import LlmProviders
|
||||
from litellm.types.llms.vertex_ai import VertexPartnerProvider
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ...custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class PartnerModelPrefixes(str, Enum):
|
||||
META_PREFIX = "meta/"
|
||||
DEEPSEEK_PREFIX = "deepseek-ai"
|
||||
MISTRAL_PREFIX = "mistral"
|
||||
CODERESTAL_PREFIX = "codestral"
|
||||
JAMBA_PREFIX = "jamba"
|
||||
CLAUDE_PREFIX = "claude"
|
||||
QWEN_PREFIX = "qwen"
|
||||
GPT_OSS_PREFIX = "openai/gpt-oss-"
|
||||
MINIMAX_PREFIX = "minimaxai/"
|
||||
MOONSHOT_PREFIX = "moonshotai/"
|
||||
ZAI_PREFIX = "zai-org/"
|
||||
|
||||
|
||||
class VertexAIPartnerModels(VertexBase):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_vertex_partner_model(model: str):
|
||||
"""
|
||||
Check if the model string is a Vertex AI Partner Model
|
||||
Only use this once you have confirmed that custom_llm_provider is vertex_ai
|
||||
|
||||
Returns:
|
||||
bool: True if the model string is a Vertex AI Partner Model, False otherwise
|
||||
"""
|
||||
if (
|
||||
model.startswith(PartnerModelPrefixes.META_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.DEEPSEEK_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.MISTRAL_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.CODERESTAL_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.JAMBA_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.CLAUDE_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.QWEN_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.GPT_OSS_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.MINIMAX_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.MOONSHOT_PREFIX)
|
||||
or model.startswith(PartnerModelPrefixes.ZAI_PREFIX)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def should_use_openai_handler(model: str):
|
||||
OPENAI_LIKE_VERTEX_PROVIDERS = [
|
||||
"llama",
|
||||
PartnerModelPrefixes.DEEPSEEK_PREFIX,
|
||||
PartnerModelPrefixes.QWEN_PREFIX,
|
||||
PartnerModelPrefixes.GPT_OSS_PREFIX,
|
||||
PartnerModelPrefixes.MINIMAX_PREFIX,
|
||||
PartnerModelPrefixes.MOONSHOT_PREFIX,
|
||||
PartnerModelPrefixes.ZAI_PREFIX,
|
||||
]
|
||||
if any(provider in model for provider in OPENAI_LIKE_VERTEX_PROVIDERS):
|
||||
return True
|
||||
return False
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
import vertexai
|
||||
|
||||
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||
from litellm.llms.codestral.completion.handler import (
|
||||
CodestralTextCompletion,
|
||||
)
|
||||
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
openai_like_chat_completions = OpenAILikeChatHandler()
|
||||
codestral_fim_completions = CodestralTextCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
||||
## CONSTRUCT API BASE
|
||||
stream: bool = optional_params.get("stream", False) or False
|
||||
|
||||
optional_params["stream"] = stream
|
||||
|
||||
if self.should_use_openai_handler(model):
|
||||
partner = VertexPartnerProvider.llama
|
||||
elif "mistral" in model or "codestral" in model:
|
||||
partner = VertexPartnerProvider.mistralai
|
||||
elif "jamba" in model:
|
||||
partner = VertexPartnerProvider.ai21
|
||||
elif "claude" in model:
|
||||
partner = VertexPartnerProvider.claude
|
||||
else:
|
||||
raise ValueError(f"Unknown partner model: {model}")
|
||||
|
||||
api_base = self.get_complete_vertex_url(
|
||||
custom_api_base=api_base,
|
||||
vertex_location=vertex_location,
|
||||
vertex_project=vertex_project,
|
||||
project_id=project_id,
|
||||
partner=partner,
|
||||
stream=stream,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if "codestral" in model or "mistral" in model:
|
||||
model = model.split("@")[0]
|
||||
|
||||
if "codestral" in model and litellm_params.get("text_completion") is True:
|
||||
optional_params["model"] = model
|
||||
text_completion_model_response = litellm.TextCompletionResponse(
|
||||
stream=stream
|
||||
)
|
||||
return codestral_fim_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=access_token,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=text_completion_model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
)
|
||||
elif "claude" in model:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update({"Authorization": "Bearer {}".format(access_token)})
|
||||
|
||||
optional_params.update(
|
||||
{
|
||||
"anthropic_version": "vertex-2023-10-16",
|
||||
"is_vertex_request": True,
|
||||
}
|
||||
)
|
||||
|
||||
return anthropic_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=access_token,
|
||||
logging_obj=logging_obj,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_llm_provider=LlmProviders.VERTEX_AI.value,
|
||||
)
|
||||
elif self.should_use_openai_handler(model):
|
||||
return base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider="vertex_ai",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=access_token,
|
||||
logging_obj=logging_obj, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=access_token,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
custom_llm_provider="vertex_ai",
|
||||
custom_endpoint=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
litellm_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
):
|
||||
"""
|
||||
Count tokens for Vertex AI partner models (Anthropic Claude, Mistral, etc.)
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "claude-3-5-sonnet-20241022")
|
||||
messages: List of messages in Anthropic Messages API format
|
||||
litellm_params: LiteLLM parameters dict
|
||||
vertex_project: Optional Google Cloud project ID
|
||||
vertex_location: Optional Vertex AI location
|
||||
vertex_credentials: Optional Vertex AI credentials
|
||||
|
||||
Returns:
|
||||
Dict containing token count information
|
||||
"""
|
||||
try:
|
||||
import vertexai
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
|
||||
try:
|
||||
from litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler import (
|
||||
VertexAIPartnerModelsTokenCounter,
|
||||
)
|
||||
|
||||
# Prepare request data in Anthropic Messages API format
|
||||
request_data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# Prepare litellm_params with credentials
|
||||
_litellm_params = litellm_params.copy()
|
||||
if vertex_project:
|
||||
_litellm_params["vertex_project"] = vertex_project
|
||||
if vertex_location:
|
||||
_litellm_params["vertex_location"] = vertex_location
|
||||
if vertex_credentials:
|
||||
_litellm_params["vertex_credentials"] = vertex_credentials
|
||||
|
||||
# Call the token counter
|
||||
token_counter = VertexAIPartnerModelsTokenCounter()
|
||||
result = await token_counter.handle_count_tokens_request(
|
||||
model=model,
|
||||
request_data=request_data,
|
||||
litellm_params=_litellm_params,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Vertex AI BGE (BAAI General Embedding) Configuration
|
||||
|
||||
BGE models deployed on Vertex AI require different input/output format:
|
||||
- Request: Use "prompt" instead of "content" as the input field
|
||||
- Response: Embeddings are returned directly as arrays, not wrapped in objects
|
||||
|
||||
Model name handling:
|
||||
- Model names like "bge/endpoint_id" are automatically transformed in common_utils._get_vertex_url()
|
||||
- This module focuses on request/response transformation only
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
from .types import (
|
||||
EmbeddingParameters,
|
||||
TaskType,
|
||||
TextEmbeddingBGEInput,
|
||||
VertexEmbeddingRequest,
|
||||
)
|
||||
|
||||
|
||||
class VertexBGEConfig:
|
||||
"""
|
||||
Configuration and transformation logic for BGE models on Vertex AI.
|
||||
|
||||
BGE (BAAI General Embedding) models use a different request format
|
||||
where the input field is named "prompt" instead of "content".
|
||||
|
||||
Supported model patterns (after provider split in main.py):
|
||||
- "bge-small-en-v1.5" (model name)
|
||||
- "bge/204379420394258432" (endpoint ID pattern)
|
||||
|
||||
Note: Model name transformation (bge/ -> numeric ID) is handled automatically
|
||||
in common_utils._get_vertex_url(). This class focuses on request/response format only.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_bge_model(model: str) -> bool:
|
||||
"""
|
||||
Check if the model is a BGE (BAAI General Embedding) model.
|
||||
|
||||
After provider split in main.py, supports:
|
||||
- "bge-small-en-v1.5" (model name)
|
||||
- "bge/204379420394258432" (endpoint ID pattern)
|
||||
|
||||
Args:
|
||||
model: The model name after provider split
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a BGE model
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
# Check for "bge/" prefix (endpoint pattern) or "bge" in model name
|
||||
return model_lower.startswith("bge/") or "bge" in model_lower
|
||||
|
||||
@staticmethod
|
||||
def transform_request(
|
||||
input: Union[list, str], optional_params: dict, model: str
|
||||
) -> VertexEmbeddingRequest:
|
||||
"""
|
||||
Transforms an OpenAI request to a Vertex BGE embedding request.
|
||||
|
||||
BGE models use "prompt" instead of "content" as the input field.
|
||||
|
||||
Args:
|
||||
input: The input text(s) to embed
|
||||
optional_params: Optional parameters for the request
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
VertexEmbeddingRequest: The transformed request
|
||||
"""
|
||||
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
||||
vertex_text_embedding_input_list: List[TextEmbeddingBGEInput] = []
|
||||
task_type: Optional[TaskType] = optional_params.get("task_type")
|
||||
title = optional_params.get("title")
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
for text in input:
|
||||
embedding_input = VertexBGEConfig._create_embedding_input(
|
||||
prompt=text, task_type=task_type, title=title
|
||||
)
|
||||
vertex_text_embedding_input_list.append(embedding_input)
|
||||
|
||||
vertex_request["instances"] = vertex_text_embedding_input_list
|
||||
vertex_request["parameters"] = EmbeddingParameters(**optional_params)
|
||||
|
||||
return vertex_request
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_input(
|
||||
prompt: str,
|
||||
task_type: Optional[TaskType] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> TextEmbeddingBGEInput:
|
||||
"""
|
||||
Creates a TextEmbeddingBGEInput object for BGE models.
|
||||
|
||||
BGE models use "prompt" instead of "content" as the input field.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to be embedded
|
||||
task_type: The type of task to be performed
|
||||
title: The title of the document to be embedded
|
||||
|
||||
Returns:
|
||||
TextEmbeddingBGEInput: A TextEmbeddingBGEInput object
|
||||
"""
|
||||
text_embedding_input = TextEmbeddingBGEInput(prompt=prompt)
|
||||
if task_type is not None:
|
||||
text_embedding_input["task_type"] = task_type
|
||||
if title is not None:
|
||||
text_embedding_input["title"] = title
|
||||
return text_embedding_input
|
||||
|
||||
@staticmethod
|
||||
def transform_response(
|
||||
response: dict, model: str, model_response: EmbeddingResponse
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transforms a Vertex BGE embedding response to OpenAI format.
|
||||
|
||||
BGE models return embeddings directly as arrays in predictions:
|
||||
{
|
||||
"predictions": [
|
||||
[0.002, 0.021, ...],
|
||||
[0.003, 0.022, ...]
|
||||
]
|
||||
}
|
||||
|
||||
Args:
|
||||
response: The raw response from Vertex AI
|
||||
model: The model name
|
||||
model_response: The EmbeddingResponse object to populate
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse: The transformed response in OpenAI format
|
||||
|
||||
Raises:
|
||||
KeyError: If response doesn't contain 'predictions'
|
||||
ValueError: If predictions is not a list or contains invalid data
|
||||
"""
|
||||
if "predictions" not in response:
|
||||
raise KeyError("Response missing 'predictions' field")
|
||||
|
||||
_predictions = response["predictions"]
|
||||
|
||||
if not isinstance(_predictions, list):
|
||||
raise ValueError(
|
||||
f"Expected 'predictions' to be a list, got {type(_predictions)}"
|
||||
)
|
||||
|
||||
embedding_response = []
|
||||
# BGE models don't return token counts, so we estimate or set to 0
|
||||
input_tokens = 0
|
||||
|
||||
for idx, embedding_values in enumerate(_predictions):
|
||||
if not isinstance(embedding_values, list):
|
||||
raise ValueError(
|
||||
f"Expected embedding at index {idx} to be a list, got {type(embedding_values)}"
|
||||
)
|
||||
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding_values,
|
||||
}
|
||||
)
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
@@ -0,0 +1,232 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .types import *
|
||||
|
||||
|
||||
class VertexEmbedding(VertexBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
print_verbose,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_key: Optional[str] = None,
|
||||
encoding=None,
|
||||
aembedding: Optional[bool] = False,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
vertex_project: Optional[str] = None,
|
||||
vertex_location: Optional[str] = None,
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
|
||||
gemini_api_key: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if aembedding is True:
|
||||
return self.async_embedding( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
timeout=timeout,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
gemini_api_key=gemini_api_key,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||
optional_params=optional_params
|
||||
)
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
# Extract use_psc_endpoint_format from optional_params
|
||||
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=False,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
mode="embedding",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
_client_params = {}
|
||||
if timeout:
|
||||
_client_params["timeout"] = timeout
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client(params=_client_params)
|
||||
else:
|
||||
client = client # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=vertex_request,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": vertex_request,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = client.post(url=api_base, headers=headers, json=vertex_request) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
_json_response = response.json()
|
||||
## LOGGING POST-CALL
|
||||
logging_obj.post_call(
|
||||
input=input, api_key=None, original_response=_json_response
|
||||
)
|
||||
|
||||
model_response = (
|
||||
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
|
||||
response=_json_response, model=model, model_response=model_response
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
async def async_embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
optional_params: dict,
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
vertex_project: Optional[str] = None,
|
||||
vertex_location: Optional[str] = None,
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
|
||||
gemini_api_key: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
encoding=None,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Async embedding implementation
|
||||
"""
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||
optional_params=optional_params
|
||||
)
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
# Extract use_psc_endpoint_format from optional_params
|
||||
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=False,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
mode="embedding",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
_async_client_params = {}
|
||||
if timeout:
|
||||
_async_client_params["timeout"] = timeout
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
|
||||
)
|
||||
else:
|
||||
client = client # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=vertex_request,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": vertex_request,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, json=vertex_request) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
_json_response = response.json()
|
||||
## LOGGING POST-CALL
|
||||
logging_obj.post_call(
|
||||
input=input, api_key=None, original_response=_json_response
|
||||
)
|
||||
|
||||
model_response = (
|
||||
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
|
||||
response=_json_response, model=model, model_response=model_response
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,285 @@
|
||||
import types
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
from .types import *
|
||||
|
||||
|
||||
class VertexAITextEmbeddingConfig(BaseModel):
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
|
||||
|
||||
Args:
|
||||
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
|
||||
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
|
||||
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||
"""
|
||||
|
||||
auto_truncate: Optional[bool] = None
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_truncate: Optional[bool] = None,
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return ["dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, kwargs: dict
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "dimensions":
|
||||
optional_params["outputDimensionality"] = value
|
||||
|
||||
if "input_type" in kwargs:
|
||||
optional_params["task_type"] = kwargs.pop("input_type")
|
||||
return optional_params, kwargs
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def transform_openai_request_to_vertex_embedding_request(
|
||||
self, input: Union[list, str], optional_params: dict, model: str
|
||||
) -> VertexEmbeddingRequest:
|
||||
"""
|
||||
Transforms an openai request to a vertex embedding request.
|
||||
"""
|
||||
# Import here to avoid circular import issues with litellm.__init__
|
||||
from litellm.llms.vertex_ai.vertex_embeddings.bge import VertexBGEConfig
|
||||
|
||||
if model.isdigit():
|
||||
return self._transform_openai_request_to_fine_tuned_embedding_request(
|
||||
input, optional_params, model
|
||||
)
|
||||
if VertexBGEConfig.is_bge_model(model):
|
||||
return VertexBGEConfig.transform_request(
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
||||
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
|
||||
task_type: Optional[TaskType] = optional_params.get("task_type")
|
||||
title = optional_params.get("title")
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input] # Convert single string to list for uniform processing
|
||||
|
||||
for text in input:
|
||||
embedding_input = self.create_embedding_input(
|
||||
content=text, task_type=task_type, title=title
|
||||
)
|
||||
vertex_text_embedding_input_list.append(embedding_input)
|
||||
|
||||
vertex_request["instances"] = vertex_text_embedding_input_list
|
||||
vertex_request["parameters"] = EmbeddingParameters(**optional_params)
|
||||
|
||||
return vertex_request
|
||||
|
||||
def _transform_openai_request_to_fine_tuned_embedding_request(
|
||||
self, input: Union[list, str], optional_params: dict, model: str
|
||||
) -> VertexEmbeddingRequest:
|
||||
"""
|
||||
Transforms an openai request to a vertex fine-tuned embedding request.
|
||||
|
||||
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
|
||||
Sample Request:
|
||||
|
||||
```json
|
||||
{
|
||||
"instances" : [
|
||||
{
|
||||
"inputs": "How would the Future of AI in 10 Years look?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 128,
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
"top_k": 10
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
||||
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
|
||||
if isinstance(input, str):
|
||||
input = [input] # Convert single string to list for uniform processing
|
||||
|
||||
for text in input:
|
||||
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
|
||||
vertex_text_embedding_input_list.append(embedding_input)
|
||||
|
||||
vertex_request["instances"] = vertex_text_embedding_input_list
|
||||
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
|
||||
**optional_params
|
||||
)
|
||||
# Remove 'shared_session' from parameters if present
|
||||
if (
|
||||
vertex_request["parameters"] is not None
|
||||
and "shared_session" in vertex_request["parameters"]
|
||||
):
|
||||
del vertex_request["parameters"]["shared_session"] # type: ignore[typeddict-item]
|
||||
|
||||
return vertex_request
|
||||
|
||||
def create_embedding_input(
|
||||
self,
|
||||
content: str,
|
||||
task_type: Optional[TaskType] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> TextEmbeddingInput:
|
||||
"""
|
||||
Creates a TextEmbeddingInput object.
|
||||
|
||||
Vertex requires a List of TextEmbeddingInput objects. This helper function creates a single TextEmbeddingInput object.
|
||||
|
||||
Args:
|
||||
content (str): The content to be embedded.
|
||||
task_type (Optional[TaskType]): The type of task to be performed.
|
||||
title (Optional[str]): The title of the document to be embedded.
|
||||
|
||||
Returns:
|
||||
TextEmbeddingInput: A TextEmbeddingInput object.
|
||||
"""
|
||||
text_embedding_input = TextEmbeddingInput(content=content)
|
||||
if task_type is not None:
|
||||
text_embedding_input["task_type"] = task_type
|
||||
if title is not None:
|
||||
text_embedding_input["title"] = title
|
||||
return text_embedding_input
|
||||
|
||||
def transform_vertex_response_to_openai(
|
||||
self, response: dict, model: str, model_response: EmbeddingResponse
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transforms a vertex embedding response to an openai response.
|
||||
"""
|
||||
if model.isdigit():
|
||||
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
|
||||
response, model, model_response
|
||||
)
|
||||
|
||||
# Import here to avoid circular import issues with litellm.__init__
|
||||
from litellm.llms.vertex_ai.vertex_embeddings.bge import VertexBGEConfig
|
||||
|
||||
if VertexBGEConfig.is_bge_model(model):
|
||||
return VertexBGEConfig.transform_response(
|
||||
response=response, model=model, model_response=model_response
|
||||
)
|
||||
|
||||
_predictions = response["predictions"]
|
||||
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, element in enumerate(_predictions):
|
||||
embedding = element["embeddings"]
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding["values"],
|
||||
}
|
||||
)
|
||||
input_tokens += embedding["statistics"]["token_count"]
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _transform_vertex_response_to_openai_for_fine_tuned_models(
|
||||
self, response: dict, model: str, model_response: EmbeddingResponse
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transforms a vertex fine-tuned model embedding response to an openai response format.
|
||||
"""
|
||||
_predictions = response["predictions"]
|
||||
|
||||
embedding_response = []
|
||||
# For fine-tuned models, we don't get token counts in the response
|
||||
input_tokens = 0
|
||||
|
||||
for idx, embedding_values in enumerate(_predictions):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding_values[
|
||||
0
|
||||
], # The embedding values are nested one level deeper
|
||||
}
|
||||
)
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Types for Vertex Embeddings Requests
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
|
||||
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
|
||||
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
|
||||
CLASSIFICATION = "CLASSIFICATION"
|
||||
CLUSTERING = "CLUSTERING"
|
||||
QUESTION_ANSWERING = "QUESTION_ANSWERING"
|
||||
FACT_VERIFICATION = "FACT_VERIFICATION"
|
||||
CODE_RETRIEVAL_QUERY = "CODE_RETRIEVAL_QUERY"
|
||||
|
||||
|
||||
class TextEmbeddingInput(TypedDict, total=False):
|
||||
content: str
|
||||
task_type: Optional[TaskType]
|
||||
title: Optional[str]
|
||||
|
||||
|
||||
class TextEmbeddingBGEInput(TypedDict, total=False):
|
||||
prompt: str
|
||||
task_type: Optional[TaskType]
|
||||
title: Optional[str]
|
||||
|
||||
|
||||
# Fine-tuned models require a different input format
|
||||
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
|
||||
class TextEmbeddingFineTunedInput(TypedDict, total=False):
|
||||
inputs: str
|
||||
|
||||
|
||||
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
|
||||
max_new_tokens: Optional[int]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
top_k: Optional[int]
|
||||
|
||||
|
||||
class EmbeddingParameters(TypedDict, total=False):
|
||||
auto_truncate: Optional[bool]
|
||||
output_dimensionality: Optional[int]
|
||||
|
||||
|
||||
class VertexEmbeddingRequest(TypedDict, total=False):
|
||||
instances: Union[
|
||||
List[TextEmbeddingInput],
|
||||
List[TextEmbeddingBGEInput],
|
||||
List[TextEmbeddingFineTunedInput],
|
||||
]
|
||||
parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
|
||||
|
||||
|
||||
# Example usage:
|
||||
# example_request: VertexEmbeddingRequest = {
|
||||
# "instances": [
|
||||
# {
|
||||
# "content": "I would like embeddings for this text!",
|
||||
# "task_type": "RETRIEVAL_DOCUMENT",
|
||||
# "title": "document title"
|
||||
# }
|
||||
# ],
|
||||
# "parameters": {
|
||||
# "auto_truncate": True,
|
||||
# "output_dimensionality": None
|
||||
# }
|
||||
# }
|
||||
@@ -0,0 +1 @@
|
||||
"""Vertex AI Gemma-AI Models Handler"""
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
API Handler for calling Vertex AI Gemma Models
|
||||
|
||||
These models use a custom prediction endpoint format that wraps messages in 'instances'
|
||||
with @requestFormat: "chatCompletions" and returns responses wrapped in 'predictions'.
|
||||
|
||||
Usage:
|
||||
|
||||
response = litellm.completion(
|
||||
model="vertex_ai/gemma/gemma-3-12b-it-1222199011122",
|
||||
messages=[{"role": "user", "content": "What is machine learning?"}],
|
||||
vertex_project="your-project-id",
|
||||
vertex_location="us-central1",
|
||||
)
|
||||
|
||||
Sent to this route when `model` is in the format `vertex_ai/gemma/{MODEL_NAME}`
|
||||
|
||||
The API expects a custom endpoint URL format:
|
||||
https://{ENDPOINT_NUMBER}.{location}-{REGION_NUMBER}.prediction.vertexai.goog/v1/projects/{PROJECT_ID}/locations/{location}/endpoints/{ENDPOINT_ID}:predict
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ..common_utils import VertexAIError, get_vertex_base_model_name
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAIGemmaModels(VertexBase):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
client=None,
|
||||
):
|
||||
"""
|
||||
Handles calling Vertex AI Gemma Models
|
||||
|
||||
Sent to this route when `model` is in the format `vertex_ai/gemma/{MODEL_NAME}`
|
||||
"""
|
||||
try:
|
||||
import vertexai
|
||||
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
from litellm.llms.vertex_ai.vertex_gemma_models.transformation import (
|
||||
VertexGemmaConfig,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
model = get_vertex_base_model_name(model=model)
|
||||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
gemma_transformation = VertexGemmaConfig()
|
||||
|
||||
## CONSTRUCT API BASE
|
||||
stream: bool = optional_params.get("stream", False) or False
|
||||
optional_params["stream"] = stream
|
||||
|
||||
# If api_base is not provided, it should be set as an environment variable
|
||||
# or passed explicitly because the endpoint URL is unique per deployment
|
||||
if api_base is None:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="api_base is required for Vertex AI Gemma models. Please provide the full endpoint URL.",
|
||||
)
|
||||
|
||||
# Check if we need to append :predict
|
||||
if not api_base.endswith(":predict"):
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint="predict",
|
||||
stream=stream,
|
||||
auth_header=None,
|
||||
url=api_base,
|
||||
)
|
||||
# If api_base already ends with :predict, use it as-is
|
||||
|
||||
# Use the custom transformation handler for gemma models
|
||||
return gemma_transformation.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=access_token,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Transformation logic for Vertex AI Gemma Models
|
||||
|
||||
Handles the custom request/response format:
|
||||
- Request: Wraps messages in 'instances' with @requestFormat: "chatCompletions"
|
||||
- Response: Extracts data from 'predictions' wrapper
|
||||
|
||||
The actual message transformation reuses OpenAIGPTConfig since Gemma uses OpenAI-compatible format.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class VertexGemmaConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Configuration and transformation class for Vertex AI Gemma models
|
||||
|
||||
Extends OpenAIGPTConfig to wrap/unwrap the instances/predictions format
|
||||
used by Vertex AI's Gemma deployment endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Vertex AI Gemma models do not support streaming.
|
||||
Return True to enable fake streaming on the client side.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _handle_fake_stream_response(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
) -> Union[ModelResponse, Any]:
|
||||
"""
|
||||
Helper method to return fake stream iterator if streaming is requested.
|
||||
|
||||
Args:
|
||||
model_response: The completed model response
|
||||
stream: Whether streaming was requested
|
||||
|
||||
Returns:
|
||||
MockResponseIterator if stream=True, otherwise the model_response
|
||||
"""
|
||||
if stream:
|
||||
from litellm.llms.base_llm.base_model_iterator import MockResponseIterator
|
||||
|
||||
return MockResponseIterator(model_response=model_response)
|
||||
return model_response
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform request to Vertex Gemma format.
|
||||
|
||||
Uses parent class to create OpenAI-compatible request, then wraps it
|
||||
in the Vertex Gemma instances format.
|
||||
"""
|
||||
# Get the base OpenAI request from parent class
|
||||
openai_request = super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Remove params not needed/supported by Vertex Gemma
|
||||
openai_request.pop("model", None)
|
||||
openai_request.pop(
|
||||
"stream", None
|
||||
) # Streaming not supported, will be faked client-side
|
||||
openai_request.pop("stream_options", None) # Stream options not supported
|
||||
|
||||
# Wrap in Vertex Gemma format
|
||||
return {
|
||||
"instances": [
|
||||
{
|
||||
"@requestFormat": "chatCompletions",
|
||||
**openai_request,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def _unwrap_predictions_response(
|
||||
self,
|
||||
response_json: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Unwrap the Vertex Gemma predictions format to OpenAI format.
|
||||
|
||||
Vertex Gemma wraps the OpenAI-compatible response in a 'predictions' field.
|
||||
This method extracts it so the parent class can process it normally.
|
||||
"""
|
||||
if "predictions" not in response_json:
|
||||
raise BaseLLMException(
|
||||
status_code=422,
|
||||
message="Invalid response format: missing 'predictions' field",
|
||||
)
|
||||
|
||||
return response_json["predictions"]
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
litellm_params: dict,
|
||||
logger_fn: Optional[Callable] = None,
|
||||
client: Optional[httpx.Client] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
encoding=None,
|
||||
custom_llm_provider: str = "vertex_ai",
|
||||
):
|
||||
"""
|
||||
Make completion request to Vertex Gemma endpoint.
|
||||
Supports both sync and async requests with fake streaming.
|
||||
"""
|
||||
if acompletion:
|
||||
return self._async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
)
|
||||
else:
|
||||
return self._sync_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def _sync_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding: Any,
|
||||
):
|
||||
"""Synchronous completion request"""
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
# Check if streaming is requested (will be faked)
|
||||
stream = optional_params.get("stream", False)
|
||||
|
||||
# Transform the request using parent class methods
|
||||
request_data = self.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params.copy(),
|
||||
litellm_params=litellm_params,
|
||||
headers={},
|
||||
)
|
||||
|
||||
# Set up headers
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Log the request
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
# Make the HTTP request
|
||||
http_handler = HTTPHandler(concurrent_limit=1)
|
||||
response = http_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=f"Request failed: {response.text}",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Unwrap predictions to get OpenAI-compatible response
|
||||
openai_response = self._unwrap_predictions_response(response_json)
|
||||
|
||||
# Use litellm's standard response converter
|
||||
model_response = cast(
|
||||
ModelResponse,
|
||||
convert_to_model_response_object(
|
||||
response_object=openai_response,
|
||||
model_response_object=model_response,
|
||||
_response_headers={},
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure model is set correctly
|
||||
model_response.model = model
|
||||
|
||||
# Log the response
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
# Return fake stream iterator if streaming was requested
|
||||
return self._handle_fake_stream_response(
|
||||
model_response=model_response, stream=stream
|
||||
)
|
||||
|
||||
async def _async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding: Any,
|
||||
):
|
||||
"""Asynchronous completion request"""
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
# Check if streaming is requested (will be faked)
|
||||
stream = optional_params.get("stream", False)
|
||||
|
||||
# Transform the request using parent class async methods
|
||||
request_data = await self.async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params.copy(),
|
||||
litellm_params=litellm_params,
|
||||
headers={},
|
||||
)
|
||||
|
||||
# Set up headers
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Log the request
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
# Make the HTTP request
|
||||
http_handler = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.VERTEX_AI,
|
||||
)
|
||||
response = await http_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=f"Request failed: {response.text}",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Unwrap predictions to get OpenAI-compatible response
|
||||
openai_response = self._unwrap_predictions_response(response_json)
|
||||
|
||||
# Use litellm's standard response converter
|
||||
model_response = cast(
|
||||
ModelResponse,
|
||||
convert_to_model_response_object(
|
||||
response_object=openai_response,
|
||||
model_response_object=model_response,
|
||||
_response_headers={},
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure model is set correctly
|
||||
model_response.model = model
|
||||
|
||||
# Log the response
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
# Return fake stream iterator if streaming was requested
|
||||
return self._handle_fake_stream_response(
|
||||
model_response=model_response, stream=stream
|
||||
)
|
||||
@@ -0,0 +1,804 @@
|
||||
"""
|
||||
Base Vertex, Google AI Studio LLM Class
|
||||
|
||||
Handles Authentication and generating request urls for Vertex AI and Google AI Studio
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES, VertexPartnerProvider
|
||||
|
||||
from .common_utils import (
|
||||
_get_gemini_url,
|
||||
_get_vertex_url,
|
||||
all_gemini_url_modes,
|
||||
get_vertex_base_model_name,
|
||||
get_vertex_base_url,
|
||||
)
|
||||
|
||||
GOOGLE_IMPORT_ERROR_MESSAGE = (
|
||||
"Google Cloud SDK not found. Install it with: pip install 'litellm[google]' "
|
||||
"or pip install google-cloud-aiplatform"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.auth.credentials import Credentials as GoogleCredentialsObject
|
||||
else:
|
||||
GoogleCredentialsObject = Any
|
||||
|
||||
|
||||
class VertexBase:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self._credentials: Optional[GoogleCredentialsObject] = None
|
||||
self._credentials_project_mapping: Dict[
|
||||
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
|
||||
Tuple[GoogleCredentialsObject, str],
|
||||
] = {}
|
||||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
|
||||
def get_vertex_region(self, vertex_region: Optional[str], model: str) -> str:
|
||||
import litellm
|
||||
|
||||
# Try to get supported_regions directly from model_cost
|
||||
# Check both with and without vertex_ai/ prefix
|
||||
model_key = (
|
||||
f"vertex_ai/{model}" if not model.startswith("vertex_ai/") else model
|
||||
)
|
||||
model_info = litellm.model_cost.get(model_key, {})
|
||||
supported_regions = model_info.get("supported_regions")
|
||||
|
||||
if supported_regions and len(supported_regions) > 0:
|
||||
# If user didn't specify region, use the first supported region
|
||||
if vertex_region is None:
|
||||
return supported_regions[0]
|
||||
# If user specified a region not supported by this model, override it
|
||||
if vertex_region not in supported_regions:
|
||||
verbose_logger.warning(
|
||||
"Vertex AI model '%s' does not support region '%s' "
|
||||
"(supported: %s). Routing to '%s'.",
|
||||
model,
|
||||
vertex_region,
|
||||
supported_regions,
|
||||
supported_regions[0],
|
||||
)
|
||||
return supported_regions[0]
|
||||
return vertex_region
|
||||
return vertex_region or "us-central1"
|
||||
|
||||
def load_auth(
|
||||
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
if credentials is not None:
|
||||
if isinstance(credentials, str):
|
||||
verbose_logger.debug(
|
||||
"Vertex: Loading vertex credentials from %s", credentials
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
||||
credentials,
|
||||
os.path.exists(credentials),
|
||||
os.getcwd(),
|
||||
)
|
||||
|
||||
try:
|
||||
if os.path.exists(credentials):
|
||||
json_obj = json.load(open(credentials))
|
||||
else:
|
||||
json_obj = json.loads(credentials)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Unable to load vertex credentials from environment. Got={}".format(
|
||||
credentials
|
||||
)
|
||||
)
|
||||
elif isinstance(credentials, dict):
|
||||
json_obj = credentials
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid credentials type: {}".format(type(credentials))
|
||||
)
|
||||
|
||||
# Check if the JSON object contains Workload Identity Federation configuration
|
||||
if "type" in json_obj and json_obj["type"] == "external_account":
|
||||
# If environment_id key contains "aws" value it corresponds to an AWS config file
|
||||
credential_source = json_obj.get("credential_source", {})
|
||||
environment_id = (
|
||||
credential_source.get("environment_id", "")
|
||||
if isinstance(credential_source, dict)
|
||||
else ""
|
||||
)
|
||||
if isinstance(environment_id, str) and "aws" in environment_id:
|
||||
# Check if explicit AWS params are in the JSON (bypasses metadata)
|
||||
from litellm.llms.vertex_ai.vertex_ai_aws_wif import (
|
||||
VertexAIAwsWifAuth,
|
||||
)
|
||||
|
||||
aws_params = VertexAIAwsWifAuth.extract_aws_params(json_obj)
|
||||
if aws_params:
|
||||
creds = VertexAIAwsWifAuth.credentials_from_explicit_aws(
|
||||
json_obj,
|
||||
aws_params=aws_params,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
creds = self._credentials_from_identity_pool_with_aws(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
creds = self._credentials_from_identity_pool(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
# Check if the JSON object contains Authorized User configuration (via gcloud auth application-default login)
|
||||
elif "type" in json_obj and json_obj["type"] == "authorized_user":
|
||||
creds = self._credentials_from_authorized_user(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
creds.quota_project_id
|
||||
) # authorized user credentials don't have a project_id, only quota_project_id
|
||||
else:
|
||||
creds = self._credentials_from_service_account(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
project_id = getattr(creds, "project_id", None)
|
||||
else:
|
||||
creds, creds_project_id = self._credentials_from_default_auth(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
if project_id is None:
|
||||
project_id = creds_project_id
|
||||
|
||||
self.refresh_auth(creds)
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not isinstance(project_id, str):
|
||||
raise TypeError(
|
||||
f"Expected project_id to be a str but got {type(project_id)}"
|
||||
)
|
||||
|
||||
return creds, project_id
|
||||
|
||||
# Google Auth Helpers -- extracted for mocking purposes in tests
|
||||
def _credentials_from_identity_pool(self, json_obj, scopes):
|
||||
try:
|
||||
from google.auth import identity_pool
|
||||
except ImportError:
|
||||
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
|
||||
|
||||
creds = identity_pool.Credentials.from_info(json_obj)
|
||||
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
|
||||
creds = creds.with_scopes(scopes)
|
||||
return creds
|
||||
|
||||
def _credentials_from_identity_pool_with_aws(self, json_obj, scopes):
|
||||
try:
|
||||
from google.auth import aws
|
||||
except ImportError:
|
||||
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
|
||||
|
||||
creds = aws.Credentials.from_info(json_obj)
|
||||
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
|
||||
creds = creds.with_scopes(scopes)
|
||||
return creds
|
||||
|
||||
def _credentials_from_authorized_user(self, json_obj, scopes):
|
||||
try:
|
||||
import google.oauth2.credentials
|
||||
except ImportError:
|
||||
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
|
||||
|
||||
return google.oauth2.credentials.Credentials.from_authorized_user_info(
|
||||
json_obj, scopes=scopes
|
||||
)
|
||||
|
||||
def _credentials_from_service_account(self, json_obj, scopes):
|
||||
try:
|
||||
import google.oauth2.service_account
|
||||
except ImportError:
|
||||
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
|
||||
|
||||
return google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj, scopes=scopes
|
||||
)
|
||||
|
||||
def _credentials_from_default_auth(self, scopes):
|
||||
try:
|
||||
import google.auth as google_auth
|
||||
except ImportError:
|
||||
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
|
||||
|
||||
return google_auth.default(scopes=scopes)
|
||||
|
||||
def get_default_vertex_location(self) -> str:
|
||||
return "us-central1"
|
||||
|
||||
def get_api_base(
|
||||
self, api_base: Optional[str], vertex_location: Optional[str]
|
||||
) -> str:
|
||||
if api_base:
|
||||
return api_base
|
||||
return get_vertex_base_url(
|
||||
vertex_location or self.get_default_vertex_location()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_vertex_url(
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
partner: VertexPartnerProvider,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
api_base: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Return the base url for the vertex partner models"""
|
||||
|
||||
if api_base is None:
|
||||
api_base = get_vertex_base_url(vertex_location)
|
||||
if partner == VertexPartnerProvider.llama:
|
||||
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi/chat/completions"
|
||||
elif partner == VertexPartnerProvider.mistralai:
|
||||
if stream:
|
||||
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||
elif partner == VertexPartnerProvider.ai21:
|
||||
if stream:
|
||||
return f"{api_base}/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"{api_base}/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
||||
elif partner == VertexPartnerProvider.claude:
|
||||
if stream:
|
||||
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"{api_base}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
|
||||
|
||||
def get_complete_vertex_url(
|
||||
self,
|
||||
custom_api_base: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
project_id: str,
|
||||
partner: VertexPartnerProvider,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
) -> str:
|
||||
# Use get_vertex_region to handle global-only models
|
||||
resolved_location = self.get_vertex_region(vertex_location, model)
|
||||
api_base = self.get_api_base(
|
||||
api_base=custom_api_base, vertex_location=resolved_location
|
||||
)
|
||||
default_api_base = VertexBase.create_vertex_url(
|
||||
vertex_location=resolved_location,
|
||||
vertex_project=vertex_project or project_id,
|
||||
partner=partner,
|
||||
stream=stream,
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=custom_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=stream,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
model=model,
|
||||
vertex_project=vertex_project or project_id,
|
||||
vertex_location=resolved_location,
|
||||
vertex_api_version="v1", # Partner models typically use v1
|
||||
)
|
||||
return api_base
|
||||
|
||||
def refresh_auth(self, credentials: Any) -> None:
|
||||
try:
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(GOOGLE_IMPORT_ERROR_MESSAGE)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
def _ensure_access_token(
|
||||
self,
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
project_id: Optional[str],
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns auth token and project id
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
return "", ""
|
||||
else:
|
||||
return self.get_access_token(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
use this helper to decide if request should be sent to v1 or v1beta1
|
||||
|
||||
Returns true if any beta feature is enabled
|
||||
Returns false in all other cases
|
||||
"""
|
||||
return False
|
||||
|
||||
def _check_custom_proxy(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
gemini_api_key: Optional[str],
|
||||
endpoint: str,
|
||||
stream: Optional[bool],
|
||||
auth_header: Optional[str],
|
||||
url: str,
|
||||
model: Optional[str] = None,
|
||||
vertex_project: Optional[str] = None,
|
||||
vertex_location: Optional[str] = None,
|
||||
vertex_api_version: Optional[Literal["v1", "v1beta1"]] = None,
|
||||
use_psc_endpoint_format: bool = False,
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||
|
||||
Handles custom api_base for:
|
||||
1. Gemini (Google AI Studio) - constructs /models/{model}:{endpoint}
|
||||
2. Vertex AI with standard proxies - constructs {api_base}:{endpoint}
|
||||
3. Vertex AI with PSC endpoints - constructs full path structure
|
||||
{api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
|
||||
(only when use_psc_endpoint_format=True)
|
||||
|
||||
Args:
|
||||
use_psc_endpoint_format: If True, constructs PSC endpoint URL format.
|
||||
If False (default), uses api_base as-is and appends :{endpoint}
|
||||
|
||||
## Returns
|
||||
- (auth_header, url) - Tuple[Optional[str], str]
|
||||
"""
|
||||
if api_base:
|
||||
if custom_llm_provider == "gemini":
|
||||
# For Gemini (Google AI Studio), construct the full path like other providers
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
"Model parameter is required for Gemini custom API base URLs"
|
||||
)
|
||||
url = "{}/models/{}:{}".format(api_base, model, endpoint)
|
||||
if gemini_api_key is None:
|
||||
raise ValueError(
|
||||
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
|
||||
)
|
||||
if gemini_api_key is not None:
|
||||
auth_header = {"x-goog-api-key": gemini_api_key} # type: ignore[assignment]
|
||||
else:
|
||||
# For Vertex AI
|
||||
if use_psc_endpoint_format:
|
||||
# User explicitly specified PSC endpoint format
|
||||
# Construct full PSC/custom endpoint URL
|
||||
if not (vertex_project and vertex_location and model):
|
||||
raise ValueError(
|
||||
"vertex_project, vertex_location, and model are required when use_psc_endpoint_format=True"
|
||||
)
|
||||
# Strip routing prefixes (bge/, gemma/, etc.) for endpoint URL construction
|
||||
model_for_url = get_vertex_base_model_name(model=model)
|
||||
# Format: {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
|
||||
version = vertex_api_version or "v1"
|
||||
url = "{}/{}/projects/{}/locations/{}/endpoints/{}:{}".format(
|
||||
api_base.rstrip("/"),
|
||||
version,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
model_for_url,
|
||||
endpoint,
|
||||
)
|
||||
else:
|
||||
# Fallback to simple format if we don't have all parameters
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
if stream is True:
|
||||
url = url + "?alt=sse"
|
||||
return auth_header, url
|
||||
|
||||
def _get_token_and_url(
|
||||
self,
|
||||
model: str,
|
||||
auth_header: Optional[str],
|
||||
gemini_api_key: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
should_use_v1beta1_features: Optional[bool] = False,
|
||||
mode: all_gemini_url_modes = "chat",
|
||||
use_psc_endpoint_format: bool = False,
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
||||
Handles logic if it's google ai studio vs. vertex ai.
|
||||
|
||||
Returns
|
||||
token, url
|
||||
"""
|
||||
version: Optional[Literal["v1beta1", "v1"]] = None
|
||||
if custom_llm_provider == "gemini":
|
||||
url, endpoint = _get_gemini_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
gemini_api_key=gemini_api_key,
|
||||
)
|
||||
auth_header = None # this field is not used for gemin
|
||||
else:
|
||||
vertex_location = self.get_vertex_region(
|
||||
vertex_region=vertex_location,
|
||||
model=model,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||
url, endpoint = _get_vertex_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version=version,
|
||||
)
|
||||
|
||||
return self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
auth_header=auth_header,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
gemini_api_key=gemini_api_key,
|
||||
endpoint=endpoint,
|
||||
stream=stream,
|
||||
url=url,
|
||||
model=model,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version=version,
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
def _handle_reauthentication(
|
||||
self,
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
project_id: Optional[str],
|
||||
credential_cache_key: Tuple,
|
||||
error: Exception,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Handle reauthentication when credentials refresh fails.
|
||||
|
||||
This method clears the cached credentials and attempts to reload them once.
|
||||
It should only be called when "Reauthentication is needed" error occurs.
|
||||
|
||||
Args:
|
||||
credentials: The original credentials
|
||||
project_id: The project ID
|
||||
credential_cache_key: The cache key to clear
|
||||
error: The original error that triggered reauthentication
|
||||
|
||||
Returns:
|
||||
Tuple of (access_token, project_id)
|
||||
|
||||
Raises:
|
||||
The original error if reauthentication fails
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Handling reauthentication for project_id: {project_id}. "
|
||||
f"Clearing cache and retrying once."
|
||||
)
|
||||
|
||||
# Clear the cached credentials
|
||||
if credential_cache_key in self._credentials_project_mapping:
|
||||
del self._credentials_project_mapping[credential_cache_key]
|
||||
|
||||
# Retry once with _retry_reauth=True to prevent infinite recursion
|
||||
try:
|
||||
return self.get_access_token(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
_retry_reauth=True,
|
||||
)
|
||||
except Exception as retry_error:
|
||||
verbose_logger.error(
|
||||
f"Reauthentication retry failed for project_id: {project_id}. "
|
||||
f"Original error: {str(error)}. Retry error: {str(retry_error)}"
|
||||
)
|
||||
# Re-raise the original error for better context
|
||||
raise error
|
||||
|
||||
def get_access_token(
|
||||
self,
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
project_id: Optional[str],
|
||||
_retry_reauth: bool = False,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Get access token and project id
|
||||
|
||||
1. Check if credentials are already in self._credentials_project_mapping
|
||||
2. If not, load credentials and add to self._credentials_project_mapping
|
||||
3. Check if loaded credentials have expired
|
||||
4. If expired, refresh credentials
|
||||
5. Return access token and project id
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for authentication
|
||||
project_id: The Google Cloud project ID
|
||||
_retry_reauth: Internal flag to prevent infinite recursion during reauthentication
|
||||
|
||||
Returns:
|
||||
Tuple of (access_token, project_id)
|
||||
"""
|
||||
|
||||
# Convert dict credentials to string for caching
|
||||
cache_credentials = (
|
||||
json.dumps(credentials) if isinstance(credentials, dict) else credentials
|
||||
)
|
||||
credential_cache_key = (cache_credentials, project_id)
|
||||
_credentials: Optional[GoogleCredentialsObject] = None
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Checking cached credentials for project_id: {project_id}"
|
||||
)
|
||||
|
||||
if credential_cache_key in self._credentials_project_mapping:
|
||||
verbose_logger.debug(
|
||||
f"Cached credentials found for project_id: {project_id}."
|
||||
)
|
||||
# Retrieve both credentials and cached project_id
|
||||
cached_entry = self._credentials_project_mapping[credential_cache_key]
|
||||
verbose_logger.debug("cached_entry: %s", cached_entry)
|
||||
if isinstance(cached_entry, tuple):
|
||||
_credentials, credential_project_id = cached_entry
|
||||
else:
|
||||
# Backward compatibility with old cache format
|
||||
_credentials = cached_entry
|
||||
credential_project_id = _credentials.quota_project_id or getattr(
|
||||
_credentials, "project_id", None
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Using cached credentials for project_id: %s",
|
||||
credential_project_id,
|
||||
)
|
||||
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Credential cache key not found for project_id: {project_id}, loading new credentials"
|
||||
)
|
||||
|
||||
try:
|
||||
_credentials, credential_project_id = self.load_auth(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Failed to load vertex credentials. Check to see if credentials containing partial/invalid information. Error: {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
if _credentials is None:
|
||||
raise ValueError(
|
||||
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
|
||||
project_id
|
||||
)
|
||||
)
|
||||
# Cache the project_id and credentials from load_auth result (resolved project_id)
|
||||
self._credentials_project_mapping[credential_cache_key] = (
|
||||
_credentials,
|
||||
credential_project_id,
|
||||
)
|
||||
|
||||
## VALIDATE CREDENTIALS
|
||||
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
|
||||
if (
|
||||
project_id is None
|
||||
and credential_project_id is not None
|
||||
and isinstance(credential_project_id, str)
|
||||
):
|
||||
project_id = credential_project_id
|
||||
# Update cache with resolved project_id for future lookups
|
||||
resolved_cache_key = (cache_credentials, project_id)
|
||||
if resolved_cache_key not in self._credentials_project_mapping:
|
||||
self._credentials_project_mapping[resolved_cache_key] = (
|
||||
_credentials,
|
||||
credential_project_id,
|
||||
)
|
||||
|
||||
# Check if credentials are None before accessing attributes
|
||||
if _credentials is None:
|
||||
raise ValueError("Credentials are None after loading")
|
||||
|
||||
if _credentials.expired:
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"Credentials expired, refreshing for project_id: {project_id}"
|
||||
)
|
||||
self.refresh_auth(_credentials)
|
||||
self._credentials_project_mapping[credential_cache_key] = (
|
||||
_credentials,
|
||||
credential_project_id,
|
||||
)
|
||||
except Exception as e:
|
||||
# if refresh fails, it's possible the user has re-authenticated via `gcloud auth application-default login`
|
||||
# in this case, we should try to reload the credentials by clearing the cache and retrying
|
||||
if "Reauthentication is needed" in str(e) and not _retry_reauth:
|
||||
return self._handle_reauthentication(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
credential_cache_key=credential_cache_key,
|
||||
error=e,
|
||||
)
|
||||
raise e
|
||||
|
||||
## VALIDATION STEP
|
||||
if _credentials.token is None or not isinstance(_credentials.token, str):
|
||||
raise ValueError(
|
||||
"Could not resolve credentials token. Got None or non-string token - {}".format(
|
||||
_credentials.token
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
return _credentials.token, project_id
|
||||
|
||||
async def _ensure_access_token_async(
|
||||
self,
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
project_id: Optional[str],
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Async version of _ensure_access_token
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
return "", ""
|
||||
else:
|
||||
try:
|
||||
return await asyncify(self.get_access_token)(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def set_headers(
|
||||
self, auth_header: Optional[str], extra_headers: Optional[dict]
|
||||
) -> dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if auth_header is not None:
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_vertex_ai_project(litellm_params: dict) -> Optional[str]:
|
||||
return (
|
||||
litellm_params.pop("vertex_project", None)
|
||||
or litellm_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_vertex_ai_credentials(litellm_params: dict) -> Optional[str]:
|
||||
return (
|
||||
litellm_params.pop("vertex_credentials", None)
|
||||
or litellm_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_vertex_ai_location(litellm_params: dict) -> Optional[str]:
|
||||
return (
|
||||
litellm_params.pop("vertex_location", None)
|
||||
or litellm_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def safe_get_vertex_ai_project(litellm_params: dict) -> Optional[str]:
|
||||
"""
|
||||
Safely get Vertex AI project without mutating the litellm_params dict.
|
||||
|
||||
Unlike get_vertex_ai_project(), this does NOT pop values from the dict,
|
||||
making it safe to call multiple times with the same litellm_params.
|
||||
|
||||
Args:
|
||||
litellm_params: Dictionary containing Vertex AI parameters
|
||||
|
||||
Returns:
|
||||
Vertex AI project ID or None
|
||||
"""
|
||||
return (
|
||||
litellm_params.get("vertex_project")
|
||||
or litellm_params.get("vertex_ai_project")
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def safe_get_vertex_ai_credentials(litellm_params: dict) -> Optional[str]:
|
||||
"""
|
||||
Safely get Vertex AI credentials without mutating the litellm_params dict.
|
||||
|
||||
Unlike get_vertex_ai_credentials(), this does NOT pop values from the dict,
|
||||
making it safe to call multiple times with the same litellm_params.
|
||||
|
||||
Args:
|
||||
litellm_params: Dictionary containing Vertex AI parameters
|
||||
|
||||
Returns:
|
||||
Vertex AI credentials or None
|
||||
"""
|
||||
return (
|
||||
litellm_params.get("vertex_credentials")
|
||||
or litellm_params.get("vertex_ai_credentials")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def safe_get_vertex_ai_location(litellm_params: dict) -> Optional[str]:
|
||||
"""
|
||||
Safely get Vertex AI location without mutating the litellm_params dict.
|
||||
|
||||
Unlike get_vertex_ai_location(), this does NOT pop values from the dict,
|
||||
making it safe to call multiple times with the same litellm_params.
|
||||
|
||||
Args:
|
||||
litellm_params: Dictionary containing Vertex AI parameters
|
||||
|
||||
Returns:
|
||||
Vertex AI location/region or None
|
||||
"""
|
||||
return (
|
||||
litellm_params.get("vertex_location")
|
||||
or litellm_params.get("vertex_ai_location")
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
API Handler for calling Vertex AI Model Garden Models
|
||||
|
||||
Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions`
|
||||
|
||||
Usage:
|
||||
|
||||
response = litellm.completion(
|
||||
model="vertex_ai/openai/5464397967697903616",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
|
||||
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
|
||||
|
||||
|
||||
Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ..common_utils import VertexAIError, get_vertex_base_model_name
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
def create_vertex_url(
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
api_base: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Return the base url for the vertex garden models"""
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
return f"{base_url}/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}"
|
||||
|
||||
|
||||
class VertexAIModelGardenModels(VertexBase):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
client=None,
|
||||
):
|
||||
"""
|
||||
Handles calling Vertex AI Model Garden Models in OpenAI compatible format
|
||||
|
||||
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
|
||||
"""
|
||||
try:
|
||||
import vertexai
|
||||
|
||||
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
model = get_vertex_base_model_name(model=model)
|
||||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
openai_like_chat_completions = OpenAILikeChatHandler()
|
||||
|
||||
## CONSTRUCT API BASE
|
||||
stream: bool = optional_params.get("stream", False) or False
|
||||
optional_params["stream"] = stream
|
||||
default_api_base = create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
stream=stream,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=stream,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
model=model,
|
||||
vertex_project=vertex_project or project_id,
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_api_version="v1beta1",
|
||||
)
|
||||
model = ""
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=access_token,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Vertex AI Video Generation Module
|
||||
|
||||
This module provides support for Vertex AI's Veo video generation API.
|
||||
"""
|
||||
|
||||
from .transformation import VertexAIVideoConfig
|
||||
|
||||
__all__ = ["VertexAIVideoConfig"]
|
||||
@@ -0,0 +1,636 @@
|
||||
"""
|
||||
Vertex AI Video Generation Transformation
|
||||
|
||||
Handles transformation of requests/responses for Vertex AI's Veo video generation API.
|
||||
Based on: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/veo-video-generation
|
||||
"""
|
||||
|
||||
import base64
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
from litellm.constants import DEFAULT_GOOGLE_VIDEO_DURATION_SECONDS
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
from litellm.llms.base_llm.videos.transformation import BaseVideoConfig
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
get_vertex_base_url,
|
||||
)
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.videos.main import VideoCreateOptionalRequestParams, VideoObject
|
||||
from litellm.types.videos.utils import (
|
||||
encode_video_id_with_provider,
|
||||
extract_original_video_id,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseLLMException as _BaseLLMException,
|
||||
)
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
def _convert_image_to_vertex_format(image_file) -> Dict[str, str]:
|
||||
"""
|
||||
Convert image file to Vertex AI format with base64 encoding and MIME type.
|
||||
|
||||
Args:
|
||||
image_file: File-like object opened in binary mode (e.g., open("path", "rb"))
|
||||
|
||||
Returns:
|
||||
Dict with bytesBase64Encoded and mimeType
|
||||
"""
|
||||
mime_type = ImageEditRequestUtils.get_image_content_type(image_file)
|
||||
|
||||
if hasattr(image_file, "seek"):
|
||||
image_file.seek(0)
|
||||
image_bytes = image_file.read()
|
||||
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
return {"bytesBase64Encoded": base64_encoded, "mimeType": mime_type}
|
||||
|
||||
|
||||
class VertexAIVideoConfig(BaseVideoConfig, VertexBase):
|
||||
"""
|
||||
Configuration class for Vertex AI (Veo) video generation.
|
||||
|
||||
Veo uses a long-running operation model:
|
||||
1. POST to :predictLongRunning returns operation name
|
||||
2. Poll operation using :fetchPredictOperation until done=true
|
||||
3. Extract video data (base64) from response
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
BaseVideoConfig.__init__(self)
|
||||
VertexBase.__init__(self)
|
||||
|
||||
@staticmethod
|
||||
def extract_model_from_operation_name(operation_name: str) -> Optional[str]:
|
||||
"""
|
||||
Extract the model name from a Vertex AI operation name.
|
||||
|
||||
Args:
|
||||
operation_name: Operation name in format:
|
||||
projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL/operations/OPERATION_ID
|
||||
|
||||
Returns:
|
||||
Model name (e.g., "veo-2.0-generate-001") or None if extraction fails
|
||||
"""
|
||||
parts = operation_name.split("/")
|
||||
# Model is at index 7 in the operation name format
|
||||
if len(parts) >= 8:
|
||||
return parts[7]
|
||||
return None
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the list of supported OpenAI parameters for Veo video generation.
|
||||
Veo supports minimal parameters compared to OpenAI.
|
||||
"""
|
||||
return ["model", "prompt", "input_reference", "seconds", "size"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
video_create_optional_params: VideoCreateOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Map OpenAI-style parameters to Veo format.
|
||||
|
||||
Mappings:
|
||||
- prompt → prompt (in instances)
|
||||
- input_reference → image (in instances)
|
||||
- size → aspectRatio (e.g., "1280x720" → "16:9")
|
||||
- seconds → durationSeconds (defaults to 4 seconds if not provided)
|
||||
"""
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
# Map input_reference to image (will be processed in transform_video_create_request)
|
||||
if "input_reference" in video_create_optional_params:
|
||||
mapped_params["image"] = video_create_optional_params["input_reference"]
|
||||
elif "image" in video_create_optional_params:
|
||||
mapped_params["image"] = video_create_optional_params["image"]
|
||||
|
||||
# Pass through a provider-specific parameters block if provided directly
|
||||
if "parameters" in video_create_optional_params:
|
||||
mapped_params["parameters"] = video_create_optional_params["parameters"]
|
||||
|
||||
# Map size to aspectRatio
|
||||
if "size" in video_create_optional_params:
|
||||
size = video_create_optional_params["size"]
|
||||
if size is not None:
|
||||
aspect_ratio = self._convert_size_to_aspect_ratio(size)
|
||||
if aspect_ratio:
|
||||
mapped_params["aspectRatio"] = aspect_ratio
|
||||
|
||||
# Map seconds to durationSeconds, default to 4 seconds (matching OpenAI)
|
||||
if "seconds" in video_create_optional_params:
|
||||
seconds = video_create_optional_params["seconds"]
|
||||
try:
|
||||
duration = int(seconds) if isinstance(seconds, str) else seconds
|
||||
if duration is not None:
|
||||
mapped_params["durationSeconds"] = duration
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, use default
|
||||
pass
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _convert_size_to_aspect_ratio(self, size: str) -> Optional[str]:
|
||||
"""
|
||||
Convert OpenAI size format to Veo aspectRatio format.
|
||||
|
||||
Supported aspect ratios: 9:16 (portrait), 16:9 (landscape)
|
||||
"""
|
||||
if not size:
|
||||
return None
|
||||
|
||||
aspect_ratio_map = {
|
||||
"1280x720": "16:9",
|
||||
"1920x1080": "16:9",
|
||||
"720x1280": "9:16",
|
||||
"1080x1920": "9:16",
|
||||
}
|
||||
|
||||
return aspect_ratio_map.get(size, "16:9")
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
litellm_params: Optional[Union[GenericLiteLLMParams, dict]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and return headers for Vertex AI OCR.
|
||||
|
||||
Vertex AI uses Bearer token authentication with access token from credentials.
|
||||
"""
|
||||
# Extract Vertex AI parameters using safe helpers from VertexBase
|
||||
# Use safe_get_* methods that don't mutate litellm_params dict
|
||||
# Ensure litellm_params is a dict for type checking
|
||||
params_dict: Dict[str, Any] = (
|
||||
cast(Dict[str, Any], litellm_params) if litellm_params is not None else {}
|
||||
)
|
||||
|
||||
vertex_project = VertexBase.safe_get_vertex_ai_project(
|
||||
litellm_params=params_dict
|
||||
)
|
||||
vertex_credentials = VertexBase.safe_get_vertex_ai_credentials(
|
||||
litellm_params=params_dict
|
||||
)
|
||||
|
||||
# Get access token from Vertex credentials
|
||||
access_token, project_id = self.get_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
**headers,
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Veo video generation.
|
||||
|
||||
Returns URL for :predictLongRunning endpoint:
|
||||
https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL:predictLongRunning
|
||||
"""
|
||||
vertex_project = VertexBase.safe_get_vertex_ai_project(litellm_params)
|
||||
vertex_location = VertexBase.safe_get_vertex_ai_location(litellm_params)
|
||||
|
||||
if not vertex_project:
|
||||
raise ValueError(
|
||||
"vertex_project is required for Vertex AI video generation. "
|
||||
"Set it via environment variable VERTEXAI_PROJECT or pass as parameter."
|
||||
)
|
||||
|
||||
# Default to us-central1 if no location specified
|
||||
vertex_location = vertex_location or "us-central1"
|
||||
|
||||
# Extract model name (remove vertex_ai/ prefix if present)
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
# Construct the URL
|
||||
if api_base:
|
||||
base_url = api_base.rstrip("/")
|
||||
else:
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}"
|
||||
|
||||
return url
|
||||
|
||||
def transform_video_create_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
video_create_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles, str]:
|
||||
"""
|
||||
Transform the video creation request for Veo API.
|
||||
|
||||
Veo expects:
|
||||
{
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "A cat playing with a ball of yarn",
|
||||
"image": {
|
||||
"bytesBase64Encoded": "...",
|
||||
"mimeType": "image/jpeg"
|
||||
}
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"aspectRatio": "16:9",
|
||||
"durationSeconds": 8
|
||||
}
|
||||
}
|
||||
"""
|
||||
# Build instance with prompt
|
||||
instance_dict: Dict[str, Any] = {"prompt": prompt}
|
||||
params_copy = video_create_optional_request_params.copy()
|
||||
|
||||
# Check if user wants to provide full instance dict
|
||||
if "instances" in params_copy and isinstance(params_copy["instances"], dict):
|
||||
# Replace/merge with user-provided instance
|
||||
instance_dict.update(params_copy["instances"])
|
||||
params_copy.pop("instances")
|
||||
elif "image" in params_copy and params_copy["image"] is not None:
|
||||
image = params_copy["image"]
|
||||
if isinstance(image, dict):
|
||||
# Already in Vertex format e.g. {"gcsUri": "gs://..."} or
|
||||
# {"bytesBase64Encoded": "...", "mimeType": "..."}
|
||||
image_data = image
|
||||
elif isinstance(image, str) and image.startswith("gs://"):
|
||||
# Bare GCS URI — Vertex AI accepts gcsUri natively, no download needed
|
||||
image_data = {"gcsUri": image}
|
||||
elif isinstance(image, str):
|
||||
raise ValueError(
|
||||
f"Unsupported image value '{image}'. "
|
||||
"Provide a GCS URI (gs://...), a dict with 'gcsUri' or "
|
||||
"'bytesBase64Encoded'/'mimeType', or a binary file-like object."
|
||||
)
|
||||
else:
|
||||
# File-like object — encode to base64
|
||||
image_data = _convert_image_to_vertex_format(image)
|
||||
instance_dict["image"] = image_data
|
||||
params_copy.pop("image")
|
||||
|
||||
# Extract a nested "parameters" block that map_openai_params may have placed
|
||||
# inside params_copy (e.g. from provider-specific pass-through). Merging it
|
||||
# flat prevents the double-nesting bug:
|
||||
# {"parameters": {"parameters": {...}}} ← wrong
|
||||
# {"parameters": {...}} ← correct
|
||||
nested_params = params_copy.pop("parameters", None)
|
||||
vertex_params: Dict[str, Any] = {}
|
||||
if isinstance(nested_params, dict):
|
||||
vertex_params.update(nested_params)
|
||||
vertex_params.update(params_copy)
|
||||
|
||||
# Build request data directly (TypedDict doesn't have model_dump)
|
||||
request_data: Dict[str, Any] = {"instances": [instance_dict]}
|
||||
|
||||
# Only add parameters if there are any
|
||||
if vertex_params:
|
||||
request_data["parameters"] = vertex_params
|
||||
|
||||
# Append :predictLongRunning endpoint to api_base
|
||||
url = f"{api_base}:predictLongRunning"
|
||||
|
||||
# No files needed - everything is in JSON
|
||||
return request_data, [], url
|
||||
|
||||
def transform_video_create_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
request_data: Optional[Dict] = None,
|
||||
) -> VideoObject:
|
||||
"""
|
||||
Transform the Veo video creation response.
|
||||
|
||||
Veo returns:
|
||||
{
|
||||
"name": "projects/PROJECT_ID/locations/LOCATION/publishers/google/models/MODEL/operations/OPERATION_ID"
|
||||
}
|
||||
|
||||
We return this as a VideoObject with:
|
||||
- id: operation name (used for polling)
|
||||
- status: "processing"
|
||||
- usage: includes duration_seconds for cost calculation
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
|
||||
operation_name = response_data.get("name")
|
||||
if not operation_name:
|
||||
raise ValueError(f"No operation name in Veo response: {response_data}")
|
||||
|
||||
if custom_llm_provider:
|
||||
video_id = encode_video_id_with_provider(
|
||||
operation_name, custom_llm_provider, model
|
||||
)
|
||||
else:
|
||||
video_id = operation_name
|
||||
|
||||
video_obj = VideoObject(
|
||||
id=video_id, object="video", status="processing", model=model
|
||||
)
|
||||
|
||||
usage_data = {}
|
||||
if request_data:
|
||||
parameters = request_data.get("parameters", {})
|
||||
duration = (
|
||||
parameters.get("durationSeconds")
|
||||
or DEFAULT_GOOGLE_VIDEO_DURATION_SECONDS
|
||||
)
|
||||
if duration is not None:
|
||||
try:
|
||||
usage_data["duration_seconds"] = float(duration)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
video_obj.usage = usage_data
|
||||
return video_obj
|
||||
|
||||
def transform_video_status_retrieve_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video status retrieve request for Veo API.
|
||||
|
||||
Veo polls operations using :fetchPredictOperation endpoint with POST request.
|
||||
"""
|
||||
operation_name = extract_original_video_id(video_id)
|
||||
model = self.extract_model_from_operation_name(operation_name)
|
||||
|
||||
if not model:
|
||||
raise ValueError(
|
||||
f"Invalid operation name format: {operation_name}. "
|
||||
"Expected format: projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL/operations/OPERATION_ID"
|
||||
)
|
||||
|
||||
# Construct the full URL including model ID
|
||||
# URL format: https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/publishers/google/models/MODEL:fetchPredictOperation
|
||||
# Strip trailing slashes from api_base and append model
|
||||
url = f"{api_base.rstrip('/')}/{model}:fetchPredictOperation"
|
||||
|
||||
# Request body contains the operation name
|
||||
params = {"operationName": operation_name}
|
||||
|
||||
return url, params
|
||||
|
||||
def transform_video_status_retrieve_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
"""
|
||||
Transform the Veo operation status response.
|
||||
|
||||
Veo returns:
|
||||
{
|
||||
"name": "projects/.../operations/OPERATION_ID",
|
||||
"done": false # or true when complete
|
||||
}
|
||||
|
||||
When done=true:
|
||||
{
|
||||
"name": "projects/.../operations/OPERATION_ID",
|
||||
"done": true,
|
||||
"response": {
|
||||
"@type": "type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse",
|
||||
"raiMediaFilteredCount": 0,
|
||||
"videos": [
|
||||
{
|
||||
"bytesBase64Encoded": "...",
|
||||
"mimeType": "video/mp4"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
|
||||
operation_name = response_data.get("name", "")
|
||||
is_done = response_data.get("done", False)
|
||||
error_data = response_data.get("error")
|
||||
|
||||
# Extract model from operation name
|
||||
model = self.extract_model_from_operation_name(operation_name)
|
||||
|
||||
if custom_llm_provider:
|
||||
video_id = encode_video_id_with_provider(
|
||||
operation_name, custom_llm_provider, model
|
||||
)
|
||||
else:
|
||||
video_id = operation_name
|
||||
|
||||
# Convert createTime to Unix timestamp
|
||||
create_time_str = response_data.get("metadata", {}).get("createTime")
|
||||
if create_time_str:
|
||||
try:
|
||||
created_at = _convert_vertex_datetime_to_openai_datetime(
|
||||
create_time_str
|
||||
)
|
||||
except Exception:
|
||||
created_at = int(time.time())
|
||||
else:
|
||||
created_at = int(time.time())
|
||||
|
||||
if error_data:
|
||||
status = "failed"
|
||||
elif is_done:
|
||||
status = "completed"
|
||||
else:
|
||||
status = "processing"
|
||||
|
||||
video_obj = VideoObject(
|
||||
id=video_id,
|
||||
object="video",
|
||||
status=status,
|
||||
model=model,
|
||||
created_at=created_at,
|
||||
error=error_data,
|
||||
)
|
||||
return video_obj
|
||||
|
||||
def transform_video_content_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video content request for Veo API.
|
||||
|
||||
For Veo, we need to:
|
||||
1. Poll the operation status to ensure it's complete
|
||||
2. Extract the base64 video data from the response
|
||||
3. Return it for decoding
|
||||
|
||||
Since we need to make an HTTP call here, we'll use the same fetchPredictOperation
|
||||
approach as status retrieval.
|
||||
"""
|
||||
return self.transform_video_status_retrieve_request(
|
||||
video_id, api_base, litellm_params, headers
|
||||
)
|
||||
|
||||
def transform_video_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> bytes:
|
||||
"""
|
||||
Transform the Veo video content download response.
|
||||
|
||||
Extracts the base64 encoded video from the response and decodes it to bytes.
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
|
||||
if not response_data.get("done", False):
|
||||
raise ValueError(
|
||||
"Video generation is not complete yet. "
|
||||
"Please check status with video_status() before downloading."
|
||||
)
|
||||
|
||||
try:
|
||||
video_response = response_data.get("response", {})
|
||||
videos = video_response.get("videos", [])
|
||||
|
||||
if not videos or len(videos) == 0:
|
||||
raise ValueError("No video data found in completed operation")
|
||||
|
||||
# Get the first video
|
||||
video_data = videos[0]
|
||||
base64_encoded = video_data.get("bytesBase64Encoded")
|
||||
|
||||
if not base64_encoded:
|
||||
raise ValueError("No base64 encoded video data found")
|
||||
|
||||
# Decode base64 to bytes
|
||||
video_bytes = base64.b64decode(base64_encoded)
|
||||
return video_bytes
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
raise ValueError(f"Failed to extract video data: {e}")
|
||||
|
||||
def transform_video_remix_request(
|
||||
self,
|
||||
video_id: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Video remix is not supported by Veo API.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Video remix is not supported by Vertex AI Veo. "
|
||||
"Please use video_generation() to create new videos."
|
||||
)
|
||||
|
||||
def transform_video_remix_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
"""Video remix is not supported."""
|
||||
raise NotImplementedError("Video remix is not supported by Vertex AI Veo.")
|
||||
|
||||
def transform_video_list_request(
|
||||
self,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
order: Optional[str] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Video list is not supported by Veo API.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Video list is not supported by Vertex AI Veo. "
|
||||
"Use the operations endpoint directly if you need to list operations."
|
||||
)
|
||||
|
||||
def transform_video_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Video list is not supported."""
|
||||
raise NotImplementedError("Video list is not supported by Vertex AI Veo.")
|
||||
|
||||
def transform_video_delete_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Video delete is not supported by Veo API.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Video delete is not supported by Vertex AI Veo. "
|
||||
"Videos are automatically cleaned up by Google."
|
||||
)
|
||||
|
||||
def transform_video_delete_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> VideoObject:
|
||||
"""Video delete is not supported."""
|
||||
raise NotImplementedError("Video delete is not supported by Vertex AI Veo.")
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from litellm.llms.vertex_ai.common_utils import VertexAIError
|
||||
|
||||
return VertexAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
Reference in New Issue
Block a user