chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,210 @@
# Arize Phoenix Prompt Management Integration
This integration enables using prompt versions from Arize Phoenix with LiteLLM's completion function.
## Features
- Fetch prompt versions from Arize Phoenix API
- Workspace-based access control through Arize Phoenix permissions
- Mustache/Handlebars-style variable templating (`{{variable}}`)
- Support for multi-message chat templates
- Automatic model and parameter configuration from prompt metadata
- OpenAI and Anthropic provider parameter support
## Configuration
Configure Arize Phoenix access in your application:
```python
import litellm
# Configure Arize Phoenix access
# api_base should include your workspace, e.g., "https://app.phoenix.arize.com/s/your-workspace/v1"
api_key = "your-arize-phoenix-token"
api_base = "https://app.phoenix.arize.com/s/krrishdholakia/v1"
```
## Usage
### Basic Usage
```python
import litellm
# Use with completion
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="UHJvbXB0VmVyc2lvbjox", # Your prompt version ID
prompt_variables={"question": "What is artificial intelligence?"},
api_key="your-arize-phoenix-token",
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
)
print(response.choices[0].message.content)
```
### With Additional Messages
You can also combine prompt templates with additional messages:
```python
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="UHJvbXB0VmVyc2lvbjox",
prompt_variables={"question": "Explain quantum computing"},
api_key="your-arize-phoenix-token",
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
messages=[
{"role": "user", "content": "Please keep your response under 100 words."}
],
)
```
### Direct Manager Usage
You can also use the prompt manager directly:
```python
from litellm.integrations.arize.arize_phoenix_prompt_manager import ArizePhoenixPromptManager
# Initialize the manager
manager = ArizePhoenixPromptManager(
api_key="your-arize-phoenix-token",
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
prompt_id="UHJvbXB0VmVyc2lvbjox",
)
# Get rendered messages
messages, metadata = manager.get_prompt_template(
prompt_id="UHJvbXB0VmVyc2lvbjox",
prompt_variables={"question": "What is machine learning?"}
)
print("Rendered messages:", messages)
print("Metadata:", metadata)
```
## Prompt Format
Arize Phoenix prompts support the following structure:
```json
{
"data": {
"description": "A chatbot prompt",
"model_provider": "OPENAI",
"model_name": "gpt-4o",
"template": {
"type": "chat",
"messages": [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a chatbot"
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "{{question}}"
}
]
}
]
},
"template_type": "CHAT",
"template_format": "MUSTACHE",
"invocation_parameters": {
"type": "openai",
"openai": {
"temperature": 1.0
}
},
"id": "UHJvbXB0VmVyc2lvbjox"
}
}
```
### Variable Substitution
Variables in your prompt templates use Mustache/Handlebars syntax:
- `{{variable_name}}` - Simple variable substitution
Example:
```
Template: "Hello {{name}}, your order {{order_id}} is ready!"
Variables: {"name": "Alice", "order_id": "12345"}
Result: "Hello Alice, your order 12345 is ready!"
```
## API Reference
### ArizePhoenixPromptManager
Main class for managing Arize Phoenix prompts.
**Methods:**
- `get_prompt_template(prompt_id, prompt_variables)` - Get and render a prompt template
- `get_available_prompts()` - List available prompt IDs
- `reload_prompts()` - Reload prompts from Arize Phoenix
### ArizePhoenixClient
Low-level client for Arize Phoenix API.
**Methods:**
- `get_prompt_version(prompt_version_id)` - Fetch a prompt version
- `test_connection()` - Test API connection
## Error Handling
The integration provides detailed error messages:
- **404**: Prompt version not found
- **401**: Authentication failed (check your access token)
- **403**: Access denied (check workspace permissions)
Example:
```python
try:
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="invalid-id",
arize_config=arize_config,
)
except Exception as e:
print(f"Error: {e}")
```
## Getting Your Prompt Version ID and API Base
1. Log in to Arize Phoenix
2. Navigate to your workspace
3. Go to Prompts section
4. Select a prompt version
5. The ID will be in the URL: `/s/{workspace}/v1/prompt_versions/{PROMPT_VERSION_ID}`
Your `api_base` should be: `https://app.phoenix.arize.com/s/{workspace}/v1`
For example:
- Workspace: `krrishdholakia`
- API Base: `https://app.phoenix.arize.com/s/krrishdholakia/v1`
- Prompt Version ID: `UHJvbXB0VmVyc2lvbjox`
You can also fetch it via API:
```bash
curl -L -X GET 'https://app.phoenix.arize.com/s/krrishdholakia/v1/prompt_versions/UHJvbXB0VmVyc2lvbjox' \
-H 'Authorization: Bearer YOUR_TOKEN'
```
## Support
For issues or questions:
- LiteLLM Issues: https://github.com/BerriAI/litellm/issues
- Arize Phoenix Docs: https://docs.arize.com/phoenix

View File

@@ -0,0 +1,52 @@
import os
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
from .arize_phoenix_prompt_manager import ArizePhoenixPromptManager
# Global instances
global_arize_config: Optional[dict] = None
def prompt_initializer(
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
) -> "CustomPromptManagement":
"""
Initialize a prompt from Arize Phoenix.
"""
api_key = getattr(litellm_params, "api_key", None) or os.environ.get(
"PHOENIX_API_KEY"
)
api_base = getattr(litellm_params, "api_base", None)
prompt_id = getattr(litellm_params, "prompt_id", None)
if not api_key or not api_base:
raise ValueError(
"api_key and api_base are required for Arize Phoenix prompt integration"
)
try:
arize_prompt_manager = ArizePhoenixPromptManager(
**{
"api_key": api_key,
"api_base": api_base,
"prompt_id": prompt_id,
**litellm_params.model_dump(
exclude={"api_key", "api_base", "prompt_id"}
),
},
)
return arize_prompt_manager
except Exception as e:
raise e
prompt_initializer_registry = {
SupportedPromptIntegrations.ARIZE_PHOENIX.value: prompt_initializer,
}

View File

