chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
from .lasso import LassoGuardrail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
|
||||
import litellm
|
||||
|
||||
_lasso_callback = LassoGuardrail(
|
||||
guardrail_name=guardrail.get("guardrail_name", ""),
|
||||
api_key=litellm_params.api_key,
|
||||
api_base=litellm_params.api_base,
|
||||
user_id=litellm_params.lasso_user_id,
|
||||
conversation_id=litellm_params.lasso_conversation_id,
|
||||
event_hook=litellm_params.mode,
|
||||
default_on=litellm_params.default_on,
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_lasso_callback)
|
||||
|
||||
return _lasso_callback
|
||||
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.LASSO.value: initialize_guardrail,
|
||||
}
|
||||
|
||||
|
||||
guardrail_class_registry = {
|
||||
SupportedGuardrailIntegrations.LASSO.value: LassoGuardrail,
|
||||
}
|
||||
@@ -0,0 +1,678 @@
|
||||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use Lasso Security Guardrails for your LLM calls
|
||||
# https://www.lasso.security/
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
TypedDict,
|
||||
)
|
||||
|
||||
try:
|
||||
import ulid
|
||||
|
||||
ULID_AVAILABLE = True
|
||||
except ImportError:
|
||||
ulid = None # type: ignore
|
||||
ULID_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
HTTPX_AVAILABLE = True
|
||||
except ImportError:
|
||||
httpx = None # type: ignore
|
||||
HTTPX_AVAILABLE = False
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.integrations.custom_guardrail import dc as global_cache
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
import litellm
|
||||
|
||||
|
||||
class LassoResponse(TypedDict):
|
||||
"""Type definition for Lasso API response."""
|
||||
|
||||
violations_detected: bool
|
||||
deputies: Dict[str, bool]
|
||||
findings: Dict[str, List[Dict[str, Any]]]
|
||||
messages: Optional[List[Dict[str, str]]]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
||||
|
||||
|
||||
class LassoGuardrailMissingSecrets(Exception):
|
||||
"""Exception raised when Lasso API key is missing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LassoGuardrailAPIError(Exception):
|
||||
"""Exception raised when there's an error calling the Lasso API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LassoGuardrail(CustomGuardrail):
|
||||
"""
|
||||
Lasso Security Guardrail integration for LiteLLM.
|
||||
|
||||
Provides content moderation, PII detection, and policy enforcement
|
||||
through the Lasso Security API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lasso_api_key: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
mask: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.lasso_api_key = lasso_api_key or api_key or os.environ.get("LASSO_API_KEY")
|
||||
self.user_id = user_id or os.environ.get("LASSO_USER_ID")
|
||||
self.conversation_id = conversation_id or os.environ.get(
|
||||
"LASSO_CONVERSATION_ID"
|
||||
)
|
||||
self.mask = mask or False
|
||||
|
||||
if self.lasso_api_key is None:
|
||||
raise LassoGuardrailMissingSecrets(
|
||||
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or "
|
||||
"pass it as a parameter to the guardrail in the config file"
|
||||
)
|
||||
|
||||
self.api_base = (
|
||||
api_base
|
||||
or os.getenv("LASSO_API_BASE")
|
||||
or "https://server.lasso.security/gateway/v3"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Lasso guardrail initialized: {kwargs.get('guardrail_name', 'unknown')}, "
|
||||
f"event_hook: {kwargs.get('event_hook', 'unknown')}, mask: {self.mask}"
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _generate_ulid(self) -> str:
|
||||
"""
|
||||
Generate a ULID (Universally Unique Lexicographically Sortable Identifier).
|
||||
Falls back to UUID if ULID library is not available.
|
||||
"""
|
||||
if ULID_AVAILABLE and ulid is not None:
|
||||
return str(ulid.ULID()) # type: ignore
|
||||
else:
|
||||
verbose_proxy_logger.debug("ULID library not available, using UUID")
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache, # Deprecated, use global_cache instead (kept to align with CustomGuardrail interface)
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
"mcp_call",
|
||||
"anthropic_messages",
|
||||
],
|
||||
) -> Union[Exception, str, dict, None]:
|
||||
"""
|
||||
Runs before the LLM API call to validate and potentially modify input.
|
||||
Uses 'PROMPT' messageType as this is input to the model.
|
||||
"""
|
||||
# Check if this guardrail should run for this request
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.pre_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return data
|
||||
|
||||
# Get or generate conversation_id and store it in data for post-call consistency
|
||||
# The conversation_id is being stored in the cache so it can be used by the post_call hook
|
||||
self._get_or_generate_conversation_id(data, global_cache)
|
||||
|
||||
return await self._run_lasso_guardrail(
|
||||
data, global_cache, message_type="PROMPT"
|
||||
)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
"mcp_call",
|
||||
"anthropic_messages",
|
||||
],
|
||||
cache: DualCache,
|
||||
):
|
||||
"""
|
||||
This is used for during_call moderation.
|
||||
Uses 'PROMPT' messageType as this runs concurrently with input processing.
|
||||
"""
|
||||
# Check if this guardrail should run for this request
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return data
|
||||
|
||||
return await self._run_lasso_guardrail(data, cache, message_type="PROMPT")
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
"""
|
||||
Runs after the LLM API call to validate the response.
|
||||
Uses 'COMPLETION' messageType as this is output from the model.
|
||||
"""
|
||||
# Check if this guardrail should run for this request
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return response
|
||||
|
||||
# Extract messages from the response for validation
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
response_messages = []
|
||||
for choice in response.choices:
|
||||
if hasattr(choice, "message") and choice.message.content:
|
||||
response_messages.append(
|
||||
{"role": "assistant", "content": choice.message.content}
|
||||
)
|
||||
|
||||
if response_messages:
|
||||
# Include litellm_call_id from original data for conversation_id consistency
|
||||
response_data = {
|
||||
"messages": response_messages,
|
||||
"litellm_call_id": data.get("litellm_call_id"),
|
||||
}
|
||||
|
||||
# Handle masking for post-call
|
||||
if self.mask:
|
||||
headers = self._prepare_headers(response_data, global_cache)
|
||||
payload = self._prepare_payload(
|
||||
response_messages, response_data, global_cache, "COMPLETION"
|
||||
)
|
||||
api_url = f"{self.api_base}/classifix"
|
||||
|
||||
try:
|
||||
lasso_response = await self._call_lasso_api(
|
||||
headers=headers, payload=payload, api_url=api_url
|
||||
)
|
||||
self._process_lasso_response(lasso_response)
|
||||
|
||||
# Apply masking to the actual response if masked content is available
|
||||
masked_messages = lasso_response.get("messages")
|
||||
if (
|
||||
lasso_response.get("violations_detected")
|
||||
and masked_messages
|
||||
):
|
||||
self._apply_masking_to_model_response(
|
||||
response, masked_messages
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Applied Lasso masking to model response"
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in post-call Lasso masking: {str(e)}"
|
||||
)
|
||||
raise LassoGuardrailAPIError(
|
||||
f"Failed to apply post-call masking: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# Use the same data for conversation_id consistency (no cache access needed)
|
||||
await self._run_lasso_guardrail(
|
||||
response_data, cache=global_cache, message_type="COMPLETION"
|
||||
)
|
||||
verbose_proxy_logger.debug("Post-call Lasso validation completed")
|
||||
else:
|
||||
verbose_proxy_logger.warning("No response messages found to validate")
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Unexpected response type for post-call hook: {type(response)}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_or_generate_conversation_id(self, data: dict, cache: DualCache) -> str:
|
||||
"""
|
||||
Get or generate a conversation_id for this request.
|
||||
|
||||
This method ensures session consistency by using litellm_call_id as a cache key.
|
||||
The same conversation_id is used for both pre-call and post-call hooks within
|
||||
the same request, enabling proper conversation grouping in Lasso UI.
|
||||
|
||||
Example:
|
||||
>>> guardrail = LassoGuardrail(lasso_api_key="key")
|
||||
>>> data = {"litellm_call_id": "call_123"}
|
||||
>>> conversation_id = guardrail._get_or_generate_conversation_id(data, cache)
|
||||
>>> # Returns consistent ID for same litellm_call_id
|
||||
|
||||
Args:
|
||||
data: The request data containing litellm_call_id
|
||||
cache: The cache instance for storing conversation_id
|
||||
|
||||
Returns:
|
||||
str: The conversation_id to use for this request
|
||||
"""
|
||||
# Use global conversation_id if set
|
||||
if self.conversation_id:
|
||||
return self.conversation_id
|
||||
|
||||
# Get the litellm_call_id which is consistent across all hooks for this request
|
||||
litellm_call_id = data.get("litellm_call_id")
|
||||
|
||||
if not litellm_call_id:
|
||||
# Fallback to generating a new ULID if no litellm_call_id available
|
||||
return self._generate_ulid()
|
||||
|
||||
# Use litellm_call_id as cache key for conversation_id
|
||||
cache_key = f"lasso_conversation_id:{litellm_call_id}"
|
||||
|
||||
# Try to get existing conversation_id from cache
|
||||
try:
|
||||
cached_conversation_id = cache.get_cache(cache_key)
|
||||
if cached_conversation_id:
|
||||
return cached_conversation_id
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(f"Cache retrieval failed: {e}")
|
||||
|
||||
# Generate new conversation_id and store in cache
|
||||
generated_id = self._generate_ulid()
|
||||
|
||||
try:
|
||||
cache.set_cache(cache_key, generated_id, ttl=3600) # Cache for 1 hour
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(f"Cache storage failed: {e}")
|
||||
|
||||
return generated_id
|
||||
|
||||
async def _run_lasso_guardrail(
|
||||
self,
|
||||
data: dict,
|
||||
cache: DualCache,
|
||||
message_type: Literal["PROMPT", "COMPLETION"] = "PROMPT",
|
||||
):
|
||||
"""
|
||||
Run the Lasso guardrail with the specified message type.
|
||||
|
||||
This is the core method that handles both classification and masking workflows.
|
||||
It chooses the appropriate API endpoint based on the masking configuration
|
||||
and processes the response according to Lasso's action-based system.
|
||||
|
||||
Workflow:
|
||||
1. Validate messages are present
|
||||
2. Prepare headers and payload
|
||||
3. Choose API endpoint (classify vs classifix)
|
||||
4. Call Lasso API
|
||||
5. Process response and apply masking if needed
|
||||
6. Handle blocking vs non-blocking violations
|
||||
|
||||
Args:
|
||||
data: The request data containing messages
|
||||
cache: The cache instance for storing conversation_id (optional for post-call)
|
||||
message_type: Either "PROMPT" for input or "COMPLETION" for output
|
||||
|
||||
Raises:
|
||||
LassoGuardrailAPIError: If the Lasso API call fails
|
||||
HTTPException: If blocking violations are detected
|
||||
"""
|
||||
messages: List[Dict[str, str]] = data.get("messages", [])
|
||||
if not messages:
|
||||
return data
|
||||
|
||||
if self.mask:
|
||||
return await self._handle_masking(data, cache, message_type, messages)
|
||||
else:
|
||||
return await self._handle_classification(
|
||||
data, cache, message_type, messages
|
||||
)
|
||||
|
||||
async def _handle_classification(
|
||||
self,
|
||||
data: dict,
|
||||
cache: DualCache,
|
||||
message_type: Literal["PROMPT", "COMPLETION"],
|
||||
messages: List[Dict[str, str]],
|
||||
) -> dict:
|
||||
"""Handle classification without masking."""
|
||||
try:
|
||||
headers = self._prepare_headers(data, cache)
|
||||
payload = self._prepare_payload(messages, data, cache, message_type)
|
||||
response = await self._call_lasso_api(headers=headers, payload=payload)
|
||||
self._process_lasso_response(response)
|
||||
return data
|
||||
except Exception as e:
|
||||
await self._handle_api_error(e, message_type)
|
||||
return data # This line won't be reached due to exception, but satisfies type checker
|
||||
|
||||
async def _handle_masking(
|
||||
self,
|
||||
data: dict,
|
||||
cache: DualCache,
|
||||
message_type: Literal["PROMPT", "COMPLETION"],
|
||||
messages: List[Dict[str, str]],
|
||||
) -> dict:
|
||||
"""Handle masking with classifix endpoint."""
|
||||
try:
|
||||
headers = self._prepare_headers(data, cache)
|
||||
payload = self._prepare_payload(messages, data, cache, message_type)
|
||||
api_url = f"{self.api_base}/classifix"
|
||||
response = await self._call_lasso_api(
|
||||
headers=headers, payload=payload, api_url=api_url
|
||||
)
|
||||
self._process_lasso_response(response)
|
||||
|
||||
# Apply masking to messages if violations detected and masked messages are available
|
||||
if response.get("violations_detected") and response.get("messages"):
|
||||
data["messages"] = response["messages"]
|
||||
self._log_masking_applied(message_type, dict(response))
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
await self._handle_api_error(e, message_type)
|
||||
return data # This line won't be reached due to exception, but satisfies type checker
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
error: Exception,
|
||||
message_type: Literal["PROMPT", "COMPLETION"],
|
||||
) -> None:
|
||||
"""Handle API errors with specific error types."""
|
||||
if isinstance(error, HTTPException):
|
||||
raise error
|
||||
|
||||
# Log error with context
|
||||
verbose_proxy_logger.error(
|
||||
f"Error calling Lasso API: {str(error)}",
|
||||
extra={
|
||||
"guardrail_name": getattr(self, "guardrail_name", "unknown"),
|
||||
"message_type": message_type,
|
||||
"error_type": type(error).__name__,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle specific error types if httpx is available
|
||||
if HTTPX_AVAILABLE:
|
||||
if isinstance(error, httpx.TimeoutException):
|
||||
raise LassoGuardrailAPIError("Lasso API timeout")
|
||||
elif isinstance(error, httpx.HTTPStatusError):
|
||||
if error.response.status_code == 401:
|
||||
raise LassoGuardrailMissingSecrets("Invalid API key")
|
||||
elif error.response.status_code == 429:
|
||||
raise LassoGuardrailAPIError("Lasso API rate limit exceeded")
|
||||
else:
|
||||
raise LassoGuardrailAPIError(
|
||||
f"API error: {error.response.status_code}"
|
||||
)
|
||||
|
||||
# Generic error handling
|
||||
raise LassoGuardrailAPIError(
|
||||
f"Failed to verify request safety with Lasso API: {str(error)}"
|
||||
)
|
||||
|
||||
def _log_masking_applied(
|
||||
self,
|
||||
message_type: Literal["PROMPT", "COMPLETION"],
|
||||
response: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Log masking application with structured context."""
|
||||
conversation_id = getattr(self, "conversation_id", "unknown")
|
||||
verbose_proxy_logger.debug(
|
||||
"Lasso masking applied",
|
||||
extra={
|
||||
"guardrail_name": getattr(self, "guardrail_name", "unknown"),
|
||||
"message_type": message_type,
|
||||
"violations_count": len(response.get("findings", {})),
|
||||
"masked_fields": len(response.get("messages", [])),
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
)
|
||||
|
||||
def _prepare_headers(self, data: dict, cache: DualCache) -> Dict[str, str]:
|
||||
"""Prepare headers for the Lasso API request."""
|
||||
if not self.lasso_api_key:
|
||||
raise LassoGuardrailMissingSecrets(
|
||||
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or "
|
||||
"pass it as a parameter to the guardrail in the config file"
|
||||
)
|
||||
|
||||
headers: Dict[str, str] = {
|
||||
"lasso-api-key": self.lasso_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add optional headers if provided
|
||||
if self.user_id:
|
||||
headers["lasso-user-id"] = self.user_id
|
||||
|
||||
# Always include conversation_id (generated or provided)
|
||||
conversation_id = self._get_or_generate_conversation_id(data, cache)
|
||||
|
||||
headers["lasso-conversation-id"] = conversation_id
|
||||
|
||||
return headers
|
||||
|
||||
def _prepare_payload(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
data: dict,
|
||||
cache: DualCache,
|
||||
message_type: Literal["PROMPT", "COMPLETION"] = "PROMPT",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare the payload for the Lasso API request.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
message_type: Type of message - "PROMPT" for input, "COMPLETION" for output
|
||||
data: Request data (used for conversation_id generation)
|
||||
cache: Cache instance for storing conversation_id (optional for post-call)
|
||||
"""
|
||||
payload: Dict[str, Any] = {"messages": messages, "messageType": message_type}
|
||||
|
||||
# Add optional parameters if available
|
||||
if self.user_id:
|
||||
payload["userId"] = self.user_id
|
||||
|
||||
# Always include sessionId (conversation_id - generated or provided)
|
||||
conversation_id = self._get_or_generate_conversation_id(data, cache)
|
||||
|
||||
payload["sessionId"] = conversation_id
|
||||
|
||||
return payload
|
||||
|
||||
async def _call_lasso_api(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
payload: Dict[str, Any],
|
||||
api_url: Optional[str] = None,
|
||||
) -> LassoResponse:
|
||||
"""Call the Lasso API and return the response."""
|
||||
url = api_url or f"{self.api_base}/classify"
|
||||
verbose_proxy_logger.debug(
|
||||
f"Calling Lasso API with messageType: {payload.get('messageType')}"
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _process_lasso_response(self, response: LassoResponse) -> None:
|
||||
"""
|
||||
Process the Lasso API response and handle violations according to action types.
|
||||
|
||||
This method implements the action-based blocking logic:
|
||||
- BLOCK: Raises HTTPException to stop request/response
|
||||
- AUTO_MASKING: Logs warning and continues (masking applied elsewhere)
|
||||
- WARN: Logs warning and continues
|
||||
|
||||
Example Response:
|
||||
{
|
||||
"violations_detected": true,
|
||||
"findings": {
|
||||
"jailbreak": [{
|
||||
"action": "BLOCK",
|
||||
"severity": "HIGH"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
response: The response dictionary from Lasso API
|
||||
|
||||
Raises:
|
||||
HTTPException: If any finding has "action": "BLOCK"
|
||||
"""
|
||||
if response and response.get("violations_detected") is True:
|
||||
violated_deputies = self._parse_violated_deputies(response)
|
||||
verbose_proxy_logger.warning(
|
||||
f"Lasso guardrail detected violations: {violated_deputies}"
|
||||
)
|
||||
|
||||
# Check if any findings have "BLOCK" action
|
||||
blocking_violations = self._check_for_blocking_actions(response)
|
||||
|
||||
if blocking_violations:
|
||||
# Block the request/response for findings with "BLOCK" action
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated Lasso guardrail policy",
|
||||
"detection_message": f"Blocking violations detected: {', '.join(blocking_violations)}",
|
||||
"lasso_response": response,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Continue with warning for non-blocking violations (e.g., AUTO_MASKING)
|
||||
verbose_proxy_logger.info(
|
||||
f"Non-blocking Lasso violations detected, continuing with warning: {violated_deputies}"
|
||||
)
|
||||
|
||||
def _check_for_blocking_actions(self, response: LassoResponse) -> List[str]:
|
||||
"""
|
||||
Check findings for actions that should block the request/response.
|
||||
|
||||
Examines the findings section of the Lasso response to identify which
|
||||
deputies have violations with "BLOCK" action. This enables granular
|
||||
control where some violations (like PII) can be masked while others
|
||||
(like jailbreaks) are blocked entirely.
|
||||
|
||||
Args:
|
||||
response: The response dictionary from Lasso API
|
||||
|
||||
Returns:
|
||||
List[str]: Names of deputies with blocking violations
|
||||
|
||||
Example:
|
||||
>>> response = {
|
||||
... "findings": {
|
||||
... "jailbreak": [{"action": "BLOCK"}],
|
||||
... "pattern-detection": [{"action": "AUTO_MASKING"}]
|
||||
... }
|
||||
... }
|
||||
>>> guardrail._check_for_blocking_actions(response)
|
||||
['jailbreak']
|
||||
"""
|
||||
blocking_violations = []
|
||||
findings = response.get("findings", {})
|
||||
|
||||
for deputy_name, deputy_findings in findings.items():
|
||||
if isinstance(deputy_findings, list):
|
||||
for finding in deputy_findings:
|
||||
if isinstance(finding, dict) and finding.get("action") == "BLOCK":
|
||||
if deputy_name not in blocking_violations:
|
||||
blocking_violations.append(deputy_name)
|
||||
break # No need to check other findings for this deputy
|
||||
|
||||
return blocking_violations
|
||||
|
||||
def _parse_violated_deputies(self, response: LassoResponse) -> List[str]:
|
||||
"""Parse the response to extract violated deputies."""
|
||||
violated_deputies = []
|
||||
if "deputies" in response:
|
||||
for deputy, is_violated in response["deputies"].items():
|
||||
if is_violated:
|
||||
violated_deputies.append(deputy)
|
||||
return violated_deputies
|
||||
|
||||
def _apply_masking_to_model_response(
|
||||
self,
|
||||
model_response: litellm.ModelResponse,
|
||||
masked_messages: List[Dict[str, str]],
|
||||
) -> None:
|
||||
"""Apply masking to the actual model response when mask=True and masked content is available."""
|
||||
masked_index = 0
|
||||
for choice in model_response.choices:
|
||||
if (
|
||||
hasattr(choice, "message")
|
||||
and choice.message.content
|
||||
and masked_index < len(masked_messages)
|
||||
):
|
||||
# Replace the content with the masked version from Lasso
|
||||
choice.message.content = masked_messages[masked_index]["content"]
|
||||
masked_index += 1
|
||||
verbose_proxy_logger.debug(
|
||||
f"Applied masked content to choice {masked_index}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.lasso import (
|
||||
LassoGuardrailConfigModel,
|
||||
)
|
||||
|
||||
return LassoGuardrailConfigModel
|
||||
Reference in New Issue
Block a user