chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
VertexAIModelRoute,
|
||||
get_vertex_ai_model_route,
|
||||
)
|
||||
|
||||
from .vertex_gemini_transformation import VertexAIGeminiImageGenerationConfig
|
||||
from .vertex_imagen_transformation import VertexAIImagenImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"VertexAIGeminiImageGenerationConfig",
|
||||
"VertexAIImagenImageGenerationConfig",
|
||||
"get_vertex_ai_image_generation_config",
|
||||
]
|
||||
|
||||
|
||||
def get_vertex_ai_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
"""
|
||||
Get the appropriate image generation config for a Vertex AI model.
|
||||
|
||||
Routes to the correct transformation class based on the model type:
|
||||
- Gemini image generation models use generateContent API (VertexAIGeminiImageGenerationConfig)
|
||||
- Imagen models use predict API (VertexAIImagenImageGenerationConfig)
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gemini-2.5-flash-image", "imagegeneration@006")
|
||||
|
||||
Returns:
|
||||
BaseImageGenerationConfig: The appropriate configuration class
|
||||
"""
|
||||
# Determine the model route
|
||||
model_route = get_vertex_ai_model_route(model)
|
||||
|
||||
if model_route == VertexAIModelRoute.GEMINI:
|
||||
# Gemini models use generateContent API
|
||||
return VertexAIGeminiImageGenerationConfig()
|
||||
else:
|
||||
# Default to Imagen for other models (imagegeneration, etc.)
|
||||
# This includes NON_GEMINI models like imagegeneration@006
|
||||
return VertexAIImagenImageGenerationConfig()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Vertex AI Image Generation Cost Calculator
|
||||
"""
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||
calculate_image_response_cost_from_usage,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
) -> float:
|
||||
"""
|
||||
Vertex AI Image Generation Cost Calculator
|
||||
"""
|
||||
_model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
token_based_cost = calculate_image_response_cost_from_usage(
|
||||
model=model,
|
||||
image_response=image_response,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
if token_based_cost is not None:
|
||||
return token_based_cost
|
||||
|
||||
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,282 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from openai.types.image import Image
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class VertexImageGeneration(VertexLLM):
|
||||
def process_image_generation_response(
|
||||
self,
|
||||
json_response: Dict[str, Any],
|
||||
model_response: ImageResponse,
|
||||
model: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
if "predictions" not in json_response:
|
||||
raise litellm.InternalServerError(
|
||||
message=f"image generation response does not contain 'predictions', got {json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
predictions = json_response["predictions"]
|
||||
response_data: List[Image] = []
|
||||
|
||||
for prediction in predictions:
|
||||
bytes_base64_encoded = prediction["bytesBase64Encoded"]
|
||||
image_object = Image(b64_json=bytes_base64_encoded)
|
||||
response_data.append(image_object)
|
||||
|
||||
model_response.data = response_data
|
||||
return model_response
|
||||
|
||||
def transform_optional_params(self, optional_params: Optional[dict]) -> dict:
|
||||
"""
|
||||
Transform the optional params to the format expected by the Vertex AI API.
|
||||
For example, "aspect_ratio" is transformed to "aspectRatio".
|
||||
"""
|
||||
default_params = {
|
||||
"sampleCount": 1,
|
||||
}
|
||||
if optional_params is None:
|
||||
return default_params
|
||||
|
||||
def snake_to_camel(snake_str: str) -> str:
|
||||
"""Convert snake_case to camelCase"""
|
||||
components = snake_str.split("_")
|
||||
return components[0] + "".join(word.capitalize() for word in components[1:])
|
||||
|
||||
transformed_params = default_params.copy()
|
||||
for key, value in optional_params.items():
|
||||
if "_" in key:
|
||||
camel_case_key = snake_to_camel(key)
|
||||
transformed_params[camel_case_key] = value
|
||||
else:
|
||||
transformed_params[key] = value
|
||||
|
||||
return transformed_params
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
api_base: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
model_response: ImageResponse,
|
||||
logging_obj: Any,
|
||||
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[Any] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
aimg_generation=False,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
if aimg_generation is True:
|
||||
return self.aimage_generation( # type: ignore
|
||||
prompt=prompt,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
model=model,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
auth_header: Optional[str] = None
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=None,
|
||||
auth_header=auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="image_generation",
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
||||
# Transform optional params to camelCase format
|
||||
optional_params = self.transform_optional_params(optional_params)
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return self.process_image_generation_response(
|
||||
json_response, model_response, model
|
||||
)
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
api_base: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
model_response: ImageResponse,
|
||||
logging_obj: Any,
|
||||
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
else:
|
||||
self.async_handler = client # type: ignore
|
||||
|
||||
# make POST request to
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||
|
||||
"""
|
||||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-d {
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "a cat"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"sampleCount": 1
|
||||
}
|
||||
} \
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
auth_header: Optional[str] = None
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=None,
|
||||
auth_header=auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="image_generation",
|
||||
)
|
||||
|
||||
# Transform optional params to camelCase format
|
||||
optional_params = self.transform_optional_params(optional_params)
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return self.process_image_generation_response(
|
||||
json_response, model_response, model
|
||||
)
|
||||
|
||||
def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
|
||||
if "predictions" in json_response:
|
||||
if "bytesBase64Encoded" in json_response["predictions"][0]:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,327 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ImageObject,
|
||||
ImageResponse,
|
||||
ImageUsage,
|
||||
ImageUsageInputTokensDetails,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIGeminiImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Gemini Image Generation Configuration
|
||||
|
||||
Uses generateContent API for Gemini image generation models on Vertex AI
|
||||
Supports models like gemini-2.5-flash-image, gemini-3-pro-image-preview, etc.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageGenerationConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Gemini image generation supported parameters
|
||||
|
||||
Includes native Gemini imageConfig params (aspectRatio, imageSize)
|
||||
in both camelCase and snake_case variants.
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"size",
|
||||
"aspectRatio",
|
||||
"aspect_ratio",
|
||||
"imageSize",
|
||||
"image_size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
mapped_params = {}
|
||||
|
||||
for k, v in non_default_params.items():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Map OpenAI parameters to Gemini format
|
||||
if k == "n":
|
||||
mapped_params["candidate_count"] = v
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Gemini aspectRatio
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
|
||||
elif k in ("aspectRatio", "aspect_ratio"):
|
||||
mapped_params["aspectRatio"] = v
|
||||
elif k in ("imageSize", "image_size"):
|
||||
mapped_params["imageSize"] = v
|
||||
else:
|
||||
mapped_params[k] = v
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
Map OpenAI size format to Gemini aspect ratio format
|
||||
"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Gemini generateContent API
|
||||
"""
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
# If a custom api_base is provided, use it directly
|
||||
# This allows users to use proxies or mock endpoints
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_location = (
|
||||
self.safe_get_vertex_ai_location(litellm_params)
|
||||
or self._resolve_vertex_location()
|
||||
)
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
|
||||
# If a custom api_base is provided, skip credential validation
|
||||
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
|
||||
_api_base = litellm_params.get("api_base") or api_base
|
||||
if _api_base is not None:
|
||||
return headers
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_credentials = (
|
||||
self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
or self._resolve_vertex_credentials()
|
||||
)
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Gemini format
|
||||
|
||||
Uses generateContent API with responseModalities: ["IMAGE"]
|
||||
"""
|
||||
# Prepare messages with the prompt
|
||||
contents = [{"role": "user", "parts": [{"text": prompt}]}]
|
||||
|
||||
# Prepare generation config
|
||||
generation_config: Dict[str, Any] = {"responseModalities": ["IMAGE"]}
|
||||
|
||||
# Handle image-specific config parameters
|
||||
image_config: Dict[str, Any] = {}
|
||||
|
||||
# Map aspectRatio
|
||||
if "aspectRatio" in optional_params:
|
||||
image_config["aspectRatio"] = optional_params["aspectRatio"]
|
||||
elif "aspect_ratio" in optional_params:
|
||||
image_config["aspectRatio"] = optional_params["aspect_ratio"]
|
||||
|
||||
# Map imageSize (for Gemini 3 Pro)
|
||||
if "imageSize" in optional_params:
|
||||
image_config["imageSize"] = optional_params["imageSize"]
|
||||
elif "image_size" in optional_params:
|
||||
image_config["imageSize"] = optional_params["image_size"]
|
||||
|
||||
if image_config:
|
||||
generation_config["imageConfig"] = image_config
|
||||
|
||||
# Handle candidate_count (n parameter)
|
||||
if "candidate_count" in optional_params:
|
||||
generation_config["candidateCount"] = optional_params["candidate_count"]
|
||||
elif "n" in optional_params:
|
||||
generation_config["candidateCount"] = optional_params["n"]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"generationConfig": generation_config,
|
||||
}
|
||||
|
||||
return request_body
|
||||
|
||||
def _transform_image_usage(self, usage: dict) -> ImageUsage:
|
||||
input_tokens_details = ImageUsageInputTokensDetails(
|
||||
image_tokens=0,
|
||||
text_tokens=0,
|
||||
)
|
||||
tokens_details = usage.get("promptTokensDetails", [])
|
||||
for details in tokens_details:
|
||||
if isinstance(details, dict) and (modality := details.get("modality")):
|
||||
token_count = details.get("tokenCount", 0)
|
||||
if modality == "TEXT":
|
||||
input_tokens_details.text_tokens += token_count
|
||||
elif modality == "IMAGE":
|
||||
input_tokens_details.image_tokens += token_count
|
||||
|
||||
return ImageUsage(
|
||||
input_tokens=usage.get("promptTokenCount", 0),
|
||||
input_tokens_details=input_tokens_details,
|
||||
output_tokens=usage.get("candidatesTokenCount", 0),
|
||||
total_tokens=usage.get("totalTokenCount", 0),
|
||||
)
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Gemini image generation response to litellm ImageResponse format
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Gemini image generation models return in candidates format
|
||||
candidates = response_data.get("candidates", [])
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
# Look for inlineData with image
|
||||
if "inlineData" in part:
|
||||
inline_data = part["inlineData"]
|
||||
if "data" in inline_data:
|
||||
thought_sig = part.get("thoughtSignature")
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=inline_data["data"],
|
||||
url=None,
|
||||
provider_specific_fields={
|
||||
"thought_signature": thought_sig
|
||||
}
|
||||
if thought_sig
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
if usage_metadata := response_data.get("usageMetadata", None):
|
||||
model_response.usage = self._transform_image_usage(usage_metadata)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,256 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIImagenImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Imagen Image Generation Configuration
|
||||
|
||||
Uses predict API for Imagen models on Vertex AI
|
||||
Supports models like imagegeneration@006
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageGenerationConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Imagen API supported parameters
|
||||
"""
|
||||
return ["n", "size"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
mapped_params = {}
|
||||
|
||||
for k, v in non_default_params.items():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Map OpenAI parameters to Imagen format
|
||||
if k == "n":
|
||||
mapped_params["sampleCount"] = v
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Imagen aspectRatio
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
|
||||
else:
|
||||
mapped_params[k] = v
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
Map OpenAI size format to Imagen aspect ratio format
|
||||
"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Imagen predict API
|
||||
"""
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
# If a custom api_base is provided, use it directly
|
||||
# This allows users to use proxies or mock endpoints
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_location = (
|
||||
self.safe_get_vertex_ai_location(litellm_params)
|
||||
or self._resolve_vertex_location()
|
||||
)
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
|
||||
# If a custom api_base is provided, skip credential validation
|
||||
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
|
||||
_api_base = litellm_params.get("api_base") or api_base
|
||||
if _api_base is not None:
|
||||
return headers
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_credentials = (
|
||||
self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
or self._resolve_vertex_credentials()
|
||||
)
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Imagen format
|
||||
|
||||
Uses predict API with instances and parameters
|
||||
"""
|
||||
# Default parameters
|
||||
default_params = {
|
||||
"sampleCount": 1,
|
||||
}
|
||||
|
||||
# Merge with optional params
|
||||
parameters = {**default_params, **optional_params}
|
||||
|
||||
request_body = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
return request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Imagen image generation response to litellm ImageResponse format
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Imagen format - predictions with generated images
|
||||
predictions = response_data.get("predictions", [])
|
||||
for prediction in predictions:
|
||||
# Imagen returns images as bytesBase64Encoded
|
||||
if "bytesBase64Encoded" in prediction:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=prediction["bytesBase64Encoded"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
Reference in New Issue
Block a user