@@ -0,0 +1,502 @@
import json
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from typing_extensions import override
from litellm._logging import verbose_logger
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
BaseLLMObsOTELAttributes,
safe_set_attribute,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span
from litellm.integrations._types.open_inference import (
MessageAttributes,
ImageAttributes,
SpanAttributes,
AudioAttributes,
EmbeddingAttributes,
OpenInferenceSpanKindValues,
)
class ArizeOTELAttributes(BaseLLMObsOTELAttributes):
@staticmethod
@override
def set_messages(span: "Span", kwargs: Dict[str, Any]):
messages = kwargs.get("messages")
# for /chat/completions
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
if messages:
last_message = messages[-1]
safe_set_attribute(
span,
SpanAttributes.INPUT_VALUE,
last_message.get("content", ""),
)
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page.
for idx, msg in enumerate(messages):
prefix = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}"
# Set the role per message.
safe_set_attribute(
span, f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", msg.get("role")
)
# Set the content per message.
safe_set_attribute(
span,
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
msg.get("content", ""),
)
@staticmethod
@override
def set_response_output_messages(span: "Span", response_obj):
"""
Sets output message attributes on the span from the LLM response.
Args:
span: The OpenTelemetry span to set attributes on
response_obj: The response object containing choices with messages
"""
from litellm.integrations._types.open_inference import (
MessageAttributes,
SpanAttributes,
)
for idx, choice in enumerate(response_obj.get("choices", [])):
response_message = choice.get("message", {})
safe_set_attribute(
span,
SpanAttributes.OUTPUT_VALUE,
response_message.get("content", ""),
)
# This shows up under `output_messages` tab on the span page.
prefix = f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}"
safe_set_attribute(
span,
f"{prefix}.{MessageAttributes.MESSAGE_ROLE}",
response_message.get("role"),
)
safe_set_attribute(
span,
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
response_message.get("content", ""),
)
def _set_response_attributes(span: "Span", response_obj):
"""Helper to set response output and token usage attributes on span."""
if not hasattr(response_obj, "get"):
return
_set_choice_outputs(span, response_obj, MessageAttributes, SpanAttributes)
_set_image_outputs(span, response_obj, ImageAttributes, SpanAttributes)
_set_audio_outputs(span, response_obj, AudioAttributes, SpanAttributes)
_set_embedding_outputs(span, response_obj, EmbeddingAttributes, SpanAttributes)
_set_structured_outputs(span, response_obj, MessageAttributes, SpanAttributes)
_set_usage_outputs(span, response_obj, SpanAttributes)
def _set_choice_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
for idx, choice in enumerate(response_obj.get("choices", [])):
response_message = choice.get("message", {})
safe_set_attribute(
span,
span_attrs.OUTPUT_VALUE,
response_message.get("content", ""),
)
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{idx}"
safe_set_attribute(
span,
f"{prefix}.{msg_attrs.MESSAGE_ROLE}",
response_message.get("role"),
)
safe_set_attribute(
span,
f"{prefix}.{msg_attrs.MESSAGE_CONTENT}",
response_message.get("content", ""),
)
def _set_image_outputs(span: "Span", response_obj, image_attrs, span_attrs):
images = response_obj.get("data", [])
for i, image in enumerate(images):
img_url = image.get("url")
if img_url is None and image.get("b64_json"):
img_url = f"data:image/png;base64,{image.get('b64_json')}"
if not img_url:
continue
if i == 0:
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, img_url)
safe_set_attribute(span, f"{image_attrs.IMAGE_URL}.{i}", img_url)
def _set_audio_outputs(span: "Span", response_obj, audio_attrs, span_attrs):
audio = response_obj.get("audio", [])
for i, audio_item in enumerate(audio):
audio_url = audio_item.get("url")
if audio_url is None and audio_item.get("b64_json"):
audio_url = f"data:audio/wav;base64,{audio_item.get('b64_json')}"
if audio_url:
if i == 0:
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, audio_url)
safe_set_attribute(span, f"{audio_attrs.AUDIO_URL}.{i}", audio_url)
audio_mime = audio_item.get("mime_type")
if audio_mime:
safe_set_attribute(span, f"{audio_attrs.AUDIO_MIME_TYPE}.{i}", audio_mime)
audio_transcript = audio_item.get("transcript")
if audio_transcript:
safe_set_attribute(
span, f"{audio_attrs.AUDIO_TRANSCRIPT}.{i}", audio_transcript
)
def _set_embedding_outputs(span: "Span", response_obj, embedding_attrs, span_attrs):
embeddings = response_obj.get("data", [])
for i, embedding_item in enumerate(embeddings):
embedding_vector = embedding_item.get("embedding")
if embedding_vector:
if i == 0:
safe_set_attribute(
span,
span_attrs.OUTPUT_VALUE,
str(embedding_vector),
)
safe_set_attribute(
span,
f"{embedding_attrs.EMBEDDING_VECTOR}.{i}",
str(embedding_vector),
)
embedding_text = embedding_item.get("text")
if embedding_text:
safe_set_attribute(
span,
f"{embedding_attrs.EMBEDDING_TEXT}.{i}",
str(embedding_text),
)
def _set_structured_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
output_items = response_obj.get("output", [])
for i, item in enumerate(output_items):
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{i}"
if not hasattr(item, "type"):
continue
item_type = item.type
if item_type == "reasoning" and hasattr(item, "summary"):
for summary in item.summary:
if hasattr(summary, "text"):
safe_set_attribute(
span,
f"{prefix}.{msg_attrs.MESSAGE_REASONING_SUMMARY}",
summary.text,
)
elif item_type == "message" and hasattr(item, "content"):
message_content = ""
content_list = item.content
if content_list and len(content_list) > 0:
first_content = content_list[0]
message_content = getattr(first_content, "text", "")
message_role = getattr(item, "role", "assistant")
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, message_content)
safe_set_attribute(
span, f"{prefix}.{msg_attrs.MESSAGE_CONTENT}", message_content
)
safe_set_attribute(span, f"{prefix}.{msg_attrs.MESSAGE_ROLE}", message_role)
def _set_usage_outputs(span: "Span", response_obj, span_attrs):
usage = response_obj and response_obj.get("usage")
if not usage:
return
safe_set_attribute(
span, span_attrs.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens")
)
completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens")
if completion_tokens:
safe_set_attribute(
span, span_attrs.LLM_TOKEN_COUNT_COMPLETION, completion_tokens
)
prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens")
if prompt_tokens:
safe_set_attribute(span, span_attrs.LLM_TOKEN_COUNT_PROMPT, prompt_tokens)
reasoning_tokens = usage.get("output_tokens_details", {}).get("reasoning_tokens")
if reasoning_tokens:
safe_set_attribute(
span,
span_attrs.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
reasoning_tokens,
)
def _infer_open_inference_span_kind(call_type: Optional[str]) -> str:
"""
Map LiteLLM call types to OpenInference span kinds.
"""
if not call_type:
return OpenInferenceSpanKindValues.UNKNOWN.value
lowered = str(call_type).lower()
if "embed" in lowered:
return OpenInferenceSpanKindValues.EMBEDDING.value
if "rerank" in lowered:
return OpenInferenceSpanKindValues.RERANKER.value
if "search" in lowered:
return OpenInferenceSpanKindValues.RETRIEVER.value
if "moderation" in lowered or "guardrail" in lowered:
return OpenInferenceSpanKindValues.GUARDRAIL.value
if lowered == "call_mcp_tool" or lowered == "mcp" or lowered.endswith("tool"):
return OpenInferenceSpanKindValues.TOOL.value
if "asend_message" in lowered or "a2a" in lowered or "assistant" in lowered:
return OpenInferenceSpanKindValues.AGENT.value
if any(
keyword in lowered
for keyword in (
"completion",
"chat",
"image",
"audio",
"speech",
"transcription",
"generate_content",
"response",
"videos",
"realtime",
"pass_through",
"anthropic_messages",
"ocr",
)
):
return OpenInferenceSpanKindValues.LLM.value
if any(
keyword in lowered
for keyword in ("file", "batch", "container", "fine_tuning_job")
):
return OpenInferenceSpanKindValues.CHAIN.value
return OpenInferenceSpanKindValues.UNKNOWN.value
def _set_tool_attributes(
span: "Span", optional_tools: Optional[list], metadata_tools: Optional[list]
):
"""set tool attributes on span from optional_params or tool call metadata"""
if optional_tools:
for idx, tool in enumerate(optional_tools):
if not isinstance(tool, dict):
continue
function = (
tool.get("function") if isinstance(tool.get("function"), dict) else None
)
if not function:
continue
tool_name = function.get("name")
if tool_name:
safe_set_attribute(
span, f"{SpanAttributes.LLM_TOOLS}.{idx}.name", tool_name
)
tool_description = function.get("description")
if tool_description:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_TOOLS}.{idx}.description",
tool_description,
)
params = function.get("parameters")
if params is not None:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_TOOLS}.{idx}.parameters",
json.dumps(params),
)
if metadata_tools and isinstance(metadata_tools, list):
for idx, tool in enumerate(metadata_tools):
if not isinstance(tool, dict):
continue
tool_name = tool.get("name")
if tool_name:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.name",
tool_name,
)
tool_description = tool.get("description")
if tool_description:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.description",
tool_description,
)
def set_attributes(
span: "Span", kwargs, response_obj, attributes: Type[BaseLLMObsOTELAttributes]
):
"""
Populates span with OpenInference-compliant LLM attributes for Arize and Phoenix tracing.
"""
try:
optional_params = _sanitize_optional_params(kwargs.get("optional_params"))
litellm_params = kwargs.get("litellm_params", {}) or {}
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
metadata = (
standard_logging_payload.get("metadata")
if standard_logging_payload
else None
)
_set_metadata_attributes(span, metadata, SpanAttributes)
metadata_tools = _extract_metadata_tools(metadata)
optional_tools = _extract_optional_tools(optional_params)
call_type = standard_logging_payload.get("call_type")
_set_request_attributes(
span=span,
kwargs=kwargs,
standard_logging_payload=standard_logging_payload,
optional_params=optional_params,
litellm_params=litellm_params,
response_obj=response_obj,
span_attrs=SpanAttributes,
)
span_kind = _infer_open_inference_span_kind(call_type=call_type)
_set_tool_attributes(span, optional_tools, metadata_tools)
if (
optional_tools or metadata_tools
) and span_kind != OpenInferenceSpanKindValues.TOOL.value:
span_kind = OpenInferenceSpanKindValues.TOOL.value
safe_set_attribute(span, SpanAttributes.OPENINFERENCE_SPAN_KIND, span_kind)
attributes.set_messages(span, kwargs)
model_params = (
standard_logging_payload.get("model_parameters")
if standard_logging_payload
else None
)
_set_model_params(span, model_params, SpanAttributes)
_set_response_attributes(span=span, response_obj=response_obj)
except Exception as e:
verbose_logger.error(
f"[Arize/Phoenix] Failed to set OpenInference span attributes: {e}"
)
if hasattr(span, "record_exception"):
span.record_exception(e)
def _sanitize_optional_params(optional_params: Optional[dict]) -> dict:
if not isinstance(optional_params, dict):
return {}
optional_params.pop("secret_fields", None)
return optional_params
def _set_metadata_attributes(span: "Span", metadata: Optional[Any], span_attrs) -> None:
if metadata is not None:
safe_set_attribute(span, span_attrs.METADATA, safe_dumps(metadata))
def _extract_metadata_tools(metadata: Optional[Any]) -> Optional[list]:
if not isinstance(metadata, dict):
return None
llm_obj = metadata.get("llm")
if isinstance(llm_obj, dict):
return llm_obj.get("tools")
return None
def _extract_optional_tools(optional_params: dict) -> Optional[list]:
return optional_params.get("tools") if isinstance(optional_params, dict) else None
def _set_request_attributes(
span: "Span",
kwargs,
standard_logging_payload: StandardLoggingPayload,
optional_params: dict,
litellm_params: dict,
response_obj,
span_attrs,
):
if kwargs.get("model"):
safe_set_attribute(span, span_attrs.LLM_MODEL_NAME, kwargs.get("model"))
safe_set_attribute(
span, "llm.request.type", standard_logging_payload.get("call_type")
)
safe_set_attribute(
span,
span_attrs.LLM_PROVIDER,
litellm_params.get("custom_llm_provider", "Unknown"),
)
if optional_params.get("max_tokens"):
safe_set_attribute(
span, "llm.request.max_tokens", optional_params.get("max_tokens")
)
if optional_params.get("temperature"):
safe_set_attribute(
span, "llm.request.temperature", optional_params.get("temperature")
)
if optional_params.get("top_p"):
safe_set_attribute(span, "llm.request.top_p", optional_params.get("top_p"))
safe_set_attribute(
span, "llm.is_streaming", str(optional_params.get("stream", False))
)
if optional_params.get("user"):
safe_set_attribute(span, "llm.user", optional_params.get("user"))
if response_obj and response_obj.get("id"):
safe_set_attribute(span, "llm.response.id", response_obj.get("id"))
if response_obj and response_obj.get("model"):
safe_set_attribute(span, "llm.response.model", response_obj.get("model"))
def _set_model_params(span: "Span", model_params: Optional[dict], span_attrs) -> None:
if not model_params:
return
safe_set_attribute(
span, span_attrs.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
)
if model_params.get("user"):
user_id = model_params.get("user")
if user_id is not None:
safe_set_attribute(span, span_attrs.USER_ID, user_id)

