chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway
This folder contains the GCS Bucket Logging integration for LiteLLM Gateway.
## Folder Structure
- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket
- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets
## Further Reading
- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/observability/gcs_bucket_integration)
- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging)

View File

@@ -0,0 +1,419 @@
import asyncio
import hashlib
import json
import os
import time
from litellm._uuid import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.proxy._types import CommonProxyErrors
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
else:
VertexBase = Any
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
self.flush_interval = int(
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
)
self.use_batched_logging = (
os.getenv(
"GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()
).lower()
== "true"
)
self.flush_lock = asyncio.Lock()
super().__init__(
flush_lock=self.flush_lock,
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue( # type: ignore[assignment]
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
)
asyncio.create_task(self.periodic_flush())
AdditionalLoggingUtils.__init__(self)
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
#### ASYNC ####
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
if self.log_queue.full():
await self.flush_queue()
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
if self.log_queue.full():
await self.flush_queue()
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
def _drain_queue_batch(self) -> List[GCSLogQueueItem]:
"""
Drain items from the queue (non-blocking), respecting batch_size limit.
This prevents unbounded queue growth when processing is slower than log accumulation.
Returns:
List of items to process, up to batch_size items
"""
items_to_process: List[GCSLogQueueItem] = []
while len(items_to_process) < self.batch_size:
try:
items_to_process.append(self.log_queue.get_nowait())
except asyncio.QueueEmpty:
break
return items_to_process
def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str:
"""
Generate object name for a batched log file.
Format: {date}/batch-{batch_id}.ndjson
"""
return f"{date_str}/batch-{batch_id}.ndjson"
def _get_config_key(self, kwargs: Dict[str, Any]) -> str:
"""
Extract a synchronous grouping key from kwargs to group items by GCS config.
This allows us to batch items with the same bucket/credentials together.
Returns a string key that uniquely identifies the GCS config combination.
This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key()
for logging purposes.
"""
standard_callback_dynamic_params = (
kwargs.get("standard_callback_dynamic_params", None) or {}
)
bucket_name = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
or "default"
)
path_service_account = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
or "default"
)
return f"{bucket_name}|{path_service_account}"
def _sanitize_config_key(self, config_key: str) -> str:
"""
Create a sanitized version of the config key for logging.
Uses a hash to avoid exposing sensitive bucket names or service account paths.
Returns a short hash prefix for safe logging.
"""
hash_obj = hashlib.sha256(config_key.encode("utf-8"))
return f"config-{hash_obj.hexdigest()[:8]}"
def _group_items_by_config(
self, items: List[GCSLogQueueItem]
) -> Dict[str, List[GCSLogQueueItem]]:
"""
Group items by their GCS config (bucket + credentials).
This ensures items with different configs are processed separately.
Returns a dict mapping config_key -> list of items with that config.
"""
grouped: Dict[str, List[GCSLogQueueItem]] = {}
for item in items:
config_key = self._get_config_key(item["kwargs"])
if config_key not in grouped:
grouped[config_key] = []
grouped[config_key].append(item)
return grouped
def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str:
"""
Combine multiple log payloads into newline-delimited JSON (NDJSON) format.
Each line is a valid JSON object representing one log entry.
"""
lines = []
for item in items:
logging_payload = item["payload"]
json_line = json.dumps(logging_payload, default=str, ensure_ascii=False)
lines.append(json_line)
return "\n".join(lines)
async def _send_grouped_batch(
self, items: List[GCSLogQueueItem], config_key: str
) -> Tuple[int, int]:
"""
Send a batch of items that share the same GCS config.
Returns:
(success_count, error_count)
"""
if not items:
return (0, 0)
first_kwargs = items[0]["kwargs"]
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
first_kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
current_date = self._get_object_date_from_datetime(
datetime.now(timezone.utc)
)
batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
object_name = self._generate_batch_object_name(current_date, batch_id)
combined_payload = self._combine_payloads_to_ndjson(items)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=combined_payload,
)
success_count = len(items)
error_count = 0
return (success_count, error_count)
except Exception as e:
success_count = 0
error_count = len(items)
verbose_logger.exception(
f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}"
)
return (success_count, error_count)
async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None:
"""
Send each log individually as separate GCS objects (legacy behavior).
This is used when GCS_USE_BATCHED_LOGGING is disabled.
"""
for item in items:
await self._send_single_log_item(item)
async def _send_single_log_item(self, item: GCSLogQueueItem) -> None:
"""
Send a single log item to GCS as an individual object.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
item["kwargs"]
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(
kwargs=item["kwargs"],
logging_payload=item["payload"],
response_obj=item["response_obj"],
)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=item["payload"],
)
except Exception as e:
verbose_logger.exception(
f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}"
)
async def async_send_batch(self):
"""
Process queued logs - sends logs to GCS Bucket.
If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads
into single GCS object uploads (NDJSON format), dramatically reducing API calls.
If disabled, sends each log individually as separate GCS objects (legacy behavior).
"""
items_to_process = self._drain_queue_batch()
if not items_to_process:
return
if self.use_batched_logging:
grouped_items = self._group_items_by_config(items_to_process)
for config_key, group_items in grouped_items.items():
await self._send_grouped_batch(group_items, config_key)
else:
await self._send_individual_logs(items_to_process)
def _get_object_name(
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
) -> str:
"""
Get the object name to use for the current payload
"""
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
if logging_payload.get("error_str", None) is not None:
object_name = self._generate_failure_object_name(
request_date_str=current_date,
)
else:
object_name = self._generate_success_object_name(
request_date_str=current_date,
response_id=response_obj.get("id", ""),
)
# used for testing
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
if "gcs_log_id" in _metadata:
object_name = _metadata["gcs_log_id"]
return object_name
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
Tries current day, next day, and previous day until it finds the payload
"""
if start_time_utc is None:
raise ValueError(
"start_time_utc is required for getting a payload from GCS Bucket"
)
dates_to_try = [
start_time_utc,
start_time_utc + timedelta(days=1),
start_time_utc - timedelta(days=1),
]
date_str = None
for date in dates_to_try:
try:
date_str = self._get_object_date_from_datetime(datetime_obj=date)
object_name = self._generate_success_object_name(
request_date_str=date_str,
response_id=request_id,
)
encoded_object_name = quote(object_name, safe="")
response = await self.download_gcs_object(encoded_object_name)
if response is not None:
loaded_response = json.loads(response)
return loaded_response
except Exception as e:
verbose_logger.debug(
f"Failed to fetch payload for date {date_str}: {str(e)}"
)
continue
return None
def _generate_success_object_name(
self,
request_date_str: str,
response_id: str,
) -> str:
return f"{request_date_str}/{response_id}"
def _generate_failure_object_name(
self,
request_date_str: str,
) -> str:
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
return datetime_obj.strftime("%Y-%m-%d")
async def flush_queue(self):
"""
Override flush_queue to work with asyncio.Queue.
"""
await self.async_send_batch()
self.last_flush_time = time.time()
async def periodic_flush(self):
"""
Override periodic_flush to work with asyncio.Queue.
"""
while True:
await asyncio.sleep(self.flush_interval)
verbose_logger.debug(
f"GCS Bucket periodic flush after {self.flush_interval} seconds"
)
await self.flush_queue()
async def async_health_check(self) -> IntegrationHealthCheckStatus:
raise NotImplementedError("GCS Bucket does not support health check")

