chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,123 @@
# LiteLLM Google GenAI Interface
Interface to interact with Google GenAI Functions in the native Google interface format.
## Overview
This module provides a native interface to Google's Generative AI API, allowing you to use Google's content generation capabilities with both streaming and non-streaming modes, in both synchronous and asynchronous contexts.
## Available Functions
### Non-Streaming Functions
- `generate_content()` - Synchronous content generation
- `agenerate_content()` - Asynchronous content generation
### Streaming Functions
- `generate_content_stream()` - Synchronous streaming content generation
- `agenerate_content_stream()` - Asynchronous streaming content generation
## Usage Examples
### Basic Non-Streaming Usage
```python
from litellm.google_genai import generate_content, agenerate_content
from google.genai.types import ContentDict, PartDict
# Synchronous usage
contents = ContentDict(
parts=[
PartDict(text="Hello, can you tell me a short joke?")
],
)
response = generate_content(
contents=contents,
model="gemini-pro", # or your preferred model
# Add other model-specific parameters as needed
)
print(response)
```
### Async Non-Streaming Usage
```python
import asyncio
from litellm.google_genai import agenerate_content
from google.genai.types import ContentDict, PartDict
async def main():
contents = ContentDict(
parts=[
PartDict(text="Hello, can you tell me a short joke?")
],
)
response = await agenerate_content(
contents=contents,
model="gemini-pro",
# Add other model-specific parameters as needed
)
print(response)
# Run the async function
asyncio.run(main())
```
### Streaming Usage
```python
from litellm.google_genai import generate_content_stream
from google.genai.types import ContentDict, PartDict
# Synchronous streaming
contents = ContentDict(
parts=[
PartDict(text="Tell me a story about space exploration")
],
)
for chunk in generate_content_stream(
contents=contents,
model="gemini-pro",
):
print(f"Chunk: {chunk}")
```
### Async Streaming Usage
```python
import asyncio
from litellm.google_genai import agenerate_content_stream
from google.genai.types import ContentDict, PartDict
async def main():
contents = ContentDict(
parts=[
PartDict(text="Tell me a story about space exploration")
],
)
async for chunk in agenerate_content_stream(
contents=contents,
model="gemini-pro",
):
print(f"Async chunk: {chunk}")
asyncio.run(main())
```
## Testing
This module includes comprehensive tests covering:
- Sync and async non-streaming requests
- Sync and async streaming requests
- Response validation
- Error handling scenarios
See `tests/unified_google_tests/base_google_test.py` for test implementation examples.

View File

@@ -0,0 +1,19 @@
"""
This allows using Google GenAI model in their native interface.
This module provides generate_content functionality for Google GenAI models.
"""
from .main import (
agenerate_content,
agenerate_content_stream,
generate_content,
generate_content_stream,
)
__all__ = [
"generate_content",
"agenerate_content",
"generate_content_stream",
"agenerate_content_stream",
]

View File

@@ -0,0 +1,19 @@
"""
Google GenAI Adapters for LiteLLM
This module provides adapters for transforming Google GenAI generate_content requests
to/from LiteLLM completion format with full support for:
- Text content transformation
- Tool calling (function declarations, function calls, function responses)
- Streaming (both regular and tool calling)
- Mixed content (text + tool calls)
"""
from .handler import GenerateContentToCompletionHandler
from .transformation import GoogleGenAIAdapter, GoogleGenAIStreamWrapper
__all__ = [
"GoogleGenAIAdapter",
"GoogleGenAIStreamWrapper",
"GenerateContentToCompletionHandler",
]

View File

