chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,71 @@
from typing import TYPE_CHECKING, Union
from litellm.types.guardrails import SupportedGuardrailIntegrations
from .prompt_shield import AzureContentSafetyPromptShieldGuardrail
from .text_moderation import AzureContentSafetyTextModerationGuardrail
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
import litellm
if not litellm_params.api_key:
raise ValueError("Azure Content Safety: api_key is required")
if not litellm_params.api_base:
raise ValueError("Azure Content Safety: api_base is required")
azure_guardrail = litellm_params.guardrail.split("/")[1]
guardrail_name = guardrail.get("guardrail_name")
if not guardrail_name:
raise ValueError("Azure Content Safety: guardrail_name is required")
if azure_guardrail == "prompt_shield":
azure_content_safety_guardrail: Union[
AzureContentSafetyPromptShieldGuardrail,
AzureContentSafetyTextModerationGuardrail,
] = AzureContentSafetyPromptShieldGuardrail(
guardrail_name=guardrail_name,
**{
**litellm_params.model_dump(exclude_none=True),
"api_key": litellm_params.api_key,
"api_base": litellm_params.api_base,
"default_on": litellm_params.default_on,
"event_hook": litellm_params.mode,
},
)
elif azure_guardrail == "text_moderations":
azure_content_safety_guardrail = AzureContentSafetyTextModerationGuardrail(
guardrail_name=guardrail_name,
**{
**litellm_params.model_dump(exclude_none=True),
"api_key": litellm_params.api_key,
"api_base": litellm_params.api_base,
"default_on": litellm_params.default_on,
"event_hook": litellm_params.mode,
},
)
else:
raise ValueError(
f"Azure Content Safety: {azure_guardrail} is not a valid guardrail"
)
litellm.logging_callback_manager.add_litellm_callback(
azure_content_safety_guardrail
)
return azure_content_safety_guardrail
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.AZURE_PROMPT_SHIELD.value: initialize_guardrail,
SupportedGuardrailIntegrations.AZURE_TEXT_MODERATIONS.value: initialize_guardrail,
}
guardrail_class_registry = {
SupportedGuardrailIntegrations.AZURE_PROMPT_SHIELD.value: AzureContentSafetyPromptShieldGuardrail,
SupportedGuardrailIntegrations.AZURE_TEXT_MODERATIONS.value: AzureContentSafetyTextModerationGuardrail,
}

View File

@@ -0,0 +1,165 @@
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
if TYPE_CHECKING:
from litellm.types.llms.openai import AllMessageValues
# Azure Content Safety APIs have a 10,000 character limit per request.
AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH = 10000
class AzureGuardrailBase:
"""
Base class for Azure guardrails.
Provides shared initialisation (API credentials, HTTP client) and
utilities (text splitting, authenticated POST) used by all Azure
Content Safety guardrails.
"""
def __init__(
self,
api_key: str,
api_base: str,
**kwargs: Any,
):
# Forward remaining kwargs to the next class in the MRO
# (typically CustomGuardrail).
super().__init__(**kwargs)
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.api_key = api_key
self.api_base = api_base
self.api_version: str = kwargs.get("api_version") or "2024-09-01"
async def _post_to_content_safety(
self, endpoint_path: str, request_body: Dict[str, Any]
) -> Dict[str, Any]:
"""POST to an Azure Content Safety endpoint with standard auth headers.
Args:
endpoint_path: The API action, e.g. ``"text:shieldPrompt"`` or
``"text:analyze"``.
request_body: JSON-serialisable request payload.
Returns:
Parsed JSON response dict.
"""
url = f"{self.api_base}/contentsafety/{endpoint_path}?api-version={self.api_version}"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
"Content-Type": "application/json",
}
verbose_proxy_logger.debug(
"Azure Content Safety request [%s]: %s", endpoint_path, request_body
)
response = await self.async_handler.post(
url=url,
headers=headers,
json=request_body,
)
response_json: Dict[str, Any] = response.json()
verbose_proxy_logger.debug(
"Azure Content Safety response [%s]: %s", endpoint_path, response_json
)
return response_json
@staticmethod
def split_text_by_words(text: str, max_length: int) -> List[str]:
"""
Split text into chunks at word boundaries without breaking words.
Always returns at least one chunk. Short text (≤ max_length) is
returned as a single-element list so callers can use a uniform
loop without branching on length.
Args:
text: The text to split
max_length: Maximum character length of each chunk
Returns:
List of text chunks, each not exceeding max_length
"""
if len(text) <= max_length:
return [text]
# Tokenize into alternating non-whitespace and whitespace runs so
# that original newlines, tabs, and multiple spaces are preserved
# within each chunk.
tokens = re.findall(r"\S+|\s+", text)
chunks: List[str] = []
current_chunk = ""
for token in tokens:
# Would appending this token exceed the limit?
if len(current_chunk) + len(token) <= max_length:
current_chunk += token
else:
# Flush whatever we have accumulated so far
if current_chunk:
chunks.append(current_chunk)
current_chunk = ""
# Force-split any single token longer than max_length
while len(token) > max_length:
chunks.append(token[:max_length])
token = token[max_length:]
current_chunk = token
if current_chunk:
chunks.append(current_chunk)
return chunks
def get_user_prompt(self, messages: List["AllMessageValues"]) -> Optional[str]:
"""
Get the last consecutive block of messages from the user.
Example:
messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm good, thank you!"},
{"role": "user", "content": "What is the weather in Tokyo?"},
]
get_user_prompt(messages) -> "What is the weather in Tokyo?"
"""
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
if not messages:
return None
# Iterate from the end to find the last consecutive block of user messages
user_messages = []
for message in reversed(messages):
if message.get("role") == "user":
user_messages.append(message)
else:
# Stop when we hit a non-user message
break
if not user_messages:
return None
# Reverse to get the messages in chronological order
user_messages.reverse()
user_prompt = ""
for message in user_messages:
text_content = convert_content_list_to_str(message)
user_prompt += text_content + "\n"
result = user_prompt.strip()
return result if result else None

