# +-----------------------------------------------+ # | | # | PII Masking | # | with Microsoft Presidio | # | https://github.com/BerriAI/litellm/issues/ | # +-----------------------------------------------+ # # Tell us how we can improve! - Krrish & Ishaan import asyncio import json import threading from contextlib import asynccontextmanager from datetime import datetime from typing import ( TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple, Union, cast, ) import aiohttp import litellm # noqa: E401 from litellm import get_secret from litellm._logging import verbose_proxy_logger from litellm.types.utils import GenericGuardrailAPIInputs if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.caching.caching import DualCache from litellm.exceptions import BlockedPiiEntityError, GuardrailRaisedException from litellm.integrations.custom_guardrail import ( CustomGuardrail, log_guardrail_information, ) from litellm.proxy._types import UserAPIKeyAuth from litellm.types.guardrails import ( GuardrailEventHooks, LitellmParams, PiiAction, PiiEntityType, PresidioPerRequestConfig, ) from litellm.types.proxy.guardrails.guardrail_hooks.presidio import ( PresidioAnalyzeRequest, PresidioAnalyzeResponseItem, ) from litellm.types.utils import GuardrailStatus, StreamingChoices from litellm.utils import ( EmbeddingResponse, ImageResponse, ModelResponse, ModelResponseStream, ) class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): user_api_key_cache = None ad_hoc_recognizers = None # Class variables or attributes def __init__( self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None, presidio_analyzer_api_base: Optional[str] = None, presidio_anonymizer_api_base: Optional[str] = None, output_parse_pii: Optional[bool] = False, apply_to_output: bool = False, presidio_ad_hoc_recognizers: Optional[str] = None, logging_only: Optional[bool] = None, pii_entities_config: Optional[ Dict[Union[PiiEntityType, str], PiiAction] ] = None, presidio_language: Optional[str] = None, presidio_score_thresholds: Optional[ Dict[Union[PiiEntityType, str], float] ] = None, presidio_entities_deny_list: Optional[List[Union[PiiEntityType, str]]] = None, **kwargs, ): if logging_only is True: self.logging_only = True kwargs["event_hook"] = GuardrailEventHooks.logging_only super().__init__(**kwargs) self.guardrail_provider = "presidio" self.pii_tokens: dict = ( {} ) # mapping of PII token to original text - only used with Presidio `replace` operation self.mock_redacted_text = mock_redacted_text self.output_parse_pii = output_parse_pii or False self.apply_to_output = apply_to_output # When output_parse_pii or apply_to_output is enabled, the guardrail must # also run on post_call to unmask/mask the response. Expand the event_hook # so should_run_guardrail returns True for both pre_call and post_call. if (self.output_parse_pii or self.apply_to_output) and not logging_only: current_hook = self.event_hook if isinstance(current_hook, str) and current_hook != "post_call": self.event_hook = cast(List[GuardrailEventHooks], [current_hook, "post_call"]) elif isinstance(current_hook, list) and "post_call" not in current_hook: self.event_hook = cast(List[GuardrailEventHooks], current_hook + ["post_call"]) self.pii_entities_config: Dict[Union[PiiEntityType, str], PiiAction] = ( pii_entities_config or {} ) self.presidio_score_thresholds: Dict[Union[PiiEntityType, str], float] = ( presidio_score_thresholds or {} ) self.presidio_entities_deny_list: List[Union[PiiEntityType, str]] = ( presidio_entities_deny_list or [] ) self.presidio_language = presidio_language or "en" # Shared HTTP session to prevent memory leaks (issue #14540) self._http_session: Optional[aiohttp.ClientSession] = None # Lock to prevent race conditions when creating session under concurrent load # Note: asyncio.Lock() can be created without an event loop; it only needs one when awaited self._session_lock: asyncio.Lock = asyncio.Lock() # Track main thread ID to safely identity when we are running in main loop vs background thread self._main_thread_id = threading.get_ident() # Loop-bound session cache for background threads self._loop_sessions: Dict[asyncio.AbstractEventLoop, aiohttp.ClientSession] = {} if mock_testing is True: # for testing purposes only return ad_hoc_recognizers = presidio_ad_hoc_recognizers if ad_hoc_recognizers is not None: try: with open(ad_hoc_recognizers, "r") as file: self.ad_hoc_recognizers = json.load(file) except FileNotFoundError: raise Exception(f"File not found. file_path={ad_hoc_recognizers}") except json.JSONDecodeError as e: raise Exception( f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}" ) except Exception as e: raise Exception( f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}" ) self.validate_environment( presidio_analyzer_api_base=presidio_analyzer_api_base, presidio_anonymizer_api_base=presidio_anonymizer_api_base, ) def validate_environment( self, presidio_analyzer_api_base: Optional[str] = None, presidio_anonymizer_api_base: Optional[str] = None, ): self.presidio_analyzer_api_base: Optional[ str ] = presidio_analyzer_api_base or get_secret( "PRESIDIO_ANALYZER_API_BASE", None ) # type: ignore self.presidio_anonymizer_api_base: Optional[ str ] = presidio_anonymizer_api_base or litellm.get_secret( "PRESIDIO_ANONYMIZER_API_BASE", None ) # type: ignore if self.presidio_analyzer_api_base is None: raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment") if not self.presidio_analyzer_api_base.endswith("/"): self.presidio_analyzer_api_base += "/" if not ( self.presidio_analyzer_api_base.startswith("http://") or self.presidio_analyzer_api_base.startswith("https://") ): # add http:// if unset, assume communicating over private network - e.g. render self.presidio_analyzer_api_base = ( "http://" + self.presidio_analyzer_api_base ) if self.presidio_anonymizer_api_base is None: raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment") if not self.presidio_anonymizer_api_base.endswith("/"): self.presidio_anonymizer_api_base += "/" if not ( self.presidio_anonymizer_api_base.startswith("http://") or self.presidio_anonymizer_api_base.startswith("https://") ): # add http:// if unset, assume communicating over private network - e.g. render self.presidio_anonymizer_api_base = ( "http://" + self.presidio_anonymizer_api_base ) @asynccontextmanager async def _get_session_iterator( self, ) -> AsyncGenerator[aiohttp.ClientSession, None]: """ Async context manager for yielding an HTTP session. Logic: 1. If running in the main thread (where the object was initialized/destined to live normally), use the shared `self._http_session` (protected by a lock). 2. If running in a background thread (e.g. logging hook), use a cached session for that loop. """ current_loop = asyncio.get_running_loop() # Check if we are in the stored main thread if threading.get_ident() == self._main_thread_id: # Main thread -> use shared session async with self._session_lock: if self._http_session is None or self._http_session.closed: self._http_session = aiohttp.ClientSession() yield self._http_session else: # Background thread/loop -> use loop-bound session cache # This avoids "attached to a different loop" or "no running event loop" errors # when accessing the shared session created in the main loop if ( current_loop not in self._loop_sessions or self._loop_sessions[current_loop].closed ): self._loop_sessions[current_loop] = aiohttp.ClientSession() yield self._loop_sessions[current_loop] async def _close_http_session(self) -> None: """Close all cached HTTP sessions.""" if self._http_session is not None and not self._http_session.closed: await self._http_session.close() self._http_session = None for session in self._loop_sessions.values(): if not session.closed: await session.close() self._loop_sessions.clear() def __del__(self): """Cleanup: we try to close, but doing async cleanup in __del__ is risky.""" pass def _has_block_action(self) -> bool: """Return True if pii_entities_config has any BLOCK action (fail-closed on analyzer errors).""" if not self.pii_entities_config: return False return any( action == PiiAction.BLOCK for action in self.pii_entities_config.values() ) def _get_presidio_analyze_request_payload( self, text: str, presidio_config: Optional[PresidioPerRequestConfig], request_data: dict, ) -> PresidioAnalyzeRequest: """ Construct the payload for the Presidio analyze request API Ref: https://microsoft.github.io/presidio/api-docs/api-docs.html#tag/Analyzer/paths/~1analyze/post """ analyze_payload: PresidioAnalyzeRequest = PresidioAnalyzeRequest( text=text, language=self.presidio_language, ) ################################################################## ###### Check if user has configured any params for this guardrail ################################################################ if self.ad_hoc_recognizers is not None: analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers if self.pii_entities_config: analyze_payload["entities"] = list(self.pii_entities_config.keys()) ################################################################## ######### End of adding config params ################################################################## # Check if client side request passed any dynamic params if presidio_config and presidio_config.language: analyze_payload["language"] = presidio_config.language casted_analyze_payload: dict = cast(dict, analyze_payload) casted_analyze_payload.update( self.get_guardrail_dynamic_request_body_params(request_data=request_data) ) return cast(PresidioAnalyzeRequest, casted_analyze_payload) async def analyze_text( self, text: str, presidio_config: Optional[PresidioPerRequestConfig], request_data: dict, ) -> Union[List[PresidioAnalyzeResponseItem], Dict]: """ Send text to the Presidio analyzer endpoint and get analysis results """ try: # Skip empty or whitespace-only text to avoid Presidio errors # Common in tool/function calling where assistant content is empty if not text or len(text.strip()) == 0: verbose_proxy_logger.debug( "Skipping Presidio analysis for empty/whitespace-only text" ) return [] if self.mock_redacted_text is not None: return self.mock_redacted_text # Use shared session to prevent memory leak (issue #14540) async with self._get_session_iterator() as session: # Make the request to /analyze analyze_url = f"{self.presidio_analyzer_api_base}analyze" analyze_payload: PresidioAnalyzeRequest = ( self._get_presidio_analyze_request_payload( text=text, presidio_config=presidio_config, request_data=request_data, ) ) verbose_proxy_logger.debug( "Making request to: %s with payload: %s", analyze_url, analyze_payload, ) def _fail_on_invalid_response( reason: str, ) -> List[PresidioAnalyzeResponseItem]: should_fail_closed = ( bool(self.pii_entities_config) or self.output_parse_pii or self.apply_to_output ) if should_fail_closed: raise GuardrailRaisedException( guardrail_name=self.guardrail_name, message=f"Presidio analyzer returned invalid response; cannot verify PII when PII protection is configured: {reason}", should_wrap_with_default_message=False, ) verbose_proxy_logger.warning( "Presidio analyzer %s, returning empty list", reason ) return [] async with session.post( analyze_url, json=analyze_payload, headers={"Accept": "application/json"}, ) as response: # Validate HTTP status if response.status >= 400: error_body = await response.text() return _fail_on_invalid_response( f"HTTP {response.status} from Presidio analyzer: {error_body[:200]}" ) # Validate Content-Type is JSON content_type = getattr( response, "content_type", response.headers.get("Content-Type", ""), ) if "application/json" not in content_type: error_body = await response.text() return _fail_on_invalid_response( f"expected application/json Content-Type but received '{content_type}'; body: '{error_body[:200]}'" ) analyze_results = await response.json() verbose_proxy_logger.debug("analyze_results: %s", analyze_results) # Handle error responses from Presidio (e.g., {'error': 'No text provided'}) # Presidio may return a dict instead of a list when errors occur if isinstance(analyze_results, dict): if "error" in analyze_results: return _fail_on_invalid_response( f"error: {analyze_results.get('error')}" ) # If it's a dict but not an error, try to process it as a single item verbose_proxy_logger.debug( "Presidio returned dict (not list), attempting to process as single item" ) try: return [PresidioAnalyzeResponseItem(**analyze_results)] except Exception as e: return _fail_on_invalid_response( f"failed to parse dict response: {e}" ) # Handle unexpected types (str, None, etc.) - e.g. from malformed/error if not isinstance(analyze_results, list): return _fail_on_invalid_response( f"unexpected type {type(analyze_results).__name__} (expected list or dict), response: {str(analyze_results)[:200]}" ) # Normal case: list of results final_results = [] for item in analyze_results: if not isinstance(item, dict): verbose_proxy_logger.warning( "Skipping invalid Presidio result item (expected dict, got %s): %s", type(item).__name__, str(item)[:100], ) continue try: final_results.append(PresidioAnalyzeResponseItem(**item)) except Exception as e: verbose_proxy_logger.warning( "Failed to parse Presidio result item: %s (error: %s)", item, e, ) continue return final_results except GuardrailRaisedException: # Re-raise GuardrailRaisedException without wrapping raise except Exception as e: # Sanitize exception to avoid leaking the original text (which may # contain API keys or other secrets) in error responses. raise Exception(f"Presidio PII analysis failed: {type(e).__name__}") from e async def anonymize_text( self, text: str, analyze_results: Any, output_parse_pii: bool, masked_entity_count: Dict[str, int], request_data: Optional[Dict] = None, ) -> str: """ Send analysis results to the Presidio anonymizer endpoint to get redacted text """ try: # If there are no detections after filtering, return the original text if isinstance(analyze_results, list) and len(analyze_results) == 0: return text # Use shared session to prevent memory leak (issue #14540) async with self._get_session_iterator() as session: # Make the request to /anonymize anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize" verbose_proxy_logger.debug("Making request to: %s", anonymize_url) anonymize_payload = { "text": text, "analyzer_results": analyze_results, } async with session.post( anonymize_url, json=anonymize_payload, headers={"Accept": "application/json"}, ) as response: # Validate HTTP status if response.status >= 400: error_body = await response.text() raise Exception( f"Presidio anonymizer returned HTTP {response.status}: {error_body[:200]}" ) # Validate Content-Type is JSON content_type = getattr( response, "content_type", response.headers.get("Content-Type", ""), ) if "application/json" not in content_type: error_body = await response.text() raise Exception( f"Presidio anonymizer returned non-JSON Content-Type '{content_type}'; body: '{error_body[:200]}'" ) redacted_text = await response.json() new_text = text if redacted_text is not None: verbose_proxy_logger.debug("redacted_text: %s", redacted_text) # Process items in reverse order by start position so that # replacing later spans first does not shift earlier coordinates. for item in sorted( redacted_text["items"], key=lambda x: x["start"], reverse=True ): start = item["start"] end = item["end"] replacement = item["text"] # replacement token if item["operator"] == "replace" and output_parse_pii is True: if request_data is None: verbose_proxy_logger.warning( "Presidio anonymize_text called without request_data — " "PII tokens cannot be stored per-request. " "This may indicate a missing caller update." ) request_data = {} # Store pii_tokens in metadata to avoid leaking to LLM providers. # Providers like Anthropic reject unknown top-level fields. if not request_data.get("metadata"): request_data["metadata"] = {} if "pii_tokens" not in request_data["metadata"]: request_data["metadata"]["pii_tokens"] = {} pii_tokens = request_data["metadata"]["pii_tokens"] # Append a sequential number to make each token unique # per request, so unmasking maps back to the correct # original value. Format: , # This is LLM-friendly and degrades gracefully if the # LLM doesn't echo the token verbatim. seq = len(pii_tokens) + 1 if replacement.endswith(">"): replacement = f"{replacement[:-1]}_{seq}>" else: replacement = f"{replacement}_{seq}" # Use ORIGINAL text (not new_text) since start/end # reference the original text's coordinates. pii_tokens[replacement] = text[start:end] new_text = new_text[:start] + replacement + new_text[end:] entity_type = item.get("entity_type", None) if entity_type is not None: masked_entity_count[entity_type] = ( masked_entity_count.get(entity_type, 0) + 1 ) # When output_parse_pii is True, new_text contains sequentially # numbered tokens (e.g. ) that match the keys # in pii_tokens. Returning redacted_text["text"] (Presidio's # original output) would send un-numbered tokens to the LLM, # making unmasking impossible. # When output_parse_pii is False, new_text == redacted_text["text"] # because no suffix is appended. return new_text else: raise Exception("Invalid anonymizer response: received None") except Exception as e: # Sanitize exception to avoid leaking the original text (which may # contain API keys or other secrets) in error responses. error_str = str(e) if ( "Invalid anonymizer response" in error_str or "Presidio anonymizer returned" in error_str ): raise raise Exception( f"Presidio PII anonymization failed: {type(e).__name__}" ) from e def filter_analyze_results_by_score( self, analyze_results: Union[List[PresidioAnalyzeResponseItem], Dict] ) -> Union[List[PresidioAnalyzeResponseItem], Dict]: """ Drop detections that fall below configured per-entity score thresholds or match an entity type in the deny list. """ if not self.presidio_score_thresholds and not self.presidio_entities_deny_list: return analyze_results if not isinstance(analyze_results, list): return analyze_results filtered_results: List[PresidioAnalyzeResponseItem] = [] deny_list_strings = [ getattr(x, "value", str(x)) for x in self.presidio_entities_deny_list ] for item in analyze_results: entity_type = item.get("entity_type") str_entity_type = str( getattr(entity_type, "value", entity_type) if entity_type is not None else entity_type ) if entity_type and str_entity_type in deny_list_strings: continue if self.presidio_score_thresholds: score = item.get("score") threshold = None if entity_type is not None: threshold = self.presidio_score_thresholds.get(entity_type) if threshold is None: threshold = self.presidio_score_thresholds.get("ALL") if threshold is not None: if score is None or score < threshold: continue filtered_results.append(item) return filtered_results def raise_exception_if_blocked_entities_detected( self, analyze_results: Union[List[PresidioAnalyzeResponseItem], Dict] ): """ Raise an exception if blocked entities are detected """ if self.pii_entities_config is None: return if isinstance(analyze_results, Dict): # if mock testing is enabled, analyze_results is a dict # we don't need to raise an exception in this case return for result in analyze_results: entity_type = result.get("entity_type") if entity_type: # Check if entity_type is in config (supports both enum and string) if ( entity_type in self.pii_entities_config and self.pii_entities_config[entity_type] == PiiAction.BLOCK ): raise BlockedPiiEntityError( entity_type=entity_type, guardrail_name=self.guardrail_name, ) async def check_pii( self, text: str, output_parse_pii: bool, presidio_config: Optional[PresidioPerRequestConfig], request_data: dict, ) -> str: """ Calls Presidio Analyze + Anonymize endpoints for PII Analysis + Masking """ start_time = datetime.now() analyze_results: Optional[Union[List[PresidioAnalyzeResponseItem], Dict]] = None status: GuardrailStatus = "success" masked_entity_count: Dict[str, int] = {} exception_str: str = "" try: if self.mock_redacted_text is not None: redacted_text = self.mock_redacted_text else: # First get analysis results analyze_results = await self.analyze_text( text=text, presidio_config=presidio_config, request_data=request_data, ) verbose_proxy_logger.debug("analyze_results: %s", analyze_results) # Apply score threshold filtering if configured analyze_results = self.filter_analyze_results_by_score( analyze_results=analyze_results ) #################################################### # Blocked Entities check #################################################### self.raise_exception_if_blocked_entities_detected( analyze_results=analyze_results ) # Then anonymize the text using the analysis results anonymized_text = await self.anonymize_text( text=text, analyze_results=analyze_results, output_parse_pii=output_parse_pii, masked_entity_count=masked_entity_count, request_data=request_data, ) return anonymized_text return redacted_text["text"] except Exception as e: status = "guardrail_failed_to_respond" exception_str = str(e) raise e finally: #################################################### # Create Guardrail Trace for logging on Langfuse, Datadog, etc. #################################################### guardrail_json_response: Union[Exception, str, dict, List[dict]] = {} if status == "success": if isinstance(analyze_results, List): guardrail_json_response = [dict(item) for item in analyze_results] else: guardrail_json_response = exception_str self.add_standard_logging_guardrail_information_to_request_data( guardrail_provider=self.guardrail_provider, guardrail_json_response=guardrail_json_response, request_data=request_data, guardrail_status=status, start_time=start_time.timestamp(), end_time=datetime.now().timestamp(), duration=(datetime.now() - start_time).total_seconds(), masked_entity_count=masked_entity_count, ) async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): """ - Check if request turned off pii - Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls') - Take the request data - Call /analyze -> get the results - Call /anonymize w/ the analyze results -> get the redacted text For multiple messages in /chat/completions, we'll need to call them in parallel. """ try: content_safety = data.get("content_safety", None) verbose_proxy_logger.debug("content_safety: %s", content_safety) presidio_config = self.get_presidio_settings_from_request_data(data) messages = data.get("messages", None) if messages is None: return data tasks = [] task_mappings: List[ Tuple[int, Optional[int]] ] = [] # Track (message_index, content_index) for each task for msg_idx, m in enumerate(messages): content = m.get("content", None) if content is None: continue if isinstance(content, str): tasks.append( self.check_pii( text=content, output_parse_pii=self.output_parse_pii, presidio_config=presidio_config, request_data=data, ) ) task_mappings.append( (msg_idx, None) ) # None indicates string content elif isinstance(content, list): for content_idx, c in enumerate(content): text_str = c.get("text", None) if text_str is None: continue tasks.append( self.check_pii( text=text_str, output_parse_pii=self.output_parse_pii, presidio_config=presidio_config, request_data=data, ) ) task_mappings.append((msg_idx, int(content_idx))) responses = await asyncio.gather(*tasks) # Map responses back to the correct message and content item for task_idx, r in enumerate(responses): mapping = task_mappings[task_idx] msg_idx = cast(int, mapping[0]) content_idx_optional = cast(Optional[int], mapping[1]) content = messages[msg_idx].get("content", None) if content is None: continue if isinstance(content, str) and content_idx_optional is None: messages[msg_idx][ "content" ] = r # replace content with redacted string elif isinstance(content, list) and content_idx_optional is not None: messages[msg_idx]["content"][content_idx_optional]["text"] = r verbose_proxy_logger.debug( f"Presidio PII Masking: Redacted pii message: {data['messages']}" ) data["messages"] = messages return data except Exception as e: raise e def logging_hook( self, kwargs: dict, result: Any, call_type: str ) -> Tuple[dict, Any]: from concurrent.futures import ThreadPoolExecutor def run_in_new_loop(): """Run the coroutine in a new event loop within this thread.""" new_loop = asyncio.new_event_loop() try: asyncio.set_event_loop(new_loop) return new_loop.run_until_complete( self.async_logging_hook( kwargs=kwargs, result=result, call_type=call_type ) ) finally: new_loop.close() asyncio.set_event_loop(None) try: # First, try to get the current event loop _ = asyncio.get_running_loop() # If we're already in an event loop, run in a separate thread # to avoid nested event loop issues with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(run_in_new_loop) return future.result() except RuntimeError: # No running event loop, we can safely run in this thread return run_in_new_loop() async def async_logging_hook( self, kwargs: dict, result: Any, call_type: str ) -> Tuple[dict, Any]: """ Masks the input before logging to langfuse, datadog, etc. """ if ( call_type == "completion" or call_type == "acompletion" ): # /chat/completions requests messages: Optional[List] = kwargs.get("messages", None) tasks = [] task_mappings: List[ Tuple[int, Optional[int]] ] = [] # Track (message_index, content_index) for each task if messages is None: return kwargs, result presidio_config = self.get_presidio_settings_from_request_data(kwargs) for msg_idx, m in enumerate(messages): content = m.get("content", None) if content is None: continue if isinstance(content, str): tasks.append( self.check_pii( text=content, output_parse_pii=False, presidio_config=presidio_config, request_data=kwargs, ) ) # need to pass separately b/c presidio has context window limits task_mappings.append( (msg_idx, None) ) # None indicates string content elif isinstance(content, list): for content_idx, c in enumerate(content): text_str = c.get("text", None) if text_str is None: continue tasks.append( self.check_pii( text=text_str, output_parse_pii=False, presidio_config=presidio_config, request_data=kwargs, ) ) task_mappings.append((msg_idx, int(content_idx))) responses = await asyncio.gather(*tasks) # Map responses back to the correct message and content item for task_idx, r in enumerate(responses): mapping = task_mappings[task_idx] msg_idx = cast(int, mapping[0]) content_idx_optional = cast(Optional[int], mapping[1]) content = messages[msg_idx].get("content", None) if content is None: continue if isinstance(content, str) and content_idx_optional is None: messages[msg_idx][ "content" ] = r # replace content with redacted string elif isinstance(content, list) and content_idx_optional is not None: messages[msg_idx]["content"][content_idx_optional]["text"] = r verbose_proxy_logger.debug( f"Presidio PII Masking: Redacted pii message: {messages}" ) kwargs["messages"] = messages return kwargs, result async def async_post_call_success_hook( # type: ignore self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Union[ModelResponse, EmbeddingResponse, ImageResponse], ): """ Output parse the response object to replace the masked tokens with user sent values """ verbose_proxy_logger.debug( f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" ) if self.apply_to_output is True: if self._is_anthropic_message_response(response): return await self._process_anthropic_response_for_pii( response=cast(dict, response), request_data=data, mode="mask" ) return await self._mask_output_response( response=response, request_data=data ) if self.output_parse_pii is False and litellm.output_parse_pii is False: return response if isinstance(response, ModelResponse) and not isinstance( response.choices[0], StreamingChoices ): # /chat/completions requests await self._process_response_for_pii( response=response, request_data=data, mode="unmask", ) elif self._is_anthropic_message_response(response): await self._process_anthropic_response_for_pii( response=cast(dict, response), request_data=data, mode="unmask" ) return response @staticmethod def _unmask_pii_text(text: str, pii_tokens: Dict[str, str]) -> str: """ Replace PII tokens in *text* with their original values. Includes a fallback for tokens that were truncated by ``max_tokens``: if the *end* of ``text`` matches the *beginning* of a token and the overlap is long enough, the truncated suffix is replaced with the original value. The minimum overlap length is ``min(20, len(token) // 2)`` to reduce the risk of false positives when multiple tokens share a common prefix. """ for token, original_text in pii_tokens.items(): if token in text: text = text.replace(token, original_text) else: # FALLBACK: Handle truncated tokens (token cut off by max_tokens) # Only check at the very end of the text. min_overlap = min(20, len(token) // 2) for i in range(max(0, len(text) - len(token)), len(text)): sub = text[i:] if token.startswith(sub) and len(sub) >= min_overlap: text = text[:i] + original_text break return text @staticmethod def _is_anthropic_message_response(response: Any) -> bool: """Check if the response is an Anthropic native message dict.""" return ( isinstance(response, dict) and response.get("type") == "message" and isinstance(response.get("content"), list) ) async def _process_anthropic_response_for_pii( self, response: dict, request_data: dict, mode: Literal["mask", "unmask"], ) -> dict: """ Process an Anthropic native message dict for PII masking/unmasking. Handles content blocks with type == "text". """ metadata = (request_data.get("metadata") or {}) if request_data else {} pii_tokens = metadata.get("pii_tokens", {}) if not pii_tokens and mode == "unmask": verbose_proxy_logger.debug( "No pii_tokens in metadata for Anthropic response unmask" ) presidio_config = self.get_presidio_settings_from_request_data( request_data or {} ) content = response.get("content") if not isinstance(content, list): return response for block in content: if not isinstance(block, dict) or block.get("type") != "text": continue text_value = block.get("text") if text_value is None: continue if mode == "unmask": block["text"] = self._unmask_pii_text(text_value, pii_tokens) elif mode == "mask": block["text"] = await self.check_pii( text=text_value, output_parse_pii=False, presidio_config=presidio_config, request_data=request_data, ) return response async def _process_response_for_pii( self, response: ModelResponse, request_data: dict, mode: Literal["mask", "unmask"], ) -> ModelResponse: """ Helper to recursively process a ModelResponse for PII. Handles all choices and tool calls. """ metadata = (request_data.get("metadata") or {}) if request_data else {} pii_tokens = metadata.get("pii_tokens", {}) if not pii_tokens and mode == "unmask": verbose_proxy_logger.debug( "No pii_tokens found in request_data['metadata'] — nothing to unmask" ) presidio_config = self.get_presidio_settings_from_request_data( request_data or {} ) for choice in response.choices: message = getattr(choice, "message", None) if message is None: continue # 1. Process content content = getattr(message, "content", None) if isinstance(content, str): if mode == "unmask": message.content = self._unmask_pii_text(content, pii_tokens) elif mode == "mask": message.content = await self.check_pii( text=content, output_parse_pii=False, presidio_config=presidio_config, request_data=request_data, ) elif isinstance(content, list): for item in content: if not isinstance(item, dict): continue text_value = item.get("text") if text_value is None: continue if mode == "unmask": item["text"] = self._unmask_pii_text(text_value, pii_tokens) elif mode == "mask": item["text"] = await self.check_pii( text=text_value, output_parse_pii=False, presidio_config=presidio_config, request_data=request_data, ) # 2. Process tool calls tool_calls = getattr(message, "tool_calls", None) if tool_calls: for tool_call in tool_calls: function = getattr(tool_call, "function", None) if function and hasattr(function, "arguments"): args = function.arguments if isinstance(args, str): if mode == "unmask": function.arguments = self._unmask_pii_text( args, pii_tokens ) elif mode == "mask": function.arguments = await self.check_pii( text=args, output_parse_pii=False, presidio_config=presidio_config, request_data=request_data, ) # 3. Process legacy function calls function_call = getattr(message, "function_call", None) if function_call and hasattr(function_call, "arguments"): args = function_call.arguments if isinstance(args, str): if mode == "unmask": function_call.arguments = self._unmask_pii_text( args, pii_tokens ) elif mode == "mask": function_call.arguments = await self.check_pii( text=args, output_parse_pii=False, presidio_config=presidio_config, request_data=request_data, ) return response async def _mask_output_response( self, response: Union[ModelResponse, EmbeddingResponse, ImageResponse], request_data: dict, ): """ Apply Presidio masking on model responses (non-streaming). """ if not isinstance(response, ModelResponse): return response # skip streaming here; handled in async_post_call_streaming_iterator_hook if isinstance(response, ModelResponseStream): return response await self._process_response_for_pii( response=response, request_data=request_data, mode="mask", ) return response async def _stream_apply_output_masking( self, response: Any, request_data: dict, ) -> AsyncGenerator[Union[ModelResponseStream, bytes], None]: """Apply Presidio masking to streaming output (apply_to_output=True path).""" from litellm.llms.base_llm.base_model_iterator import ( convert_model_response_to_streaming, ) from litellm.main import stream_chunk_builder from litellm.types.utils import ModelResponse all_chunks: List[ModelResponseStream] = [] try: async for chunk in response: if isinstance(chunk, ModelResponseStream): all_chunks.append(chunk) elif isinstance(chunk, bytes): yield chunk # type: ignore[misc] continue if not all_chunks: verbose_proxy_logger.warning( "Presidio apply_to_output: streaming response contained only " "bytes chunks (Anthropic native SSE). Output PII masking was " "skipped for this response." ) return assembled_model_response = stream_chunk_builder( chunks=all_chunks, messages=request_data.get("messages") ) if not isinstance(assembled_model_response, ModelResponse): for chunk in all_chunks: yield chunk return await self._process_response_for_pii( response=assembled_model_response, request_data=request_data, mode="mask", ) mock_response_stream = convert_model_response_to_streaming( assembled_model_response ) yield mock_response_stream except Exception as e: verbose_proxy_logger.error(f"Error masking streaming PII output: {str(e)}") for chunk in all_chunks: yield chunk async def _stream_pii_unmasking( self, response: Any, request_data: dict, ) -> AsyncGenerator[Union[ModelResponseStream, bytes], None]: """Apply PII unmasking to streaming output (output_parse_pii=True path).""" from litellm.llms.base_llm.base_model_iterator import ( convert_model_response_to_streaming, ) from litellm.main import stream_chunk_builder from litellm.types.utils import ModelResponse remaining_chunks: List[ModelResponseStream] = [] try: async for chunk in response: if isinstance(chunk, ModelResponseStream): remaining_chunks.append(chunk) elif isinstance(chunk, bytes): yield chunk # type: ignore[misc] continue if not remaining_chunks: return assembled_model_response = stream_chunk_builder( chunks=remaining_chunks, messages=request_data.get("messages") ) if not isinstance(assembled_model_response, ModelResponse): for chunk in remaining_chunks: yield chunk return self._preserve_usage_from_last_chunk( assembled_model_response, remaining_chunks ) await self._process_response_for_pii( response=assembled_model_response, request_data=request_data, mode="unmask", ) mock_response_stream = convert_model_response_to_streaming( assembled_model_response ) yield mock_response_stream except Exception as e: verbose_proxy_logger.error(f"Error in PII streaming processing: {str(e)}") for chunk in remaining_chunks: yield chunk async def async_post_call_streaming_iterator_hook( # type: ignore[override] self, user_api_key_dict: UserAPIKeyAuth, response: Any, request_data: dict, ) -> AsyncGenerator[Union[ModelResponseStream, bytes], None]: """ Process streaming response chunks to unmask PII tokens when needed. Note: the return type includes `bytes` because Anthropic native SSE streaming sends raw bytes chunks that pass through untransformed. The base class declares ModelResponseStream only. """ if self.apply_to_output: async for chunk in self._stream_apply_output_masking( response, request_data ): yield chunk return metadata = (request_data.get("metadata") or {}) if request_data else {} pii_tokens = metadata.get("pii_tokens", {}) if not pii_tokens and request_data: verbose_proxy_logger.debug( "No pii_tokens in request_data['metadata'] for streaming unmask path" ) if not (self.output_parse_pii and pii_tokens): async for chunk in response: yield chunk return async for chunk in self._stream_pii_unmasking(response, request_data): yield chunk @staticmethod def _preserve_usage_from_last_chunk( assembled_model_response: Any, chunks: List[Any], ) -> None: """Copy usage metadata from the last chunk when stream_chunk_builder misses it.""" if not getattr(assembled_model_response, "usage", None) and chunks: last_chunk_usage = getattr(chunks[-1], "usage", None) if last_chunk_usage: setattr(assembled_model_response, "usage", last_chunk_usage) def get_presidio_settings_from_request_data( self, data: dict ) -> Optional[PresidioPerRequestConfig]: if "metadata" in data: _metadata = data.get("metadata", None) if _metadata is None: return None _guardrail_config = _metadata.get("guardrail_config") if _guardrail_config: _presidio_config = PresidioPerRequestConfig(**_guardrail_config) return _presidio_config return None def print_verbose(self, print_statement): try: verbose_proxy_logger.debug(print_statement) if litellm.set_verbose: print(print_statement) # noqa except Exception: pass @log_guardrail_information async def apply_guardrail( self, inputs: "GenericGuardrailAPIInputs", request_data: dict, input_type: Literal["request", "response"], logging_obj: Optional["LiteLLMLoggingObj"] = None, ) -> "GenericGuardrailAPIInputs": """ UI will call this function to check: 1. If the connection to the guardrail is working 2. When Testing the guardrail with some text, this function will be called with the input text and returns a text after applying the guardrail """ texts = inputs.get("texts", []) # When input_type is "response" and pii_tokens are available, # unmask the text instead of masking it. metadata = (request_data.get("metadata") or {}) if request_data else {} pii_tokens = metadata.get("pii_tokens", {}) new_texts = [] if input_type == "response" and pii_tokens: for text in texts: new_texts.append(self._unmask_pii_text(text, pii_tokens)) else: for text in texts: modified_text = await self.check_pii( text=text, output_parse_pii=self.output_parse_pii, presidio_config=None, request_data=request_data or {}, ) new_texts.append(modified_text) inputs["texts"] = new_texts return inputs def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None: """ Update the guardrails litellm params in memory """ super().update_in_memory_litellm_params(litellm_params) if litellm_params.pii_entities_config: self.pii_entities_config = litellm_params.pii_entities_config if litellm_params.presidio_score_thresholds: self.presidio_score_thresholds = litellm_params.presidio_score_thresholds if litellm_params.presidio_entities_deny_list: self.presidio_entities_deny_list = ( litellm_params.presidio_entities_deny_list )