@@ -0,0 +1,183 @@
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union, cast
import litellm
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import ModelResponse
from .transformation import GoogleGenAIAdapter
# Initialize adapter
GOOGLE_GENAI_ADAPTER = GoogleGenAIAdapter()
class GenerateContentToCompletionHandler:
"""Handler for transforming generate_content calls to completion format when provider config is None"""
@staticmethod
def _prepare_completion_kwargs(
model: str,
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
config: Optional[Dict[str, Any]] = None,
stream: bool = False,
litellm_params: Optional[GenericLiteLLMParams] = None,
extra_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Prepare kwargs for litellm.completion/acompletion"""
# Transform generate_content request to completion format
completion_request = (
GOOGLE_GENAI_ADAPTER.translate_generate_content_to_completion(
model=model,
contents=contents,
config=config,
litellm_params=litellm_params,
**(extra_kwargs or {}),
)
)
completion_kwargs: Dict[str, Any] = dict(completion_request)
# Forward extra_kwargs that should be passed to completion call
if extra_kwargs is not None:
# Forward metadata for custom callback
if "metadata" in extra_kwargs:
completion_kwargs["metadata"] = extra_kwargs["metadata"]
# Forward extra_headers for providers that require custom headers (e.g., github_copilot)
if "extra_headers" in extra_kwargs:
completion_kwargs["extra_headers"] = extra_kwargs["extra_headers"]
if stream:
completion_kwargs["stream"] = stream
return completion_kwargs
@staticmethod
async def async_generate_content_handler(
model: str,
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
litellm_params: GenericLiteLLMParams,
config: Optional[Dict[str, Any]] = None,
stream: bool = False,
**kwargs,
) -> Union[Dict[str, Any], AsyncIterator[bytes]]:
"""Handle generate_content call asynchronously using completion adapter"""
completion_kwargs = (
GenerateContentToCompletionHandler._prepare_completion_kwargs(
model=model,
contents=contents,
config=config,
stream=stream,
litellm_params=litellm_params,
extra_kwargs=kwargs,
)
)
try:
completion_response = await litellm.acompletion(**completion_kwargs)
if stream:
# Check if completion_response is actually a stream or a ModelResponse
# This can happen in error cases or when stream is not properly supported
if not hasattr(completion_response, "__aiter__"):
# If it's not a stream, treat it as a regular response
generate_content_response = (
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
cast(ModelResponse, completion_response)
)
)
return generate_content_response
else:
# Transform streaming completion response to generate_content format
transformed_stream = GOOGLE_GENAI_ADAPTER.translate_completion_output_params_streaming(
completion_response
)
if transformed_stream is not None:
return transformed_stream
raise ValueError("Failed to transform streaming response")
else:
# Transform completion response back to generate_content format
generate_content_response = (
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
cast(ModelResponse, completion_response)
)
)
return generate_content_response
except Exception as e:
raise ValueError(
f"Error calling litellm.acompletion for generate_content: {str(e)}"
)
@staticmethod
def generate_content_handler(
model: str,
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
litellm_params: GenericLiteLLMParams,
config: Optional[Dict[str, Any]] = None,
stream: bool = False,
_is_async: bool = False,
**kwargs,
) -> Union[
Dict[str, Any],
AsyncIterator[bytes],
Coroutine[Any, Any, Union[Dict[str, Any], AsyncIterator[bytes]]],
]:
"""Handle generate_content call using completion adapter"""
if _is_async:
return GenerateContentToCompletionHandler.async_generate_content_handler(
model=model,
contents=contents,
config=config,
stream=stream,
litellm_params=litellm_params,
**kwargs,
)
completion_kwargs = (
GenerateContentToCompletionHandler._prepare_completion_kwargs(
model=model,
contents=contents,
config=config,
stream=stream,
litellm_params=litellm_params,
extra_kwargs=kwargs,
)
)
try:
completion_response = litellm.completion(**completion_kwargs)
if stream:
# Check if completion_response is actually a stream or a ModelResponse
# This can happen in error cases or when stream is not properly supported
if not hasattr(completion_response, "__iter__"):
# If it's not a stream, treat it as a regular response
generate_content_response = (
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
cast(ModelResponse, completion_response)
)
)
return generate_content_response
else:
# Transform streaming completion response to generate_content format
transformed_stream = GOOGLE_GENAI_ADAPTER.translate_completion_output_params_streaming(
completion_response
)
if transformed_stream is not None:
return transformed_stream
raise ValueError("Failed to transform streaming response")
else:
# Transform completion response back to generate_content format
generate_content_response = (
GOOGLE_GENAI_ADAPTER.translate_completion_to_generate_content(
cast(ModelResponse, completion_response)
)
)
return generate_content_response
except Exception as e:
raise ValueError(
f"Error calling litellm.completion for generate_content: {str(e)}"
)

View File

@@ -0,0 +1,783 @@
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from litellm import verbose_logger
from litellm.litellm_core_utils.json_validation_rule import normalize_tool_schema
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionImageObject,
ChatCompletionRequest,
ChatCompletionSystemMessage,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolChoiceValues,
ChatCompletionToolMessage,
ChatCompletionToolParam,
ChatCompletionUserMessage,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
Choices,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
class GoogleGenAIStreamWrapper(AdapterCompletionStreamWrapper):
"""
Wrapper for streaming Google GenAI generate_content responses.
Transforms OpenAI streaming chunks to Google GenAI format.
"""
sent_first_chunk: bool = False
# State tracking for accumulating partial tool calls
accumulated_tool_calls: Dict[str, Dict[str, Any]]
def __init__(self, completion_stream: Any):
self.sent_first_chunk = False
self.accumulated_tool_calls = {}
self._returned_response = False
super().__init__(completion_stream)
def __next__(self):
try:
if not hasattr(self.completion_stream, "__iter__"):
if self._returned_response:
raise StopIteration
self._returned_response = True
return GoogleGenAIAdapter().translate_completion_to_generate_content(
self.completion_stream
)
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
continue
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
chunk, self
)
if transformed_chunk:
return transformed_chunk
raise StopIteration
except StopIteration:
raise
except Exception:
raise StopIteration
async def __anext__(self):
try:
if not hasattr(self.completion_stream, "__aiter__"):
if self._returned_response:
raise StopAsyncIteration
self._returned_response = True
return GoogleGenAIAdapter().translate_completion_to_generate_content(
self.completion_stream
)
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
continue
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
chunk, self
)
if transformed_chunk:
return transformed_chunk
# After the stream is exhausted, check for any remaining accumulated tool calls
if self.accumulated_tool_calls:
try:
parts = []
for (
tool_call_index,
tool_call_data,
) in self.accumulated_tool_calls.items():
try:
# For tool calls with no arguments, accumulated_args will be "", which is not valid JSON.
# We default to an empty JSON object in this case.
parsed_args = json.loads(
tool_call_data["arguments"] or "{}"
)
function_call_part = {
"functionCall": {
"name": tool_call_data["name"]
or "undefined_tool_name",
"args": parsed_args,
}
}
parts.append(function_call_part)
except json.JSONDecodeError:
# This can happen if the stream is abruptly cut off mid-argument string.
verbose_logger.warning(
f"Could not parse tool call arguments at end of stream for index {tool_call_index}. "
f"Name: {tool_call_data['name']}. "
f"Partial args: {tool_call_data['arguments']}"
)
pass
if parts:
final_chunk = {
"candidates": [
{
"content": {"parts": parts, "role": "model"},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [],
}
]
}
return final_chunk
finally:
# Ensure the accumulator is always cleared to prevent memory leaks
self.accumulated_tool_calls.clear()
raise StopAsyncIteration
except StopAsyncIteration:
raise
except Exception:
raise StopAsyncIteration
def google_genai_sse_wrapper(self) -> Iterator[bytes]:
"""
Convert Google GenAI streaming chunks to Server-Sent Events format.
"""
for chunk in self.completion_stream:
if isinstance(chunk, dict):
payload = f"data: {json.dumps(chunk)}\n\n"
yield payload.encode()
else:
yield chunk
async def async_google_genai_sse_wrapper(self) -> AsyncIterator[bytes]:
"""
Async version of google_genai_sse_wrapper.
"""
from litellm.types.utils import ModelResponseStream
async for chunk in self.completion_stream:
if isinstance(chunk, dict):
payload = f"data: {json.dumps(chunk)}\n\n"
yield payload.encode()
elif isinstance(chunk, ModelResponseStream):
# Transform OpenAI streaming chunk to Google GenAI format
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
chunk, self
)
if isinstance(transformed_chunk, dict): # Only return non-empty chunks
payload = f"data: {json.dumps(transformed_chunk)}\n\n"
yield payload.encode()
else:
# For empty chunks, continue to next iteration
continue
else:
# For other chunk types, yield them directly
if hasattr(chunk, "encode"):
yield chunk.encode()
else:
yield str(chunk).encode()
class GoogleGenAIAdapter:
"""Adapter for transforming Google GenAI generate_content requests to/from litellm.completion format"""
def __init__(self) -> None:
pass
def translate_generate_content_to_completion(
self,
model: str,
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
config: Optional[Dict[str, Any]] = None,
litellm_params: Optional[GenericLiteLLMParams] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Transform generate_content request to litellm completion format
Args:
model: The model name
contents: Generate content contents (can be list or single dict)
config: Optional config parameters
**kwargs: Additional parameters from the original request
Returns:
Dict in OpenAI format
"""
# Extract top-level fields from kwargs
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
"system_instruction"
)
tools = kwargs.get("tools")
tool_config = kwargs.get("toolConfig") or kwargs.get("tool_config")
# Normalize contents to list format
if isinstance(contents, dict):
contents_list = [contents]
else:
contents_list = contents
# Transform contents to OpenAI messages format
messages = self._transform_contents_to_messages(
contents_list, system_instruction=system_instruction
)
# Create base request as dict (which is compatible with ChatCompletionRequest)
completion_request: ChatCompletionRequest = {
"model": model,
"messages": messages,
}
#########################################################
# Supported OpenAI chat completion params
# - temperature
# - max_tokens
# - top_p
# - frequency_penalty
# - presence_penalty
# - stop
# - tools
# - tool_choice
#########################################################
# Add config parameters if provided
if config:
# Map common Google GenAI config parameters to OpenAI equivalents
if "temperature" in config:
completion_request["temperature"] = config["temperature"]
if "maxOutputTokens" in config:
completion_request["max_tokens"] = config["maxOutputTokens"]
if "topP" in config:
completion_request["top_p"] = config["topP"]
if "topK" in config:
# OpenAI doesn't have direct topK, but we can pass it as extra
pass
if "stopSequences" in config:
completion_request["stop"] = config["stopSequences"]
# Handle tools transformation
if tools:
# Check if tools are already in OpenAI format or Google GenAI format
if isinstance(tools, list) and len(tools) > 0:
# Tools are in Google GenAI format, transform them
openai_tools = self._transform_google_genai_tools_to_openai(tools)
if openai_tools:
completion_request["tools"] = openai_tools
# Handle tool_config (tool choice)
if tool_config:
tool_choice = self._transform_google_genai_tool_config_to_openai(
tool_config
)
if tool_choice:
completion_request["tool_choice"] = tool_choice
#########################################################
# forward any litellm specific params
#########################################################
completion_request_dict = dict(completion_request)
if litellm_params:
completion_request_dict = self._add_generic_litellm_params_to_request(
completion_request_dict=completion_request_dict,
litellm_params=litellm_params,
)
return completion_request_dict
def _add_generic_litellm_params_to_request(
self,
completion_request_dict: Dict[str, Any],
litellm_params: Optional[GenericLiteLLMParams] = None,
) -> dict:
"""Add generic litellm params to request. e.g add api_base, api_key, api_version, etc.
Args:
completion_request_dict: Dict[str, Any]
litellm_params: GenericLiteLLMParams
Returns:
Dict[str, Any]
"""
allowed_fields = GenericLiteLLMParams.model_fields.keys()
if litellm_params:
litellm_dict = litellm_params.model_dump(exclude_none=True)
for key, value in litellm_dict.items():
if key in allowed_fields:
completion_request_dict[key] = value
return completion_request_dict
def translate_completion_output_params_streaming(
self,
completion_stream: Any,
) -> Union[AsyncIterator[bytes], None]:
"""Transform streaming completion output to Google GenAI format"""
google_genai_wrapper = GoogleGenAIStreamWrapper(
completion_stream=completion_stream
)
# Return the SSE-wrapped version for proper event formatting
return google_genai_wrapper.async_google_genai_sse_wrapper()
def _transform_google_genai_tools_to_openai(
self,
tools: List[Dict[str, Any]],
) -> List[ChatCompletionToolParam]:
"""Transform Google GenAI tools to OpenAI tools format"""
openai_tools: List[Dict[str, Any]] = []
for tool in tools:
if "functionDeclarations" in tool:
for func_decl in tool["functionDeclarations"]:
function_chunk: Dict[str, Any] = {
"name": func_decl.get("name", ""),
}
if "description" in func_decl:
function_chunk["description"] = func_decl["description"]
if "parametersJsonSchema" in func_decl:
function_chunk["parameters"] = func_decl["parametersJsonSchema"]
openai_tool = {"type": "function", "function": function_chunk}
openai_tools.append(openai_tool)
# normalize the tool schemas
normalized_tools = [normalize_tool_schema(tool) for tool in openai_tools]
return cast(List[ChatCompletionToolParam], normalized_tools)
def _transform_google_genai_tool_config_to_openai(
self,
tool_config: Dict[str, Any],
) -> Optional[ChatCompletionToolChoiceValues]:
"""Transform Google GenAI tool_config to OpenAI tool_choice"""
function_calling_config = tool_config.get("functionCallingConfig", {})
mode = function_calling_config.get("mode", "AUTO")
mode_mapping = {"AUTO": "auto", "ANY": "required", "NONE": "none"}
tool_choice = mode_mapping.get(mode, "auto")
return cast(ChatCompletionToolChoiceValues, tool_choice)
def _transform_contents_to_messages(
self,
contents: List[Dict[str, Any]],
system_instruction: Optional[Dict[str, Any]] = None,
) -> List[AllMessageValues]:
"""Transform Google GenAI contents to OpenAI messages format"""
messages: List[AllMessageValues] = []
# Handle system instruction
if system_instruction:
system_parts = system_instruction.get("parts", [])
if system_parts and "text" in system_parts[0]:
messages.append(
ChatCompletionSystemMessage(
role="system", content=system_parts[0]["text"]
)
)
for content in contents:
role = content.get("role", "user")
parts = content.get("parts", [])
if role == "user":
# Handle user messages with potential function responses
content_parts: List[
Union[ChatCompletionTextObject, ChatCompletionImageObject]
] = []
tool_messages: List[ChatCompletionToolMessage] = []
for part in parts:
if isinstance(part, dict):
if "text" in part:
content_parts.append(
cast(
ChatCompletionTextObject,
{"type": "text", "text": part["text"]},
)
)
elif "inline_data" in part:
# Handle Base64 image data
inline_data = part["inline_data"]
mime_type = inline_data.get("mime_type", "image/jpeg")
data = inline_data.get("data", "")
content_parts.append(
cast(
ChatCompletionImageObject,
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{data}"
},
},
)
)
elif "functionResponse" in part:
# Transform function response to tool message
func_response = part["functionResponse"]
tool_message = ChatCompletionToolMessage(
role="tool",
tool_call_id=f"call_{func_response.get('name', 'unknown')}",
content=json.dumps(func_response.get("response", {})),
)
tool_messages.append(tool_message)
elif isinstance(part, str):
content_parts.append(
cast(
ChatCompletionTextObject, {"type": "text", "text": part}
)
)
# Add user message if there's content
if content_parts:
# If only one text part, use simple string format for backward compatibility
if (
len(content_parts) == 1
and isinstance(content_parts[0], dict)
and content_parts[0].get("type") == "text"
):
text_part = cast(ChatCompletionTextObject, content_parts[0])
messages.append(
ChatCompletionUserMessage(
role="user", content=text_part["text"]
)
)
else:
# Use multimodal format (array of content parts)
messages.append(
ChatCompletionUserMessage(
role="user", content=content_parts
)
)
# Add tool messages
messages.extend(tool_messages)
elif role == "model":
# Handle assistant messages with potential function calls
combined_text = ""
tool_calls: List[ChatCompletionAssistantToolCall] = []
for part in parts:
if isinstance(part, dict):
if "text" in part:
combined_text += part["text"]
elif "functionCall" in part:
# Transform function call to tool call
func_call = part["functionCall"]
tool_call = ChatCompletionAssistantToolCall(
id=f"call_{func_call.get('name', 'unknown')}",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=func_call.get("name", ""),
arguments=json.dumps(func_call.get("args", {})),
),
)
tool_calls.append(tool_call)
elif isinstance(part, str):
combined_text += part
# Create assistant message
if tool_calls:
assistant_message = ChatCompletionAssistantMessage(
role="assistant",
content=combined_text if combined_text else None,
tool_calls=tool_calls,
)
else:
assistant_message = ChatCompletionAssistantMessage(
role="assistant",
content=combined_text if combined_text else None,
)
messages.append(assistant_message)
return messages
def translate_completion_to_generate_content(
self,
response: ModelResponse,
) -> Dict[str, Any]:
"""
Transform litellm completion response to Google GenAI generate_content format
Args:
response: ModelResponse from litellm.completion
Returns:
Dict in Google GenAI generate_content response format
"""
# Extract the main response content
choice = response.choices[0] if response.choices else None
if not choice:
raise ValueError("Invalid completion response: no choices found")
# Handle different choice types (Choices vs StreamingChoices)
if isinstance(choice, Choices):
if not choice.message:
raise ValueError(
"Invalid completion response: no message found in choice"
)
parts = self._transform_openai_message_to_google_genai_parts(choice.message)
else:
# Fallback for generic choice objects
message_content = getattr(choice, "message", {}).get(
"content", ""
) or getattr(choice, "delta", {}).get("content", "")
parts = [{"text": message_content}] if message_content else []
# Create Google GenAI format response
generate_content_response: Dict[str, Any] = {
"candidates": [
{
"content": {"parts": parts, "role": "model"},
"finishReason": self._map_finish_reason(
getattr(choice, "finish_reason", None)
),
"index": 0,
"safetyRatings": [],
}
],
"usageMetadata": (
self._map_usage(getattr(response, "usage", None))
if hasattr(response, "usage") and getattr(response, "usage", None)
else {
"promptTokenCount": 0,
"candidatesTokenCount": 0,
"totalTokenCount": 0,
}
),
}
# Add text field for convenience (common in Google GenAI responses)
text_content = ""
for part in parts:
if isinstance(part, dict) and "text" in part:
text_content += part["text"]
if text_content:
generate_content_response["text"] = text_content
return generate_content_response
def translate_streaming_completion_to_generate_content(
self,
response: Union[ModelResponse, ModelResponseStream],
wrapper: GoogleGenAIStreamWrapper,
) -> Optional[Dict[str, Any]]:
"""
Transform streaming litellm completion chunk to Google GenAI generate_content format
Args:
response: Streaming ModelResponse chunk from litellm.completion
wrapper: GoogleGenAIStreamWrapper instance
Returns:
Dict in Google GenAI streaming generate_content response format
"""
# Extract the main response content from streaming chunk
choice = response.choices[0] if response.choices else None
if not choice:
# Return empty chunk if no choices
return None
# Handle streaming choice
if isinstance(choice, StreamingChoices):
if choice.delta:
parts = self._transform_openai_delta_to_google_genai_parts_with_accumulation(
choice.delta, wrapper
)
else:
parts = []
finish_reason = getattr(choice, "finish_reason", None)
else:
# Fallback for generic choice objects
message_content = getattr(choice, "delta", {}).get("content", "")
parts = [{"text": message_content}] if message_content else []
finish_reason = getattr(choice, "finish_reason", None)
# Only create response chunk if we have parts or it's the final chunk
if not parts and not finish_reason:
return None
# Create Google GenAI streaming format response
streaming_chunk: Dict[str, Any] = {
"candidates": [
{
"content": {"parts": parts, "role": "model"},
"finishReason": (
self._map_finish_reason(finish_reason)
if finish_reason
else None
),
"index": 0,
"safetyRatings": [],
}
]
}
# Add usage metadata only in the final chunk (when finish_reason is present)
if finish_reason:
usage_metadata = (
self._map_usage(getattr(response, "usage", None))
if hasattr(response, "usage") and getattr(response, "usage", None)
else {
"promptTokenCount": 0,
"candidatesTokenCount": 0,
"totalTokenCount": 0,
}
)
streaming_chunk["usageMetadata"] = usage_metadata
# Add text field for convenience (common in Google GenAI responses)
text_content = ""
for part in parts:
if isinstance(part, dict) and "text" in part:
text_content += part["text"]
if text_content:
streaming_chunk["text"] = text_content
return streaming_chunk
def _transform_openai_message_to_google_genai_parts(
self,
message: Any,
) -> List[Dict[str, Any]]:
"""Transform OpenAI message to Google GenAI parts format"""
parts: List[Dict[str, Any]] = []
# Add text content if present
if hasattr(message, "content") and message.content:
parts.append({"text": message.content})
# Add tool calls if present
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if hasattr(tool_call, "function") and tool_call.function:
try:
args = (
json.loads(tool_call.function.arguments)
if tool_call.function.arguments
else {}
)
except json.JSONDecodeError:
args = {}
function_call_part = {
"functionCall": {
"name": tool_call.function.name or "undefined_tool_name",
"args": args,
}
}
parts.append(function_call_part)
return parts if parts else [{"text": ""}]
def _transform_openai_delta_to_google_genai_parts_with_accumulation(
self, delta: Any, wrapper: GoogleGenAIStreamWrapper
) -> List[Dict[str, Any]]:
"""Transforms OpenAI delta to Google GenAI parts, accumulating streaming tool calls."""
# 1. Initialize wrapper state if it doesn't exist
if not hasattr(wrapper, "accumulated_tool_calls"):
wrapper.accumulated_tool_calls = {}
parts: List[Dict[str, Any]] = []
if hasattr(delta, "content") and delta.content:
parts.append({"text": delta.content})
# 2. Ensure tool_calls is iterable
tool_calls = delta.tool_calls or []
for tool_call in tool_calls:
if not hasattr(tool_call, "function"):
continue
# 3. Use `index` as the primary key for accumulation
tool_call_index = getattr(tool_call, "index", None)
if tool_call_index is None:
continue # Index is essential for tracking streaming tool calls
# Initialize accumulator for this index if it's new
if tool_call_index not in wrapper.accumulated_tool_calls:
wrapper.accumulated_tool_calls[tool_call_index] = {
"name": "",
"arguments": "",
}
# Accumulate name and arguments
function_name = getattr(tool_call.function, "name", None)
args_chunk = getattr(tool_call.function, "arguments", None)
# Optimization: Skip chunks that have no new data
if not function_name and not args_chunk:
verbose_logger.debug(
f"Skipping empty tool call chunk for index: {tool_call_index}"
)
continue
if function_name:
wrapper.accumulated_tool_calls[tool_call_index]["name"] = function_name
if args_chunk:
wrapper.accumulated_tool_calls[tool_call_index][
"arguments"
] += args_chunk
# Attempt to parse and emit a complete tool call
accumulated_data = wrapper.accumulated_tool_calls[tool_call_index]
accumulated_name = accumulated_data["name"]
accumulated_args = accumulated_data["arguments"]
# 5. Attempt to parse arguments even if name hasn't arrived.
try:
# Attempt to parse the accumulated arguments string
parsed_args = json.loads(accumulated_args)
# If parsing succeeds, but we don't have a name yet, wait.
# The part will be created by a later chunk that brings the name.
if accumulated_name:
# If successful, create the part and clean up
function_call_part = {
"functionCall": {"name": accumulated_name, "args": parsed_args}
}
parts.append(function_call_part)
# Remove the completed tool call from the accumulator
del wrapper.accumulated_tool_calls[tool_call_index]
except json.JSONDecodeError:
# The JSON for arguments is still incomplete.
# We will continue to accumulate and wait for more chunks.
pass
return parts
def _map_finish_reason(self, finish_reason: Optional[str]) -> str:
"""Map OpenAI finish reasons to Google GenAI finish reasons"""
if not finish_reason:
return "STOP"
mapping = {
"stop": "STOP",
"length": "MAX_TOKENS",
"content_filter": "SAFETY",
"tool_calls": "STOP",
"function_call": "STOP",
}
return mapping.get(finish_reason, "STOP")
def _map_usage(self, usage: Any) -> Dict[str, int]:
"""Map OpenAI usage to Google GenAI usage format"""
return {
"promptTokenCount": getattr(usage, "prompt_tokens", 0) or 0,
"candidatesTokenCount": getattr(usage, "completion_tokens", 0) or 0,
"totalTokenCount": getattr(usage, "total_tokens", 0) or 0,
}