View File

@@ -0,0 +1,347 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from litellm.integrations.gcs_bucket.gcs_bucket_mock_client import (
should_use_gcs_mock,
create_mock_gcs_client,
mock_vertex_auth_methods,
)
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
if TYPE_CHECKING:
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
else:
VertexBase = Any
IAM_AUTH_KEY = "IAM_AUTH"
class GCSBucketBase(CustomBatchLogger):
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
self.is_mock_mode = should_use_gcs_mock()
if self.is_mock_mode:
mock_vertex_auth_methods()
create_mock_gcs_client()
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
_path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
self.path_service_account_json: Optional[str] = _path_service_account
self.BUCKET_NAME: Optional[str] = _bucket_name
self.vertex_instances: Dict[str, VertexBase] = {}
super().__init__(**kwargs)
async def construct_request_headers(
self,
service_account_json: Optional[str],
vertex_instance: Optional[VertexBase] = None,
) -> Dict[str, str]:
from litellm import vertex_chat_completion
if vertex_instance is None:
vertex_instance = vertex_chat_completion
_auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
credentials=service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_instance._get_token_and_url(
model="gcs-bucket",
auth_header=_auth_header,
vertex_credentials=service_account_json,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
verbose_logger.debug("constructed auth_header %s", auth_header)
headers = {
"Authorization": f"Bearer {auth_header}", # auth_header
"Content-Type": "application/json",
}
return headers
def sync_construct_request_headers(self) -> Dict[str, str]:
"""
Construct request headers for GCS API calls
"""
from litellm import vertex_chat_completion
# Get project_id from environment if available, otherwise None
# This helps support use of this library to auth to pull secrets
# from Secret Manager.
project_id = os.getenv("GOOGLE_SECRET_MANAGER_PROJECT_ID")
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
credentials=self.path_service_account_json,
project_id=project_id,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="gcs-bucket",
auth_header=_auth_header,
vertex_credentials=self.path_service_account_json,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
verbose_logger.debug("constructed auth_header %s", auth_header)
headers = {
"Authorization": f"Bearer {auth_header}", # auth_header
"Content-Type": "application/json",
}
return headers
def _handle_folders_in_bucket_name(
self,
bucket_name: str,
object_name: str,
) -> Tuple[str, str]:
"""
Handles when the user passes a bucket name with a folder postfix
Example:
- Bucket name: "my-bucket/my-folder/dev"
- Object name: "my-object"
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
"""
if "/" in bucket_name:
bucket_name, prefix = bucket_name.split("/", 1)
object_name = f"{prefix}/{object_name}"
return bucket_name, object_name
return bucket_name, object_name
async def get_gcs_logging_config(
self, kwargs: Optional[Dict[str, Any]] = {}
) -> GCSLoggingConfig:
"""
This function is used to get the GCS logging config for the GCS Bucket Logger.
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
If no dynamic parameters are provided, it uses the default values.
"""
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params", None)
bucket_name: str
path_service_account: Optional[str]
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
)
_bucket_name: Optional[str] = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
)
_path_service_account: Optional[str] = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
)
if _bucket_name is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = _bucket_name
path_service_account = _path_service_account
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
else:
# If no dynamic parameters, use the default instance
if self.BUCKET_NAME is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
return GCSLoggingConfig(
bucket_name=bucket_name,
vertex_instance=vertex_instance,
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(
self, credentials: Optional[str]
) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
"""
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
if _in_memory_key not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[_in_memory_key] = vertex_instance
return self.vertex_instances[_in_memory_key]
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
"""
Returns key to use for caching the Vertex instance in-memory.
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
- If a credentials string is provided, it is used as the key.
- If no credentials string is provided, "IAM_AUTH" is used as the key.
"""
return credentials or IAM_AUTH_KEY
async def download_gcs_object(self, object_name: str, **kwargs):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name: str, **kwargs):
"""
Delete an object from GCS.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def _log_json_data_on_gcs(
self,
headers: Dict[str, str],
bucket_name: str,
object_name: str,
logging_payload: Union[StandardLoggingPayload, str],
):
"""
Helper function to make POST request to GCS Bucket in the specified bucket.
"""
if isinstance(logging_payload, str):
json_logged_payload = logging_payload
else:
json_logged_payload = json.dumps(logging_payload, default=str)
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
return response.json()

View File

@@ -0,0 +1,254 @@
"""
Mock client for GCS Bucket integration testing.
This module intercepts GCS API calls and Vertex AI auth calls, returning successful
mock responses, allowing full code execution without making actual network calls.
Usage:
Set GCS_MOCK=true in environment variables or config to enable mock mode.
"""
import asyncio
from litellm._logging import verbose_logger
from litellm.integrations.mock_client_factory import (
MockClientConfig,
create_mock_client_factory,
MockResponse,
)
# Use factory for POST handler
_config = MockClientConfig(
name="GCS",
env_var="GCS_MOCK",
default_latency_ms=150,
default_status_code=200,
default_json_data={"kind": "storage#object", "name": "mock-object"},
url_matchers=["storage.googleapis.com"],
patch_async_handler=True,
patch_sync_client=False,
)
_create_mock_gcs_post, should_use_gcs_mock = create_mock_client_factory(_config)
# Store original methods for GET/DELETE (GCS-specific)
_original_async_handler_get = None
_original_async_handler_delete = None
_mocks_initialized = False
# Default mock latency in seconds (simulates network round-trip)
# Typical GCS API calls take 100-300ms for uploads, 50-150ms for GET/DELETE
_MOCK_LATENCY_SECONDS = (
float(__import__("os").getenv("GCS_MOCK_LATENCY_MS", "150")) / 1000.0
)
async def _mock_async_handler_get(
self, url, params=None, headers=None, follow_redirects=None
):
"""Monkey-patched AsyncHTTPHandler.get that intercepts GCS calls."""
# Only mock GCS API calls
if isinstance(url, str) and "storage.googleapis.com" in url:
verbose_logger.info(f"[GCS MOCK] GET to {url}")
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
# Return a minimal but valid StandardLoggingPayload JSON string as bytes
# This matches what GCS returns when downloading with ?alt=media
mock_payload = {
"id": "mock-request-id",
"trace_id": "mock-trace-id",
"call_type": "completion",
"stream": False,
"response_cost": 0.0,
"status": "success",
"status_fields": {"llm_api_status": "success"},
"custom_llm_provider": "mock",
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"startTime": 0.0,
"endTime": 0.0,
"completionStartTime": 0.0,
"response_time": 0.0,
"model_map_information": {"model": "mock-model"},
"model": "mock-model",
"model_id": None,
"model_group": None,
"api_base": "https://api.mock.com",
"metadata": {},
"cache_hit": None,
"cache_key": None,
"saved_cache_cost": 0.0,
"request_tags": [],
"end_user": None,
"requester_ip_address": None,
"messages": None,
"response": None,
"error_str": None,
"error_information": None,
"model_parameters": {},
"hidden_params": {},
"guardrail_information": None,
"standard_built_in_tools_params": None,
}
return MockResponse(
status_code=200,
json_data=mock_payload,
url=url,
elapsed_seconds=_MOCK_LATENCY_SECONDS,
)
if _original_async_handler_get is not None:
return await _original_async_handler_get(
self,
url=url,
params=params,
headers=headers,
follow_redirects=follow_redirects,
)
raise RuntimeError("Original AsyncHTTPHandler.get not available")
async def _mock_async_handler_delete(
self,
url,
data=None,
json=None,
params=None,
headers=None,
timeout=None,
stream=False,
content=None,
):
"""Monkey-patched AsyncHTTPHandler.delete that intercepts GCS calls."""
# Only mock GCS API calls
if isinstance(url, str) and "storage.googleapis.com" in url:
verbose_logger.info(f"[GCS MOCK] DELETE to {url}")
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
# DELETE returns 204 No Content with empty body (not JSON)
return MockResponse(
status_code=204,
json_data=None, # Empty body for DELETE
url=url,
elapsed_seconds=_MOCK_LATENCY_SECONDS,
)
if _original_async_handler_delete is not None:
return await _original_async_handler_delete(
self,
url=url,
data=data,
json=json,
params=params,
headers=headers,
timeout=timeout,
stream=stream,
content=content,
)
raise RuntimeError("Original AsyncHTTPHandler.delete not available")
def create_mock_gcs_client():
"""
Monkey-patch AsyncHTTPHandler methods to intercept GCS calls.
AsyncHTTPHandler is used by LiteLLM's get_async_httpx_client() which is what
GCSBucketBase uses for making API calls.
This function is idempotent - it only initializes mocks once, even if called multiple times.
"""
global _original_async_handler_get, _original_async_handler_delete, _mocks_initialized
# Use factory for POST handler
_create_mock_gcs_post()
# If already initialized, skip GET/DELETE patching
if _mocks_initialized:
return
verbose_logger.debug("[GCS MOCK] Initializing GCS GET/DELETE handlers...")
# Patch GET and DELETE handlers (GCS-specific)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
if _original_async_handler_get is None:
_original_async_handler_get = AsyncHTTPHandler.get
AsyncHTTPHandler.get = _mock_async_handler_get # type: ignore
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.get")
if _original_async_handler_delete is None:
_original_async_handler_delete = AsyncHTTPHandler.delete
AsyncHTTPHandler.delete = _mock_async_handler_delete # type: ignore
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.delete")
verbose_logger.debug(
f"[GCS MOCK] Mock latency set to {_MOCK_LATENCY_SECONDS*1000:.0f}ms"
)
verbose_logger.debug("[GCS MOCK] GCS mock client initialization complete")
_mocks_initialized = True
def mock_vertex_auth_methods():
"""
Monkey-patch Vertex AI auth methods to return fake tokens.
This prevents auth failures when GCS_MOCK is enabled.
This function is idempotent - it only patches once, even if called multiple times.
"""
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
# Store original methods if not already stored
if not hasattr(VertexBase, "_original_ensure_access_token_async"):
setattr(
VertexBase,
"_original_ensure_access_token_async",
VertexBase._ensure_access_token_async,
)
setattr(
VertexBase, "_original_ensure_access_token", VertexBase._ensure_access_token
)
setattr(
VertexBase, "_original_get_token_and_url", VertexBase._get_token_and_url
)
async def _mock_ensure_access_token_async(
self, credentials, project_id, custom_llm_provider
):
"""Mock async auth method - returns fake token."""
verbose_logger.debug(
"[GCS MOCK] Vertex AI auth: _ensure_access_token_async called"
)
return ("mock-gcs-token", "mock-project-id")
def _mock_ensure_access_token(
self, credentials, project_id, custom_llm_provider
):
"""Mock sync auth method - returns fake token."""
verbose_logger.debug(
"[GCS MOCK] Vertex AI auth: _ensure_access_token called"
)
return ("mock-gcs-token", "mock-project-id")
def _mock_get_token_and_url(
self,
model,
auth_header,
vertex_credentials,
vertex_project,
vertex_location,
gemini_api_key,
stream,
custom_llm_provider,
api_base,
):
"""Mock get_token_and_url - returns fake token."""
verbose_logger.debug("[GCS MOCK] Vertex AI auth: _get_token_and_url called")
return ("mock-gcs-token", "https://storage.googleapis.com")
# Patch the methods
VertexBase._ensure_access_token_async = _mock_ensure_access_token_async # type: ignore
VertexBase._ensure_access_token = _mock_ensure_access_token # type: ignore
VertexBase._get_token_and_url = _mock_get_token_and_url # type: ignore
verbose_logger.debug("[GCS MOCK] Patched Vertex AI auth methods")
# should_use_gcs_mock is already created by the factory