160 lines
5.2 KiB
Python
160 lines
5.2 KiB
Python
|
|
import asyncio
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import TYPE_CHECKING, Any, List, Optional
|
||
|
|
|
||
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||
|
|
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||
|
|
PassThroughEndpointLogging,
|
||
|
|
)
|
||
|
|
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from litellm.llms.base_llm.google_genai.transformation import (
|
||
|
|
BaseGoogleGenAIGenerateContentConfig,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
BaseGoogleGenAIGenerateContentConfig = Any
|
||
|
|
|
||
|
|
GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ = PassThroughEndpointLogging()
|
||
|
|
|
||
|
|
|
||
|
|
class BaseGoogleGenAIGenerateContentStreamingIterator:
|
||
|
|
"""
|
||
|
|
Base class for Google GenAI Generate Content streaming iterators that provides common logic
|
||
|
|
for streaming response handling and logging.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
litellm_logging_obj: LiteLLMLoggingObj,
|
||
|
|
request_body: dict,
|
||
|
|
model: str,
|
||
|
|
):
|
||
|
|
self.litellm_logging_obj = litellm_logging_obj
|
||
|
|
self.request_body = request_body
|
||
|
|
self.start_time = datetime.now()
|
||
|
|
self.collected_chunks: List[bytes] = []
|
||
|
|
self.model = model
|
||
|
|
|
||
|
|
async def _handle_async_streaming_logging(
|
||
|
|
self,
|
||
|
|
):
|
||
|
|
"""Handle the logging after all chunks have been collected."""
|
||
|
|
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||
|
|
PassThroughStreamingHandler,
|
||
|
|
)
|
||
|
|
|
||
|
|
end_time = datetime.now()
|
||
|
|
asyncio.create_task(
|
||
|
|
PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
||
|
|
litellm_logging_obj=self.litellm_logging_obj,
|
||
|
|
passthrough_success_handler_obj=GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ,
|
||
|
|
url_route="/v1/generateContent",
|
||
|
|
request_body=self.request_body or {},
|
||
|
|
endpoint_type=EndpointType.VERTEX_AI,
|
||
|
|
start_time=self.start_time,
|
||
|
|
raw_bytes=self.collected_chunks,
|
||
|
|
end_time=end_time,
|
||
|
|
model=self.model,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class GoogleGenAIGenerateContentStreamingIterator(
|
||
|
|
BaseGoogleGenAIGenerateContentStreamingIterator
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Streaming iterator specifically for Google GenAI generate content API.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
response,
|
||
|
|
model: str,
|
||
|
|
logging_obj: LiteLLMLoggingObj,
|
||
|
|
generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig,
|
||
|
|
litellm_metadata: dict,
|
||
|
|
custom_llm_provider: str,
|
||
|
|
request_body: Optional[dict] = None,
|
||
|
|
):
|
||
|
|
super().__init__(
|
||
|
|
litellm_logging_obj=logging_obj,
|
||
|
|
request_body=request_body or {},
|
||
|
|
model=model,
|
||
|
|
)
|
||
|
|
self.response = response
|
||
|
|
self.model = model
|
||
|
|
self.generate_content_provider_config = generate_content_provider_config
|
||
|
|
self.litellm_metadata = litellm_metadata
|
||
|
|
self.custom_llm_provider = custom_llm_provider
|
||
|
|
# Store the iterator once to avoid multiple stream consumption
|
||
|
|
self.stream_iterator = response.iter_bytes()
|
||
|
|
|
||
|
|
def __iter__(self):
|
||
|
|
return self
|
||
|
|
|
||
|
|
def __next__(self):
|
||
|
|
try:
|
||
|
|
# Get the next chunk from the stored iterator
|
||
|
|
chunk = next(self.stream_iterator)
|
||
|
|
self.collected_chunks.append(chunk)
|
||
|
|
# Just yield raw bytes
|
||
|
|
return chunk
|
||
|
|
except StopIteration:
|
||
|
|
raise StopIteration
|
||
|
|
|
||
|
|
def __aiter__(self):
|
||
|
|
return self
|
||
|
|
|
||
|
|
async def __anext__(self):
|
||
|
|
# This should not be used for sync responses
|
||
|
|
# If you need async iteration, use AsyncGoogleGenAIGenerateContentStreamingIterator
|
||
|
|
raise NotImplementedError(
|
||
|
|
"Use AsyncGoogleGenAIGenerateContentStreamingIterator for async iteration"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class AsyncGoogleGenAIGenerateContentStreamingIterator(
|
||
|
|
BaseGoogleGenAIGenerateContentStreamingIterator
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Async streaming iterator specifically for Google GenAI generate content API.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
response,
|
||
|
|
model: str,
|
||
|
|
logging_obj: LiteLLMLoggingObj,
|
||
|
|
generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig,
|
||
|
|
litellm_metadata: dict,
|
||
|
|
custom_llm_provider: str,
|
||
|
|
request_body: Optional[dict] = None,
|
||
|
|
):
|
||
|
|
super().__init__(
|
||
|
|
litellm_logging_obj=logging_obj,
|
||
|
|
request_body=request_body or {},
|
||
|
|
model=model,
|
||
|
|
)
|
||
|
|
self.response = response
|
||
|
|
self.model = model
|
||
|
|
self.generate_content_provider_config = generate_content_provider_config
|
||
|
|
self.litellm_metadata = litellm_metadata
|
||
|
|
self.custom_llm_provider = custom_llm_provider
|
||
|
|
# Store the async iterator once to avoid multiple stream consumption
|
||
|
|
self.stream_iterator = response.aiter_bytes()
|
||
|
|
|
||
|
|
def __aiter__(self):
|
||
|
|
return self
|
||
|
|
|
||
|
|
async def __anext__(self):
|
||
|
|
try:
|
||
|
|
# Get the next chunk from the stored async iterator
|
||
|
|
chunk = await self.stream_iterator.__anext__()
|
||
|
|
self.collected_chunks.append(chunk)
|
||
|
|
# Just yield raw bytes
|
||
|
|
return chunk
|
||
|
|
except StopAsyncIteration:
|
||
|
|
await self._handle_async_streaming_logging()
|
||
|
|
raise StopAsyncIteration
|