View File

@@ -0,0 +1,165 @@
#!/usr/bin/env python3
"""
Azure Prompt Shield Native Guardrail Integrationfor LiteLLM
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.types.utils import CallTypesLiteral
from .base import AzureGuardrailBase
if TYPE_CHECKING:
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import AllMessageValues
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_prompt_shield import (
AzurePromptShieldGuardrailResponse,
)
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
class AzureContentSafetyPromptShieldGuardrail(AzureGuardrailBase, CustomGuardrail):
"""
LiteLLM Built-in Guardrail for Azure Content Safety Guardrail (Prompt Shield).
This guardrail scans prompts and responses using the Azure Prompt Shield API to detect
malicious content, injection attempts, and policy violations.
Configuration:
guardrail_name: Name of the guardrail instance
api_key: Azure Prompt Shield API key
api_base: Azure Prompt Shield API endpoint
default_on: Whether to enable by default
"""
def __init__(
self,
guardrail_name: str,
api_key: str,
api_base: str,
**kwargs,
):
"""Initialize Azure Prompt Shield guardrail handler."""
from litellm.types.guardrails import GuardrailEventHooks
supported_event_hooks = [
GuardrailEventHooks.pre_call,
GuardrailEventHooks.during_call,
]
# AzureGuardrailBase.__init__ stores api_key, api_base, api_version,
# async_handler and forwards the rest to CustomGuardrail.
super().__init__(
api_key=api_key,
api_base=api_base,
guardrail_name=guardrail_name,
supported_event_hooks=supported_event_hooks,
**kwargs,
)
verbose_proxy_logger.debug(
f"Initialized Azure Prompt Shield Guardrail: {guardrail_name}"
)
async def async_make_request(
self, user_prompt: str
) -> "AzurePromptShieldGuardrailResponse":
"""
Make a request to the Azure Prompt Shield API.
Long prompts are automatically split at word boundaries into chunks
that respect the Azure Content Safety 10 000-character limit. Each
chunk is analysed independently; an attack in *any* chunk raises
an HTTPException immediately.
"""
from .base import AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_prompt_shield import (
AzurePromptShieldGuardrailRequestBody,
AzurePromptShieldGuardrailResponse,
)
chunks = self.split_text_by_words(
user_prompt, AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
)
last_response: Optional[AzurePromptShieldGuardrailResponse] = None
for chunk in chunks:
request_body = AzurePromptShieldGuardrailRequestBody(
documents=[], userPrompt=chunk
)
response_json = await self._post_to_content_safety(
"text:shieldPrompt", cast(dict, request_body)
)
last_response = cast(AzurePromptShieldGuardrailResponse, response_json)
if last_response["userPromptAnalysis"].get("attackDetected"):
verbose_proxy_logger.warning(
"Azure Prompt Shield: Attack detected in chunk of length %d",
len(chunk),
)
raise HTTPException(
status_code=400,
detail={
"error": "Violated Azure Prompt Shield guardrail policy",
"detection_message": f"Attack detected: {last_response['userPromptAnalysis']}",
},
)
# chunks is always non-empty (split_text_by_words guarantees ≥1 element)
assert last_response is not None
return last_response
@log_guardrail_information
async def async_pre_call_hook(
self,
user_api_key_dict: "UserAPIKeyAuth",
cache: Any,
data: Dict[str, Any],
call_type: CallTypesLiteral,
) -> Optional[Dict[str, Any]]:
"""
Pre-call hook to scan user prompts before sending to LLM.
Raises HTTPException if content should be blocked.
"""
verbose_proxy_logger.debug(
"Azure Prompt Shield: Running pre-call prompt scan, on call_type: %s",
call_type,
)
new_messages: Optional[List[AllMessageValues]] = data.get("messages")
if new_messages is None:
verbose_proxy_logger.warning(
"Azure Prompt Shield: not running guardrail. No messages in data"
)
return data
user_prompt = self.get_user_prompt(new_messages)
if user_prompt:
verbose_proxy_logger.debug(
f"Azure Prompt Shield: User prompt: {user_prompt}"
)
await self.async_make_request(
user_prompt=user_prompt,
)
else:
verbose_proxy_logger.warning("Azure Prompt Shield: No user prompt found")
return None
@staticmethod
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
"""
Get the config model for the Azure Prompt Shield guardrail.
"""
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_prompt_shield import (
AzurePromptShieldGuardrailConfigModel,
)
return AzurePromptShieldGuardrailConfigModel

View File

@@ -0,0 +1,281 @@
#!/usr/bin/env python3
"""
Azure Text Moderation Native Guardrail Integrationfor LiteLLM
"""
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union, cast
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import CallTypesLiteral
from .base import AzureGuardrailBase
if TYPE_CHECKING:
from litellm.types.llms.openai import AllMessageValues
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
AzureTextModerationGuardrailResponse,
)
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse
class AzureContentSafetyTextModerationGuardrail(AzureGuardrailBase, CustomGuardrail):
"""
LiteLLM Built-in Guardrail for Azure Content Safety (Text Moderation).
This guardrail scans prompts and responses using the Azure Text Moderation API to detect
malicious content and policy violations based on severity thresholds.
Configuration:
guardrail_name: Name of the guardrail instance
api_key: Azure Text Moderation API key
api_base: Azure Text Moderation API endpoint
default_on: Whether to enable by default
"""
default_severity_threshold: int = 2
def __init__(
self,
guardrail_name: str,
api_key: str,
api_base: str,
severity_threshold: Optional[int] = None,
severity_threshold_by_category: Optional[Dict[str, int]] = None,
**kwargs,
):
"""Initialize Azure Text Moderation guardrail handler."""
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
AzureTextModerationRequestBodyOptionalParams,
)
# AzureGuardrailBase.__init__ stores api_key, api_base, api_version,
# async_handler and forwards the rest to CustomGuardrail.
super().__init__(
api_key=api_key,
api_base=api_base,
guardrail_name=guardrail_name,
**kwargs,
)
self.optional_params_request_body: (
AzureTextModerationRequestBodyOptionalParams
) = {
"categories": kwargs.get("categories")
or [
"Hate",
"Sexual",
"SelfHarm",
"Violence",
],
"blocklistNames": cast(
Optional[List[str]], kwargs.get("blocklistNames") or None
),
"haltOnBlocklistHit": kwargs.get("haltOnBlocklistHit") or False,
"outputType": kwargs.get("outputType") or "FourSeverityLevels",
}
self.severity_threshold = (
int(severity_threshold) if severity_threshold else None
)
self.severity_threshold_by_category = severity_threshold_by_category
verbose_proxy_logger.info(
f"Initialized Azure Text Moderation Guardrail: {guardrail_name}"
)
@staticmethod
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
AzureContentSafetyTextModerationConfigModel,
)
return AzureContentSafetyTextModerationConfigModel
async def async_make_request(
self, text: str
) -> "AzureTextModerationGuardrailResponse":
"""
Make a request to the Azure Text Moderation API.
Long texts are automatically split at word boundaries into chunks
that respect the Azure Content Safety 10 000-character limit. Each
chunk is analysed independently; a severity-threshold violation in
*any* chunk raises an HTTPException immediately.
"""
from .base import AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
AzureTextModerationGuardrailRequestBody,
AzureTextModerationGuardrailResponse,
)
chunks = self.split_text_by_words(text, AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH)
last_response: Optional[AzureTextModerationGuardrailResponse] = None
for chunk in chunks:
request_body = AzureTextModerationGuardrailRequestBody(
text=chunk,
**self.optional_params_request_body, # type: ignore[misc]
)
response_json = await self._post_to_content_safety(
"text:analyze", cast(dict, request_body)
)
chunk_response = cast(AzureTextModerationGuardrailResponse, response_json)
# For multi-chunk texts the callers only see the final response,
# so we must check every intermediate chunk here to avoid silently
# swallowing a violation that appears in an earlier chunk.
try:
self.check_severity_threshold(response=chunk_response)
except HTTPException:
verbose_proxy_logger.warning(
"Azure Text Moderation: Violation detected in chunk of length %d",
len(chunk),
)
raise
last_response = chunk_response
# chunks is always non-empty (split_text_by_words guarantees ≥1 element)
assert last_response is not None
return last_response
def check_severity_threshold(
self, response: "AzureTextModerationGuardrailResponse"
) -> Literal[True]:
"""
- Check if threshold set by category
- Check if general severity threshold set
- If both none, use default_severity_threshold
"""
if self.severity_threshold_by_category:
for category in response["categoriesAnalysis"]:
severity_category_threshold_item = (
self.severity_threshold_by_category.get(category["category"])
)
if (
severity_category_threshold_item is not None
and category["severity"] >= severity_category_threshold_item
):
raise HTTPException(
status_code=400,
detail={
"error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
category["category"],
self.severity_threshold_by_category.get(
category["category"]
),
category["severity"],
)
},
)
if self.severity_threshold:
for category in response["categoriesAnalysis"]:
if category["severity"] >= self.severity_threshold:
raise HTTPException(
status_code=400,
detail={
"error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
category["category"],
self.severity_threshold,
category["severity"],
)
},
)
if (
self.severity_threshold is None
and self.severity_threshold_by_category is None
):
for category in response["categoriesAnalysis"]:
if category["severity"] >= self.default_severity_threshold:
raise HTTPException(
status_code=400,
detail={
"error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
category["category"],
self.default_severity_threshold,
category["severity"],
)
},
)
return True
@log_guardrail_information
async def async_pre_call_hook(
self,
user_api_key_dict: "UserAPIKeyAuth",
cache: Any,
data: Dict[str, Any],
call_type: CallTypesLiteral,
) -> Optional[Dict[str, Any]]:
"""
Pre-call hook to scan user prompts before sending to LLM.
Raises HTTPException if content should be blocked.
"""
verbose_proxy_logger.info(
"Azure Text Moderation: Running pre-call prompt scan, on call_type: %s",
call_type,
)
new_messages: Optional[List[AllMessageValues]] = data.get("messages")
if new_messages is None:
verbose_proxy_logger.warning(
"Azure Text Moderation: not running guardrail. No messages in data"
)
return data
user_prompt = self.get_user_prompt(new_messages)
if user_prompt:
verbose_proxy_logger.info(
f"Azure Text Moderation: User prompt: {user_prompt}"
)
await self.async_make_request(
text=user_prompt,
)
else:
verbose_proxy_logger.warning("Azure Text Moderation: No text found")
return None
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: "UserAPIKeyAuth",
response: Union[Any, "ModelResponse", "EmbeddingResponse", "ImageResponse"],
) -> Any:
from litellm.types.utils import Choices, ModelResponse
if (
isinstance(response, ModelResponse)
and response.choices
and isinstance(response.choices[0], Choices)
):
content = response.choices[0].message.content or ""
await self.async_make_request(
text=content,
)
return response
async def async_post_call_streaming_hook(
self, user_api_key_dict: UserAPIKeyAuth, response: str
) -> Any:
try:
if response is not None and len(response) > 0:
await self.async_make_request(
text=response,
)
return response
except HTTPException as e:
import json
error_returned = json.dumps({"error": e.detail})
return f"data: {error_returned}\n\n"