""" DB helpers for LiteLLM_ToolTable — the global tool registry. Tools are auto-discovered from LLM responses and upserted here. Admins use the management endpoints to read and update input_policy / output_policy. """ import uuid from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ToolDiscoveryQueueItem from litellm.types.tool_management import ( LiteLLM_ToolTableRow, ToolPolicyOverrideRow, ) if TYPE_CHECKING: from litellm.proxy.utils import PrismaClient def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow: """Convert a Prisma model instance or dict to LiteLLM_ToolTableRow.""" model_dump = getattr(row, "model_dump", None) if callable(model_dump): row = model_dump() elif not isinstance(row, dict): row = { k: getattr(row, k, None) for k in ( "tool_id", "tool_name", "origin", "input_policy", "output_policy", "call_count", "assignments", "key_hash", "team_id", "key_alias", "user_agent", "last_used_at", "created_at", "updated_at", "created_by", "updated_by", ) } return LiteLLM_ToolTableRow( tool_id=row.get("tool_id", ""), tool_name=row.get("tool_name", ""), origin=row.get("origin"), input_policy=row.get("input_policy") or "untrusted", output_policy=row.get("output_policy") or "untrusted", call_count=int(row.get("call_count") or 0), assignments=row.get("assignments"), key_hash=row.get("key_hash"), team_id=row.get("team_id"), key_alias=row.get("key_alias"), user_agent=row.get("user_agent"), last_used_at=row.get("last_used_at"), created_at=row.get("created_at"), updated_at=row.get("updated_at"), created_by=row.get("created_by"), updated_by=row.get("updated_by"), ) async def batch_upsert_tools( prisma_client: "PrismaClient", items: List[ToolDiscoveryQueueItem], ) -> None: """ Batch-upsert tool registry rows via Prisma. On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1. On conflict: increments call_count; preserves existing policies. """ if not items: return try: data = [item for item in items if item.get("tool_name")] if not data: return now = datetime.now(timezone.utc) table = prisma_client.db.litellm_tooltable for item in data: tool_name = item.get("tool_name", "") origin = item.get("origin") or "user_defined" created_by = item.get("created_by") or "system" key_hash = item.get("key_hash") team_id = item.get("team_id") key_alias = item.get("key_alias") user_agent = item.get("user_agent") await table.upsert( where={"tool_name": tool_name}, data={ "create": { "tool_id": str(uuid.uuid4()), "tool_name": tool_name, "origin": origin, "input_policy": "untrusted", "output_policy": "untrusted", "call_count": 1, "created_by": created_by, "updated_by": created_by, "key_hash": key_hash, "team_id": team_id, "key_alias": key_alias, "user_agent": user_agent, "last_used_at": now, }, "update": { "call_count": {"increment": 1}, "updated_at": now, "last_used_at": now, }, }, ) verbose_proxy_logger.debug( "tool_registry_writer: upserted %d tool(s)", len(data) ) except Exception as e: verbose_proxy_logger.error( "tool_registry_writer batch_upsert_tools error: %s", e ) async def list_tools( prisma_client: "PrismaClient", input_policy: Optional[str] = None, ) -> List[LiteLLM_ToolTableRow]: """Return all tools, optionally filtered by input_policy.""" try: where = {"input_policy": input_policy} if input_policy is not None else {} rows = await prisma_client.db.litellm_tooltable.find_many( where=where, order={"created_at": "desc"}, ) return [_row_to_model(row) for row in rows] except Exception as e: verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e) return [] async def get_tool( prisma_client: "PrismaClient", tool_name: str, ) -> Optional[LiteLLM_ToolTableRow]: """Return a single tool row by tool_name.""" try: row = await prisma_client.db.litellm_tooltable.find_unique( where={"tool_name": tool_name}, ) if row is None: return None return _row_to_model(row) except Exception as e: verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e) return None async def update_tool_policy( prisma_client: "PrismaClient", tool_name: str, updated_by: Optional[str], input_policy: Optional[str] = None, output_policy: Optional[str] = None, ) -> Optional[LiteLLM_ToolTableRow]: """Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet.""" try: _updated_by = updated_by or "system" now = datetime.now(timezone.utc) create_data: dict = { "tool_id": str(uuid.uuid4()), "tool_name": tool_name, "input_policy": input_policy or "untrusted", "output_policy": output_policy or "untrusted", "created_by": _updated_by, "updated_by": _updated_by, "created_at": now, "updated_at": now, } update_data: dict = { "updated_by": _updated_by, "updated_at": now, } if input_policy is not None: update_data["input_policy"] = input_policy if output_policy is not None: update_data["output_policy"] = output_policy await prisma_client.db.litellm_tooltable.upsert( where={"tool_name": tool_name}, data={ "create": create_data, "update": update_data, }, ) return await get_tool(prisma_client, tool_name) except Exception as e: verbose_proxy_logger.error( "tool_registry_writer update_tool_policy error: %s", e ) return None async def get_tools_by_names( prisma_client: "PrismaClient", tool_names: List[str], ) -> Dict[str, Tuple[str, str]]: """ Return a {tool_name: (input_policy, output_policy)} map for the given tool names. """ if not tool_names: return {} try: rows = await prisma_client.db.litellm_tooltable.find_many( where={"tool_name": {"in": tool_names}}, ) return { row.tool_name: ( getattr(row, "input_policy", "untrusted") or "untrusted", getattr(row, "output_policy", "untrusted") or "untrusted", ) for row in rows } except Exception as e: verbose_proxy_logger.error( "tool_registry_writer get_tools_by_names error: %s", e ) return {} async def list_overrides_for_tool( prisma_client: "PrismaClient", tool_name: str, ) -> List[ToolPolicyOverrideRow]: """ Return override-like rows for a tool by finding object permissions that have this tool in blocked_tools, then resolving each permission to key/team scope for display. """ out: List[ToolPolicyOverrideRow] = [] try: perms = await prisma_client.db.litellm_objectpermissiontable.find_many( where={"blocked_tools": {"has": tool_name}}, include={ "verification_tokens": True, "teams": True, }, ) for perm in perms: op_id = getattr(perm, "object_permission_id", None) or "" tokens = getattr(perm, "verification_tokens", []) or [] teams = getattr(perm, "teams", []) or [] for t in tokens: out.append( ToolPolicyOverrideRow( override_id=op_id, tool_name=tool_name, team_id=None, key_hash=getattr(t, "token", None), input_policy="blocked", key_alias=getattr(t, "key_alias", None), created_at=None, updated_at=None, ) ) for team in teams: out.append( ToolPolicyOverrideRow( override_id=op_id, tool_name=tool_name, team_id=getattr(team, "team_id", None), key_hash=None, input_policy="blocked", key_alias=getattr(team, "team_alias", None), created_at=None, updated_at=None, ) ) return out except Exception as e: verbose_proxy_logger.error( "tool_registry_writer list_overrides_for_tool error: %s", e ) return [] class ToolPolicyRegistry: """ In-memory registry of tool policies synced from DB. Hot path uses get_effective_policies only — no DB, no cache. """ def __init__(self) -> None: self._tool_input_policies: Dict[str, str] = {} self._tool_output_policies: Dict[str, str] = {} self._blocked_tools_by_op_id: Dict[str, List[str]] = {} self._initialized: bool = False def is_initialized(self) -> bool: return self._initialized async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None: """Load all tool policies and object-permission blocked_tools from DB.""" try: tools = await prisma_client.db.litellm_tooltable.find_many() self._tool_input_policies = { row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted" for row in tools } self._tool_output_policies = { row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted" for row in tools } perms = await prisma_client.db.litellm_objectpermissiontable.find_many() self._blocked_tools_by_op_id = {} for row in perms: op_id = getattr(row, "object_permission_id", None) blocked = getattr(row, "blocked_tools", None) or [] if op_id: self._blocked_tools_by_op_id[op_id] = list(blocked) self._initialized = True verbose_proxy_logger.info( "ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB", len(self._tool_input_policies), len(self._blocked_tools_by_op_id), ) except Exception as e: verbose_proxy_logger.exception( "ToolPolicyRegistry sync_tool_policy_from_db error: %s", e ) raise def get_input_policy(self, tool_name: str) -> str: return self._tool_input_policies.get(tool_name, "untrusted") def get_output_policy(self, tool_name: str) -> str: return self._tool_output_policies.get(tool_name, "untrusted") def get_effective_policies( self, tool_names: List[str], object_permission_id: Optional[str] = None, team_object_permission_id: Optional[str] = None, ) -> Dict[str, str]: """ Return effective input_policy per tool from in-memory state. If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted". """ if not tool_names: return {} blocked: set = set() for op_id in (object_permission_id, team_object_permission_id): if op_id and op_id.strip(): blocked.update(self._blocked_tools_by_op_id.get(op_id.strip(), [])) result: Dict[str, str] = {} for name in tool_names: if name in blocked: result[name] = "blocked" else: result[name] = self._tool_input_policies.get(name, "untrusted") return result _tool_policy_registry: Optional[ToolPolicyRegistry] = None def get_tool_policy_registry() -> ToolPolicyRegistry: """Return the global ToolPolicyRegistry singleton.""" global _tool_policy_registry if _tool_policy_registry is None: _tool_policy_registry = ToolPolicyRegistry() return _tool_policy_registry async def add_tool_to_object_permission_blocked( prisma_client: "PrismaClient", object_permission_id: str, tool_name: str, ) -> bool: """Add tool_name to the permission's blocked_tools if not already present.""" if not object_permission_id or not tool_name: return False try: row = await prisma_client.db.litellm_objectpermissiontable.find_unique( where={"object_permission_id": object_permission_id}, ) if row is None: return False current = list(getattr(row, "blocked_tools", []) or []) if tool_name in current: return True current.append(tool_name) await prisma_client.db.litellm_objectpermissiontable.update( where={"object_permission_id": object_permission_id}, data={"blocked_tools": current}, ) return True except Exception as e: verbose_proxy_logger.error( "tool_registry_writer add_tool_to_object_permission_blocked error: %s", e ) return False async def remove_tool_from_object_permission_blocked( prisma_client: "PrismaClient", object_permission_id: str, tool_name: str, ) -> bool: """Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list.""" if not object_permission_id or not tool_name: return False try: row = await prisma_client.db.litellm_objectpermissiontable.find_unique( where={"object_permission_id": object_permission_id}, ) if row is None: return False current = list(getattr(row, "blocked_tools", []) or []) if tool_name not in current: return False current = [t for t in current if t != tool_name] await prisma_client.db.litellm_objectpermissiontable.update( where={"object_permission_id": object_permission_id}, data={"blocked_tools": current}, ) return True except Exception as e: verbose_proxy_logger.error( "tool_registry_writer remove_tool_from_object_permission_blocked error: %s", e, ) return False