chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
"""
A2A (Agent-to-Agent) Protocol Provider for LiteLLM
"""
from .chat.transformation import A2AConfig
__all__ = ["A2AConfig"]

View File

@@ -0,0 +1,6 @@
"""
A2A Chat Completion Implementation
"""
from .transformation import A2AConfig
__all__ = ["A2AConfig"]

View File

@@ -0,0 +1,155 @@
# A2A Protocol Guardrail Translation Handler
Handler for processing A2A (Agent-to-Agent) Protocol messages with guardrails.
## Overview
This handler processes A2A JSON-RPC 2.0 input/output by:
1. Extracting text from message parts (`kind: "text"`)
2. Applying guardrails to text content
3. Mapping guardrailed text back to original structure
## A2A Protocol Format
### Input Format (JSON-RPC 2.0)
```json
{
"jsonrpc": "2.0",
"id": "request-id",
"method": "message/send",
"params": {
"message": {
"kind": "message",
"messageId": "...",
"role": "user",
"parts": [
{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}
]
},
"metadata": {
"guardrails": ["block-ssn"]
}
}
}
```
### Output Formats
The handler supports multiple A2A response formats:
**Direct message:**
```json
{
"result": {
"kind": "message",
"parts": [{"kind": "text", "text": "Response text"}]
}
}
```
**Nested message:**
```json
{
"result": {
"message": {
"parts": [{"kind": "text", "text": "Response text"}]
}
}
}
```
**Task with artifacts:**
```json
{
"result": {
"kind": "task",
"artifacts": [
{"parts": [{"kind": "text", "text": "Artifact text"}]}
]
}
}
```
**Task with status message:**
```json
{
"result": {
"kind": "task",
"status": {
"message": {
"parts": [{"kind": "text", "text": "Status message"}]
}
}
}
}
```
**Streaming artifact-update:**
```json
{
"result": {
"kind": "artifact-update",
"artifact": {
"parts": [{"kind": "text", "text": "Streaming text"}]
}
}
}
```
## Usage
The handler is automatically discovered and applied when guardrails are used with A2A endpoints.
### Via LiteLLM Proxy
```bash
curl -X POST 'http://localhost:4000/a2a/my-agent' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"jsonrpc": "2.0",
"id": "1",
"method": "message/send",
"params": {
"message": {
"kind": "message",
"messageId": "msg-1",
"role": "user",
"parts": [{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}]
},
"metadata": {
"guardrails": ["block-ssn"]
}
}
}'
```
### Specifying Guardrails
Guardrails can be specified in the A2A request via the `metadata.guardrails` field:
```json
{
"params": {
"message": {...},
"metadata": {
"guardrails": ["block-ssn", "pii-filter"]
}
}
}
```
## Extension
Override these methods to customize behavior:
- `_extract_texts_from_result()`: Custom text extraction from A2A responses
- `_extract_texts_from_parts()`: Custom text extraction from message parts
- `_apply_text_to_path()`: Custom application of guardrailed text
## Call Types
This handler is registered for:
- `CallTypes.send_message`: Synchronous A2A message sending
- `CallTypes.asend_message`: Asynchronous A2A message sending

View File

@@ -0,0 +1,11 @@
"""A2A Protocol handler for Unified Guardrails."""
from litellm.llms.a2a.chat.guardrail_translation.handler import A2AGuardrailHandler
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.send_message: A2AGuardrailHandler,
CallTypes.asend_message: A2AGuardrailHandler,
}
__all__ = ["guardrail_translation_mappings"]

View File

