Files

515 lines
16 KiB
Python
Raw Permalink Normal View History

from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import ORJSONResponse, StreamingResponse
import litellm
from litellm._uuid import uuid
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.types.llms.vertex_ai import TokenCountDetailsResponse
router = APIRouter(
tags=["google genai endpoints"],
)
@router.post(
"/v1beta/models/{model_name:path}:generateContent",
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/models/{model_name:path}:generateContent",
dependencies=[Depends(user_api_key_auth)],
)
async def google_generate_content(
request: Request,
model_name: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
version,
)
data = await _read_request_body(request=request)
if "model" not in data:
data["model"] = model_name
# Extract generationConfig and pass it as config parameter
generation_config = data.pop("generationConfig", None)
if generation_config:
data["config"] = generation_config
# Add user authentication metadata for cost tracking
data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
general_settings=general_settings,
version=version,
)
# Create logging object with full request metadata so callbacks (e.g. S3) get user/trace_id
data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
logging_obj, data = litellm.utils.function_setup(
original_function="agenerate_content",
rules_obj=litellm.utils.Rules(),
start_time=datetime.now(),
**data,
)
data["litellm_logging_obj"] = logging_obj
# call router
if llm_router is None:
raise HTTPException(status_code=500, detail="Router not initialized")
response = await llm_router.agenerate_content(**data)
return response
@router.post(
"/v1beta/models/{model_name:path}:streamGenerateContent",
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/models/{model_name:path}:streamGenerateContent",
dependencies=[Depends(user_api_key_auth)],
)
async def google_stream_generate_content(
request: Request,
model_name: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
version,
)
data = await _read_request_body(request=request)
if "model" not in data:
data["model"] = model_name
data["stream"] = True # enforce streaming for this endpoint
# Extract generationConfig and pass it as config parameter
generation_config = data.pop("generationConfig", None)
if generation_config:
data["config"] = generation_config
# Add user authentication metadata for cost tracking
data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
general_settings=general_settings,
version=version,
)
# Create logging object with full request metadata so streaming END callbacks (e.g. S3) get user/trace_id
data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
logging_obj, data = litellm.utils.function_setup(
original_function="agenerate_content_stream",
rules_obj=litellm.utils.Rules(),
start_time=datetime.now(),
**data,
)
data["litellm_logging_obj"] = logging_obj
# call router
if llm_router is None:
raise HTTPException(status_code=500, detail="Router not initialized")
response = await llm_router.agenerate_content_stream(**data)
# Check if response is an async iterator (streaming response)
if response is not None and hasattr(response, "__aiter__"):
return StreamingResponse(content=response, media_type="text/event-stream")
return response
@router.post(
"/v1beta/models/{model_name:path}:countTokens",
dependencies=[Depends(user_api_key_auth)],
response_model=TokenCountDetailsResponse,
)
@router.post(
"/models/{model_name:path}:countTokens",
dependencies=[Depends(user_api_key_auth)],
response_model=TokenCountDetailsResponse,
)
async def google_count_tokens(request: Request, model_name: str):
"""
```json
return {
"totalTokens": 31,
"totalBillableCharacters": 96,
"promptTokensDetails": [
{
"modality": "TEXT",
"tokenCount": 31
}
]
}
```
"""
from litellm.google_genai.adapters.transformation import GoogleGenAIAdapter
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.proxy_server import token_counter as internal_token_counter
data = await _read_request_body(request=request)
contents = data.get("contents", [])
# Create TokenCountRequest for the internal endpoint
from litellm.proxy._types import TokenCountRequest
# Translate contents to openai format messages using the adapter
messages = (
GoogleGenAIAdapter()
.translate_generate_content_to_completion(model_name, contents)
.get("messages", [])
)
token_request = TokenCountRequest(
model=model_name,
contents=contents,
messages=messages, # compatibility when use openai-like endpoint
)
# Call the internal token counter function with direct request flag set to False
token_response = await internal_token_counter(
request=token_request,
call_endpoint=True,
)
if token_response is not None:
# cast the response to the well known format
original_response: dict = token_response.original_response or {}
if original_response:
return TokenCountDetailsResponse(
totalTokens=original_response.get("totalTokens", 0),
promptTokensDetails=original_response.get("promptTokensDetails", []),
)
else:
return TokenCountDetailsResponse(
totalTokens=token_response.total_tokens or 0,
promptTokensDetails=[],
)
#########################################################
# Return the response in the well known format
#########################################################
return TokenCountDetailsResponse(
totalTokens=0,
promptTokensDetails=[],
)
# ============================================================
# Google Interactions API Endpoints
# Per OpenAPI spec: https://ai.google.dev/static/api/interactions.openapi.json
# ============================================================
@router.post(
"/v1beta/interactions",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
@router.post(
"/interactions",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
async def create_interaction(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new interaction using Google's Interactions API.
Per OpenAPI spec: POST /{api_version}/interactions
Supports both model interactions and agent interactions:
- Model: Provide `model` parameter (e.g., "gemini-2.5-flash")
- Agent: Provide `agent` parameter (e.g., "deep-research-pro-preview-12-2025")
Example:
```bash
curl -X POST "http://localhost:4000/v1beta/interactions" \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "gemini/gemini-2.5-flash",
"input": "Hello, how are you?"
}'
```
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
# Default to gemini provider for interactions
if "custom_llm_provider" not in data:
data["custom_llm_provider"] = "gemini"
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="acreate_interaction",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=data.get("model") or data.get("agent"),
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.get(
"/v1beta/interactions/{interaction_id}",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
@router.get(
"/interactions/{interaction_id}",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
async def get_interaction(
request: Request,
interaction_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get an interaction by ID.
Per OpenAPI spec: GET /{api_version}/interactions/{interaction_id}
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = {"interaction_id": interaction_id, "custom_llm_provider": "gemini"}
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="aget_interaction",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.delete(
"/v1beta/interactions/{interaction_id}",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
@router.delete(
"/interactions/{interaction_id}",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
async def delete_interaction(
request: Request,
interaction_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete an interaction by ID.
Per OpenAPI spec: DELETE /{api_version}/interactions/{interaction_id}
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = {"interaction_id": interaction_id, "custom_llm_provider": "gemini"}
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="adelete_interaction",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.post(
"/v1beta/interactions/{interaction_id}/cancel",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
@router.post(
"/interactions/{interaction_id}/cancel",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["interactions"],
)
async def cancel_interaction(
request: Request,
interaction_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Cancel an interaction by ID.
Per OpenAPI spec: POST /{api_version}/interactions/{interaction_id}:cancel
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = {"interaction_id": interaction_id, "custom_llm_provider": "gemini"}
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="acancel_interaction",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)