View File

@@ -0,0 +1,548 @@
import asyncio
import contextvars
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, Optional, Union
import httpx
from pydantic import BaseModel, ConfigDict
import litellm
from litellm.constants import request_timeout
# Import the adapter for fallback to completion format
from litellm.google_genai.adapters.handler import GenerateContentToCompletionHandler
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.google_genai.transformation import (
BaseGoogleGenAIGenerateContentConfig,
)
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import ProviderConfigManager, client
if TYPE_CHECKING:
from litellm.types.google_genai.main import (
GenerateContentConfigDict,
GenerateContentContentListUnionDict,
GenerateContentResponse,
ToolConfigDict,
)
else:
GenerateContentConfigDict = Any
GenerateContentContentListUnionDict = Any
GenerateContentResponse = Any
ToolConfigDict = Any
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
base_llm_http_handler = BaseLLMHTTPHandler()
#################################################
class GenerateContentSetupResult(BaseModel):
"""Internal Type - Result of setting up a generate content call"""
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
model: str
request_body: Dict[str, Any]
custom_llm_provider: str
generate_content_provider_config: Optional[BaseGoogleGenAIGenerateContentConfig]
generate_content_config_dict: Dict[str, Any]
litellm_params: GenericLiteLLMParams
litellm_logging_obj: LiteLLMLoggingObj
litellm_call_id: Optional[str]
class GenerateContentHelper:
"""Helper class for Google GenAI generate content operations"""
@staticmethod
def mock_generate_content_response(
mock_response: str = "This is a mock response from Google GenAI generate_content.",
) -> Dict[str, Any]:
"""Mock response for generate_content for testing purposes"""
return {
"text": mock_response,
"candidates": [
{
"content": {"parts": [{"text": mock_response}], "role": "model"},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [],
}
],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 20,
"totalTokenCount": 30,
},
}
@staticmethod
def setup_generate_content_call(
model: str,
contents: GenerateContentContentListUnionDict,
config: Optional[GenerateContentConfigDict] = None,
custom_llm_provider: Optional[str] = None,
tools: Optional[ToolConfigDict] = None,
**kwargs,
) -> GenerateContentSetupResult:
"""
Common setup logic for generate_content calls
Args:
model: The model name
contents: The content to generate from
config: Optional configuration
custom_llm_provider: Optional custom LLM provider
tools: Optional tools
**kwargs: Additional keyword arguments
Returns:
GenerateContentSetupResult containing all setup information
"""
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
"litellm_logging_obj"
)
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)
## MOCK RESPONSE LOGIC (only for non-streaming)
if (
not kwargs.get("stream", False)
and litellm_params.mock_response
and isinstance(litellm_params.mock_response, str)
):
raise ValueError("Mock response should be handled by caller")
(
model,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
)
if litellm_params.custom_llm_provider is None:
litellm_params.custom_llm_provider = custom_llm_provider
# get provider config
generate_content_provider_config: Optional[
BaseGoogleGenAIGenerateContentConfig
] = ProviderConfigManager.get_provider_google_genai_generate_content_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
if generate_content_provider_config is None:
# Use adapter to transform to completion format when provider config is None
# Signal that we should use the adapter by returning special result
if litellm_logging_obj is None:
raise ValueError("litellm_logging_obj is required, but got None")
return GenerateContentSetupResult(
model=model,
custom_llm_provider=custom_llm_provider,
request_body={}, # Will be handled by adapter
generate_content_provider_config=None, # type: ignore
generate_content_config_dict=dict(config or {}),
litellm_params=litellm_params,
litellm_logging_obj=litellm_logging_obj,
litellm_call_id=litellm_call_id,
)
#########################################################################################
# Construct request body
#########################################################################################
# Create Google Optional Params Config
generate_content_config_dict = (
generate_content_provider_config.map_generate_content_optional_params(
generate_content_config_dict=config or {},
model=model,
)
)
# Extract systemInstruction from kwargs to pass to transform
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
"system_instruction"
)
request_body = (
generate_content_provider_config.transform_generate_content_request(
model=model,
contents=contents,
tools=tools,
generate_content_config_dict=generate_content_config_dict,
system_instruction=system_instruction,
)
)
# Pre Call logging
if litellm_logging_obj is None:
raise ValueError("litellm_logging_obj is required, but got None")
litellm_logging_obj.update_environment_variables(
model=model,
optional_params=dict(generate_content_config_dict),
litellm_params={
"litellm_call_id": litellm_call_id,
},
custom_llm_provider=custom_llm_provider,
)
return GenerateContentSetupResult(
model=model,
custom_llm_provider=custom_llm_provider,
request_body=request_body,
generate_content_provider_config=generate_content_provider_config,
generate_content_config_dict=generate_content_config_dict,
litellm_params=litellm_params,
litellm_logging_obj=litellm_logging_obj,
litellm_call_id=litellm_call_id,
)
@client
async def agenerate_content(
model: str,
contents: GenerateContentContentListUnionDict,
config: Optional[GenerateContentConfigDict] = None,
tools: Optional[ToolConfigDict] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Any:
"""
Async: Generate content using Google GenAI
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["agenerate_content"] = True
# Handle generationConfig parameter from kwargs for backward compatibility
if "generationConfig" in kwargs and config is None:
config = kwargs.pop("generationConfig")
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
)
func = partial(
generate_content,
model=model,
contents=contents,
config=config,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
tools=tools,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def generate_content(
model: str,
contents: GenerateContentContentListUnionDict,
config: Optional[GenerateContentConfigDict] = None,
tools: Optional[ToolConfigDict] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Any:
"""
Generate content using Google GenAI
"""
local_vars = locals()
try:
_is_async = kwargs.pop("agenerate_content", False)
# Handle generationConfig parameter from kwargs for backward compatibility
if "generationConfig" in kwargs and config is None:
config = kwargs.pop("generationConfig")
# Check for mock response first
litellm_params = GenericLiteLLMParams(**kwargs)
if litellm_params.mock_response and isinstance(
litellm_params.mock_response, str
):
return GenerateContentHelper.mock_generate_content_response(
mock_response=litellm_params.mock_response
)
# Setup the call
setup_result = GenerateContentHelper.setup_generate_content_call(
model=model,
contents=contents,
config=config,
custom_llm_provider=custom_llm_provider,
tools=tools,
**kwargs,
)
# Extract systemInstruction from kwargs to pass to handler
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
"system_instruction"
)
# Check if we should use the adapter (when provider config is None)
if setup_result.generate_content_provider_config is None:
# Use the adapter to convert to completion format
return GenerateContentToCompletionHandler.generate_content_handler(
model=model,
contents=contents, # type: ignore
config=setup_result.generate_content_config_dict,
tools=tools,
_is_async=_is_async,
litellm_params=setup_result.litellm_params,
extra_headers=extra_headers,
**kwargs,
)
# Call the standard handler
response = base_llm_http_handler.generate_content_handler(
model=setup_result.model,
contents=contents,
tools=tools,
generate_content_provider_config=setup_result.generate_content_provider_config,
generate_content_config_dict=setup_result.generate_content_config_dict,
custom_llm_provider=setup_result.custom_llm_provider,
litellm_params=setup_result.litellm_params,
logging_obj=setup_result.litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
litellm_metadata=kwargs.get("litellm_metadata", {}),
system_instruction=system_instruction,
)
return response
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
async def agenerate_content_stream(
model: str,
contents: GenerateContentContentListUnionDict,
config: Optional[GenerateContentConfigDict] = None,
tools: Optional[ToolConfigDict] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Any:
"""
Async: Generate content using Google GenAI with streaming response
"""
local_vars = locals()
try:
kwargs["agenerate_content_stream"] = True
# Handle generationConfig parameter from kwargs for backward compatibility
if "generationConfig" in kwargs and config is None:
config = kwargs.pop("generationConfig")
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, api_base=local_vars.get("base_url", None)
)
# Setup the call
setup_result = GenerateContentHelper.setup_generate_content_call(
model=model,
contents=contents,
config=config,
custom_llm_provider=custom_llm_provider,
tools=tools,
**kwargs,
)
# Extract systemInstruction from kwargs to pass to handler
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
"system_instruction"
)
# Check if we should use the adapter (when provider config is None)
if setup_result.generate_content_provider_config is None:
if "stream" in kwargs:
kwargs.pop("stream", None)
# Use the adapter to convert to completion format
return (
await GenerateContentToCompletionHandler.async_generate_content_handler(
model=model,
contents=contents, # type: ignore
config=setup_result.generate_content_config_dict,
litellm_params=setup_result.litellm_params,
tools=tools,
stream=True,
extra_headers=extra_headers,
**kwargs,
)
)
# Call the handler with async enabled and streaming
# Return the coroutine directly for the router to handle
return await base_llm_http_handler.generate_content_handler(
model=setup_result.model,
contents=contents,
generate_content_provider_config=setup_result.generate_content_provider_config,
generate_content_config_dict=setup_result.generate_content_config_dict,
tools=tools,
custom_llm_provider=setup_result.custom_llm_provider,
litellm_params=setup_result.litellm_params,
logging_obj=setup_result.litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=True,
client=kwargs.get("client"),
stream=True,
litellm_metadata=kwargs.get("litellm_metadata", {}),
system_instruction=system_instruction,
)
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def generate_content_stream(
model: str,
contents: GenerateContentContentListUnionDict,
config: Optional[GenerateContentConfigDict] = None,
tools: Optional[ToolConfigDict] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Iterator[Any]:
"""
Generate content using Google GenAI with streaming response
"""
local_vars = locals()
try:
# Remove any async-related flags since this is the sync function
_is_async = kwargs.pop("agenerate_content_stream", False)
# Handle generationConfig parameter from kwargs for backward compatibility
if "generationConfig" in kwargs and config is None:
config = kwargs.pop("generationConfig")
# Setup the call
setup_result = GenerateContentHelper.setup_generate_content_call(
model=model,
contents=contents,
config=config,
custom_llm_provider=custom_llm_provider,
tools=tools,
**kwargs,
)
# Check if we should use the adapter (when provider config is None)
if setup_result.generate_content_provider_config is None:
if "stream" in kwargs:
kwargs.pop("stream", None)
# Use the adapter to convert to completion format
return GenerateContentToCompletionHandler.generate_content_handler(
model=model,
contents=contents, # type: ignore
config=setup_result.generate_content_config_dict,
_is_async=_is_async,
litellm_params=setup_result.litellm_params,
stream=True,
extra_headers=extra_headers,
**kwargs,
)
# Call the handler with streaming enabled (sync version)
return base_llm_http_handler.generate_content_handler(
model=setup_result.model,
contents=contents,
generate_content_provider_config=setup_result.generate_content_provider_config,
generate_content_config_dict=setup_result.generate_content_config_dict,
tools=tools,
custom_llm_provider=setup_result.custom_llm_provider,
litellm_params=setup_result.litellm_params,
logging_obj=setup_result.litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
stream=True,
litellm_metadata=kwargs.get("litellm_metadata", {}),
)
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)

View File

@@ -0,0 +1,159 @@
import asyncio
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
if TYPE_CHECKING:
from litellm.llms.base_llm.google_genai.transformation import (
BaseGoogleGenAIGenerateContentConfig,
)
else:
BaseGoogleGenAIGenerateContentConfig = Any
GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ = PassThroughEndpointLogging()
class BaseGoogleGenAIGenerateContentStreamingIterator:
"""
Base class for Google GenAI Generate Content streaming iterators that provides common logic
for streaming response handling and logging.
"""
def __init__(
self,
litellm_logging_obj: LiteLLMLoggingObj,
request_body: dict,
model: str,
):
self.litellm_logging_obj = litellm_logging_obj
self.request_body = request_body
self.start_time = datetime.now()
self.collected_chunks: List[bytes] = []
self.model = model
async def _handle_async_streaming_logging(
self,
):
"""Handle the logging after all chunks have been collected."""
from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler,
)
end_time = datetime.now()
asyncio.create_task(
PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=self.litellm_logging_obj,
passthrough_success_handler_obj=GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ,
url_route="/v1/generateContent",
request_body=self.request_body or {},
endpoint_type=EndpointType.VERTEX_AI,
start_time=self.start_time,
raw_bytes=self.collected_chunks,
end_time=end_time,
model=self.model,
)
)
class GoogleGenAIGenerateContentStreamingIterator(
BaseGoogleGenAIGenerateContentStreamingIterator
):
"""
Streaming iterator specifically for Google GenAI generate content API.
"""
def __init__(
self,
response,
model: str,
logging_obj: LiteLLMLoggingObj,
generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig,
litellm_metadata: dict,
custom_llm_provider: str,
request_body: Optional[dict] = None,
):
super().__init__(
litellm_logging_obj=logging_obj,
request_body=request_body or {},
model=model,
)
self.response = response
self.model = model
self.generate_content_provider_config = generate_content_provider_config
self.litellm_metadata = litellm_metadata
self.custom_llm_provider = custom_llm_provider
# Store the iterator once to avoid multiple stream consumption
self.stream_iterator = response.iter_bytes()
def __iter__(self):
return self
def __next__(self):
try:
# Get the next chunk from the stored iterator
chunk = next(self.stream_iterator)
self.collected_chunks.append(chunk)
# Just yield raw bytes
return chunk
except StopIteration:
raise StopIteration
def __aiter__(self):
return self
async def __anext__(self):
# This should not be used for sync responses
# If you need async iteration, use AsyncGoogleGenAIGenerateContentStreamingIterator
raise NotImplementedError(
"Use AsyncGoogleGenAIGenerateContentStreamingIterator for async iteration"
)
class AsyncGoogleGenAIGenerateContentStreamingIterator(
BaseGoogleGenAIGenerateContentStreamingIterator
):
"""
Async streaming iterator specifically for Google GenAI generate content API.
"""
def __init__(
self,
response,
model: str,
logging_obj: LiteLLMLoggingObj,
generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig,
litellm_metadata: dict,
custom_llm_provider: str,
request_body: Optional[dict] = None,
):
super().__init__(
litellm_logging_obj=logging_obj,
request_body=request_body or {},
model=model,
)
self.response = response
self.model = model
self.generate_content_provider_config = generate_content_provider_config
self.litellm_metadata = litellm_metadata
self.custom_llm_provider = custom_llm_provider
# Store the async iterator once to avoid multiple stream consumption
self.stream_iterator = response.aiter_bytes()
def __aiter__(self):
return self
async def __anext__(self):
try:
# Get the next chunk from the stored async iterator
chunk = await self.stream_iterator.__anext__()
self.collected_chunks.append(chunk)
# Just yield raw bytes
return chunk
except StopAsyncIteration:
await self._handle_async_streaming_logging()
raise StopAsyncIteration