View File

@@ -0,0 +1,214 @@
"""
arize AI is OTEL compatible
this file has Arize ai specific helper functions
"""
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm.integrations.arize import _utils
from litellm.integrations.arize._utils import ArizeOTELAttributes
from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.types.integrations.arize import ArizeConfig
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import StandardCallbackDynamicParams
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol
Span = Union[_Span, Any]
else:
Protocol = Any
Span = Any
class ArizeLogger(OpenTelemetry):
"""
Arize logger that sends traces to an Arize endpoint.
Creates its own dedicated TracerProvider so it can coexist with the
generic ``otel`` callback (or any other OTEL-based integration) without
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
"""
def _init_tracing(self, tracer_provider):
"""
Override to always create a *private* TracerProvider for Arize.
See ArizePhoenixLogger._init_tracing for full rationale.
"""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import SpanKind
if tracer_provider is not None:
self.tracer = tracer_provider.get_tracer("litellm")
self.span_kind = SpanKind
return
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
provider.add_span_processor(self._get_span_processor())
self.tracer = provider.get_tracer("litellm")
self.span_kind = SpanKind
def _init_otel_logger_on_litellm_proxy(self):
"""
Override: Arize should NOT overwrite the proxy's
``open_telemetry_logger``. That attribute is reserved for the
primary ``otel`` callback which handles proxy-level parent spans.
"""
pass
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return
@staticmethod
def set_arize_attributes(span: Span, kwargs, response_obj):
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
return
@staticmethod
def get_arize_config() -> ArizeConfig:
"""
Helper function to get Arize configuration.
Returns:
ArizeConfig: A Pydantic model containing Arize configuration.
Raises:
ValueError: If required environment variables are not set.
"""
space_id = os.environ.get("ARIZE_SPACE_ID")
space_key = os.environ.get("ARIZE_SPACE_KEY")
api_key = os.environ.get("ARIZE_API_KEY")
project_name = os.environ.get("ARIZE_PROJECT_NAME")
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
endpoint = None
protocol: Protocol = "otlp_grpc"
if grpc_endpoint:
protocol = "otlp_grpc"
endpoint = grpc_endpoint
elif http_endpoint:
protocol = "otlp_http"
endpoint = http_endpoint
else:
protocol = "otlp_grpc"
endpoint = "https://otlp.arize.com/v1"
return ArizeConfig(
space_id=space_id,
space_key=space_key,
api_key=api_key,
protocol=protocol,
endpoint=endpoint,
project_name=project_name,
)
async def async_service_success_hook(
self,
payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
pass
async def async_service_failure_hook(
self,
payload: ServiceLoggerPayload,
error: Optional[str] = "",
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None,
):
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
pass
# def create_litellm_proxy_request_started_span(
# self,
# start_time: datetime,
# headers: dict,
# ):
# """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
# pass
async def async_health_check(self):
"""
Performs a health check for Arize integration.
Returns:
dict: Health check result with status and message
"""
try:
config = self.get_arize_config()
if not config.space_id and not config.space_key:
return {
"status": "unhealthy",
"error_message": "ARIZE_SPACE_ID or ARIZE_SPACE_KEY environment variable not set",
}
if not config.api_key:
return {
"status": "unhealthy",
"error_message": "ARIZE_API_KEY environment variable not set",
}
return {
"status": "healthy",
"message": "Arize credentials are configured properly",
}
except Exception as e:
return {
"status": "unhealthy",
"error_message": f"Arize health check failed: {str(e)}",
}
def construct_dynamic_otel_headers(
self, standard_callback_dynamic_params: StandardCallbackDynamicParams
) -> Optional[dict]:
"""
Construct dynamic Arize headers from standard callback dynamic params
This is used for team/key based logging.
Returns:
dict: A dictionary of dynamic Arize headers
"""
dynamic_headers = {}
#########################################################
# `arize-space-id` handling
# the suggested param is `arize_space_key`
#########################################################
if standard_callback_dynamic_params.get("arize_space_id"):
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
"arize_space_id"
)
if standard_callback_dynamic_params.get("arize_space_key"):
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
"arize_space_key"
)
#########################################################
# `api_key` handling
#########################################################
if standard_callback_dynamic_params.get("arize_api_key"):
dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
"arize_api_key"
)
return dynamic_headers

