chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
# Vertex AI Batch Prediction Jobs
|
||||
|
||||
Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
|
||||
|
||||
Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||
|
||||
@@ -0,0 +1,378 @@
|
||||
import json
|
||||
from typing import Any, Coroutine, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.llms.openai import CreateBatchRequest
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VERTEX_CREDENTIALS_TYPES,
|
||||
VertexAIBatchPredictionJob,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from .transformation import VertexAIBatchTransformation
|
||||
|
||||
|
||||
class VertexAIBatchPrediction(VertexLLM):
|
||||
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.gcs_bucket_name = gcs_bucket_name
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_batch_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
model=None,
|
||||
vertex_project=vertex_project or project_id,
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_api_version="v1",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
vertex_batch_request: VertexAIBatchPredictionJob = VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||
request=create_batch_data
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_create_batch(
|
||||
vertex_batch_request=vertex_batch_request,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(vertex_batch_request),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_create_batch(
|
||||
self,
|
||||
vertex_batch_request: VertexAIBatchPredictionJob,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
) -> LiteLLMBatch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
try:
|
||||
response = await client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(vertex_batch_request),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_body = e.response.text
|
||||
litellm.verbose_logger.error(
|
||||
"Vertex AI batch create failed: status=%s, body=%s",
|
||||
e.response.status_code,
|
||||
error_body[:1000],
|
||||
)
|
||||
raise
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
def create_vertex_batch_url(
|
||||
self,
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
) -> str:
|
||||
"""Return the base url for the vertex garden models"""
|
||||
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
batch_id: str,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_batch_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
# Append batch_id to the URL
|
||||
default_api_base = f"{default_api_base}/{batch_id}"
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
model=None,
|
||||
vertex_project=vertex_project or project_id,
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_api_version="v1",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_retrieve_batch(
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# Log the request using logging_obj if available
|
||||
if logging_obj is not None:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
|
||||
if isinstance(logging_obj, Logging):
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
"request_str": (
|
||||
f"\nGET Request Sent from LiteLLM:\n"
|
||||
f"curl -X GET \\\n"
|
||||
f"{api_base} \\\n"
|
||||
f"-H 'Authorization: Bearer ***REDACTED***' \\\n"
|
||||
f"-H 'Content-Type: application/json; charset=utf-8'\n"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
response = sync_handler.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_retrieve_batch(
|
||||
self,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> LiteLLMBatch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
|
||||
# Log the request using logging_obj if available
|
||||
if logging_obj is not None:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
|
||||
if isinstance(logging_obj, Logging):
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
"request_str": (
|
||||
f"\nGET Request Sent from LiteLLM:\n"
|
||||
f"curl -X GET \\\n"
|
||||
f"{api_base} \\\n"
|
||||
f"-H 'Authorization: Bearer ***REDACTED***' \\\n"
|
||||
f"-H 'Content-Type: application/json; charset=utf-8'\n"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
def list_batches(
|
||||
self,
|
||||
_is_async: bool,
|
||||
after: Optional[str],
|
||||
limit: Optional[int],
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
):
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_batch_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
if limit is not None:
|
||||
params["pageSize"] = str(limit)
|
||||
if after is not None:
|
||||
params["pageToken"] = after
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_list_batches(
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
response = sync_handler.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_list_batches(
|
||||
self,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any],
|
||||
):
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
response = await client.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
@@ -0,0 +1,227 @@
|
||||
from litellm._uuid import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
|
||||
class VertexAIBatchTransformation:
|
||||
"""
|
||||
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||
|
||||
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||
cls,
|
||||
request: CreateBatchRequest,
|
||||
) -> VertexAIBatchPredictionJob:
|
||||
"""
|
||||
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||
"""
|
||||
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
|
||||
input_file_id = request.get("input_file_id")
|
||||
if input_file_id is None:
|
||||
raise ValueError("input_file_id is required, but not provided")
|
||||
input_config: InputConfig = InputConfig(
|
||||
gcsSource=GcsSource(uris=[input_file_id]), instancesFormat="jsonl"
|
||||
)
|
||||
model: str = cls._get_model_from_gcs_file(input_file_id)
|
||||
output_config: OutputConfig = OutputConfig(
|
||||
predictionsFormat="jsonl",
|
||||
gcsDestination=GcsDestination(
|
||||
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
|
||||
),
|
||||
)
|
||||
return VertexAIBatchPredictionJob(
|
||||
inputConfig=input_config,
|
||||
outputConfig=output_config,
|
||||
model=model,
|
||||
displayName=request_display_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> LiteLLMBatch:
|
||||
return LiteLLMBatch(
|
||||
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
|
||||
completion_window="24hrs",
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response.get("createTime", "")
|
||||
),
|
||||
endpoint="",
|
||||
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
|
||||
response
|
||||
),
|
||||
object="batch",
|
||||
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
|
||||
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
|
||||
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
|
||||
response
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_vertex_ai_batch_list_response_to_openai_list_response(
|
||||
cls, response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transforms Vertex AI batch list response into OpenAI-compatible list response.
|
||||
"""
|
||||
|
||||
batch_jobs = response.get("batchPredictionJobs", []) or []
|
||||
data = [
|
||||
cls.transform_vertex_ai_batch_response_to_openai_batch_response(job)
|
||||
for job in batch_jobs
|
||||
]
|
||||
|
||||
first_id = data[0].id if len(data) > 0 else None
|
||||
last_id = data[-1].id if len(data) > 0 else None
|
||||
next_page_token = response.get("nextPageToken")
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"first_id": first_id,
|
||||
"last_id": last_id,
|
||||
"has_more": bool(next_page_token),
|
||||
"next_page_token": next_page_token,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_batch_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the batch id from the Vertex AI Batch response safely
|
||||
|
||||
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
|
||||
returns: `3814889423749775360`
|
||||
"""
|
||||
_name = response.get("name", "")
|
||||
if not _name:
|
||||
return ""
|
||||
|
||||
# Split by '/' and get the last part if it exists
|
||||
parts = _name.split("/")
|
||||
return parts[-1] if parts else _name
|
||||
|
||||
@classmethod
|
||||
def _get_input_file_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the input file id from the Vertex AI Batch response
|
||||
"""
|
||||
input_file_id: str = ""
|
||||
input_config = response.get("inputConfig")
|
||||
if input_config is None:
|
||||
return input_file_id
|
||||
|
||||
gcs_source = input_config.get("gcsSource")
|
||||
if gcs_source is None:
|
||||
return input_file_id
|
||||
|
||||
uris = gcs_source.get("uris", "")
|
||||
if len(uris) == 0:
|
||||
return input_file_id
|
||||
|
||||
return uris[0]
|
||||
|
||||
@classmethod
|
||||
def _get_output_file_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the output file id from the Vertex AI Batch response
|
||||
"""
|
||||
|
||||
output_file_id: str = (
|
||||
response.get("outputInfo", OutputInfo()).get("gcsOutputDirectory", "")
|
||||
+ "/predictions.jsonl"
|
||||
)
|
||||
if output_file_id != "/predictions.jsonl":
|
||||
return output_file_id
|
||||
|
||||
output_config = response.get("outputConfig")
|
||||
if output_config is None:
|
||||
return output_file_id
|
||||
|
||||
gcs_destination = output_config.get("gcsDestination")
|
||||
if gcs_destination is None:
|
||||
return output_file_id
|
||||
|
||||
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
|
||||
return output_uri_prefix
|
||||
|
||||
@classmethod
|
||||
def _get_batch_job_status_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> BatchJobStatus:
|
||||
"""
|
||||
Gets the batch job status from the Vertex AI Batch response
|
||||
|
||||
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
|
||||
"""
|
||||
state_mapping: Dict[str, BatchJobStatus] = {
|
||||
"JOB_STATE_UNSPECIFIED": "failed",
|
||||
"JOB_STATE_QUEUED": "validating",
|
||||
"JOB_STATE_PENDING": "validating",
|
||||
"JOB_STATE_RUNNING": "in_progress",
|
||||
"JOB_STATE_SUCCEEDED": "completed",
|
||||
"JOB_STATE_FAILED": "failed",
|
||||
"JOB_STATE_CANCELLING": "cancelling",
|
||||
"JOB_STATE_CANCELLED": "cancelled",
|
||||
"JOB_STATE_PAUSED": "in_progress",
|
||||
"JOB_STATE_EXPIRED": "expired",
|
||||
"JOB_STATE_UPDATING": "in_progress",
|
||||
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
|
||||
}
|
||||
|
||||
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
|
||||
return state_mapping[vertex_state]
|
||||
|
||||
@classmethod
|
||||
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
|
||||
"""
|
||||
Gets the gcs uri prefix from the input file id
|
||||
|
||||
Example:
|
||||
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
|
||||
returns: "gs://litellm-testing-bucket"
|
||||
|
||||
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
|
||||
returns: "gs://litellm-testing-bucket/batches"
|
||||
"""
|
||||
# Split the path and remove the filename
|
||||
path_parts = input_file_id.rsplit("/", 1)
|
||||
return path_parts[0]
|
||||
|
||||
@classmethod
|
||||
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
|
||||
"""
|
||||
Extracts the model from the gcs file uri
|
||||
|
||||
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
|
||||
|
||||
Why?
|
||||
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
|
||||
|
||||
|
||||
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
|
||||
returns: "publishers/google/models/gemini-1.5-flash-001"
|
||||
"""
|
||||
from urllib.parse import unquote
|
||||
|
||||
decoded_uri = unquote(gcs_file_uri)
|
||||
|
||||
model_path = decoded_uri.split("publishers/")[1]
|
||||
parts = model_path.split("/")
|
||||
model = f"publishers/{'/'.join(parts[:3])}"
|
||||
return model
|
||||
Reference in New Issue
Block a user