chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,71 @@
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
from .prompt_shield import AzureContentSafetyPromptShieldGuardrail
|
||||
from .text_moderation import AzureContentSafetyTextModerationGuardrail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
|
||||
import litellm
|
||||
|
||||
if not litellm_params.api_key:
|
||||
raise ValueError("Azure Content Safety: api_key is required")
|
||||
if not litellm_params.api_base:
|
||||
raise ValueError("Azure Content Safety: api_base is required")
|
||||
|
||||
azure_guardrail = litellm_params.guardrail.split("/")[1]
|
||||
|
||||
guardrail_name = guardrail.get("guardrail_name")
|
||||
if not guardrail_name:
|
||||
raise ValueError("Azure Content Safety: guardrail_name is required")
|
||||
|
||||
if azure_guardrail == "prompt_shield":
|
||||
azure_content_safety_guardrail: Union[
|
||||
AzureContentSafetyPromptShieldGuardrail,
|
||||
AzureContentSafetyTextModerationGuardrail,
|
||||
] = AzureContentSafetyPromptShieldGuardrail(
|
||||
guardrail_name=guardrail_name,
|
||||
**{
|
||||
**litellm_params.model_dump(exclude_none=True),
|
||||
"api_key": litellm_params.api_key,
|
||||
"api_base": litellm_params.api_base,
|
||||
"default_on": litellm_params.default_on,
|
||||
"event_hook": litellm_params.mode,
|
||||
},
|
||||
)
|
||||
elif azure_guardrail == "text_moderations":
|
||||
azure_content_safety_guardrail = AzureContentSafetyTextModerationGuardrail(
|
||||
guardrail_name=guardrail_name,
|
||||
**{
|
||||
**litellm_params.model_dump(exclude_none=True),
|
||||
"api_key": litellm_params.api_key,
|
||||
"api_base": litellm_params.api_base,
|
||||
"default_on": litellm_params.default_on,
|
||||
"event_hook": litellm_params.mode,
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Azure Content Safety: {azure_guardrail} is not a valid guardrail"
|
||||
)
|
||||
|
||||
litellm.logging_callback_manager.add_litellm_callback(
|
||||
azure_content_safety_guardrail
|
||||
)
|
||||
return azure_content_safety_guardrail
|
||||
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.AZURE_PROMPT_SHIELD.value: initialize_guardrail,
|
||||
SupportedGuardrailIntegrations.AZURE_TEXT_MODERATIONS.value: initialize_guardrail,
|
||||
}
|
||||
|
||||
|
||||
guardrail_class_registry = {
|
||||
SupportedGuardrailIntegrations.AZURE_PROMPT_SHIELD.value: AzureContentSafetyPromptShieldGuardrail,
|
||||
SupportedGuardrailIntegrations.AZURE_TEXT_MODERATIONS.value: AzureContentSafetyTextModerationGuardrail,
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
# Azure Content Safety APIs have a 10,000 character limit per request.
|
||||
AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH = 10000
|
||||
|
||||
|
||||
class AzureGuardrailBase:
|
||||
"""
|
||||
Base class for Azure guardrails.
|
||||
|
||||
Provides shared initialisation (API credentials, HTTP client) and
|
||||
utilities (text splitting, authenticated POST) used by all Azure
|
||||
Content Safety guardrails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
# Forward remaining kwargs to the next class in the MRO
|
||||
# (typically CustomGuardrail).
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.api_version: str = kwargs.get("api_version") or "2024-09-01"
|
||||
|
||||
async def _post_to_content_safety(
|
||||
self, endpoint_path: str, request_body: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""POST to an Azure Content Safety endpoint with standard auth headers.
|
||||
|
||||
Args:
|
||||
endpoint_path: The API action, e.g. ``"text:shieldPrompt"`` or
|
||||
``"text:analyze"``.
|
||||
request_body: JSON-serialisable request payload.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response dict.
|
||||
"""
|
||||
url = f"{self.api_base}/contentsafety/{endpoint_path}?api-version={self.api_version}"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Azure Content Safety request [%s]: %s", endpoint_path, request_body
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
)
|
||||
response_json: Dict[str, Any] = response.json()
|
||||
verbose_proxy_logger.debug(
|
||||
"Azure Content Safety response [%s]: %s", endpoint_path, response_json
|
||||
)
|
||||
return response_json
|
||||
|
||||
@staticmethod
|
||||
def split_text_by_words(text: str, max_length: int) -> List[str]:
|
||||
"""
|
||||
Split text into chunks at word boundaries without breaking words.
|
||||
|
||||
Always returns at least one chunk. Short text (≤ max_length) is
|
||||
returned as a single-element list so callers can use a uniform
|
||||
loop without branching on length.
|
||||
|
||||
Args:
|
||||
text: The text to split
|
||||
max_length: Maximum character length of each chunk
|
||||
|
||||
Returns:
|
||||
List of text chunks, each not exceeding max_length
|
||||
"""
|
||||
if len(text) <= max_length:
|
||||
return [text]
|
||||
|
||||
# Tokenize into alternating non-whitespace and whitespace runs so
|
||||
# that original newlines, tabs, and multiple spaces are preserved
|
||||
# within each chunk.
|
||||
tokens = re.findall(r"\S+|\s+", text)
|
||||
|
||||
chunks: List[str] = []
|
||||
current_chunk = ""
|
||||
|
||||
for token in tokens:
|
||||
# Would appending this token exceed the limit?
|
||||
if len(current_chunk) + len(token) <= max_length:
|
||||
current_chunk += token
|
||||
else:
|
||||
# Flush whatever we have accumulated so far
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = ""
|
||||
|
||||
# Force-split any single token longer than max_length
|
||||
while len(token) > max_length:
|
||||
chunks.append(token[:max_length])
|
||||
token = token[max_length:]
|
||||
|
||||
current_chunk = token
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def get_user_prompt(self, messages: List["AllMessageValues"]) -> Optional[str]:
|
||||
"""
|
||||
Get the last consecutive block of messages from the user.
|
||||
|
||||
Example:
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm good, thank you!"},
|
||||
{"role": "user", "content": "What is the weather in Tokyo?"},
|
||||
]
|
||||
get_user_prompt(messages) -> "What is the weather in Tokyo?"
|
||||
"""
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Iterate from the end to find the last consecutive block of user messages
|
||||
user_messages = []
|
||||
for message in reversed(messages):
|
||||
if message.get("role") == "user":
|
||||
user_messages.append(message)
|
||||
else:
|
||||
# Stop when we hit a non-user message
|
||||
break
|
||||
|
||||
if not user_messages:
|
||||
return None
|
||||
|
||||
# Reverse to get the messages in chronological order
|
||||
user_messages.reverse()
|
||||
|
||||
user_prompt = ""
|
||||
for message in user_messages:
|
||||
text_content = convert_content_list_to_str(message)
|
||||
user_prompt += text_content + "\n"
|
||||
|
||||
result = user_prompt.strip()
|
||||
return result if result else None
|
||||
@@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Azure Prompt Shield Native Guardrail Integrationfor LiteLLM
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.types.utils import CallTypesLiteral
|
||||
|
||||
from .base import AzureGuardrailBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_prompt_shield import (
|
||||
AzurePromptShieldGuardrailResponse,
|
||||
)
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
||||
|
||||
|
||||
class AzureContentSafetyPromptShieldGuardrail(AzureGuardrailBase, CustomGuardrail):
|
||||
"""
|
||||
LiteLLM Built-in Guardrail for Azure Content Safety Guardrail (Prompt Shield).
|
||||
|
||||
This guardrail scans prompts and responses using the Azure Prompt Shield API to detect
|
||||
malicious content, injection attempts, and policy violations.
|
||||
|
||||
Configuration:
|
||||
guardrail_name: Name of the guardrail instance
|
||||
api_key: Azure Prompt Shield API key
|
||||
api_base: Azure Prompt Shield API endpoint
|
||||
default_on: Whether to enable by default
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guardrail_name: str,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Azure Prompt Shield guardrail handler."""
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
supported_event_hooks = [
|
||||
GuardrailEventHooks.pre_call,
|
||||
GuardrailEventHooks.during_call,
|
||||
]
|
||||
# AzureGuardrailBase.__init__ stores api_key, api_base, api_version,
|
||||
# async_handler and forwards the rest to CustomGuardrail.
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
guardrail_name=guardrail_name,
|
||||
supported_event_hooks=supported_event_hooks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Initialized Azure Prompt Shield Guardrail: {guardrail_name}"
|
||||
)
|
||||
|
||||
async def async_make_request(
|
||||
self, user_prompt: str
|
||||
) -> "AzurePromptShieldGuardrailResponse":
|
||||
"""
|
||||
Make a request to the Azure Prompt Shield API.
|
||||
|
||||
Long prompts are automatically split at word boundaries into chunks
|
||||
that respect the Azure Content Safety 10 000-character limit. Each
|
||||
chunk is analysed independently; an attack in *any* chunk raises
|
||||
an HTTPException immediately.
|
||||
"""
|
||||
from .base import AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_prompt_shield import (
|
||||
AzurePromptShieldGuardrailRequestBody,
|
||||
AzurePromptShieldGuardrailResponse,
|
||||
)
|
||||
|
||||
chunks = self.split_text_by_words(
|
||||
user_prompt, AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
|
||||
)
|
||||
|
||||
last_response: Optional[AzurePromptShieldGuardrailResponse] = None
|
||||
|
||||
for chunk in chunks:
|
||||
request_body = AzurePromptShieldGuardrailRequestBody(
|
||||
documents=[], userPrompt=chunk
|
||||
)
|
||||
response_json = await self._post_to_content_safety(
|
||||
"text:shieldPrompt", cast(dict, request_body)
|
||||
)
|
||||
|
||||
last_response = cast(AzurePromptShieldGuardrailResponse, response_json)
|
||||
|
||||
if last_response["userPromptAnalysis"].get("attackDetected"):
|
||||
verbose_proxy_logger.warning(
|
||||
"Azure Prompt Shield: Attack detected in chunk of length %d",
|
||||
len(chunk),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated Azure Prompt Shield guardrail policy",
|
||||
"detection_message": f"Attack detected: {last_response['userPromptAnalysis']}",
|
||||
},
|
||||
)
|
||||
|
||||
# chunks is always non-empty (split_text_by_words guarantees ≥1 element)
|
||||
assert last_response is not None
|
||||
return last_response
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
cache: Any,
|
||||
data: Dict[str, Any],
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Pre-call hook to scan user prompts before sending to LLM.
|
||||
|
||||
Raises HTTPException if content should be blocked.
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Azure Prompt Shield: Running pre-call prompt scan, on call_type: %s",
|
||||
call_type,
|
||||
)
|
||||
new_messages: Optional[List[AllMessageValues]] = data.get("messages")
|
||||
if new_messages is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"Azure Prompt Shield: not running guardrail. No messages in data"
|
||||
)
|
||||
return data
|
||||
user_prompt = self.get_user_prompt(new_messages)
|
||||
|
||||
if user_prompt:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Azure Prompt Shield: User prompt: {user_prompt}"
|
||||
)
|
||||
await self.async_make_request(
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning("Azure Prompt Shield: No user prompt found")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
"""
|
||||
Get the config model for the Azure Prompt Shield guardrail.
|
||||
"""
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_prompt_shield import (
|
||||
AzurePromptShieldGuardrailConfigModel,
|
||||
)
|
||||
|
||||
return AzurePromptShieldGuardrailConfigModel
|
||||
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Azure Text Moderation Native Guardrail Integrationfor LiteLLM
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.utils import CallTypesLiteral
|
||||
|
||||
from .base import AzureGuardrailBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
|
||||
AzureTextModerationGuardrailResponse,
|
||||
)
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
||||
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse
|
||||
|
||||
|
||||
class AzureContentSafetyTextModerationGuardrail(AzureGuardrailBase, CustomGuardrail):
|
||||
"""
|
||||
LiteLLM Built-in Guardrail for Azure Content Safety (Text Moderation).
|
||||
|
||||
This guardrail scans prompts and responses using the Azure Text Moderation API to detect
|
||||
malicious content and policy violations based on severity thresholds.
|
||||
|
||||
Configuration:
|
||||
guardrail_name: Name of the guardrail instance
|
||||
api_key: Azure Text Moderation API key
|
||||
api_base: Azure Text Moderation API endpoint
|
||||
default_on: Whether to enable by default
|
||||
"""
|
||||
|
||||
default_severity_threshold: int = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guardrail_name: str,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
severity_threshold: Optional[int] = None,
|
||||
severity_threshold_by_category: Optional[Dict[str, int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Azure Text Moderation guardrail handler."""
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
|
||||
AzureTextModerationRequestBodyOptionalParams,
|
||||
)
|
||||
|
||||
# AzureGuardrailBase.__init__ stores api_key, api_base, api_version,
|
||||
# async_handler and forwards the rest to CustomGuardrail.
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
guardrail_name=guardrail_name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.optional_params_request_body: (
|
||||
AzureTextModerationRequestBodyOptionalParams
|
||||
) = {
|
||||
"categories": kwargs.get("categories")
|
||||
or [
|
||||
"Hate",
|
||||
"Sexual",
|
||||
"SelfHarm",
|
||||
"Violence",
|
||||
],
|
||||
"blocklistNames": cast(
|
||||
Optional[List[str]], kwargs.get("blocklistNames") or None
|
||||
),
|
||||
"haltOnBlocklistHit": kwargs.get("haltOnBlocklistHit") or False,
|
||||
"outputType": kwargs.get("outputType") or "FourSeverityLevels",
|
||||
}
|
||||
|
||||
self.severity_threshold = (
|
||||
int(severity_threshold) if severity_threshold else None
|
||||
)
|
||||
self.severity_threshold_by_category = severity_threshold_by_category
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Initialized Azure Text Moderation Guardrail: {guardrail_name}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
|
||||
AzureContentSafetyTextModerationConfigModel,
|
||||
)
|
||||
|
||||
return AzureContentSafetyTextModerationConfigModel
|
||||
|
||||
async def async_make_request(
|
||||
self, text: str
|
||||
) -> "AzureTextModerationGuardrailResponse":
|
||||
"""
|
||||
Make a request to the Azure Text Moderation API.
|
||||
|
||||
Long texts are automatically split at word boundaries into chunks
|
||||
that respect the Azure Content Safety 10 000-character limit. Each
|
||||
chunk is analysed independently; a severity-threshold violation in
|
||||
*any* chunk raises an HTTPException immediately.
|
||||
"""
|
||||
from .base import AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
|
||||
AzureTextModerationGuardrailRequestBody,
|
||||
AzureTextModerationGuardrailResponse,
|
||||
)
|
||||
|
||||
chunks = self.split_text_by_words(text, AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH)
|
||||
|
||||
last_response: Optional[AzureTextModerationGuardrailResponse] = None
|
||||
|
||||
for chunk in chunks:
|
||||
request_body = AzureTextModerationGuardrailRequestBody(
|
||||
text=chunk,
|
||||
**self.optional_params_request_body, # type: ignore[misc]
|
||||
)
|
||||
response_json = await self._post_to_content_safety(
|
||||
"text:analyze", cast(dict, request_body)
|
||||
)
|
||||
|
||||
chunk_response = cast(AzureTextModerationGuardrailResponse, response_json)
|
||||
|
||||
# For multi-chunk texts the callers only see the final response,
|
||||
# so we must check every intermediate chunk here to avoid silently
|
||||
# swallowing a violation that appears in an earlier chunk.
|
||||
try:
|
||||
self.check_severity_threshold(response=chunk_response)
|
||||
except HTTPException:
|
||||
verbose_proxy_logger.warning(
|
||||
"Azure Text Moderation: Violation detected in chunk of length %d",
|
||||
len(chunk),
|
||||
)
|
||||
raise
|
||||
|
||||
last_response = chunk_response
|
||||
|
||||
# chunks is always non-empty (split_text_by_words guarantees ≥1 element)
|
||||
assert last_response is not None
|
||||
return last_response
|
||||
|
||||
def check_severity_threshold(
|
||||
self, response: "AzureTextModerationGuardrailResponse"
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
- Check if threshold set by category
|
||||
- Check if general severity threshold set
|
||||
- If both none, use default_severity_threshold
|
||||
"""
|
||||
|
||||
if self.severity_threshold_by_category:
|
||||
for category in response["categoriesAnalysis"]:
|
||||
severity_category_threshold_item = (
|
||||
self.severity_threshold_by_category.get(category["category"])
|
||||
)
|
||||
if (
|
||||
severity_category_threshold_item is not None
|
||||
and category["severity"] >= severity_category_threshold_item
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
|
||||
category["category"],
|
||||
self.severity_threshold_by_category.get(
|
||||
category["category"]
|
||||
),
|
||||
category["severity"],
|
||||
)
|
||||
},
|
||||
)
|
||||
if self.severity_threshold:
|
||||
for category in response["categoriesAnalysis"]:
|
||||
if category["severity"] >= self.severity_threshold:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
|
||||
category["category"],
|
||||
self.severity_threshold,
|
||||
category["severity"],
|
||||
)
|
||||
},
|
||||
)
|
||||
if (
|
||||
self.severity_threshold is None
|
||||
and self.severity_threshold_by_category is None
|
||||
):
|
||||
for category in response["categoriesAnalysis"]:
|
||||
if category["severity"] >= self.default_severity_threshold:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
|
||||
category["category"],
|
||||
self.default_severity_threshold,
|
||||
category["severity"],
|
||||
)
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
cache: Any,
|
||||
data: Dict[str, Any],
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Pre-call hook to scan user prompts before sending to LLM.
|
||||
|
||||
Raises HTTPException if content should be blocked.
|
||||
"""
|
||||
verbose_proxy_logger.info(
|
||||
"Azure Text Moderation: Running pre-call prompt scan, on call_type: %s",
|
||||
call_type,
|
||||
)
|
||||
new_messages: Optional[List[AllMessageValues]] = data.get("messages")
|
||||
if new_messages is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"Azure Text Moderation: not running guardrail. No messages in data"
|
||||
)
|
||||
return data
|
||||
user_prompt = self.get_user_prompt(new_messages)
|
||||
|
||||
if user_prompt:
|
||||
verbose_proxy_logger.info(
|
||||
f"Azure Text Moderation: User prompt: {user_prompt}"
|
||||
)
|
||||
await self.async_make_request(
|
||||
text=user_prompt,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning("Azure Text Moderation: No text found")
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: "UserAPIKeyAuth",
|
||||
response: Union[Any, "ModelResponse", "EmbeddingResponse", "ImageResponse"],
|
||||
) -> Any:
|
||||
from litellm.types.utils import Choices, ModelResponse
|
||||
|
||||
if (
|
||||
isinstance(response, ModelResponse)
|
||||
and response.choices
|
||||
and isinstance(response.choices[0], Choices)
|
||||
):
|
||||
content = response.choices[0].message.content or ""
|
||||
await self.async_make_request(
|
||||
text=content,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self, user_api_key_dict: UserAPIKeyAuth, response: str
|
||||
) -> Any:
|
||||
try:
|
||||
if response is not None and len(response) > 0:
|
||||
await self.async_make_request(
|
||||
text=response,
|
||||
)
|
||||
return response
|
||||
except HTTPException as e:
|
||||
import json
|
||||
|
||||
error_returned = json.dumps({"error": e.detail})
|
||||
return f"data: {error_returned}\n\n"
|
||||
Reference in New Issue
Block a user