View File

@@ -0,0 +1,360 @@
import os
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_logger
from litellm.integrations.arize import _utils
from litellm.integrations.arize._utils import ArizeOTELAttributes
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
if TYPE_CHECKING:
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Span as _Span
from opentelemetry.trace import SpanKind
from litellm.integrations.opentelemetry import OpenTelemetry as _OpenTelemetry
from litellm.integrations.opentelemetry import (
OpenTelemetryConfig as _OpenTelemetryConfig,
)
from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol
OpenTelemetryConfig = _OpenTelemetryConfig
Span = Union[_Span, Any]
OpenTelemetry = _OpenTelemetry
else:
Protocol = Any
OpenTelemetryConfig = Any
Span = Any
TracerProvider = Any
SpanKind = Any
# Import OpenTelemetry at runtime
try:
from litellm.integrations.opentelemetry import OpenTelemetry
except ImportError:
OpenTelemetry = None # type: ignore
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://otlp.arize.com/v1/traces"
class ArizePhoenixLogger(OpenTelemetry): # type: ignore
"""
Arize Phoenix logger that sends traces to a Phoenix endpoint.
Creates its own dedicated TracerProvider so it can coexist with the
generic ``otel`` callback (or any other OTEL-based integration) without
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
"""
def _init_tracing(self, tracer_provider):
"""
Override to always create a *private* TracerProvider for Arize Phoenix.
The base ``OpenTelemetry._init_tracing`` falls back to the global
TracerProvider when one already exists. That causes whichever
integration initialises second to silently reuse the first one's
exporter, so spans only reach one destination.
By creating our own provider we guarantee Arize Phoenix always gets
its own exporter pipeline, regardless of initialisation order.
"""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import SpanKind
if tracer_provider is not None:
# Explicitly supplied (e.g. in tests) — honour it.
self.tracer = tracer_provider.get_tracer("litellm")
self.span_kind = SpanKind
return
# Always create a dedicated provider — never touch the global one.
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
provider.add_span_processor(self._get_span_processor())
self.tracer = provider.get_tracer("litellm")
self.span_kind = SpanKind
verbose_logger.debug(
"ArizePhoenixLogger: Created dedicated TracerProvider "
"(endpoint=%s, exporter=%s)",
self.config.endpoint,
self.config.exporter,
)
def _init_otel_logger_on_litellm_proxy(self):
"""
Override: Arize Phoenix should NOT overwrite the proxy's
``open_telemetry_logger``. That attribute is reserved for the
primary ``otel`` callback which handles proxy-level parent spans.
"""
pass
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj)
return
@staticmethod
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
safe_set_attribute,
)
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
# Dynamic project name: check metadata first, then fall back to env var config
dynamic_project_name = ArizePhoenixLogger._get_dynamic_project_name(kwargs)
if dynamic_project_name:
safe_set_attribute(span, "openinference.project.name", dynamic_project_name)
else:
# Fall back to static config from env var
config = ArizePhoenixLogger.get_arize_phoenix_config()
if config.project_name:
safe_set_attribute(
span, "openinference.project.name", config.project_name
)
return
@staticmethod
def _get_dynamic_project_name(kwargs) -> Optional[str]:
"""
Retrieve dynamic Phoenix project name from request metadata.
Users can set `metadata.phoenix_project_name` in their request to route
traces to different Phoenix projects dynamically.
"""
standard_logging_payload = kwargs.get("standard_logging_object")
if isinstance(standard_logging_payload, dict):
metadata = standard_logging_payload.get("metadata")
if isinstance(metadata, dict):
project_name = metadata.get("phoenix_project_name")
if project_name:
return str(project_name)
# Also check litellm_params.metadata for SDK usage
litellm_params = kwargs.get("litellm_params")
if isinstance(litellm_params, dict):
metadata = litellm_params.get("metadata") or {}
else:
metadata = {}
if isinstance(metadata, dict):
project_name = metadata.get("phoenix_project_name")
if project_name:
return str(project_name)
return None
def _get_phoenix_context(self, kwargs):
"""
Build a trace context for Phoenix's dedicated TracerProvider.
The base ``_get_span_context`` returns parent spans from the global
TracerProvider (the ``otel`` callback). Those spans live on a
*different* TracerProvider, so they won't appear in Phoenix — using
them as parents just creates broken links.
Instead we:
1. Honour an incoming ``traceparent`` HTTP header (distributed tracing).
2. In proxy mode, create our *own* parent span on Phoenix's tracer
so the hierarchy is visible end-to-end inside Phoenix.
3. In SDK (non-proxy) mode, just return (None, None) for a root span.
"""
from opentelemetry import trace
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
headers = proxy_server_request.get("headers", {}) or {}
# Propagate distributed trace context if the caller sent a traceparent
traceparent_ctx = (
self.get_traceparent_from_header(headers=headers)
if headers.get("traceparent")
else None
)
is_proxy_mode = bool(proxy_server_request)
if is_proxy_mode:
# Create a parent span on Phoenix's own tracer so both parent
# and child are exported to Phoenix.
start_time_val = kwargs.get("start_time", kwargs.get("api_call_start_time"))
parent_span = self.tracer.start_span(
name="litellm_proxy_request",
start_time=self._to_ns(start_time_val)
if start_time_val is not None
else None,
context=traceparent_ctx,
kind=self.span_kind.SERVER,
)
ctx = trace.set_span_in_context(parent_span)
return ctx, parent_span
# SDK mode — no parent span needed
return traceparent_ctx, None
def _handle_success(self, kwargs, response_obj, start_time, end_time):
"""
Override to always create spans on ArizePhoenixLogger's dedicated TracerProvider.
The base class's ``_get_span_context`` would find the parent span created by
the ``otel`` callback on the *global* TracerProvider. That span is invisible
in Phoenix (different exporter pipeline), so we ignore it and build our own
hierarchy via ``_get_phoenix_context``.
"""
from opentelemetry.trace import Status, StatusCode
verbose_logger.debug(
"ArizePhoenixLogger: Logging kwargs: %s, OTEL config settings=%s",
kwargs,
self.config,
)
ctx, parent_span = self._get_phoenix_context(kwargs)
# Create litellm_request span (child of our parent when in proxy mode)
span = self.tracer.start_span(
name=self._get_span_name(kwargs),
start_time=self._to_ns(start_time),
context=ctx,
)
span.set_status(Status(StatusCode.OK))
self.set_attributes(span, kwargs, response_obj)
# Raw-request sub-span (if enabled) — must be created before
# ending the parent span so the hierarchy is valid.
self._maybe_log_raw_request(kwargs, response_obj, start_time, end_time, span)
span.end(end_time=self._to_ns(end_time))
# Guardrail span
self._create_guardrail_span(kwargs=kwargs, context=ctx)
# Annotate and close our proxy parent span
if parent_span is not None:
parent_span.set_status(Status(StatusCode.OK))
self.set_attributes(parent_span, kwargs, response_obj)
parent_span.end(end_time=self._to_ns(end_time))
# Metrics & cost recording
self._record_metrics(kwargs, response_obj, start_time, end_time)
# Semantic logs
if self.config.enable_events:
self._emit_semantic_logs(kwargs, response_obj, span)
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
"""
Override to always create failure spans on ArizePhoenixLogger's dedicated
TracerProvider. Mirrors ``_handle_success`` but sets ERROR status.
"""
from opentelemetry.trace import Status, StatusCode
verbose_logger.debug(
"ArizePhoenixLogger: Failure - Logging kwargs: %s, OTEL config settings=%s",
kwargs,
self.config,
)
ctx, parent_span = self._get_phoenix_context(kwargs)
# Create litellm_request span (child of our parent when in proxy mode)
span = self.tracer.start_span(
name=self._get_span_name(kwargs),
start_time=self._to_ns(start_time),
context=ctx,
)
span.set_status(Status(StatusCode.ERROR))
self.set_attributes(span, kwargs, response_obj)
self._record_exception_on_span(span=span, kwargs=kwargs)
span.end(end_time=self._to_ns(end_time))
# Guardrail span
self._create_guardrail_span(kwargs=kwargs, context=ctx)
# Annotate and close our proxy parent span
if parent_span is not None:
parent_span.set_status(Status(StatusCode.ERROR))
self.set_attributes(parent_span, kwargs, response_obj)
self._record_exception_on_span(span=parent_span, kwargs=kwargs)
parent_span.end(end_time=self._to_ns(end_time))
@staticmethod
def get_arize_phoenix_config() -> ArizePhoenixConfig:
"""
Retrieves the Arize Phoenix configuration based on environment variables.
Returns:
ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration.
"""
api_key = os.environ.get("PHOENIX_API_KEY", None)
collector_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
if not collector_endpoint:
grpc_endpoint = os.environ.get("PHOENIX_COLLECTOR_ENDPOINT", None)
http_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
collector_endpoint = http_endpoint or grpc_endpoint
endpoint = None
protocol: Protocol = "otlp_http"
if collector_endpoint:
# Parse the endpoint to determine protocol
if collector_endpoint.startswith("grpc://") or (
":4317" in collector_endpoint and "/v1/traces" not in collector_endpoint
):
endpoint = collector_endpoint
protocol = "otlp_grpc"
else:
# Phoenix Cloud endpoints (app.phoenix.arize.com) include the space in the URL
if "app.phoenix.arize.com" in collector_endpoint:
endpoint = collector_endpoint
protocol = "otlp_http"
# For other HTTP endpoints, ensure they have the correct path
elif "/v1/traces" not in collector_endpoint:
if collector_endpoint.endswith("/v1"):
endpoint = collector_endpoint + "/traces"
elif collector_endpoint.endswith("/"):
endpoint = f"{collector_endpoint}v1/traces"
else:
endpoint = f"{collector_endpoint}/v1/traces"
else:
endpoint = collector_endpoint
protocol = "otlp_http"
else:
# If no endpoint specified, self hosted phoenix
endpoint = "http://localhost:6006/v1/traces"
protocol = "otlp_http"
verbose_logger.debug(
f"No PHOENIX_COLLECTOR_ENDPOINT found, using default local Phoenix endpoint: {endpoint}"
)
otlp_auth_headers = None
if api_key is not None:
otlp_auth_headers = f"Authorization=Bearer {api_key}"
elif "app.phoenix.arize.com" in endpoint:
# Phoenix Cloud requires an API key
raise ValueError(
"PHOENIX_API_KEY must be set when using Phoenix Cloud (app.phoenix.arize.com)."
)
project_name = os.environ.get("PHOENIX_PROJECT_NAME", "default")
return ArizePhoenixConfig(
otlp_auth_headers=otlp_auth_headers,
protocol=protocol,
endpoint=endpoint,
project_name=project_name,
)
## cannot suppress additional proxy server spans, removed previous methods.
async def async_health_check(self):
config = self.get_arize_phoenix_config()
if not config.otlp_auth_headers:
return {
"status": "unhealthy",
"error_message": "PHOENIX_API_KEY environment variable not set",
}
return {
"status": "healthy",
"message": "Arize-Phoenix credentials are configured properly",
}

