chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) Protocol Provider for LiteLLM
|
||||
"""
|
||||
from .chat.transformation import A2AConfig
|
||||
|
||||
__all__ = ["A2AConfig"]
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
A2A Chat Completion Implementation
|
||||
"""
|
||||
from .transformation import A2AConfig
|
||||
|
||||
__all__ = ["A2AConfig"]
|
||||
@@ -0,0 +1,155 @@
|
||||
# A2A Protocol Guardrail Translation Handler
|
||||
|
||||
Handler for processing A2A (Agent-to-Agent) Protocol messages with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes A2A JSON-RPC 2.0 input/output by:
|
||||
1. Extracting text from message parts (`kind: "text"`)
|
||||
2. Applying guardrails to text content
|
||||
3. Mapping guardrailed text back to original structure
|
||||
|
||||
## A2A Protocol Format
|
||||
|
||||
### Input Format (JSON-RPC 2.0)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "request-id",
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"kind": "message",
|
||||
"messageId": "...",
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"guardrails": ["block-ssn"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
The handler supports multiple A2A response formats:
|
||||
|
||||
**Direct message:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "message",
|
||||
"parts": [{"kind": "text", "text": "Response text"}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Nested message:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"message": {
|
||||
"parts": [{"kind": "text", "text": "Response text"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Task with artifacts:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "task",
|
||||
"artifacts": [
|
||||
{"parts": [{"kind": "text", "text": "Artifact text"}]}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Task with status message:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"message": {
|
||||
"parts": [{"kind": "text", "text": "Status message"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Streaming artifact-update:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "artifact-update",
|
||||
"artifact": {
|
||||
"parts": [{"kind": "text", "text": "Streaming text"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with A2A endpoints.
|
||||
|
||||
### Via LiteLLM Proxy
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/a2a/my-agent' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "1",
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"kind": "message",
|
||||
"messageId": "msg-1",
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}]
|
||||
},
|
||||
"metadata": {
|
||||
"guardrails": ["block-ssn"]
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Specifying Guardrails
|
||||
|
||||
Guardrails can be specified in the A2A request via the `metadata.guardrails` field:
|
||||
|
||||
```json
|
||||
{
|
||||
"params": {
|
||||
"message": {...},
|
||||
"metadata": {
|
||||
"guardrails": ["block-ssn", "pii-filter"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `_extract_texts_from_result()`: Custom text extraction from A2A responses
|
||||
- `_extract_texts_from_parts()`: Custom text extraction from message parts
|
||||
- `_apply_text_to_path()`: Custom application of guardrailed text
|
||||
|
||||
## Call Types
|
||||
|
||||
This handler is registered for:
|
||||
- `CallTypes.send_message`: Synchronous A2A message sending
|
||||
- `CallTypes.asend_message`: Asynchronous A2A message sending
|
||||
@@ -0,0 +1,11 @@
|
||||
"""A2A Protocol handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.a2a.chat.guardrail_translation.handler import A2AGuardrailHandler
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.send_message: A2AGuardrailHandler,
|
||||
CallTypes.asend_message: A2AGuardrailHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings"]
|
||||
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
A2A Protocol Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for A2A (Agent-to-Agent) Protocol.
|
||||
It handles both JSON-RPC 2.0 input requests and output responses, extracting text
|
||||
from message parts and applying guardrails.
|
||||
|
||||
A2A Protocol Format:
|
||||
- Input: JSON-RPC 2.0 with params.message.parts containing text parts
|
||||
- Output: JSON-RPC 2.0 with result containing message/artifact parts
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class A2AGuardrailHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing A2A Protocol messages with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input messages (pre-call hook) - extracts text from A2A message parts
|
||||
2. Process output responses (post-call hook) - extracts text from A2A response parts
|
||||
|
||||
A2A Message Format:
|
||||
- Input: params.message.parts[].text (where kind == "text")
|
||||
- Output: result.message.parts[].text or result.artifacts[].parts[].text
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process A2A input messages by applying guardrails to text content.
|
||||
|
||||
Extracts text from A2A message parts and applies guardrails.
|
||||
|
||||
Args:
|
||||
data: The A2A JSON-RPC 2.0 request data
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to text content
|
||||
"""
|
||||
# A2A request format: { "params": { "message": { "parts": [...] } } }
|
||||
params = data.get("params", {})
|
||||
message = params.get("message", {})
|
||||
parts = message.get("parts", [])
|
||||
|
||||
if not parts:
|
||||
verbose_proxy_logger.debug("A2A: No parts in message, skipping guardrail")
|
||||
return data
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
text_part_indices: List[int] = [] # Track which parts contain text
|
||||
|
||||
# Step 1: Extract text from all text parts
|
||||
for part_idx, part in enumerate(parts):
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
texts_to_check.append(text)
|
||||
text_part_indices.append(part_idx)
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
if texts_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
|
||||
# Pass the structured A2A message to guardrails
|
||||
inputs["structured_messages"] = [message]
|
||||
|
||||
# Include agent model info if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Apply guardrailed text back to original parts
|
||||
if guardrailed_texts and len(guardrailed_texts) == len(text_part_indices):
|
||||
for task_idx, part_idx in enumerate(text_part_indices):
|
||||
parts[part_idx]["text"] = guardrailed_texts[task_idx]
|
||||
|
||||
verbose_proxy_logger.debug("A2A: Processed input message: %s", message)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: Any,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process A2A output response by applying guardrails to text content.
|
||||
|
||||
Handles multiple A2A response formats:
|
||||
- Direct message: {"result": {"kind": "message", "parts": [...]}}
|
||||
- Nested message: {"result": {"message": {"parts": [...]}}}
|
||||
- Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
|
||||
- Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
|
||||
|
||||
Args:
|
||||
response: A2A JSON-RPC 2.0 response dict or object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata
|
||||
|
||||
Returns:
|
||||
Modified response with guardrails applied to text content
|
||||
"""
|
||||
# Handle both dict and Pydantic model responses
|
||||
if hasattr(response, "model_dump"):
|
||||
response_dict = response.model_dump()
|
||||
is_pydantic = True
|
||||
elif isinstance(response, dict):
|
||||
response_dict = response
|
||||
is_pydantic = False
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"A2A: Unknown response type %s, skipping guardrail", type(response)
|
||||
)
|
||||
return response
|
||||
|
||||
result = response_dict.get("result", {})
|
||||
if not result or not isinstance(result, dict):
|
||||
verbose_proxy_logger.debug("A2A: No result in response, skipping guardrail")
|
||||
return response
|
||||
|
||||
# Find all text-containing parts in the response
|
||||
texts_to_check: List[str] = []
|
||||
# Each mapping is (path_to_parts_list, part_index)
|
||||
# path_to_parts_list is a tuple of keys to navigate to the parts list
|
||||
task_mappings: List[Tuple[Tuple[str, ...], int]] = []
|
||||
|
||||
# Extract texts from all possible locations
|
||||
self._extract_texts_from_result(
|
||||
result=result,
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
if not texts_to_check:
|
||||
verbose_proxy_logger.debug("A2A: No text content in response")
|
||||
return response
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response_dict}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Apply guardrailed text back to original response
|
||||
if guardrailed_texts and len(guardrailed_texts) == len(task_mappings):
|
||||
for task_idx, (path, part_idx) in enumerate(task_mappings):
|
||||
self._apply_text_to_path(
|
||||
result=result,
|
||||
path=path,
|
||||
part_idx=part_idx,
|
||||
text=guardrailed_texts[task_idx],
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("A2A: Processed output response")
|
||||
|
||||
# Update the original response
|
||||
if is_pydantic:
|
||||
# For Pydantic models, we need to update the underlying dict
|
||||
# and the model will reflect the changes
|
||||
response_dict["result"] = result
|
||||
return response
|
||||
else:
|
||||
response["result"] = result
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Process A2A streaming output by applying guardrails to accumulated text.
|
||||
|
||||
responses_so_far can be a list of JSON-RPC 2.0 objects (dict or NDJSON str), e.g.:
|
||||
- task with history, status-update, artifact-update (with result.artifact.parts),
|
||||
- then status-update (final). Text is extracted from result.artifact.parts,
|
||||
result.message.parts, result.parts, etc., concatenated in order, guardrailed once,
|
||||
then the combined guardrailed text is written into the first chunk that had text
|
||||
and all other text parts in other chunks are cleared (in-place).
|
||||
"""
|
||||
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
|
||||
|
||||
# Parse each item; keep alignment with responses_so_far (None where unparseable)
|
||||
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
|
||||
for i, item in enumerate(responses_so_far):
|
||||
if isinstance(item, dict):
|
||||
obj = item
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
obj = json.loads(item.strip())
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
if isinstance(obj.get("result"), dict):
|
||||
parsed[i] = obj
|
||||
|
||||
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
|
||||
if not valid_parsed:
|
||||
return responses_so_far
|
||||
|
||||
# Collect text from each chunk in order (by original index in responses_so_far)
|
||||
text_parts: List[str] = []
|
||||
chunk_indices_with_text: List[int] = [] # indices into valid_parsed
|
||||
for idx, (orig_i, obj) in enumerate(valid_parsed):
|
||||
t = extract_text_from_a2a_response(obj)
|
||||
if t:
|
||||
text_parts.append(t)
|
||||
chunk_indices_with_text.append(orig_i)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
if not combined_text:
|
||||
return responses_so_far
|
||||
|
||||
request_data: dict = {"responses_so_far": responses_so_far}
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=[combined_text])
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
if not guardrailed_texts:
|
||||
return responses_so_far
|
||||
guardrailed_text = guardrailed_texts[0]
|
||||
|
||||
# Find first chunk (by original index) that has text; put full guardrailed text there and clear rest
|
||||
first_chunk_with_text: Optional[int] = (
|
||||
chunk_indices_with_text[0] if chunk_indices_with_text else None
|
||||
)
|
||||
|
||||
for orig_i, obj in valid_parsed:
|
||||
result = obj.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
texts_in_chunk: List[str] = []
|
||||
mappings: List[Tuple[Tuple[str, ...], int]] = []
|
||||
self._extract_texts_from_result(
|
||||
result=result,
|
||||
texts_to_check=texts_in_chunk,
|
||||
task_mappings=mappings,
|
||||
)
|
||||
if not mappings:
|
||||
continue
|
||||
if orig_i == first_chunk_with_text:
|
||||
# Put full guardrailed text in first text part; clear others
|
||||
for task_idx, (path, part_idx) in enumerate(mappings):
|
||||
text = guardrailed_text if task_idx == 0 else ""
|
||||
self._apply_text_to_path(
|
||||
result=result,
|
||||
path=path,
|
||||
part_idx=part_idx,
|
||||
text=text,
|
||||
)
|
||||
else:
|
||||
for path, part_idx in mappings:
|
||||
self._apply_text_to_path(
|
||||
result=result,
|
||||
path=path,
|
||||
part_idx=part_idx,
|
||||
text="",
|
||||
)
|
||||
|
||||
# Write back to responses_so_far where we had NDJSON strings
|
||||
for i, item in enumerate(responses_so_far):
|
||||
if isinstance(item, str) and parsed[i] is not None:
|
||||
responses_so_far[i] = json.dumps(parsed[i]) + "\n"
|
||||
|
||||
return responses_so_far
|
||||
|
||||
def _extract_texts_from_result(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
texts_to_check: List[str],
|
||||
task_mappings: List[Tuple[Tuple[str, ...], int]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text from all possible locations in an A2A result.
|
||||
|
||||
Handles multiple response formats:
|
||||
1. Direct message with parts: {"parts": [...]}
|
||||
2. Nested message: {"message": {"parts": [...]}}
|
||||
3. Task with artifacts: {"artifacts": [{"parts": [...]}]}
|
||||
4. Task with status message: {"status": {"message": {"parts": [...]}}}
|
||||
5. Streaming artifact-update: {"artifact": {"parts": [...]}}
|
||||
"""
|
||||
# Case 1: Direct parts in result (direct message)
|
||||
if "parts" in result:
|
||||
self._extract_texts_from_parts(
|
||||
parts=result["parts"],
|
||||
path=("parts",),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 2: Nested message
|
||||
message = result.get("message")
|
||||
if message and isinstance(message, dict) and "parts" in message:
|
||||
self._extract_texts_from_parts(
|
||||
parts=message["parts"],
|
||||
path=("message", "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 3: Streaming artifact-update (singular artifact)
|
||||
artifact = result.get("artifact")
|
||||
if artifact and isinstance(artifact, dict) and "parts" in artifact:
|
||||
self._extract_texts_from_parts(
|
||||
parts=artifact["parts"],
|
||||
path=("artifact", "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 4: Task with status message
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
status_message = status.get("message")
|
||||
if (
|
||||
status_message
|
||||
and isinstance(status_message, dict)
|
||||
and "parts" in status_message
|
||||
):
|
||||
self._extract_texts_from_parts(
|
||||
parts=status_message["parts"],
|
||||
path=("status", "message", "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 5: Task with artifacts (plural, array)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts and isinstance(artifacts, list):
|
||||
for artifact_idx, art in enumerate(artifacts):
|
||||
if isinstance(art, dict) and "parts" in art:
|
||||
self._extract_texts_from_parts(
|
||||
parts=art["parts"],
|
||||
path=("artifacts", str(artifact_idx), "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
def _extract_texts_from_parts(
|
||||
self,
|
||||
parts: List[Dict[str, Any]],
|
||||
path: Tuple[str, ...],
|
||||
texts_to_check: List[str],
|
||||
task_mappings: List[Tuple[Tuple[str, ...], int]],
|
||||
) -> None:
|
||||
"""Extract text from message parts."""
|
||||
for part_idx, part in enumerate(parts):
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
texts_to_check.append(text)
|
||||
task_mappings.append((path, part_idx))
|
||||
|
||||
def _apply_text_to_path(
|
||||
self,
|
||||
result: Dict[Union[str, int], Any],
|
||||
path: Tuple[str, ...],
|
||||
part_idx: int,
|
||||
text: str,
|
||||
) -> None:
|
||||
"""Apply guardrailed text back to the specified path in the result."""
|
||||
# Navigate to the parts list
|
||||
current = result
|
||||
for key in path:
|
||||
if key.isdigit():
|
||||
# Array index
|
||||
current = current[int(key)]
|
||||
else:
|
||||
current = current[key]
|
||||
|
||||
# Update the text in the part
|
||||
current[part_idx]["text"] = text
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
A2A Streaming Response Iterator
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
||||
|
||||
from ..common_utils import extract_text_from_a2a_response
|
||||
|
||||
|
||||
class A2AModelResponseIterator(BaseModelResponseIterator):
|
||||
"""
|
||||
Iterator for parsing A2A streaming responses.
|
||||
|
||||
Converts A2A JSON-RPC streaming chunks to OpenAI-compatible format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_response,
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
model: str = "a2a/agent",
|
||||
):
|
||||
super().__init__(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
self.model = model
|
||||
|
||||
def chunk_parser(
|
||||
self, chunk: dict
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
"""
|
||||
Parse A2A streaming chunk to OpenAI format.
|
||||
|
||||
A2A chunk format:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "request-id",
|
||||
"result": {
|
||||
"message": {
|
||||
"parts": [{"kind": "text", "text": "content"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Or for tasks:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"kind": "task",
|
||||
"status": {"state": "running"},
|
||||
"artifacts": [{"parts": [{"kind": "text", "text": "content"}]}]
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# Extract text from A2A response
|
||||
text = extract_text_from_a2a_response(chunk)
|
||||
|
||||
# Determine finish reason
|
||||
finish_reason = self._get_finish_reason(chunk)
|
||||
|
||||
# Return generic streaming chunk
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
is_finished=bool(finish_reason),
|
||||
finish_reason=finish_reason or "",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except Exception:
|
||||
# Return empty chunk on parse error
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def _get_finish_reason(self, chunk: dict) -> Optional[str]:
|
||||
"""Extract finish reason from A2A chunk"""
|
||||
result = chunk.get("result", {})
|
||||
|
||||
# Check for task completion
|
||||
if isinstance(result, dict):
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
state = status.get("state")
|
||||
if state == "completed":
|
||||
return "stop"
|
||||
elif state == "failed":
|
||||
return "stop" # Map failed state to 'stop' (valid finish_reason)
|
||||
|
||||
# Check for [DONE] marker
|
||||
if chunk.get("done") is True:
|
||||
return "stop"
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
A2A Protocol Transformation for LiteLLM
|
||||
"""
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
from ..common_utils import (
|
||||
A2AError,
|
||||
convert_messages_to_prompt,
|
||||
extract_text_from_a2a_response,
|
||||
)
|
||||
from .streaming_iterator import A2AModelResponseIterator
|
||||
|
||||
|
||||
class A2AConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for A2A (Agent-to-Agent) Protocol.
|
||||
|
||||
Handles transformation between OpenAI and A2A JSON-RPC 2.0 formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def resolve_agent_config_from_registry(
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
headers: Optional[Dict[str, Any]],
|
||||
optional_params: Dict[str, Any],
|
||||
) -> tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Resolve agent configuration from registry if model format is "a2a/<agent-name>".
|
||||
|
||||
Extracts agent name from model string and looks up configuration in the
|
||||
agent registry (if available in proxy context).
|
||||
|
||||
Args:
|
||||
model: Model string (e.g., "a2a/my-agent")
|
||||
api_base: Explicit api_base (takes precedence over registry)
|
||||
api_key: Explicit api_key (takes precedence over registry)
|
||||
headers: Explicit headers (takes precedence over registry)
|
||||
optional_params: Dict to merge additional litellm_params into
|
||||
|
||||
Returns:
|
||||
Tuple of (api_base, api_key, headers) with registry values filled in
|
||||
"""
|
||||
# Extract agent name from model (e.g., "a2a/my-agent" -> "my-agent")
|
||||
agent_name = model.split("/", 1)[1] if "/" in model else None
|
||||
|
||||
# Only lookup if agent name exists and some config is missing
|
||||
if not agent_name or (
|
||||
api_base is not None and api_key is not None and headers is not None
|
||||
):
|
||||
return api_base, api_key, headers
|
||||
|
||||
# Try registry lookup (only available in proxy context)
|
||||
try:
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry,
|
||||
)
|
||||
|
||||
agent = global_agent_registry.get_agent_by_name(agent_name)
|
||||
if agent:
|
||||
# Get api_base from agent card URL
|
||||
if api_base is None and agent.agent_card_params:
|
||||
api_base = agent.agent_card_params.get("url")
|
||||
|
||||
# Get api_key, headers, and other params from litellm_params
|
||||
if agent.litellm_params:
|
||||
if api_key is None:
|
||||
api_key = agent.litellm_params.get("api_key")
|
||||
|
||||
if headers is None:
|
||||
agent_headers = agent.litellm_params.get("headers")
|
||||
if agent_headers:
|
||||
headers = agent_headers
|
||||
|
||||
# Merge other litellm_params (timeout, max_retries, etc.)
|
||||
for key, value in agent.litellm_params.items():
|
||||
if (
|
||||
key not in ["api_key", "api_base", "headers", "model"]
|
||||
and key not in optional_params
|
||||
):
|
||||
optional_params[key] = value
|
||||
except ImportError:
|
||||
pass # Registry not available (not running in proxy context)
|
||||
|
||||
return api_base, api_key, headers
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""Return list of supported OpenAI parameters"""
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to A2A parameters.
|
||||
|
||||
For A2A protocol, we need to map the stream parameter so
|
||||
transform_request can determine which JSON-RPC method to use.
|
||||
"""
|
||||
# Map stream parameter
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set headers for A2A requests.
|
||||
|
||||
Args:
|
||||
headers: Request headers dict
|
||||
model: Model name
|
||||
messages: Messages list
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
api_key: API key (optional for A2A)
|
||||
api_base: API base URL
|
||||
|
||||
Returns:
|
||||
Updated headers dict
|
||||
"""
|
||||
# Ensure Content-Type is set to application/json for JSON-RPC 2.0
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Add Authorization header if API key is provided
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete A2A agent endpoint URL.
|
||||
|
||||
A2A agents use JSON-RPC 2.0 at the base URL, not specific paths.
|
||||
The method (message/send or message/stream) is specified in the
|
||||
JSON-RPC request body, not in the URL.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the A2A agent (e.g., "http://0.0.0.0:9999")
|
||||
api_key: API key (not used for URL construction)
|
||||
model: Model name (not used for A2A, agent determined by api_base)
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
stream: Whether this is a streaming request (affects JSON-RPC method)
|
||||
|
||||
Returns:
|
||||
Complete URL for the A2A endpoint (base URL)
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for A2A provider")
|
||||
|
||||
# A2A uses JSON-RPC 2.0 at the base URL
|
||||
# Remove trailing slash for consistency
|
||||
return api_base.rstrip("/")
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI request to A2A JSON-RPC 2.0 format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
messages: List of OpenAI messages
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
A2A JSON-RPC 2.0 request dict
|
||||
"""
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
if not messages:
|
||||
raise ValueError("At least one message is required for A2A completion")
|
||||
|
||||
# Convert all messages to maintain conversation history
|
||||
# Use helper to format conversation with role prefixes
|
||||
full_context = convert_messages_to_prompt(messages)
|
||||
|
||||
# Create single A2A message with full conversation context
|
||||
a2a_message = {
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": full_context}],
|
||||
"messageId": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Build JSON-RPC 2.0 request
|
||||
# For A2A protocol, the method is "message/send" for non-streaming
|
||||
# and "message/stream" for streaming
|
||||
stream = optional_params.get("stream", False)
|
||||
method = "message/stream" if stream else "message/send"
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": {"message": a2a_message},
|
||||
}
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Any,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform A2A JSON-RPC 2.0 response to OpenAI format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: HTTP response from A2A agent
|
||||
model_response: Model response object to populate
|
||||
logging_obj: Logging object
|
||||
request_data: Original request data
|
||||
messages: Original messages
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
encoding: Encoding object
|
||||
api_key: API key
|
||||
json_mode: JSON mode flag
|
||||
|
||||
Returns:
|
||||
Populated ModelResponse object
|
||||
"""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as e:
|
||||
raise A2AError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Failed to parse A2A response: {str(e)}",
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in response_json:
|
||||
error = response_json["error"]
|
||||
raise A2AError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"A2A error: {error.get('message', 'Unknown error')}",
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
|
||||
# Extract text from A2A response
|
||||
text = extract_text_from_a2a_response(response_json)
|
||||
|
||||
# Populate model response
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
content=text,
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Set model
|
||||
model_response.model = model
|
||||
|
||||
# Set ID from response
|
||||
model_response.id = response_json.get("id", str(uuid.uuid4()))
|
||||
|
||||
return model_response
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator, Any],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> BaseModelResponseIterator:
|
||||
"""
|
||||
Get streaming iterator for A2A responses.
|
||||
|
||||
Args:
|
||||
streaming_response: Streaming response iterator
|
||||
sync_stream: Whether this is a sync stream
|
||||
json_mode: JSON mode flag
|
||||
|
||||
Returns:
|
||||
A2A streaming iterator
|
||||
"""
|
||||
return A2AModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def _openai_message_to_a2a_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert OpenAI message to A2A message format.
|
||||
|
||||
Args:
|
||||
message: OpenAI message dict
|
||||
|
||||
Returns:
|
||||
A2A message dict
|
||||
"""
|
||||
content = message.get("content", "")
|
||||
role = message.get("role", "user")
|
||||
|
||||
return {
|
||||
"role": role,
|
||||
"parts": [{"kind": "text", "text": str(content)}],
|
||||
"messageId": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
"""Return appropriate error class for A2A errors"""
|
||||
# Convert headers to dict if needed
|
||||
headers_dict = dict(headers) if isinstance(headers, httpx.Headers) else headers
|
||||
return A2AError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers_dict,
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Common utilities for A2A (Agent-to-Agent) Protocol
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class A2AError(BaseLLMException):
|
||||
"""Base exception for A2A protocol errors"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Dict[str, Any] = {},
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
def convert_messages_to_prompt(messages: List[AllMessageValues]) -> str:
|
||||
"""
|
||||
Convert OpenAI messages to a single prompt string for A2A agent.
|
||||
|
||||
Formats each message as "{role}: {content}" and joins with newlines
|
||||
to preserve conversation history. Handles both string and list content.
|
||||
|
||||
Args:
|
||||
messages: List of OpenAI-format messages
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with full conversation context
|
||||
"""
|
||||
conversation_parts = []
|
||||
for msg in messages:
|
||||
# Use LiteLLM's helper to extract text from content (handles both str and list)
|
||||
content_text = convert_content_list_to_str(message=msg)
|
||||
|
||||
# Get role
|
||||
if isinstance(msg, BaseModel):
|
||||
role = msg.model_dump().get("role", "user")
|
||||
elif isinstance(msg, dict):
|
||||
role = msg.get("role", "user")
|
||||
else:
|
||||
role = dict(msg).get("role", "user") # type: ignore
|
||||
|
||||
if content_text:
|
||||
conversation_parts.append(f"{role}: {content_text}")
|
||||
|
||||
return "\n".join(conversation_parts)
|
||||
|
||||
|
||||
def extract_text_from_a2a_message(
|
||||
message: Dict[str, Any], depth: int = 0, max_depth: int = 10
|
||||
) -> str:
|
||||
"""
|
||||
Extract text content from A2A message parts.
|
||||
|
||||
Args:
|
||||
message: A2A message dict with 'parts' containing text parts
|
||||
depth: Current recursion depth (internal use)
|
||||
max_depth: Maximum recursion depth to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
Concatenated text from all text parts
|
||||
"""
|
||||
if message is None or depth >= max_depth:
|
||||
return ""
|
||||
|
||||
parts = message.get("parts", [])
|
||||
text_parts: List[str] = []
|
||||
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
# Handle nested parts if they exist
|
||||
elif "parts" in part:
|
||||
nested_text = extract_text_from_a2a_message(part, depth + 1, max_depth)
|
||||
if nested_text:
|
||||
text_parts.append(nested_text)
|
||||
|
||||
return " ".join(text_parts)
|
||||
|
||||
|
||||
def extract_text_from_a2a_response(
|
||||
response_dict: Dict[str, Any], max_depth: int = 10
|
||||
) -> str:
|
||||
"""
|
||||
Extract text content from A2A response result.
|
||||
|
||||
Args:
|
||||
response_dict: A2A response dict with 'result' containing message
|
||||
max_depth: Maximum recursion depth to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
Text from response message parts
|
||||
"""
|
||||
result = response_dict.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
return ""
|
||||
|
||||
# A2A response can have different formats:
|
||||
# 1. Direct message: {"result": {"kind": "message", "parts": [...]}}
|
||||
# 2. Nested message: {"result": {"message": {"parts": [...]}}}
|
||||
# 3. Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
|
||||
# 4. Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
|
||||
# 5. Streaming artifact-update: {"result": {"kind": "artifact-update", "artifact": {"parts": [...]}}}
|
||||
|
||||
# Check if result itself has parts (direct message)
|
||||
if "parts" in result:
|
||||
return extract_text_from_a2a_message(result, depth=0, max_depth=max_depth)
|
||||
|
||||
# Check for nested message
|
||||
message = result.get("message")
|
||||
if message:
|
||||
return extract_text_from_a2a_message(message, depth=0, max_depth=max_depth)
|
||||
|
||||
# Check for streaming artifact-update (singular artifact)
|
||||
artifact = result.get("artifact")
|
||||
if artifact and isinstance(artifact, dict):
|
||||
return extract_text_from_a2a_message(artifact, depth=0, max_depth=max_depth)
|
||||
|
||||
# Check for task status message (common in Gemini A2A agents)
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
status_message = status.get("message")
|
||||
if status_message:
|
||||
return extract_text_from_a2a_message(
|
||||
status_message, depth=0, max_depth=max_depth
|
||||
)
|
||||
|
||||
# Handle task result with artifacts (plural, array)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts and len(artifacts) > 0:
|
||||
first_artifact = artifacts[0]
|
||||
return extract_text_from_a2a_message(
|
||||
first_artifact, depth=0, max_depth=max_depth
|
||||
)
|
||||
|
||||
return ""
|
||||
Reference in New Issue
Block a user