441 lines
15 KiB
Python
441 lines
15 KiB
Python
|
|
"""
|
||
|
|
DB helpers for LiteLLM_ToolTable — the global tool registry.
|
||
|
|
|
||
|
|
Tools are auto-discovered from LLM responses and upserted here.
|
||
|
|
Admins use the management endpoints to read and update input_policy / output_policy.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import uuid
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||
|
|
|
||
|
|
from litellm._logging import verbose_proxy_logger
|
||
|
|
from litellm.proxy._types import ToolDiscoveryQueueItem
|
||
|
|
from litellm.types.tool_management import (
|
||
|
|
LiteLLM_ToolTableRow,
|
||
|
|
ToolPolicyOverrideRow,
|
||
|
|
)
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from litellm.proxy.utils import PrismaClient
|
||
|
|
|
||
|
|
|
||
|
|
def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow:
|
||
|
|
"""Convert a Prisma model instance or dict to LiteLLM_ToolTableRow."""
|
||
|
|
model_dump = getattr(row, "model_dump", None)
|
||
|
|
if callable(model_dump):
|
||
|
|
row = model_dump()
|
||
|
|
elif not isinstance(row, dict):
|
||
|
|
row = {
|
||
|
|
k: getattr(row, k, None)
|
||
|
|
for k in (
|
||
|
|
"tool_id",
|
||
|
|
"tool_name",
|
||
|
|
"origin",
|
||
|
|
"input_policy",
|
||
|
|
"output_policy",
|
||
|
|
"call_count",
|
||
|
|
"assignments",
|
||
|
|
"key_hash",
|
||
|
|
"team_id",
|
||
|
|
"key_alias",
|
||
|
|
"user_agent",
|
||
|
|
"last_used_at",
|
||
|
|
"created_at",
|
||
|
|
"updated_at",
|
||
|
|
"created_by",
|
||
|
|
"updated_by",
|
||
|
|
)
|
||
|
|
}
|
||
|
|
return LiteLLM_ToolTableRow(
|
||
|
|
tool_id=row.get("tool_id", ""),
|
||
|
|
tool_name=row.get("tool_name", ""),
|
||
|
|
origin=row.get("origin"),
|
||
|
|
input_policy=row.get("input_policy") or "untrusted",
|
||
|
|
output_policy=row.get("output_policy") or "untrusted",
|
||
|
|
call_count=int(row.get("call_count") or 0),
|
||
|
|
assignments=row.get("assignments"),
|
||
|
|
key_hash=row.get("key_hash"),
|
||
|
|
team_id=row.get("team_id"),
|
||
|
|
key_alias=row.get("key_alias"),
|
||
|
|
user_agent=row.get("user_agent"),
|
||
|
|
last_used_at=row.get("last_used_at"),
|
||
|
|
created_at=row.get("created_at"),
|
||
|
|
updated_at=row.get("updated_at"),
|
||
|
|
created_by=row.get("created_by"),
|
||
|
|
updated_by=row.get("updated_by"),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def batch_upsert_tools(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
items: List[ToolDiscoveryQueueItem],
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Batch-upsert tool registry rows via Prisma.
|
||
|
|
|
||
|
|
On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1.
|
||
|
|
On conflict: increments call_count; preserves existing policies.
|
||
|
|
"""
|
||
|
|
if not items:
|
||
|
|
return
|
||
|
|
try:
|
||
|
|
data = [item for item in items if item.get("tool_name")]
|
||
|
|
if not data:
|
||
|
|
return
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
table = prisma_client.db.litellm_tooltable
|
||
|
|
for item in data:
|
||
|
|
tool_name = item.get("tool_name", "")
|
||
|
|
origin = item.get("origin") or "user_defined"
|
||
|
|
created_by = item.get("created_by") or "system"
|
||
|
|
key_hash = item.get("key_hash")
|
||
|
|
team_id = item.get("team_id")
|
||
|
|
key_alias = item.get("key_alias")
|
||
|
|
user_agent = item.get("user_agent")
|
||
|
|
await table.upsert(
|
||
|
|
where={"tool_name": tool_name},
|
||
|
|
data={
|
||
|
|
"create": {
|
||
|
|
"tool_id": str(uuid.uuid4()),
|
||
|
|
"tool_name": tool_name,
|
||
|
|
"origin": origin,
|
||
|
|
"input_policy": "untrusted",
|
||
|
|
"output_policy": "untrusted",
|
||
|
|
"call_count": 1,
|
||
|
|
"created_by": created_by,
|
||
|
|
"updated_by": created_by,
|
||
|
|
"key_hash": key_hash,
|
||
|
|
"team_id": team_id,
|
||
|
|
"key_alias": key_alias,
|
||
|
|
"user_agent": user_agent,
|
||
|
|
"last_used_at": now,
|
||
|
|
},
|
||
|
|
"update": {
|
||
|
|
"call_count": {"increment": 1},
|
||
|
|
"updated_at": now,
|
||
|
|
"last_used_at": now,
|
||
|
|
},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"tool_registry_writer: upserted %d tool(s)", len(data)
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"tool_registry_writer batch_upsert_tools error: %s", e
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def list_tools(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
input_policy: Optional[str] = None,
|
||
|
|
) -> List[LiteLLM_ToolTableRow]:
|
||
|
|
"""Return all tools, optionally filtered by input_policy."""
|
||
|
|
try:
|
||
|
|
where = {"input_policy": input_policy} if input_policy is not None else {}
|
||
|
|
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||
|
|
where=where,
|
||
|
|
order={"created_at": "desc"},
|
||
|
|
)
|
||
|
|
return [_row_to_model(row) for row in rows]
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e)
|
||
|
|
return []
|
||
|
|
|
||
|
|
|
||
|
|
async def get_tool(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
tool_name: str,
|
||
|
|
) -> Optional[LiteLLM_ToolTableRow]:
|
||
|
|
"""Return a single tool row by tool_name."""
|
||
|
|
try:
|
||
|
|
row = await prisma_client.db.litellm_tooltable.find_unique(
|
||
|
|
where={"tool_name": tool_name},
|
||
|
|
)
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return _row_to_model(row)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e)
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
async def update_tool_policy(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
tool_name: str,
|
||
|
|
updated_by: Optional[str],
|
||
|
|
input_policy: Optional[str] = None,
|
||
|
|
output_policy: Optional[str] = None,
|
||
|
|
) -> Optional[LiteLLM_ToolTableRow]:
|
||
|
|
"""Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet."""
|
||
|
|
try:
|
||
|
|
_updated_by = updated_by or "system"
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
|
||
|
|
create_data: dict = {
|
||
|
|
"tool_id": str(uuid.uuid4()),
|
||
|
|
"tool_name": tool_name,
|
||
|
|
"input_policy": input_policy or "untrusted",
|
||
|
|
"output_policy": output_policy or "untrusted",
|
||
|
|
"created_by": _updated_by,
|
||
|
|
"updated_by": _updated_by,
|
||
|
|
"created_at": now,
|
||
|
|
"updated_at": now,
|
||
|
|
}
|
||
|
|
update_data: dict = {
|
||
|
|
"updated_by": _updated_by,
|
||
|
|
"updated_at": now,
|
||
|
|
}
|
||
|
|
if input_policy is not None:
|
||
|
|
update_data["input_policy"] = input_policy
|
||
|
|
if output_policy is not None:
|
||
|
|
update_data["output_policy"] = output_policy
|
||
|
|
|
||
|
|
await prisma_client.db.litellm_tooltable.upsert(
|
||
|
|
where={"tool_name": tool_name},
|
||
|
|
data={
|
||
|
|
"create": create_data,
|
||
|
|
"update": update_data,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return await get_tool(prisma_client, tool_name)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"tool_registry_writer update_tool_policy error: %s", e
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
async def get_tools_by_names(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
tool_names: List[str],
|
||
|
|
) -> Dict[str, Tuple[str, str]]:
|
||
|
|
"""
|
||
|
|
Return a {tool_name: (input_policy, output_policy)} map for the given tool names.
|
||
|
|
"""
|
||
|
|
if not tool_names:
|
||
|
|
return {}
|
||
|
|
try:
|
||
|
|
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||
|
|
where={"tool_name": {"in": tool_names}},
|
||
|
|
)
|
||
|
|
return {
|
||
|
|
row.tool_name: (
|
||
|
|
getattr(row, "input_policy", "untrusted") or "untrusted",
|
||
|
|
getattr(row, "output_policy", "untrusted") or "untrusted",
|
||
|
|
)
|
||
|
|
for row in rows
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"tool_registry_writer get_tools_by_names error: %s", e
|
||
|
|
)
|
||
|
|
return {}
|
||
|
|
|
||
|
|
|
||
|
|
async def list_overrides_for_tool(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
tool_name: str,
|
||
|
|
) -> List[ToolPolicyOverrideRow]:
|
||
|
|
"""
|
||
|
|
Return override-like rows for a tool by finding object permissions that have
|
||
|
|
this tool in blocked_tools, then resolving each permission to key/team scope for display.
|
||
|
|
"""
|
||
|
|
out: List[ToolPolicyOverrideRow] = []
|
||
|
|
try:
|
||
|
|
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
|
||
|
|
where={"blocked_tools": {"has": tool_name}},
|
||
|
|
include={
|
||
|
|
"verification_tokens": True,
|
||
|
|
"teams": True,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
for perm in perms:
|
||
|
|
op_id = getattr(perm, "object_permission_id", None) or ""
|
||
|
|
tokens = getattr(perm, "verification_tokens", []) or []
|
||
|
|
teams = getattr(perm, "teams", []) or []
|
||
|
|
for t in tokens:
|
||
|
|
out.append(
|
||
|
|
ToolPolicyOverrideRow(
|
||
|
|
override_id=op_id,
|
||
|
|
tool_name=tool_name,
|
||
|
|
team_id=None,
|
||
|
|
key_hash=getattr(t, "token", None),
|
||
|
|
input_policy="blocked",
|
||
|
|
key_alias=getattr(t, "key_alias", None),
|
||
|
|
created_at=None,
|
||
|
|
updated_at=None,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
for team in teams:
|
||
|
|
out.append(
|
||
|
|
ToolPolicyOverrideRow(
|
||
|
|
override_id=op_id,
|
||
|
|
tool_name=tool_name,
|
||
|
|
team_id=getattr(team, "team_id", None),
|
||
|
|
key_hash=None,
|
||
|
|
input_policy="blocked",
|
||
|
|
key_alias=getattr(team, "team_alias", None),
|
||
|
|
created_at=None,
|
||
|
|
updated_at=None,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return out
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"tool_registry_writer list_overrides_for_tool error: %s", e
|
||
|
|
)
|
||
|
|
return []
|
||
|
|
|
||
|
|
|
||
|
|
class ToolPolicyRegistry:
|
||
|
|
"""
|
||
|
|
In-memory registry of tool policies synced from DB.
|
||
|
|
Hot path uses get_effective_policies only — no DB, no cache.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._tool_input_policies: Dict[str, str] = {}
|
||
|
|
self._tool_output_policies: Dict[str, str] = {}
|
||
|
|
self._blocked_tools_by_op_id: Dict[str, List[str]] = {}
|
||
|
|
self._initialized: bool = False
|
||
|
|
|
||
|
|
def is_initialized(self) -> bool:
|
||
|
|
return self._initialized
|
||
|
|
|
||
|
|
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
|
||
|
|
"""Load all tool policies and object-permission blocked_tools from DB."""
|
||
|
|
try:
|
||
|
|
tools = await prisma_client.db.litellm_tooltable.find_many()
|
||
|
|
self._tool_input_policies = {
|
||
|
|
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
|
||
|
|
for row in tools
|
||
|
|
}
|
||
|
|
self._tool_output_policies = {
|
||
|
|
row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted"
|
||
|
|
for row in tools
|
||
|
|
}
|
||
|
|
|
||
|
|
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
|
||
|
|
self._blocked_tools_by_op_id = {}
|
||
|
|
for row in perms:
|
||
|
|
op_id = getattr(row, "object_permission_id", None)
|
||
|
|
blocked = getattr(row, "blocked_tools", None) or []
|
||
|
|
if op_id:
|
||
|
|
self._blocked_tools_by_op_id[op_id] = list(blocked)
|
||
|
|
|
||
|
|
self._initialized = True
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB",
|
||
|
|
len(self._tool_input_policies),
|
||
|
|
len(self._blocked_tools_by_op_id),
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
"ToolPolicyRegistry sync_tool_policy_from_db error: %s", e
|
||
|
|
)
|
||
|
|
raise
|
||
|
|
|
||
|
|
def get_input_policy(self, tool_name: str) -> str:
|
||
|
|
return self._tool_input_policies.get(tool_name, "untrusted")
|
||
|
|
|
||
|
|
def get_output_policy(self, tool_name: str) -> str:
|
||
|
|
return self._tool_output_policies.get(tool_name, "untrusted")
|
||
|
|
|
||
|
|
def get_effective_policies(
|
||
|
|
self,
|
||
|
|
tool_names: List[str],
|
||
|
|
object_permission_id: Optional[str] = None,
|
||
|
|
team_object_permission_id: Optional[str] = None,
|
||
|
|
) -> Dict[str, str]:
|
||
|
|
"""
|
||
|
|
Return effective input_policy per tool from in-memory state.
|
||
|
|
If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted".
|
||
|
|
"""
|
||
|
|
if not tool_names:
|
||
|
|
return {}
|
||
|
|
blocked: set = set()
|
||
|
|
for op_id in (object_permission_id, team_object_permission_id):
|
||
|
|
if op_id and op_id.strip():
|
||
|
|
blocked.update(self._blocked_tools_by_op_id.get(op_id.strip(), []))
|
||
|
|
result: Dict[str, str] = {}
|
||
|
|
for name in tool_names:
|
||
|
|
if name in blocked:
|
||
|
|
result[name] = "blocked"
|
||
|
|
else:
|
||
|
|
result[name] = self._tool_input_policies.get(name, "untrusted")
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
_tool_policy_registry: Optional[ToolPolicyRegistry] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_tool_policy_registry() -> ToolPolicyRegistry:
|
||
|
|
"""Return the global ToolPolicyRegistry singleton."""
|
||
|
|
global _tool_policy_registry
|
||
|
|
if _tool_policy_registry is None:
|
||
|
|
_tool_policy_registry = ToolPolicyRegistry()
|
||
|
|
return _tool_policy_registry
|
||
|
|
|
||
|
|
|
||
|
|
async def add_tool_to_object_permission_blocked(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
object_permission_id: str,
|
||
|
|
tool_name: str,
|
||
|
|
) -> bool:
|
||
|
|
"""Add tool_name to the permission's blocked_tools if not already present."""
|
||
|
|
if not object_permission_id or not tool_name:
|
||
|
|
return False
|
||
|
|
try:
|
||
|
|
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||
|
|
where={"object_permission_id": object_permission_id},
|
||
|
|
)
|
||
|
|
if row is None:
|
||
|
|
return False
|
||
|
|
current = list(getattr(row, "blocked_tools", []) or [])
|
||
|
|
if tool_name in current:
|
||
|
|
return True
|
||
|
|
current.append(tool_name)
|
||
|
|
await prisma_client.db.litellm_objectpermissiontable.update(
|
||
|
|
where={"object_permission_id": object_permission_id},
|
||
|
|
data={"blocked_tools": current},
|
||
|
|
)
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"tool_registry_writer add_tool_to_object_permission_blocked error: %s", e
|
||
|
|
)
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
async def remove_tool_from_object_permission_blocked(
|
||
|
|
prisma_client: "PrismaClient",
|
||
|
|
object_permission_id: str,
|
||
|
|
tool_name: str,
|
||
|
|
) -> bool:
|
||
|
|
"""Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list."""
|
||
|
|
if not object_permission_id or not tool_name:
|
||
|
|
return False
|
||
|
|
try:
|
||
|
|
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||
|
|
where={"object_permission_id": object_permission_id},
|
||
|
|
)
|
||
|
|
if row is None:
|
||
|
|
return False
|
||
|
|
current = list(getattr(row, "blocked_tools", []) or [])
|
||
|
|
if tool_name not in current:
|
||
|
|
return False
|
||
|
|
current = [t for t in current if t != tool_name]
|
||
|
|
await prisma_client.db.litellm_objectpermissiontable.update(
|
||
|
|
where={"object_permission_id": object_permission_id},
|
||
|
|
data={"blocked_tools": current},
|
||
|
|
)
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"tool_registry_writer remove_tool_from_object_permission_blocked error: %s",
|
||
|
|
e,
|
||
|
|
)
|
||
|
|
return False
|