View File

@@ -0,0 +1,108 @@
"""
Arize Phoenix API client for fetching prompt versions from Arize Phoenix.
"""
from typing import Any, Dict, Optional
from litellm.llms.custom_httpx.http_handler import HTTPHandler
class ArizePhoenixClient:
"""
Client for interacting with Arize Phoenix API to fetch prompt versions.
Supports:
- Authentication with Bearer tokens
- Fetching prompt versions
- Direct API base URL configuration
"""
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
"""
Initialize the Arize Phoenix client.
Args:
api_key: Arize Phoenix API token
api_base: Base URL for the Arize Phoenix API (e.g., 'https://app.phoenix.arize.com/s/workspace/v1')
"""
self.api_key = api_key
self.api_base = api_base
if not self.api_key:
raise ValueError("api_key is required")
if not self.api_base:
raise ValueError("api_base is required")
# Set up authentication headers
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
# Initialize HTTPHandler
self.http_handler = HTTPHandler(disable_default_headers=True)
def get_prompt_version(self, prompt_version_id: str) -> Optional[Dict[str, Any]]:
"""
Fetch a prompt version from Arize Phoenix.
Args:
prompt_version_id: The ID of the prompt version to fetch
Returns:
Dictionary containing prompt version data, or None if not found
"""
url = f"{self.api_base}/v1/prompt_versions/{prompt_version_id}"
try:
# Use the underlying httpx client directly to avoid query param extraction
response = self.http_handler.get(url, headers=self.headers)
response.raise_for_status()
data = response.json()
return data.get("data")
except Exception as e:
# Check if it's an HTTP error
response = getattr(e, "response", None)
if response is not None and hasattr(response, "status_code"):
if response.status_code == 404:
return None
elif response.status_code == 403:
raise Exception(
f"Access denied to prompt version '{prompt_version_id}'. Check your Arize Phoenix permissions."
)
elif response.status_code == 401:
raise Exception(
"Authentication failed. Check your Arize Phoenix API key and permissions."
)
else:
raise Exception(
f"Failed to fetch prompt version '{prompt_version_id}': {e}"
)
else:
raise Exception(
f"Error fetching prompt version '{prompt_version_id}': {e}"
)
def test_connection(self) -> bool:
"""
Test the connection to the Arize Phoenix API.
Returns:
True if connection is successful, False otherwise
"""
try:
# Try to access the prompt_versions endpoint to test connection
url = f"{self.api_base}/prompt_versions"
response = self.http_handler.client.get(url, headers=self.headers)
response.raise_for_status()
return True
except Exception:
return False
def close(self):
"""Close the HTTP handler to free resources."""
if hasattr(self, "http_handler"):
self.http_handler.close()