@@ -0,0 +1,428 @@
"""
A2A Protocol Handler for Unified Guardrails
This module provides guardrail translation support for A2A (Agent-to-Agent) Protocol.
It handles both JSON-RPC 2.0 input requests and output responses, extracting text
from message parts and applying guardrails.
A2A Protocol Format:
- Input: JSON-RPC 2.0 with params.message.parts containing text parts
- Output: JSON-RPC 2.0 with result containing message/artifact parts
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import UserAPIKeyAuth
class A2AGuardrailHandler(BaseTranslation):
"""
Handler for processing A2A Protocol messages with guardrails.
This class provides methods to:
1. Process input messages (pre-call hook) - extracts text from A2A message parts
2. Process output responses (post-call hook) - extracts text from A2A response parts
A2A Message Format:
- Input: params.message.parts[].text (where kind == "text")
- Output: result.message.parts[].text or result.artifacts[].parts[].text
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> Any:
"""
Process A2A input messages by applying guardrails to text content.
Extracts text from A2A message parts and applies guardrails.
Args:
data: The A2A JSON-RPC 2.0 request data
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
Returns:
Modified data with guardrails applied to text content
"""
# A2A request format: { "params": { "message": { "parts": [...] } } }
params = data.get("params", {})
message = params.get("message", {})
parts = message.get("parts", [])
if not parts:
verbose_proxy_logger.debug("A2A: No parts in message, skipping guardrail")
return data
texts_to_check: List[str] = []
text_part_indices: List[int] = [] # Track which parts contain text
# Step 1: Extract text from all text parts
for part_idx, part in enumerate(parts):
if part.get("kind") == "text":
text = part.get("text", "")
if text:
texts_to_check.append(text)
text_part_indices.append(part_idx)
# Step 2: Apply guardrail to all texts in batch
if texts_to_check:
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
# Pass the structured A2A message to guardrails
inputs["structured_messages"] = [message]
# Include agent model info if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 3: Apply guardrailed text back to original parts
if guardrailed_texts and len(guardrailed_texts) == len(text_part_indices):
for task_idx, part_idx in enumerate(text_part_indices):
parts[part_idx]["text"] = guardrailed_texts[task_idx]
verbose_proxy_logger.debug("A2A: Processed input message: %s", message)
return data
async def process_output_response(
self,
response: Any,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
) -> Any:
"""
Process A2A output response by applying guardrails to text content.
Handles multiple A2A response formats:
- Direct message: {"result": {"kind": "message", "parts": [...]}}
- Nested message: {"result": {"message": {"parts": [...]}}}
- Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
- Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
Args:
response: A2A JSON-RPC 2.0 response dict or object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata
Returns:
Modified response with guardrails applied to text content
"""
# Handle both dict and Pydantic model responses
if hasattr(response, "model_dump"):
response_dict = response.model_dump()
is_pydantic = True
elif isinstance(response, dict):
response_dict = response
is_pydantic = False
else:
verbose_proxy_logger.warning(
"A2A: Unknown response type %s, skipping guardrail", type(response)
)
return response
result = response_dict.get("result", {})
if not result or not isinstance(result, dict):
verbose_proxy_logger.debug("A2A: No result in response, skipping guardrail")
return response
# Find all text-containing parts in the response
texts_to_check: List[str] = []
# Each mapping is (path_to_parts_list, part_index)
# path_to_parts_list is a tuple of keys to navigate to the parts list
task_mappings: List[Tuple[Tuple[str, ...], int]] = []
# Extract texts from all possible locations
self._extract_texts_from_result(
result=result,
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
if not texts_to_check:
verbose_proxy_logger.debug("A2A: No text content in response")
return response
# Step 2: Apply guardrail to all texts in batch
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response_dict}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 3: Apply guardrailed text back to original response
if guardrailed_texts and len(guardrailed_texts) == len(task_mappings):
for task_idx, (path, part_idx) in enumerate(task_mappings):
self._apply_text_to_path(
result=result,
path=path,
part_idx=part_idx,
text=guardrailed_texts[task_idx],
)
verbose_proxy_logger.debug("A2A: Processed output response")
# Update the original response
if is_pydantic:
# For Pydantic models, we need to update the underlying dict
# and the model will reflect the changes
response_dict["result"] = result
return response
else:
response["result"] = result
return response
async def process_output_streaming_response(
self,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
) -> List[Any]:
"""
Process A2A streaming output by applying guardrails to accumulated text.
responses_so_far can be a list of JSON-RPC 2.0 objects (dict or NDJSON str), e.g.:
- task with history, status-update, artifact-update (with result.artifact.parts),
- then status-update (final). Text is extracted from result.artifact.parts,
result.message.parts, result.parts, etc., concatenated in order, guardrailed once,
then the combined guardrailed text is written into the first chunk that had text
and all other text parts in other chunks are cleared (in-place).
"""
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
# Parse each item; keep alignment with responses_so_far (None where unparseable)
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
for i, item in enumerate(responses_so_far):
if isinstance(item, dict):
obj = item
elif isinstance(item, str):
try:
obj = json.loads(item.strip())
except (json.JSONDecodeError, TypeError):
continue
else:
continue
if isinstance(obj.get("result"), dict):
parsed[i] = obj
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
if not valid_parsed:
return responses_so_far
# Collect text from each chunk in order (by original index in responses_so_far)
text_parts: List[str] = []
chunk_indices_with_text: List[int] = [] # indices into valid_parsed
for idx, (orig_i, obj) in enumerate(valid_parsed):
t = extract_text_from_a2a_response(obj)
if t:
text_parts.append(t)
chunk_indices_with_text.append(orig_i)
combined_text = "".join(text_parts)
if not combined_text:
return responses_so_far
request_data: dict = {"responses_so_far": responses_so_far}
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=[combined_text])
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
if not guardrailed_texts:
return responses_so_far
guardrailed_text = guardrailed_texts[0]
# Find first chunk (by original index) that has text; put full guardrailed text there and clear rest
first_chunk_with_text: Optional[int] = (
chunk_indices_with_text[0] if chunk_indices_with_text else None
)
for orig_i, obj in valid_parsed:
result = obj.get("result", {})
if not isinstance(result, dict):
continue
texts_in_chunk: List[str] = []
mappings: List[Tuple[Tuple[str, ...], int]] = []
self._extract_texts_from_result(
result=result,
texts_to_check=texts_in_chunk,
task_mappings=mappings,
)
if not mappings:
continue
if orig_i == first_chunk_with_text:
# Put full guardrailed text in first text part; clear others
for task_idx, (path, part_idx) in enumerate(mappings):
text = guardrailed_text if task_idx == 0 else ""
self._apply_text_to_path(
result=result,
path=path,
part_idx=part_idx,
text=text,
)
else:
for path, part_idx in mappings:
self._apply_text_to_path(
result=result,
path=path,
part_idx=part_idx,
text="",
)
# Write back to responses_so_far where we had NDJSON strings
for i, item in enumerate(responses_so_far):
if isinstance(item, str) and parsed[i] is not None:
responses_so_far[i] = json.dumps(parsed[i]) + "\n"
return responses_so_far
def _extract_texts_from_result(
self,
result: Dict[str, Any],
texts_to_check: List[str],
task_mappings: List[Tuple[Tuple[str, ...], int]],
) -> None:
"""
Extract text from all possible locations in an A2A result.
Handles multiple response formats:
1. Direct message with parts: {"parts": [...]}
2. Nested message: {"message": {"parts": [...]}}
3. Task with artifacts: {"artifacts": [{"parts": [...]}]}
4. Task with status message: {"status": {"message": {"parts": [...]}}}
5. Streaming artifact-update: {"artifact": {"parts": [...]}}
"""
# Case 1: Direct parts in result (direct message)
if "parts" in result:
self._extract_texts_from_parts(
parts=result["parts"],
path=("parts",),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 2: Nested message
message = result.get("message")
if message and isinstance(message, dict) and "parts" in message:
self._extract_texts_from_parts(
parts=message["parts"],
path=("message", "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 3: Streaming artifact-update (singular artifact)
artifact = result.get("artifact")
if artifact and isinstance(artifact, dict) and "parts" in artifact:
self._extract_texts_from_parts(
parts=artifact["parts"],
path=("artifact", "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 4: Task with status message
status = result.get("status", {})
if isinstance(status, dict):
status_message = status.get("message")
if (
status_message
and isinstance(status_message, dict)
and "parts" in status_message
):
self._extract_texts_from_parts(
parts=status_message["parts"],
path=("status", "message", "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 5: Task with artifacts (plural, array)
artifacts = result.get("artifacts", [])
if artifacts and isinstance(artifacts, list):
for artifact_idx, art in enumerate(artifacts):
if isinstance(art, dict) and "parts" in art:
self._extract_texts_from_parts(
parts=art["parts"],
path=("artifacts", str(artifact_idx), "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
def _extract_texts_from_parts(
self,
parts: List[Dict[str, Any]],
path: Tuple[str, ...],
texts_to_check: List[str],
task_mappings: List[Tuple[Tuple[str, ...], int]],
) -> None:
"""Extract text from message parts."""
for part_idx, part in enumerate(parts):
if part.get("kind") == "text":
text = part.get("text", "")
if text:
texts_to_check.append(text)
task_mappings.append((path, part_idx))
def _apply_text_to_path(
self,
result: Dict[Union[str, int], Any],
path: Tuple[str, ...],
part_idx: int,
text: str,
) -> None:
"""Apply guardrailed text back to the specified path in the result."""
# Navigate to the parts list
current = result
for key in path:
if key.isdigit():
# Array index
current = current[int(key)]
else:
current = current[key]
# Update the text in the part
current[part_idx]["text"] = text

View File

@@ -0,0 +1,105 @@
"""
A2A Streaming Response Iterator
"""
from typing import Optional, Union
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
from ..common_utils import extract_text_from_a2a_response
class A2AModelResponseIterator(BaseModelResponseIterator):
"""
Iterator for parsing A2A streaming responses.
Converts A2A JSON-RPC streaming chunks to OpenAI-compatible format.
"""
def __init__(
self,
streaming_response,
sync_stream: bool,
json_mode: Optional[bool] = False,
model: str = "a2a/agent",
):
super().__init__(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
self.model = model
def chunk_parser(
self, chunk: dict
) -> Union[GenericStreamingChunk, ModelResponseStream]:
"""
Parse A2A streaming chunk to OpenAI format.
A2A chunk format:
{
"jsonrpc": "2.0",
"id": "request-id",
"result": {
"message": {
"parts": [{"kind": "text", "text": "content"}]
}
}
}
Or for tasks:
{
"jsonrpc": "2.0",
"result": {
"kind": "task",
"status": {"state": "running"},
"artifacts": [{"parts": [{"kind": "text", "text": "content"}]}]
}
}
"""
try:
# Extract text from A2A response
text = extract_text_from_a2a_response(chunk)
# Determine finish reason
finish_reason = self._get_finish_reason(chunk)
# Return generic streaming chunk
return GenericStreamingChunk(
text=text,
is_finished=bool(finish_reason),
finish_reason=finish_reason or "",
usage=None,
index=0,
tool_use=None,
)
except Exception:
# Return empty chunk on parse error
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def _get_finish_reason(self, chunk: dict) -> Optional[str]:
"""Extract finish reason from A2A chunk"""
result = chunk.get("result", {})
# Check for task completion
if isinstance(result, dict):
status = result.get("status", {})
if isinstance(status, dict):
state = status.get("state")
if state == "completed":
return "stop"
elif state == "failed":
return "stop" # Map failed state to 'stop' (valid finish_reason)
# Check for [DONE] marker
if chunk.get("done") is True:
return "stop"
return None

View File

@@ -0,0 +1,373 @@
"""
A2A Protocol Transformation for LiteLLM
"""
import uuid
from typing import Any, Dict, Iterator, List, Optional, Union
import httpx
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse
from ..common_utils import (
A2AError,
convert_messages_to_prompt,
extract_text_from_a2a_response,
)
from .streaming_iterator import A2AModelResponseIterator
class A2AConfig(BaseConfig):
"""
Configuration for A2A (Agent-to-Agent) Protocol.
Handles transformation between OpenAI and A2A JSON-RPC 2.0 formats.
"""
@staticmethod
def resolve_agent_config_from_registry(
model: str,
api_base: Optional[str],
api_key: Optional[str],
headers: Optional[Dict[str, Any]],
optional_params: Dict[str, Any],
) -> tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]:
"""
Resolve agent configuration from registry if model format is "a2a/<agent-name>".
Extracts agent name from model string and looks up configuration in the
agent registry (if available in proxy context).
Args:
model: Model string (e.g., "a2a/my-agent")
api_base: Explicit api_base (takes precedence over registry)
api_key: Explicit api_key (takes precedence over registry)
headers: Explicit headers (takes precedence over registry)
optional_params: Dict to merge additional litellm_params into
Returns:
Tuple of (api_base, api_key, headers) with registry values filled in
"""
# Extract agent name from model (e.g., "a2a/my-agent" -> "my-agent")
agent_name = model.split("/", 1)[1] if "/" in model else None
# Only lookup if agent name exists and some config is missing
if not agent_name or (
api_base is not None and api_key is not None and headers is not None
):
return api_base, api_key, headers
# Try registry lookup (only available in proxy context)
try:
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry,
)
agent = global_agent_registry.get_agent_by_name(agent_name)
if agent:
# Get api_base from agent card URL
if api_base is None and agent.agent_card_params:
api_base = agent.agent_card_params.get("url")
# Get api_key, headers, and other params from litellm_params
if agent.litellm_params:
if api_key is None:
api_key = agent.litellm_params.get("api_key")
if headers is None:
agent_headers = agent.litellm_params.get("headers")
if agent_headers:
headers = agent_headers
# Merge other litellm_params (timeout, max_retries, etc.)
for key, value in agent.litellm_params.items():
if (
key not in ["api_key", "api_base", "headers", "model"]
and key not in optional_params
):
optional_params[key] = value
except ImportError:
pass # Registry not available (not running in proxy context)
return api_base, api_key, headers
def get_supported_openai_params(self, model: str) -> List[str]:
"""Return list of supported OpenAI parameters"""
return [
"stream",
"temperature",
"max_tokens",
"top_p",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to A2A parameters.
For A2A protocol, we need to map the stream parameter so
transform_request can determine which JSON-RPC method to use.
"""
# Map stream parameter
for param, value in non_default_params.items():
if param == "stream" and value is True:
optional_params["stream"] = value
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and set headers for A2A requests.
Args:
headers: Request headers dict
model: Model name
messages: Messages list
optional_params: Optional parameters
litellm_params: LiteLLM parameters
api_key: API key (optional for A2A)
api_base: API base URL
Returns:
Updated headers dict
"""
# Ensure Content-Type is set to application/json for JSON-RPC 2.0
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
# Add Authorization header if API key is provided
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete A2A agent endpoint URL.
A2A agents use JSON-RPC 2.0 at the base URL, not specific paths.
The method (message/send or message/stream) is specified in the
JSON-RPC request body, not in the URL.
Args:
api_base: Base URL of the A2A agent (e.g., "http://0.0.0.0:9999")
api_key: API key (not used for URL construction)
model: Model name (not used for A2A, agent determined by api_base)
optional_params: Optional parameters
litellm_params: LiteLLM parameters
stream: Whether this is a streaming request (affects JSON-RPC method)
Returns:
Complete URL for the A2A endpoint (base URL)
"""
if api_base is None:
raise ValueError("api_base is required for A2A provider")
# A2A uses JSON-RPC 2.0 at the base URL
# Remove trailing slash for consistency
return api_base.rstrip("/")
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform OpenAI request to A2A JSON-RPC 2.0 format.
Args:
model: Model name
messages: List of OpenAI messages
optional_params: Optional parameters
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
A2A JSON-RPC 2.0 request dict
"""
# Generate request ID
request_id = str(uuid.uuid4())
if not messages:
raise ValueError("At least one message is required for A2A completion")
# Convert all messages to maintain conversation history
# Use helper to format conversation with role prefixes
full_context = convert_messages_to_prompt(messages)
# Create single A2A message with full conversation context
a2a_message = {
"role": "user",
"parts": [{"kind": "text", "text": full_context}],
"messageId": str(uuid.uuid4()),
}
# Build JSON-RPC 2.0 request
# For A2A protocol, the method is "message/send" for non-streaming
# and "message/stream" for streaming
stream = optional_params.get("stream", False)
method = "message/stream" if stream else "message/send"
request_data = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": {"message": a2a_message},
}
return request_data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: Any,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform A2A JSON-RPC 2.0 response to OpenAI format.
Args:
model: Model name
raw_response: HTTP response from A2A agent
model_response: Model response object to populate
logging_obj: Logging object
request_data: Original request data
messages: Original messages
optional_params: Optional parameters
litellm_params: LiteLLM parameters
encoding: Encoding object
api_key: API key
json_mode: JSON mode flag
Returns:
Populated ModelResponse object
"""
try:
response_json = raw_response.json()
except Exception as e:
raise A2AError(
status_code=raw_response.status_code,
message=f"Failed to parse A2A response: {str(e)}",
headers=dict(raw_response.headers),
)
# Check for JSON-RPC error
if "error" in response_json:
error = response_json["error"]
raise A2AError(
status_code=raw_response.status_code,
message=f"A2A error: {error.get('message', 'Unknown error')}",
headers=dict(raw_response.headers),
)
# Extract text from A2A response
text = extract_text_from_a2a_response(response_json)
# Populate model response
model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message=Message(
content=text,
role="assistant",
),
)
]
# Set model
model_response.model = model
# Set ID from response
model_response.id = response_json.get("id", str(uuid.uuid4()))
return model_response
def get_model_response_iterator(
self,
streaming_response: Union[Iterator, Any],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> BaseModelResponseIterator:
"""
Get streaming iterator for A2A responses.
Args:
streaming_response: Streaming response iterator
sync_stream: Whether this is a sync stream
json_mode: JSON mode flag
Returns:
A2A streaming iterator
"""
return A2AModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def _openai_message_to_a2a_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert OpenAI message to A2A message format.
Args:
message: OpenAI message dict
Returns:
A2A message dict
"""
content = message.get("content", "")
role = message.get("role", "user")
return {
"role": role,
"parts": [{"kind": "text", "text": str(content)}],
"messageId": str(uuid.uuid4()),
}
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
"""Return appropriate error class for A2A errors"""
# Convert headers to dict if needed
headers_dict = dict(headers) if isinstance(headers, httpx.Headers) else headers
return A2AError(
status_code=status_code,
message=error_message,
headers=headers_dict,
)

View File

@@ -0,0 +1,150 @@
"""
Common utilities for A2A (Agent-to-Agent) Protocol
"""
from typing import Any, Dict, List
from pydantic import BaseModel
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
class A2AError(BaseLLMException):
"""Base exception for A2A protocol errors"""
def __init__(
self,
status_code: int,
message: str,
headers: Dict[str, Any] = {},
):
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
def convert_messages_to_prompt(messages: List[AllMessageValues]) -> str:
"""
Convert OpenAI messages to a single prompt string for A2A agent.
Formats each message as "{role}: {content}" and joins with newlines
to preserve conversation history. Handles both string and list content.
Args:
messages: List of OpenAI-format messages
Returns:
Formatted prompt string with full conversation context
"""
conversation_parts = []
for msg in messages:
# Use LiteLLM's helper to extract text from content (handles both str and list)
content_text = convert_content_list_to_str(message=msg)
# Get role
if isinstance(msg, BaseModel):
role = msg.model_dump().get("role", "user")
elif isinstance(msg, dict):
role = msg.get("role", "user")
else:
role = dict(msg).get("role", "user") # type: ignore
if content_text:
conversation_parts.append(f"{role}: {content_text}")
return "\n".join(conversation_parts)
def extract_text_from_a2a_message(
message: Dict[str, Any], depth: int = 0, max_depth: int = 10
) -> str:
"""
Extract text content from A2A message parts.
Args:
message: A2A message dict with 'parts' containing text parts
depth: Current recursion depth (internal use)
max_depth: Maximum recursion depth to prevent infinite loops
Returns:
Concatenated text from all text parts
"""
if message is None or depth >= max_depth:
return ""
parts = message.get("parts", [])
text_parts: List[str] = []
for part in parts:
if part.get("kind") == "text":
text_parts.append(part.get("text", ""))
# Handle nested parts if they exist
elif "parts" in part:
nested_text = extract_text_from_a2a_message(part, depth + 1, max_depth)
if nested_text:
text_parts.append(nested_text)
return " ".join(text_parts)
def extract_text_from_a2a_response(
response_dict: Dict[str, Any], max_depth: int = 10
) -> str:
"""
Extract text content from A2A response result.
Args:
response_dict: A2A response dict with 'result' containing message
max_depth: Maximum recursion depth to prevent infinite loops
Returns:
Text from response message parts
"""
result = response_dict.get("result", {})
if not isinstance(result, dict):
return ""
# A2A response can have different formats:
# 1. Direct message: {"result": {"kind": "message", "parts": [...]}}
# 2. Nested message: {"result": {"message": {"parts": [...]}}}
# 3. Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
# 4. Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
# 5. Streaming artifact-update: {"result": {"kind": "artifact-update", "artifact": {"parts": [...]}}}
# Check if result itself has parts (direct message)
if "parts" in result:
return extract_text_from_a2a_message(result, depth=0, max_depth=max_depth)
# Check for nested message
message = result.get("message")
if message:
return extract_text_from_a2a_message(message, depth=0, max_depth=max_depth)
# Check for streaming artifact-update (singular artifact)
artifact = result.get("artifact")
if artifact and isinstance(artifact, dict):
return extract_text_from_a2a_message(artifact, depth=0, max_depth=max_depth)
# Check for task status message (common in Gemini A2A agents)
status = result.get("status", {})
if isinstance(status, dict):
status_message = status.get("message")
if status_message:
return extract_text_from_a2a_message(
status_message, depth=0, max_depth=max_depth
)
# Handle task result with artifacts (plural, array)
artifacts = result.get("artifacts", [])
if artifacts and len(artifacts) > 0:
first_artifact = artifacts[0]
return extract_text_from_a2a_message(
first_artifact, depth=0, max_depth=max_depth
)
return ""