260 lines
9.0 KiB
Python
260 lines
9.0 KiB
Python
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||
|
|
|
||
|
|
import litellm
|
||
|
|
from litellm._logging import verbose_logger
|
||
|
|
from litellm.constants import XAI_API_BASE
|
||
|
|
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||
|
|
from litellm.secret_managers.main import get_secret_str
|
||
|
|
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
|
||
|
|
from litellm.types.llms.xai import XAIWebSearchTool, XAIXSearchTool
|
||
|
|
from litellm.types.router import GenericLiteLLMParams
|
||
|
|
from litellm.types.utils import LlmProviders
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||
|
|
|
||
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||
|
|
else:
|
||
|
|
LiteLLMLoggingObj = Any
|
||
|
|
|
||
|
|
|
||
|
|
class XAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||
|
|
"""
|
||
|
|
Configuration for XAI's Responses API.
|
||
|
|
|
||
|
|
Inherits from OpenAIResponsesAPIConfig since XAI's Responses API is largely
|
||
|
|
compatible with OpenAI's, with a few differences:
|
||
|
|
- Does not support the 'instructions' parameter
|
||
|
|
- Requires code_interpreter tools to have 'container' field removed
|
||
|
|
- Recommends store=false when sending images
|
||
|
|
|
||
|
|
Reference: https://docs.x.ai/docs/api-reference#create-new-response
|
||
|
|
"""
|
||
|
|
|
||
|
|
@property
|
||
|
|
def custom_llm_provider(self) -> LlmProviders:
|
||
|
|
return LlmProviders.XAI
|
||
|
|
|
||
|
|
def get_supported_openai_params(self, model: str) -> list:
|
||
|
|
"""
|
||
|
|
Get supported parameters for XAI Responses API.
|
||
|
|
|
||
|
|
XAI supports most OpenAI Responses API params except 'instructions'.
|
||
|
|
"""
|
||
|
|
supported_params = super().get_supported_openai_params(model)
|
||
|
|
|
||
|
|
# Remove 'instructions' as it's not supported by XAI
|
||
|
|
if "instructions" in supported_params:
|
||
|
|
supported_params.remove("instructions")
|
||
|
|
|
||
|
|
return supported_params
|
||
|
|
|
||
|
|
def _transform_web_search_tool(
|
||
|
|
self, tool: Dict[str, Any]
|
||
|
|
) -> Union[XAIWebSearchTool, Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Transform web_search tool to XAI format.
|
||
|
|
|
||
|
|
XAI supports web_search with specific filters:
|
||
|
|
- allowed_domains (max 5)
|
||
|
|
- excluded_domains (max 5)
|
||
|
|
- enable_image_understanding
|
||
|
|
|
||
|
|
XAI does NOT support search_context_size (OpenAI-specific).
|
||
|
|
"""
|
||
|
|
xai_tool: Dict[str, Any] = {"type": "web_search"}
|
||
|
|
|
||
|
|
# Remove search_context_size if present (not supported by XAI)
|
||
|
|
if "search_context_size" in tool:
|
||
|
|
verbose_logger.info(
|
||
|
|
"XAI does not support 'search_context_size' parameter. Removing it from web_search tool."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Handle filters (XAI-specific structure)
|
||
|
|
filters = {}
|
||
|
|
if "allowed_domains" in tool:
|
||
|
|
allowed_domains = tool["allowed_domains"]
|
||
|
|
filters["allowed_domains"] = allowed_domains
|
||
|
|
|
||
|
|
if "excluded_domains" in tool:
|
||
|
|
excluded_domains = tool["excluded_domains"]
|
||
|
|
filters["excluded_domains"] = excluded_domains
|
||
|
|
|
||
|
|
# Add filters if any were specified
|
||
|
|
if filters:
|
||
|
|
xai_tool["filters"] = filters
|
||
|
|
|
||
|
|
# Handle enable_image_understanding (top-level in XAI format)
|
||
|
|
if "enable_image_understanding" in tool:
|
||
|
|
xai_tool["enable_image_understanding"] = tool["enable_image_understanding"]
|
||
|
|
|
||
|
|
return xai_tool
|
||
|
|
|
||
|
|
def _transform_x_search_tool(
|
||
|
|
self, tool: Dict[str, Any]
|
||
|
|
) -> Union[XAIXSearchTool, Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Transform x_search tool to XAI format.
|
||
|
|
|
||
|
|
XAI supports x_search with specific parameters:
|
||
|
|
- allowed_x_handles (max 10)
|
||
|
|
- excluded_x_handles (max 10)
|
||
|
|
- from_date (ISO8601: YYYY-MM-DD)
|
||
|
|
- to_date (ISO8601: YYYY-MM-DD)
|
||
|
|
- enable_image_understanding
|
||
|
|
- enable_video_understanding
|
||
|
|
"""
|
||
|
|
xai_tool: Dict[str, Any] = {"type": "x_search"}
|
||
|
|
|
||
|
|
# Handle allowed_x_handles
|
||
|
|
if "allowed_x_handles" in tool:
|
||
|
|
allowed_handles = tool["allowed_x_handles"]
|
||
|
|
xai_tool["allowed_x_handles"] = allowed_handles
|
||
|
|
|
||
|
|
# Handle excluded_x_handles
|
||
|
|
if "excluded_x_handles" in tool:
|
||
|
|
excluded_handles = tool["excluded_x_handles"]
|
||
|
|
xai_tool["excluded_x_handles"] = excluded_handles
|
||
|
|
|
||
|
|
# Handle date range
|
||
|
|
if "from_date" in tool:
|
||
|
|
xai_tool["from_date"] = tool["from_date"]
|
||
|
|
|
||
|
|
if "to_date" in tool:
|
||
|
|
xai_tool["to_date"] = tool["to_date"]
|
||
|
|
|
||
|
|
# Handle media understanding flags
|
||
|
|
if "enable_image_understanding" in tool:
|
||
|
|
xai_tool["enable_image_understanding"] = tool["enable_image_understanding"]
|
||
|
|
|
||
|
|
if "enable_video_understanding" in tool:
|
||
|
|
xai_tool["enable_video_understanding"] = tool["enable_video_understanding"]
|
||
|
|
|
||
|
|
return xai_tool
|
||
|
|
|
||
|
|
def map_openai_params(
|
||
|
|
self,
|
||
|
|
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||
|
|
model: str,
|
||
|
|
drop_params: bool,
|
||
|
|
) -> Dict:
|
||
|
|
"""
|
||
|
|
Map parameters for XAI Responses API.
|
||
|
|
|
||
|
|
Handles XAI-specific transformations:
|
||
|
|
1. Drops 'instructions' parameter (not supported)
|
||
|
|
2. Transforms code_interpreter tools to remove 'container' field
|
||
|
|
3. Transforms web_search tools to XAI format (removes search_context_size, adds filters)
|
||
|
|
4. Transforms x_search tools to XAI format
|
||
|
|
5. Sets store=false when images are detected (recommended by XAI)
|
||
|
|
"""
|
||
|
|
params = dict(response_api_optional_params)
|
||
|
|
|
||
|
|
# Drop instructions parameter (not supported by XAI)
|
||
|
|
if "instructions" in params:
|
||
|
|
verbose_logger.debug(
|
||
|
|
"XAI Responses API does not support 'instructions' parameter. Dropping it."
|
||
|
|
)
|
||
|
|
params.pop("instructions")
|
||
|
|
|
||
|
|
if "metadata" in params:
|
||
|
|
verbose_logger.debug(
|
||
|
|
"XAI Responses API does not support 'metadata' parameter. Dropping it."
|
||
|
|
)
|
||
|
|
params.pop("metadata")
|
||
|
|
|
||
|
|
# Transform tools
|
||
|
|
if "tools" in params and params["tools"]:
|
||
|
|
tools_list = params["tools"]
|
||
|
|
# Ensure tools is a list for iteration
|
||
|
|
if not isinstance(tools_list, list):
|
||
|
|
tools_list = [tools_list]
|
||
|
|
|
||
|
|
transformed_tools: List[Any] = []
|
||
|
|
for tool in tools_list:
|
||
|
|
if isinstance(tool, dict):
|
||
|
|
tool_type = tool.get("type")
|
||
|
|
|
||
|
|
if tool_type == "code_interpreter":
|
||
|
|
# XAI supports code_interpreter but doesn't use the container field
|
||
|
|
verbose_logger.debug(
|
||
|
|
"XAI: Transforming code_interpreter tool, removing container field"
|
||
|
|
)
|
||
|
|
transformed_tools.append({"type": "code_interpreter"})
|
||
|
|
|
||
|
|
elif tool_type == "web_search":
|
||
|
|
# Transform web_search to XAI format
|
||
|
|
verbose_logger.debug(
|
||
|
|
"XAI: Transforming web_search tool to XAI format"
|
||
|
|
)
|
||
|
|
transformed_tools.append(self._transform_web_search_tool(tool))
|
||
|
|
|
||
|
|
elif tool_type == "x_search":
|
||
|
|
# Transform x_search to XAI format
|
||
|
|
verbose_logger.debug(
|
||
|
|
"XAI: Transforming x_search tool to XAI format"
|
||
|
|
)
|
||
|
|
transformed_tools.append(self._transform_x_search_tool(tool))
|
||
|
|
|
||
|
|
else:
|
||
|
|
# Keep other tools as-is
|
||
|
|
transformed_tools.append(tool)
|
||
|
|
else:
|
||
|
|
transformed_tools.append(tool)
|
||
|
|
|
||
|
|
params["tools"] = transformed_tools
|
||
|
|
|
||
|
|
return params
|
||
|
|
|
||
|
|
def validate_environment(
|
||
|
|
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Validate environment and set up headers for XAI API.
|
||
|
|
|
||
|
|
Uses XAI_API_KEY from environment or litellm_params.
|
||
|
|
"""
|
||
|
|
litellm_params = litellm_params or GenericLiteLLMParams()
|
||
|
|
api_key = (
|
||
|
|
litellm_params.api_key or litellm.api_key or get_secret_str("XAI_API_KEY")
|
||
|
|
)
|
||
|
|
|
||
|
|
if not api_key:
|
||
|
|
raise ValueError(
|
||
|
|
"XAI API key is required. Set XAI_API_KEY environment variable or pass api_key parameter."
|
||
|
|
)
|
||
|
|
|
||
|
|
headers.update(
|
||
|
|
{
|
||
|
|
"Authorization": f"Bearer {api_key}",
|
||
|
|
}
|
||
|
|
)
|
||
|
|
return headers
|
||
|
|
|
||
|
|
def get_complete_url(
|
||
|
|
self,
|
||
|
|
api_base: Optional[str],
|
||
|
|
litellm_params: dict,
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
Get the complete URL for XAI Responses API endpoint.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: The full URL for the XAI /responses endpoint
|
||
|
|
"""
|
||
|
|
api_base = (
|
||
|
|
api_base
|
||
|
|
or litellm.api_base
|
||
|
|
or get_secret_str("XAI_API_BASE")
|
||
|
|
or XAI_API_BASE
|
||
|
|
)
|
||
|
|
|
||
|
|
# Remove trailing slashes
|
||
|
|
api_base = api_base.rstrip("/")
|
||
|
|
|
||
|
|
return f"{api_base}/responses"
|
||
|
|
|
||
|
|
def supports_native_websocket(self) -> bool:
|
||
|
|
"""XAI does not support native WebSocket for Responses API"""
|
||
|
|
return False
|