chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
A2A Protocol Providers.
|
||||
|
||||
This module contains provider-specific implementations for the A2A protocol.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
|
||||
|
||||
__all__ = ["BaseA2AProviderConfig", "A2AProviderConfigManager"]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Base configuration for A2A protocol providers.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
|
||||
class BaseA2AProviderConfig(ABC):
|
||||
"""
|
||||
Base configuration class for A2A protocol providers.
|
||||
|
||||
Each provider should implement this interface to define how to handle
|
||||
A2A requests for their specific agent type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# This is an abstract method - subclasses must implement
|
||||
# The yield is here to make this a generator function
|
||||
if False: # pragma: no cover
|
||||
yield {}
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
A2A Provider Config Manager.
|
||||
|
||||
Manages provider-specific configurations for A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
|
||||
|
||||
class A2AProviderConfigManager:
|
||||
"""
|
||||
Manager for A2A provider configurations.
|
||||
|
||||
Similar to ProviderConfigManager in litellm.utils but specifically for A2A providers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_provider_config(
|
||||
custom_llm_provider: Optional[str],
|
||||
) -> Optional[BaseA2AProviderConfig]:
|
||||
"""
|
||||
Get the provider configuration for a given custom_llm_provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider identifier (e.g., "pydantic_ai_agents")
|
||||
|
||||
Returns:
|
||||
Provider configuration instance or None if not found
|
||||
"""
|
||||
if custom_llm_provider is None:
|
||||
return None
|
||||
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
|
||||
return PydanticAIProviderConfig()
|
||||
|
||||
# Add more providers here as needed
|
||||
# elif custom_llm_provider == "another_provider":
|
||||
# from litellm.a2a_protocol.providers.another_provider.config import AnotherProviderConfig
|
||||
# return AnotherProviderConfig()
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,74 @@
|
||||
# A2A to LiteLLM Completion Bridge
|
||||
|
||||
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
|
||||
|
||||
## Flow
|
||||
|
||||
```
|
||||
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
|
||||
```
|
||||
|
||||
## SDK Usage
|
||||
|
||||
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, asend_message_streaming
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
# Non-streaming
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
|
||||
# Streaming
|
||||
stream_request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=stream_request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
Configure an agent with `custom_llm_provider` in `litellm_params`:
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: my-langgraph-agent
|
||||
agent_card_params:
|
||||
name: "LangGraph Agent"
|
||||
url: "http://localhost:2024" # Used as api_base
|
||||
litellm_params:
|
||||
custom_llm_provider: langgraph
|
||||
model: agent
|
||||
```
|
||||
|
||||
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
|
||||
|
||||
1. Detects `custom_llm_provider` in agent's `litellm_params`
|
||||
2. Transforms A2A message → OpenAI messages
|
||||
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
|
||||
4. Transforms response → A2A format
|
||||
|
||||
## Classes
|
||||
|
||||
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
|
||||
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
LiteLLM Completion bridge provider for A2A protocol.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
"""
|
||||
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Handler for A2A to LiteLLM completion bridge.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
|
||||
A2A Streaming Events (in order):
|
||||
1. Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status change to "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
4. Status update (kind: "status-update") - Final status "completed" with final=true
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
A2AStreamingContext,
|
||||
)
|
||||
|
||||
|
||||
class A2ACompletionBridgeHandler:
|
||||
"""
|
||||
Static methods for handling A2A requests via LiteLLM completion.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Routing to Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": False,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# Call litellm.acompletion
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# Transform response to A2A format
|
||||
a2a_response = (
|
||||
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
|
||||
response=response,
|
||||
request_id=request_id,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request via litellm.acompletion with stream=True.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content delivery
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get non-streaming response first
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Convert to fake streaming
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=response_data,
|
||||
request_id=request_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Create streaming context
|
||||
ctx = A2AStreamingContext(
|
||||
request_id=request_id,
|
||||
input_message=message,
|
||||
)
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": True,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="working",
|
||||
final=False,
|
||||
message_text="Processing request...",
|
||||
)
|
||||
yield working_event
|
||||
|
||||
# Call litellm.acompletion with streaming
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# 3. Accumulate content and emit artifact update
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response: # type: ignore[union-attr]
|
||||
chunk_count += 1
|
||||
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if content:
|
||||
accumulated_text += content
|
||||
|
||||
# Emit artifact update with accumulated content
|
||||
if accumulated_text:
|
||||
artifact_event = (
|
||||
A2ACompletionBridgeTransformation.create_artifact_update_event(
|
||||
ctx=ctx,
|
||||
text=accumulated_text,
|
||||
)
|
||||
)
|
||||
yield artifact_event
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="completed",
|
||||
final=True,
|
||||
)
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions that delegate to the class methods
|
||||
async def handle_a2a_completion(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for non-streaming A2A completion."""
|
||||
return await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
||||
async def handle_a2a_completion_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Convenience function for streaming A2A completion."""
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Transformation utilities for A2A <-> OpenAI message format conversion.
|
||||
|
||||
A2A Message Format:
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": "abc123"
|
||||
}
|
||||
|
||||
OpenAI Message Format:
|
||||
{"role": "user", "content": "Hello!"}
|
||||
|
||||
A2A Streaming Events:
|
||||
- Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
- Status update (kind: "status-update") - Status changes (working, completed)
|
||||
- Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class A2AStreamingContext:
|
||||
"""
|
||||
Context holder for A2A streaming state.
|
||||
Tracks task_id, context_id, and message accumulation.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, input_message: Dict[str, Any]):
|
||||
self.request_id = request_id
|
||||
self.task_id = str(uuid4())
|
||||
self.context_id = str(uuid4())
|
||||
self.input_message = input_message
|
||||
self.accumulated_text = ""
|
||||
self.has_emitted_task = False
|
||||
self.has_emitted_working = False
|
||||
|
||||
|
||||
class A2ACompletionBridgeTransformation:
|
||||
"""
|
||||
Static methods for transforming between A2A and OpenAI message formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def a2a_message_to_openai_messages(
|
||||
a2a_message: Dict[str, Any],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Transform an A2A message to OpenAI message format.
|
||||
|
||||
Args:
|
||||
a2a_message: A2A message with role, parts, and messageId
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format messages
|
||||
"""
|
||||
role = a2a_message.get("role", "user")
|
||||
parts = a2a_message.get("parts", [])
|
||||
|
||||
# Map A2A roles to OpenAI roles
|
||||
openai_role = role
|
||||
if role == "user":
|
||||
openai_role = "user"
|
||||
elif role == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role == "system":
|
||||
openai_role = "system"
|
||||
|
||||
# Extract text content from parts
|
||||
content_parts = []
|
||||
for part in parts:
|
||||
kind = part.get("kind", "")
|
||||
if kind == "text":
|
||||
text = part.get("text", "")
|
||||
content_parts.append(text)
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
verbose_logger.debug(
|
||||
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
|
||||
)
|
||||
|
||||
return [{"role": openai_role, "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_a2a_response(
|
||||
response: Any,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
request_id: Original A2A request ID
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
content = choice.message.content or ""
|
||||
|
||||
# Build A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
|
||||
# Build A2A response
|
||||
a2a_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""Get current timestamp in ISO format with timezone."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def create_task_event(
|
||||
ctx: A2AStreamingContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the initial task event with status 'submitted'.
|
||||
|
||||
This is the first event emitted in an A2A streaming response.
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": ctx.input_message.get("messageId", uuid4().hex),
|
||||
"parts": ctx.input_message.get("parts", []),
|
||||
"role": ctx.input_message.get("role", "user"),
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
],
|
||||
"id": ctx.task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_status_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
state: str,
|
||||
final: bool = False,
|
||||
message_text: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a status update event.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
state: Status state ('working', 'completed')
|
||||
final: Whether this is the final event
|
||||
message_text: Optional message text for 'working' status
|
||||
"""
|
||||
status: Dict[str, Any] = {
|
||||
"state": state,
|
||||
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
|
||||
}
|
||||
|
||||
# Add message for 'working' status
|
||||
if state == "working" and message_text:
|
||||
status["message"] = {
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": str(uuid4()),
|
||||
"parts": [{"kind": "text", "text": message_text}],
|
||||
"role": "agent",
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"final": final,
|
||||
"kind": "status-update",
|
||||
"status": status,
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_artifact_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an artifact update event with content.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
text: The text content for the artifact
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"artifact": {
|
||||
"artifactId": str(uuid4()),
|
||||
"name": "response",
|
||||
"parts": [{"kind": "text", "text": text}],
|
||||
},
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def openai_chunk_to_a2a_chunk(
|
||||
chunk: Any,
|
||||
request_id: Optional[str] = None,
|
||||
is_final: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform a LiteLLM streaming chunk to A2A streaming format.
|
||||
|
||||
NOTE: This method is deprecated for streaming. Use the event-based
|
||||
methods (create_task_event, create_status_update_event,
|
||||
create_artifact_update_event) instead for proper A2A streaming.
|
||||
|
||||
Args:
|
||||
chunk: LiteLLM ModelResponse chunk
|
||||
request_id: Original A2A request ID
|
||||
is_final: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
A2A streaming chunk dict or None if no content
|
||||
"""
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if not content and not is_final:
|
||||
return None
|
||||
|
||||
# Build A2A streaming chunk (legacy format)
|
||||
a2a_chunk = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
},
|
||||
"final": is_final,
|
||||
},
|
||||
}
|
||||
|
||||
return a2a_chunk
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Pydantic AI agent provider for A2A protocol.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This provider handles fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
__all__ = ["PydanticAIHandler", "PydanticAITransformation", "PydanticAIProviderConfig"]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Pydantic AI provider configuration.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
|
||||
|
||||
class PydanticAIProviderConfig(BaseA2AProviderConfig):
|
||||
"""
|
||||
Provider configuration for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This config provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle non-streaming request to Pydantic AI agent."""
|
||||
return await PydanticAIHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
)
|
||||
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Handle streaming request with fake streaming."""
|
||||
async for chunk in PydanticAIHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
chunk_size=kwargs.get("chunk_size", 50),
|
||||
delay_ms=kwargs.get("delay_ms", 10),
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Handler for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This handler provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAIHandler:
|
||||
"""
|
||||
Handler for Pydantic AI agent requests.
|
||||
|
||||
Provides:
|
||||
- Direct non-streaming requests to Pydantic AI agents
|
||||
- Fake streaming by converting non-streaming responses into streaming chunks
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming request to Pydantic AI agent.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
verbose_logger.info(f"Pydantic AI: Routing to Pydantic AI agent at {api_base}")
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming request to Pydantic AI agent with fake streaming.
|
||||
|
||||
Since Pydantic AI agents don't support streaming natively, this method:
|
||||
1. Makes a non-streaming request
|
||||
2. Converts the response into streaming chunks
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
chunk_size: Number of characters per chunk
|
||||
delay_ms: Delay between chunks in milliseconds
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get raw task response first (not the transformed A2A format)
|
||||
raw_response = await PydanticAITransformation.send_and_get_raw_response(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Convert raw task response to fake streaming chunks
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
chunk_size=chunk_size,
|
||||
delay_ms=delay_ms,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming.
|
||||
This module provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAITransformation:
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Handles:
|
||||
- Direct A2A requests to Pydantic AI endpoints
|
||||
- Polling for task completion (since Pydantic AI doesn't support streaming)
|
||||
- Fake streaming by chunking non-streaming responses
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _remove_none_values(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively remove None values from a dict/list structure.
|
||||
|
||||
FastA2A/Pydantic AI servers don't accept None values for optional fields -
|
||||
they expect those fields to be omitted entirely.
|
||||
|
||||
Args:
|
||||
obj: Dict, list, or other value to clean
|
||||
|
||||
Returns:
|
||||
Cleaned object with None values removed
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: PydanticAITransformation._remove_none_values(v)
|
||||
for k, v in obj.items()
|
||||
if v is not None
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [
|
||||
PydanticAITransformation._remove_none_values(item)
|
||||
for item in obj
|
||||
if item is not None
|
||||
]
|
||||
else:
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _params_to_dict(params: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert params to a dict, handling Pydantic models.
|
||||
|
||||
Args:
|
||||
params: Dict or Pydantic model
|
||||
|
||||
Returns:
|
||||
Dict representation of params
|
||||
"""
|
||||
if hasattr(params, "model_dump"):
|
||||
# Pydantic v2 model
|
||||
return params.model_dump(mode="python", exclude_none=True)
|
||||
elif hasattr(params, "dict"):
|
||||
# Pydantic v1 model
|
||||
return params.dict(exclude_none=True)
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
# Try to convert to dict
|
||||
return dict(params)
|
||||
|
||||
@staticmethod
|
||||
async def _poll_for_completion(
|
||||
client: AsyncHTTPHandler,
|
||||
endpoint: str,
|
||||
task_id: str,
|
||||
request_id: str,
|
||||
max_attempts: int = 30,
|
||||
poll_interval: float = 0.5,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll for task completion using tasks/get method.
|
||||
|
||||
Args:
|
||||
client: HTTPX async client
|
||||
endpoint: API endpoint URL
|
||||
task_id: Task ID to poll for
|
||||
request_id: JSON-RPC request ID
|
||||
max_attempts: Maximum polling attempts
|
||||
poll_interval: Seconds between poll attempts
|
||||
|
||||
Returns:
|
||||
Completed task response
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
poll_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"{request_id}-poll-{attempt}",
|
||||
"method": "tasks/get",
|
||||
"params": {"id": task_id},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=poll_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
poll_data = response.json()
|
||||
|
||||
result = poll_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Pydantic AI: Poll attempt {attempt + 1}/{max_attempts}, state={state}"
|
||||
)
|
||||
|
||||
if state == "completed":
|
||||
return poll_data
|
||||
elif state in ("failed", "canceled"):
|
||||
raise Exception(f"Task {task_id} ended with state: {state}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {max_attempts * poll_interval} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_and_poll_raw(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
This is an internal method used by both non-streaming and streaming handlers.
|
||||
Returns the raw Pydantic AI task format with history/artifacts.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
# Convert params to dict if it's a Pydantic model
|
||||
params_dict = PydanticAITransformation._params_to_dict(params)
|
||||
|
||||
# Remove None values - FastA2A doesn't accept null for optional fields
|
||||
params_dict = PydanticAITransformation._remove_none_values(params_dict)
|
||||
|
||||
# Ensure the message has 'kind': 'message' as required by FastA2A/Pydantic AI
|
||||
if "message" in params_dict:
|
||||
params_dict["message"]["kind"] = "message"
|
||||
|
||||
# Build A2A JSON-RPC request using message/send method for FastA2A compatibility
|
||||
a2a_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": "message/send",
|
||||
"params": params_dict,
|
||||
}
|
||||
|
||||
# FastA2A uses root endpoint (/) not /messages
|
||||
endpoint = api_base.rstrip("/")
|
||||
|
||||
verbose_logger.info(f"Pydantic AI: Sending non-streaming request to {endpoint}")
|
||||
|
||||
# Send request to Pydantic AI agent using shared async HTTP client
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=cast(Any, "pydantic_ai_agent"),
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=a2a_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
# Check if task is already completed
|
||||
result = response_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
if state != "completed":
|
||||
# Need to poll for completion
|
||||
task_id = result.get("id")
|
||||
if task_id:
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Task {task_id} submitted, polling for completion..."
|
||||
)
|
||||
response_data = await PydanticAITransformation._poll_for_completion(
|
||||
client=client,
|
||||
endpoint=endpoint,
|
||||
task_id=task_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Received completed response for request_id={request_id}"
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def send_non_streaming_request(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a non-streaming A2A request to Pydantic AI agent and wait for completion.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent (e.g., "http://localhost:9999")
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message (dict or Pydantic model)
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format with message
|
||||
"""
|
||||
# Get raw task response
|
||||
raw_response = await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Transform to standard A2A non-streaming format
|
||||
return PydanticAITransformation._transform_to_a2a_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_and_get_raw_response(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
Used by streaming handler to get raw response for fake streaming.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
return await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _transform_to_a2a_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform Pydantic AI task response to standard A2A non-streaming format.
|
||||
|
||||
Pydantic AI returns a task with history/artifacts, but the standard A2A
|
||||
non-streaming format expects:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "...",
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": "..."}],
|
||||
"messageId": "..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
response_data: Pydantic AI task response
|
||||
request_id: Original request ID
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format
|
||||
"""
|
||||
# Extract the agent response text
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Build standard A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": parts if parts else [{"kind": "text", "text": full_text}],
|
||||
"messageId": message_id,
|
||||
}
|
||||
|
||||
# Return standard A2A non-streaming format
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response_data: Dict[str, Any]) -> tuple[str, str, list]:
|
||||
"""
|
||||
Extract response text from completed task response.
|
||||
|
||||
Pydantic AI returns completed tasks with:
|
||||
- history: list of messages (user and agent)
|
||||
- artifacts: list of result artifacts
|
||||
|
||||
Args:
|
||||
response_data: Completed task response
|
||||
|
||||
Returns:
|
||||
Tuple of (full_text, message_id, parts)
|
||||
"""
|
||||
result = response_data.get("result", {})
|
||||
|
||||
# Try to extract from artifacts first (preferred for results)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts:
|
||||
for artifact in artifacts:
|
||||
parts = artifact.get("parts", [])
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
return text, str(uuid4()), parts
|
||||
|
||||
# Fall back to history - get the last agent message
|
||||
history = result.get("history", [])
|
||||
for msg in reversed(history):
|
||||
if msg.get("role") == "agent":
|
||||
parts = msg.get("parts", [])
|
||||
message_id = msg.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
if full_text:
|
||||
return full_text, message_id, parts
|
||||
|
||||
# Fall back to message field (original format)
|
||||
message = result.get("message", {})
|
||||
if message:
|
||||
parts = message.get("parts", [])
|
||||
message_id = message.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
return full_text, message_id, parts
|
||||
|
||||
return "", str(uuid4()), []
|
||||
|
||||
@staticmethod
|
||||
async def fake_streaming_from_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a non-streaming A2A response into fake streaming chunks.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update chunks (kind: "artifact-update") - Content delivery in chunks
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
response_data: Non-streaming A2A response dict (completed task)
|
||||
request_id: A2A JSON-RPC request ID
|
||||
chunk_size: Number of characters per chunk (default: 50)
|
||||
delay_ms: Delay between chunks in milliseconds (default: 10)
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Extract the response text from completed task
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Extract input message from raw response for history
|
||||
result = response_data.get("result", {})
|
||||
history = result.get("history", [])
|
||||
input_message = {}
|
||||
for msg in history:
|
||||
if msg.get("role") == "user":
|
||||
input_message = msg
|
||||
break
|
||||
|
||||
# Generate IDs for streaming events
|
||||
task_id = str(uuid4())
|
||||
context_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
input_message_id = input_message.get("messageId", str(uuid4()))
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_task_event
|
||||
task_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": context_id,
|
||||
"kind": "message",
|
||||
"messageId": input_message_id,
|
||||
"parts": input_message.get(
|
||||
"parts", [{"kind": "text", "text": ""}]
|
||||
),
|
||||
"role": "user",
|
||||
"taskId": task_id,
|
||||
}
|
||||
],
|
||||
"id": task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_status_update_event
|
||||
working_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": False,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "working",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield working_event
|
||||
|
||||
# Small delay to simulate processing
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 3. Emit artifact update chunks (kind: "artifact-update")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_artifact_update_event
|
||||
if full_text:
|
||||
# Split text into chunks
|
||||
for i in range(0, len(full_text), chunk_size):
|
||||
chunk_text = full_text[i : i + chunk_size]
|
||||
is_last_chunk = (i + chunk_size) >= len(full_text)
|
||||
|
||||
artifact_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": task_id,
|
||||
"artifact": {
|
||||
"artifactId": artifact_id,
|
||||
"parts": [
|
||||
{
|
||||
"kind": "text",
|
||||
"text": chunk_text,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
yield artifact_event
|
||||
|
||||
# Add delay between chunks (except for last chunk)
|
||||
if not is_last_chunk:
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": True,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "completed",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Fake streaming completed for request_id={request_id}"
|
||||
)
|
||||
Reference in New Issue
Block a user