501 lines
18 KiB
Python
501 lines
18 KiB
Python
|
|
"""
|
||
|
|
A2A Protocol endpoints for LiteLLM Proxy.
|
||
|
|
|
||
|
|
Allows clients to invoke agents through LiteLLM using the A2A protocol.
|
||
|
|
The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
from typing import Any, Dict, Optional
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
|
|
||
|
|
from litellm._logging import verbose_proxy_logger
|
||
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
||
|
|
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
||
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||
|
|
from litellm.types.utils import all_litellm_params
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
|
||
|
|
def _jsonrpc_error(
|
||
|
|
request_id: Optional[str],
|
||
|
|
code: int,
|
||
|
|
message: str,
|
||
|
|
status_code: int = 400,
|
||
|
|
) -> JSONResponse:
|
||
|
|
"""Create a JSON-RPC 2.0 error response."""
|
||
|
|
return JSONResponse(
|
||
|
|
content={
|
||
|
|
"jsonrpc": "2.0",
|
||
|
|
"id": request_id,
|
||
|
|
"error": {"code": code, "message": message},
|
||
|
|
},
|
||
|
|
status_code=status_code,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _get_agent(agent_id: str):
|
||
|
|
"""Look up an agent by ID or name. Returns None if not found."""
|
||
|
|
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||
|
|
|
||
|
|
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
|
||
|
|
if agent is None:
|
||
|
|
agent = global_agent_registry.get_agent_by_name(agent_name=agent_id)
|
||
|
|
return agent
|
||
|
|
|
||
|
|
|
||
|
|
def _enforce_inbound_trace_id(agent: Any, request: Request) -> None:
|
||
|
|
"""Raise 400 if agent requires x-litellm-trace-id on inbound calls and it is missing."""
|
||
|
|
agent_litellm_params = agent.litellm_params or {}
|
||
|
|
if not agent_litellm_params.get("require_trace_id_on_calls_to_agent"):
|
||
|
|
return
|
||
|
|
|
||
|
|
from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers
|
||
|
|
|
||
|
|
headers_dict = dict(request.headers)
|
||
|
|
trace_id = get_chain_id_from_headers(headers_dict)
|
||
|
|
if not trace_id:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=400,
|
||
|
|
detail=(
|
||
|
|
f"Agent '{agent.agent_id}' requires x-litellm-trace-id header "
|
||
|
|
"on all inbound requests."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def _handle_stream_message(
|
||
|
|
api_base: Optional[str],
|
||
|
|
request_id: str,
|
||
|
|
params: dict,
|
||
|
|
litellm_params: Optional[dict] = None,
|
||
|
|
agent_id: Optional[str] = None,
|
||
|
|
metadata: Optional[dict] = None,
|
||
|
|
proxy_server_request: Optional[dict] = None,
|
||
|
|
*,
|
||
|
|
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||
|
|
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||
|
|
request_data: Optional[dict] = None,
|
||
|
|
proxy_logging_obj: Optional[Any] = None,
|
||
|
|
) -> StreamingResponse:
|
||
|
|
"""Handle message/stream method via SDK functions.
|
||
|
|
|
||
|
|
When user_api_key_dict, request_data, and proxy_logging_obj are provided,
|
||
|
|
uses common_request_processing.async_streaming_data_generator with NDJSON
|
||
|
|
serializers so proxy hooks and cost injection apply.
|
||
|
|
"""
|
||
|
|
from litellm.a2a_protocol import asend_message_streaming
|
||
|
|
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
|
||
|
|
|
||
|
|
if not A2A_SDK_AVAILABLE:
|
||
|
|
|
||
|
|
async def _error_stream():
|
||
|
|
yield json.dumps(
|
||
|
|
{
|
||
|
|
"jsonrpc": "2.0",
|
||
|
|
"id": request_id,
|
||
|
|
"error": {
|
||
|
|
"code": -32603,
|
||
|
|
"message": "Server error: 'a2a' package not installed",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
) + "\n"
|
||
|
|
|
||
|
|
return StreamingResponse(_error_stream(), media_type="application/x-ndjson")
|
||
|
|
|
||
|
|
from a2a.types import MessageSendParams, SendStreamingMessageRequest
|
||
|
|
|
||
|
|
use_proxy_hooks = (
|
||
|
|
user_api_key_dict is not None
|
||
|
|
and request_data is not None
|
||
|
|
and proxy_logging_obj is not None
|
||
|
|
)
|
||
|
|
|
||
|
|
async def stream_response():
|
||
|
|
try:
|
||
|
|
a2a_request = SendStreamingMessageRequest(
|
||
|
|
id=request_id,
|
||
|
|
params=MessageSendParams(**params),
|
||
|
|
)
|
||
|
|
a2a_stream = asend_message_streaming(
|
||
|
|
request=a2a_request,
|
||
|
|
api_base=api_base,
|
||
|
|
litellm_params=litellm_params,
|
||
|
|
agent_id=agent_id,
|
||
|
|
metadata=metadata,
|
||
|
|
proxy_server_request=proxy_server_request,
|
||
|
|
agent_extra_headers=agent_extra_headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
if (
|
||
|
|
use_proxy_hooks
|
||
|
|
and user_api_key_dict is not None
|
||
|
|
and request_data is not None
|
||
|
|
and proxy_logging_obj is not None
|
||
|
|
):
|
||
|
|
from litellm.proxy.common_request_processing import (
|
||
|
|
ProxyBaseLLMRequestProcessing,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _ndjson_chunk(chunk: Any) -> str:
|
||
|
|
if hasattr(chunk, "model_dump"):
|
||
|
|
obj = chunk.model_dump(mode="json", exclude_none=True)
|
||
|
|
else:
|
||
|
|
obj = chunk
|
||
|
|
return json.dumps(obj) + "\n"
|
||
|
|
|
||
|
|
def _ndjson_error(proxy_exc: Any) -> str:
|
||
|
|
return (
|
||
|
|
json.dumps(
|
||
|
|
{
|
||
|
|
"jsonrpc": "2.0",
|
||
|
|
"id": request_id,
|
||
|
|
"error": {
|
||
|
|
"code": -32603,
|
||
|
|
"message": getattr(
|
||
|
|
proxy_exc,
|
||
|
|
"message",
|
||
|
|
f"Streaming error: {proxy_exc!s}",
|
||
|
|
),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
+ "\n"
|
||
|
|
)
|
||
|
|
|
||
|
|
async for (
|
||
|
|
line
|
||
|
|
) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||
|
|
response=a2a_stream,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
request_data=request_data,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
serialize_chunk=_ndjson_chunk,
|
||
|
|
serialize_error=_ndjson_error,
|
||
|
|
):
|
||
|
|
yield line
|
||
|
|
else:
|
||
|
|
async for chunk in a2a_stream:
|
||
|
|
if hasattr(chunk, "model_dump"):
|
||
|
|
yield json.dumps(
|
||
|
|
chunk.model_dump(mode="json", exclude_none=True)
|
||
|
|
) + "\n"
|
||
|
|
else:
|
||
|
|
yield json.dumps(chunk) + "\n"
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
|
||
|
|
if (
|
||
|
|
use_proxy_hooks
|
||
|
|
and proxy_logging_obj is not None
|
||
|
|
and user_api_key_dict is not None
|
||
|
|
and request_data is not None
|
||
|
|
):
|
||
|
|
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
original_exception=e,
|
||
|
|
request_data=request_data,
|
||
|
|
)
|
||
|
|
if transformed_exception is not None:
|
||
|
|
e = transformed_exception
|
||
|
|
if isinstance(e, HTTPException):
|
||
|
|
raise
|
||
|
|
yield json.dumps(
|
||
|
|
{
|
||
|
|
"jsonrpc": "2.0",
|
||
|
|
"id": request_id,
|
||
|
|
"error": {"code": -32603, "message": f"Streaming error: {str(e)}"},
|
||
|
|
}
|
||
|
|
) + "\n"
|
||
|
|
|
||
|
|
return StreamingResponse(stream_response(), media_type="application/x-ndjson")
|
||
|
|
|
||
|
|
|
||
|
|
@router.get(
|
||
|
|
"/a2a/{agent_id}/.well-known/agent-card.json",
|
||
|
|
tags=["[beta] A2A Agents"],
|
||
|
|
dependencies=[Depends(user_api_key_auth)],
|
||
|
|
)
|
||
|
|
@router.get(
|
||
|
|
"/a2a/{agent_id}/.well-known/agent.json",
|
||
|
|
tags=["[beta] A2A Agents"],
|
||
|
|
dependencies=[Depends(user_api_key_auth)],
|
||
|
|
)
|
||
|
|
async def get_agent_card(
|
||
|
|
agent_id: str,
|
||
|
|
request: Request,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Get the agent card for an agent (A2A discovery endpoint).
|
||
|
|
|
||
|
|
Supports both standard paths:
|
||
|
|
- /.well-known/agent-card.json
|
||
|
|
- /.well-known/agent.json
|
||
|
|
|
||
|
|
The URL in the agent card is rewritten to point to the LiteLLM proxy,
|
||
|
|
so all subsequent A2A calls go through LiteLLM for logging and cost tracking.
|
||
|
|
"""
|
||
|
|
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||
|
|
AgentRequestHandler,
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
agent = _get_agent(agent_id)
|
||
|
|
if agent is None:
|
||
|
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||
|
|
|
||
|
|
# Check agent permission (skip for admin users)
|
||
|
|
is_allowed = await AgentRequestHandler.is_agent_allowed(
|
||
|
|
agent_id=agent.agent_id,
|
||
|
|
user_api_key_auth=user_api_key_dict,
|
||
|
|
)
|
||
|
|
if not is_allowed:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=403,
|
||
|
|
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Copy and rewrite URL to point to LiteLLM proxy
|
||
|
|
agent_card = dict(agent.agent_card_params)
|
||
|
|
agent_card["url"] = f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}"
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Returning agent card for '{agent_id}' with proxy URL: {agent_card['url']}"
|
||
|
|
)
|
||
|
|
return JSONResponse(content=agent_card)
|
||
|
|
|
||
|
|
except HTTPException:
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(f"Error getting agent card: {e}")
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
|
||
|
|
@router.post(
|
||
|
|
"/a2a/{agent_id}",
|
||
|
|
tags=["[beta] A2A Agents"],
|
||
|
|
dependencies=[Depends(user_api_key_auth)],
|
||
|
|
)
|
||
|
|
@router.post(
|
||
|
|
"/a2a/{agent_id}/message/send",
|
||
|
|
tags=["[beta] A2A Agents"],
|
||
|
|
dependencies=[Depends(user_api_key_auth)],
|
||
|
|
)
|
||
|
|
@router.post(
|
||
|
|
"/v1/a2a/{agent_id}/message/send",
|
||
|
|
tags=["[beta] A2A Agents"],
|
||
|
|
dependencies=[Depends(user_api_key_auth)],
|
||
|
|
)
|
||
|
|
async def invoke_agent_a2a( # noqa: PLR0915
|
||
|
|
agent_id: str,
|
||
|
|
request: Request,
|
||
|
|
fastapi_response: Response,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Invoke an agent using the A2A protocol (JSON-RPC 2.0).
|
||
|
|
|
||
|
|
Supported methods:
|
||
|
|
- message/send: Send a message and get a response
|
||
|
|
- message/stream: Send a message and stream the response
|
||
|
|
"""
|
||
|
|
from litellm.a2a_protocol import asend_message
|
||
|
|
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
|
||
|
|
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||
|
|
AgentRequestHandler,
|
||
|
|
)
|
||
|
|
from litellm.proxy.proxy_server import (
|
||
|
|
general_settings,
|
||
|
|
proxy_config,
|
||
|
|
proxy_logging_obj,
|
||
|
|
version,
|
||
|
|
)
|
||
|
|
|
||
|
|
body = {}
|
||
|
|
try:
|
||
|
|
body = await request.json()
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(f"A2A request for agent '{agent_id}': {body}")
|
||
|
|
|
||
|
|
# Validate JSON-RPC format
|
||
|
|
if body.get("jsonrpc") != "2.0":
|
||
|
|
return _jsonrpc_error(
|
||
|
|
body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'"
|
||
|
|
)
|
||
|
|
|
||
|
|
request_id = body.get("id")
|
||
|
|
method = body.get("method")
|
||
|
|
params = body.get("params", {})
|
||
|
|
|
||
|
|
if params:
|
||
|
|
# extract any litellm params from the params - eg. 'guardrails'
|
||
|
|
params_to_remove = []
|
||
|
|
for key, value in params.items():
|
||
|
|
if key in all_litellm_params:
|
||
|
|
params_to_remove.append(key)
|
||
|
|
body[key] = value
|
||
|
|
for key in params_to_remove:
|
||
|
|
params.pop(key)
|
||
|
|
|
||
|
|
if not A2A_SDK_AVAILABLE:
|
||
|
|
return _jsonrpc_error(
|
||
|
|
request_id,
|
||
|
|
-32603,
|
||
|
|
"Server error: 'a2a' package not installed. Please install 'a2a-sdk'.",
|
||
|
|
500,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Find the agent
|
||
|
|
agent = _get_agent(agent_id)
|
||
|
|
if agent is None:
|
||
|
|
return _jsonrpc_error(
|
||
|
|
request_id, -32000, f"Agent '{agent_id}' not found", 404
|
||
|
|
)
|
||
|
|
|
||
|
|
is_allowed = await AgentRequestHandler.is_agent_allowed(
|
||
|
|
agent_id=agent.agent_id,
|
||
|
|
user_api_key_auth=user_api_key_dict,
|
||
|
|
)
|
||
|
|
if not is_allowed:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=403,
|
||
|
|
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
|
||
|
|
)
|
||
|
|
|
||
|
|
_enforce_inbound_trace_id(agent, request)
|
||
|
|
|
||
|
|
# Get backend URL and agent name
|
||
|
|
agent_url = agent.agent_card_params.get("url")
|
||
|
|
agent_name = agent.agent_card_params.get("name", agent_id)
|
||
|
|
|
||
|
|
# Get litellm_params (may include custom_llm_provider for completion bridge)
|
||
|
|
litellm_params = agent.litellm_params or {}
|
||
|
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||
|
|
|
||
|
|
# URL is required unless using completion bridge with a provider that derives endpoint from model
|
||
|
|
# (e.g., bedrock/agentcore derives endpoint from ARN in model string)
|
||
|
|
if not agent_url and not custom_llm_provider:
|
||
|
|
return _jsonrpc_error(
|
||
|
|
request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Set up data dict for litellm processing
|
||
|
|
if "metadata" not in body:
|
||
|
|
body["metadata"] = {}
|
||
|
|
body["metadata"]["agent_id"] = agent.agent_id
|
||
|
|
|
||
|
|
body.update(
|
||
|
|
{
|
||
|
|
"model": f"a2a_agent/{agent_name}",
|
||
|
|
"custom_llm_provider": "a2a_agent",
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add litellm data (user_api_key, user_id, team_id, etc.)
|
||
|
|
from litellm.proxy.common_request_processing import (
|
||
|
|
ProxyBaseLLMRequestProcessing,
|
||
|
|
)
|
||
|
|
|
||
|
|
processor = ProxyBaseLLMRequestProcessing(data=body)
|
||
|
|
data, logging_obj = await processor.common_processing_pre_call_logic(
|
||
|
|
request=request,
|
||
|
|
general_settings=general_settings,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
proxy_config=proxy_config,
|
||
|
|
route_type="asend_message",
|
||
|
|
version=version,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Build merged headers for the backend agent
|
||
|
|
static_headers: Dict[str, str] = dict(agent.static_headers or {})
|
||
|
|
|
||
|
|
raw_headers = dict(request.headers)
|
||
|
|
normalized = {k.lower(): v for k, v in raw_headers.items()}
|
||
|
|
|
||
|
|
dynamic_headers: Dict[str, str] = {}
|
||
|
|
|
||
|
|
# 1. Admin-configured extra_headers: forward named headers from client request
|
||
|
|
if agent.extra_headers:
|
||
|
|
for header_name in agent.extra_headers:
|
||
|
|
val = normalized.get(header_name.lower())
|
||
|
|
if val is not None:
|
||
|
|
dynamic_headers[header_name] = val
|
||
|
|
|
||
|
|
# 2. Convention-based forwarding: x-a2a-{agent_id_or_name}-{header_name}
|
||
|
|
# Matches both agent_id (UUID) and agent_name (alias), case-insensitive.
|
||
|
|
for alias in (agent.agent_id.lower(), agent.agent_name.lower()):
|
||
|
|
prefix = f"x-a2a-{alias}-"
|
||
|
|
for key, val in normalized.items():
|
||
|
|
if key.startswith(prefix):
|
||
|
|
header_name = key[len(prefix) :]
|
||
|
|
if header_name:
|
||
|
|
dynamic_headers[header_name] = val
|
||
|
|
|
||
|
|
agent_extra_headers = merge_agent_headers(
|
||
|
|
dynamic_headers=dynamic_headers or None,
|
||
|
|
static_headers=static_headers or None,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Route through SDK functions
|
||
|
|
if method == "message/send":
|
||
|
|
from a2a.types import MessageSendParams, SendMessageRequest
|
||
|
|
|
||
|
|
a2a_request = SendMessageRequest(
|
||
|
|
id=request_id,
|
||
|
|
params=MessageSendParams(**params),
|
||
|
|
)
|
||
|
|
response = await asend_message(
|
||
|
|
request=a2a_request,
|
||
|
|
api_base=agent_url,
|
||
|
|
litellm_params=litellm_params,
|
||
|
|
agent_id=agent.agent_id,
|
||
|
|
metadata=data.get("metadata", {}),
|
||
|
|
proxy_server_request=data.get("proxy_server_request"),
|
||
|
|
litellm_logging_obj=logging_obj,
|
||
|
|
agent_extra_headers=agent_extra_headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
response = await proxy_logging_obj.post_call_success_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
data=data,
|
||
|
|
response=response,
|
||
|
|
)
|
||
|
|
return JSONResponse(
|
||
|
|
content=(
|
||
|
|
response.model_dump(mode="json", exclude_none=True) # type: ignore
|
||
|
|
if hasattr(response, "model_dump")
|
||
|
|
else response
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
elif method == "message/stream":
|
||
|
|
return await _handle_stream_message(
|
||
|
|
api_base=agent_url,
|
||
|
|
request_id=request_id,
|
||
|
|
params=params,
|
||
|
|
litellm_params=litellm_params,
|
||
|
|
agent_id=agent.agent_id,
|
||
|
|
metadata=data.get("metadata", {}),
|
||
|
|
proxy_server_request=data.get("proxy_server_request"),
|
||
|
|
agent_extra_headers=agent_extra_headers,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
request_data=data,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")
|
||
|
|
|
||
|
|
except HTTPException:
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(f"Error invoking agent: {e}")
|
||
|
|
return _jsonrpc_error(body.get("id"), -32603, f"Internal error: {str(e)}", 500)
|