View File

@@ -0,0 +1,488 @@
"""
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
Fetches prompt versions from Arize Phoenix and provides workspace-based access control.
"""
from typing import Any, Dict, List, Optional, Tuple, Union
from jinja2 import DictLoader, Environment, select_autoescape
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.integrations.prompt_management_base import (
PromptManagementBase,
PromptManagementClient,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
from .arize_phoenix_client import ArizePhoenixClient
class ArizePhoenixPromptTemplate:
"""
Represents a prompt template loaded from Arize Phoenix.
"""
def __init__(
self,
template_id: str,
messages: List[Dict[str, Any]],
metadata: Dict[str, Any],
model: Optional[str] = None,
):
self.template_id = template_id
self.messages = messages
self.metadata = metadata
self.model = model or metadata.get("model_name")
self.model_provider = metadata.get("model_provider")
self.temperature = metadata.get("temperature")
self.max_tokens = metadata.get("max_tokens")
self.invocation_parameters = metadata.get("invocation_parameters", {})
self.description = metadata.get("description", "")
self.template_format = metadata.get("template_format", "MUSTACHE")
def __repr__(self):
return (
f"ArizePhoenixPromptTemplate(id='{self.template_id}', model='{self.model}')"
)
class ArizePhoenixTemplateManager:
"""
Manager for loading and rendering prompt templates from Arize Phoenix.
Supports:
- Fetching prompt versions from Arize Phoenix API
- Workspace-based access control through Arize Phoenix permissions
- Mustache/Handlebars-style templating (using Jinja2)
- Model configuration and invocation parameters
- Multi-message chat templates
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
prompt_id: Optional[str] = None,
):
self.api_key = api_key
self.api_base = api_base
self.prompt_id = prompt_id
self.prompts: Dict[str, ArizePhoenixPromptTemplate] = {}
self.arize_client = ArizePhoenixClient(
api_key=self.api_key, api_base=self.api_base
)
self.jinja_env = Environment(
loader=DictLoader({}),
autoescape=select_autoescape(["html", "xml"]),
# Use Mustache/Handlebars-style delimiters
variable_start_string="{{",
variable_end_string="}}",
block_start_string="{%",
block_end_string="%}",
comment_start_string="{#",
comment_end_string="#}",
)
# Load prompt from Arize Phoenix if prompt_id is provided
if self.prompt_id:
self._load_prompt_from_arize(self.prompt_id)
def _load_prompt_from_arize(self, prompt_version_id: str) -> None:
"""Load a specific prompt version from Arize Phoenix."""
try:
# Fetch the prompt version from Arize Phoenix
prompt_data = self.arize_client.get_prompt_version(prompt_version_id)
if prompt_data:
template = self._parse_prompt_data(prompt_data, prompt_version_id)
self.prompts[prompt_version_id] = template
else:
raise ValueError(f"Prompt version '{prompt_version_id}' not found")
except Exception as e:
raise Exception(
f"Failed to load prompt version '{prompt_version_id}' from Arize Phoenix: {e}"
)
def _parse_prompt_data(
self, data: Dict[str, Any], prompt_version_id: str
) -> ArizePhoenixPromptTemplate:
"""Parse Arize Phoenix prompt data and extract messages and metadata."""
template_data = data.get("template", {})
messages = template_data.get("messages", [])
# Extract invocation parameters
invocation_params = data.get("invocation_parameters", {})
provider_params = {}
# Extract provider-specific parameters
if "openai" in invocation_params:
provider_params = invocation_params["openai"]
elif "anthropic" in invocation_params:
provider_params = invocation_params["anthropic"]
else:
# Try to find any nested provider params
for key, value in invocation_params.items():
if isinstance(value, dict):
provider_params = value
break
# Build metadata dictionary
metadata = {
"model_name": data.get("model_name"),
"model_provider": data.get("model_provider"),
"description": data.get("description", ""),
"template_type": data.get("template_type"),
"template_format": data.get("template_format", "MUSTACHE"),
"invocation_parameters": invocation_params,
"temperature": provider_params.get("temperature"),
"max_tokens": provider_params.get("max_tokens"),
}
return ArizePhoenixPromptTemplate(
template_id=prompt_version_id,
messages=messages,
metadata=metadata,
)
def render_template(
self, template_id: str, variables: Optional[Dict[str, Any]] = None
) -> List[AllMessageValues]:
"""Render a template with the given variables and return formatted messages."""
if template_id not in self.prompts:
raise ValueError(f"Template '{template_id}' not found")
template = self.prompts[template_id]
rendered_messages: List[AllMessageValues] = []
for message in template.messages:
role = message.get("role", "user")
content_parts = message.get("content", [])
# Render each content part
rendered_content_parts = []
for part in content_parts:
if part.get("type") == "text":
text = part.get("text", "")
# Render the text with Jinja2 (Mustache-style)
jinja_template = self.jinja_env.from_string(text)
rendered_text = jinja_template.render(**(variables or {}))
rendered_content_parts.append(rendered_text)
else:
# Handle other content types if needed
rendered_content_parts.append(part)
# Combine rendered content
final_content = " ".join(rendered_content_parts)
rendered_messages.append(
{"role": role, "content": final_content} # type: ignore
)
return rendered_messages
def get_template(self, template_id: str) -> Optional[ArizePhoenixPromptTemplate]:
"""Get a template by ID."""
return self.prompts.get(template_id)
def list_templates(self) -> List[str]:
"""List all available template IDs."""
return list(self.prompts.keys())
class ArizePhoenixPromptManager(CustomPromptManagement):
"""
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
This class enables using prompt versions from Arize Phoenix with the
litellm completion() function by implementing the PromptManagementBase interface.
Usage:
# Configure Arize Phoenix access
arize_config = {
"workspace": "your-workspace",
"access_token": "your-token",
}
# Use with completion
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="UHJvbXB0VmVyc2lvbjox",
prompt_variables={"question": "What is AI?"},
arize_config=arize_config,
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
)
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
prompt_id: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.api_key = api_key
self.api_base = api_base
self.prompt_id = prompt_id
self._prompt_manager: Optional[ArizePhoenixTemplateManager] = None
@property
def integration_name(self) -> str:
"""Integration name used in model names like 'arize/gpt-4o'."""
return "arize"
@property
def prompt_manager(self) -> ArizePhoenixTemplateManager:
"""Get or create the prompt manager instance."""
if self._prompt_manager is None:
self._prompt_manager = ArizePhoenixTemplateManager(
api_key=self.api_key,
api_base=self.api_base,
prompt_id=self.prompt_id,
)
return self._prompt_manager
def get_prompt_template(
self,
prompt_id: str,
prompt_variables: Optional[Dict[str, Any]] = None,
) -> Tuple[List[AllMessageValues], Dict[str, Any]]:
"""
Get a prompt template and render it with variables.
Args:
prompt_id: The ID of the prompt version
prompt_variables: Variables to substitute in the template
Returns:
Tuple of (rendered_messages, metadata)
"""
template = self.prompt_manager.get_template(prompt_id)
if not template:
raise ValueError(f"Prompt template '{prompt_id}' not found")
# Render the template
rendered_messages = self.prompt_manager.render_template(
prompt_id, prompt_variables or {}
)
# Extract metadata
metadata = {
"model": template.model,
"temperature": template.temperature,
"max_tokens": template.max_tokens,
}
# Add additional invocation parameters
invocation_params = template.invocation_parameters
provider_params = {}
if "openai" in invocation_params:
provider_params = invocation_params["openai"]
elif "anthropic" in invocation_params:
provider_params = invocation_params["anthropic"]
# Add any additional parameters
for key, value in provider_params.items():
if key not in metadata:
metadata[key] = value
return rendered_messages, metadata
def pre_call_hook(
self,
user_id: Optional[str],
messages: List[AllMessageValues],
function_call: Optional[Union[Dict[str, Any], str]] = None,
litellm_params: Optional[Dict[str, Any]] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
"""
Pre-call hook that processes the prompt template before making the LLM call.
"""
if not prompt_id:
return messages, litellm_params
try:
# Get the rendered messages and metadata
rendered_messages, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables
)
# Merge rendered messages with existing messages
if rendered_messages:
# Prepend rendered messages to existing messages
final_messages = rendered_messages + messages
else:
final_messages = messages
# Update litellm_params with prompt metadata
if litellm_params is None:
litellm_params = {}
# Apply model and parameters from prompt metadata
if prompt_metadata.get("model") and not self.ignore_prompt_manager_model:
litellm_params["model"] = prompt_metadata["model"]
if not self.ignore_prompt_manager_optional_params:
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
litellm_params[param] = prompt_metadata[param]
return final_messages, litellm_params
except Exception as e:
# Log error but don't fail the call
import litellm
litellm._logging.verbose_proxy_logger.error(
f"Error in Arize Phoenix prompt pre_call_hook: {e}"
)
return messages, litellm_params
def get_available_prompts(self) -> List[str]:
"""Get list of available prompt IDs."""
return self.prompt_manager.list_templates()
def reload_prompts(self) -> None:
"""Reload prompts from Arize Phoenix."""
if self.prompt_id:
self._prompt_manager = None # Reset to force reload
self.prompt_manager # This will trigger reload
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
"""
Determine if prompt management should run based on the prompt_id.
For Arize Phoenix, we always return True and handle the prompt loading
in the _compile_prompt_helper method.
"""
return True
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Compile an Arize Phoenix prompt template into a PromptManagementClient structure.
This method:
1. Loads the prompt version from Arize Phoenix
2. Renders it with the provided variables
3. Returns formatted chat messages
4. Extracts model and optional parameters from metadata
"""
if prompt_id is None:
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
try:
# Load the prompt from Arize Phoenix if not already loaded
if prompt_id not in self.prompt_manager.prompts:
self.prompt_manager._load_prompt_from_arize(prompt_id)
# Get the rendered messages and metadata
rendered_messages, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables
)
# Extract model from metadata (if specified)
template_model = prompt_metadata.get("model")
# Extract optional parameters from metadata
optional_params = {}
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
optional_params[param] = prompt_metadata[param]
return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=rendered_messages,
prompt_template_model=template_model,
prompt_template_optional_params=optional_params,
completed_messages=None,
)
except Exception as e:
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Async version of compile prompt helper. Since Arize Phoenix operations are synchronous,
this simply delegates to the sync version.
"""
if prompt_id is None:
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
return self._compile_prompt_helper(
prompt_id=prompt_id,
prompt_spec=prompt_spec,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Get chat completion prompt from Arize Phoenix and return processed model, messages, and parameters.
"""
return PromptManagementBase.get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id,
prompt_variables,
dynamic_callback_params,
prompt_spec=prompt_spec,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=ignore_prompt_manager_model,
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
)