chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,210 @@
|
||||
# Arize Phoenix Prompt Management Integration
|
||||
|
||||
This integration enables using prompt versions from Arize Phoenix with LiteLLM's completion function.
|
||||
|
||||
## Features
|
||||
|
||||
- Fetch prompt versions from Arize Phoenix API
|
||||
- Workspace-based access control through Arize Phoenix permissions
|
||||
- Mustache/Handlebars-style variable templating (`{{variable}}`)
|
||||
- Support for multi-message chat templates
|
||||
- Automatic model and parameter configuration from prompt metadata
|
||||
- OpenAI and Anthropic provider parameter support
|
||||
|
||||
## Configuration
|
||||
|
||||
Configure Arize Phoenix access in your application:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Configure Arize Phoenix access
|
||||
# api_base should include your workspace, e.g., "https://app.phoenix.arize.com/s/your-workspace/v1"
|
||||
api_key = "your-arize-phoenix-token"
|
||||
api_base = "https://app.phoenix.arize.com/s/krrishdholakia/v1"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Use with completion
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox", # Your prompt version ID
|
||||
prompt_variables={"question": "What is artificial intelligence?"},
|
||||
api_key="your-arize-phoenix-token",
|
||||
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
### With Additional Messages
|
||||
|
||||
You can also combine prompt templates with additional messages:
|
||||
|
||||
```python
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
prompt_variables={"question": "Explain quantum computing"},
|
||||
api_key="your-arize-phoenix-token",
|
||||
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
|
||||
messages=[
|
||||
{"role": "user", "content": "Please keep your response under 100 words."}
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Direct Manager Usage
|
||||
|
||||
You can also use the prompt manager directly:
|
||||
|
||||
```python
|
||||
from litellm.integrations.arize.arize_phoenix_prompt_manager import ArizePhoenixPromptManager
|
||||
|
||||
# Initialize the manager
|
||||
manager = ArizePhoenixPromptManager(
|
||||
api_key="your-arize-phoenix-token",
|
||||
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
)
|
||||
|
||||
# Get rendered messages
|
||||
messages, metadata = manager.get_prompt_template(
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
prompt_variables={"question": "What is machine learning?"}
|
||||
)
|
||||
|
||||
print("Rendered messages:", messages)
|
||||
print("Metadata:", metadata)
|
||||
```
|
||||
|
||||
## Prompt Format
|
||||
|
||||
Arize Phoenix prompts support the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"data": {
|
||||
"description": "A chatbot prompt",
|
||||
"model_provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"template": {
|
||||
"type": "chat",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a chatbot"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "{{question}}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"template_type": "CHAT",
|
||||
"template_format": "MUSTACHE",
|
||||
"invocation_parameters": {
|
||||
"type": "openai",
|
||||
"openai": {
|
||||
"temperature": 1.0
|
||||
}
|
||||
},
|
||||
"id": "UHJvbXB0VmVyc2lvbjox"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Variable Substitution
|
||||
|
||||
Variables in your prompt templates use Mustache/Handlebars syntax:
|
||||
- `{{variable_name}}` - Simple variable substitution
|
||||
|
||||
Example:
|
||||
```
|
||||
Template: "Hello {{name}}, your order {{order_id}} is ready!"
|
||||
Variables: {"name": "Alice", "order_id": "12345"}
|
||||
Result: "Hello Alice, your order 12345 is ready!"
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### ArizePhoenixPromptManager
|
||||
|
||||
Main class for managing Arize Phoenix prompts.
|
||||
|
||||
**Methods:**
|
||||
- `get_prompt_template(prompt_id, prompt_variables)` - Get and render a prompt template
|
||||
- `get_available_prompts()` - List available prompt IDs
|
||||
- `reload_prompts()` - Reload prompts from Arize Phoenix
|
||||
|
||||
### ArizePhoenixClient
|
||||
|
||||
Low-level client for Arize Phoenix API.
|
||||
|
||||
**Methods:**
|
||||
- `get_prompt_version(prompt_version_id)` - Fetch a prompt version
|
||||
- `test_connection()` - Test API connection
|
||||
|
||||
## Error Handling
|
||||
|
||||
The integration provides detailed error messages:
|
||||
|
||||
- **404**: Prompt version not found
|
||||
- **401**: Authentication failed (check your access token)
|
||||
- **403**: Access denied (check workspace permissions)
|
||||
|
||||
Example:
|
||||
```python
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="invalid-id",
|
||||
arize_config=arize_config,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
```
|
||||
|
||||
## Getting Your Prompt Version ID and API Base
|
||||
|
||||
1. Log in to Arize Phoenix
|
||||
2. Navigate to your workspace
|
||||
3. Go to Prompts section
|
||||
4. Select a prompt version
|
||||
5. The ID will be in the URL: `/s/{workspace}/v1/prompt_versions/{PROMPT_VERSION_ID}`
|
||||
|
||||
Your `api_base` should be: `https://app.phoenix.arize.com/s/{workspace}/v1`
|
||||
|
||||
For example:
|
||||
- Workspace: `krrishdholakia`
|
||||
- API Base: `https://app.phoenix.arize.com/s/krrishdholakia/v1`
|
||||
- Prompt Version ID: `UHJvbXB0VmVyc2lvbjox`
|
||||
|
||||
You can also fetch it via API:
|
||||
```bash
|
||||
curl -L -X GET 'https://app.phoenix.arize.com/s/krrishdholakia/v1/prompt_versions/UHJvbXB0VmVyc2lvbjox' \
|
||||
-H 'Authorization: Bearer YOUR_TOKEN'
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- LiteLLM Issues: https://github.com/BerriAI/litellm/issues
|
||||
- Arize Phoenix Docs: https://docs.arize.com/phoenix
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
|
||||
|
||||
from .arize_phoenix_prompt_manager import ArizePhoenixPromptManager
|
||||
|
||||
# Global instances
|
||||
global_arize_config: Optional[dict] = None
|
||||
|
||||
|
||||
def prompt_initializer(
|
||||
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
|
||||
) -> "CustomPromptManagement":
|
||||
"""
|
||||
Initialize a prompt from Arize Phoenix.
|
||||
"""
|
||||
api_key = getattr(litellm_params, "api_key", None) or os.environ.get(
|
||||
"PHOENIX_API_KEY"
|
||||
)
|
||||
api_base = getattr(litellm_params, "api_base", None)
|
||||
prompt_id = getattr(litellm_params, "prompt_id", None)
|
||||
|
||||
if not api_key or not api_base:
|
||||
raise ValueError(
|
||||
"api_key and api_base are required for Arize Phoenix prompt integration"
|
||||
)
|
||||
|
||||
try:
|
||||
arize_prompt_manager = ArizePhoenixPromptManager(
|
||||
**{
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"prompt_id": prompt_id,
|
||||
**litellm_params.model_dump(
|
||||
exclude={"api_key", "api_base", "prompt_id"}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return arize_prompt_manager
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
prompt_initializer_registry = {
|
||||
SupportedPromptIntegrations.ARIZE_PHOENIX.value: prompt_initializer,
|
||||
}
|
||||
@@ -0,0 +1,502 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
|
||||
BaseLLMObsOTELAttributes,
|
||||
safe_set_attribute,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
ImageAttributes,
|
||||
SpanAttributes,
|
||||
AudioAttributes,
|
||||
EmbeddingAttributes,
|
||||
OpenInferenceSpanKindValues,
|
||||
)
|
||||
|
||||
|
||||
class ArizeOTELAttributes(BaseLLMObsOTELAttributes):
|
||||
@staticmethod
|
||||
@override
|
||||
def set_messages(span: "Span", kwargs: Dict[str, Any]):
|
||||
messages = kwargs.get("messages")
|
||||
|
||||
# for /chat/completions
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.INPUT_VALUE,
|
||||
last_message.get("content", ""),
|
||||
)
|
||||
|
||||
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page.
|
||||
for idx, msg in enumerate(messages):
|
||||
prefix = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}"
|
||||
# Set the role per message.
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", msg.get("role")
|
||||
)
|
||||
# Set the content per message.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
msg.get("content", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def set_response_output_messages(span: "Span", response_obj):
|
||||
"""
|
||||
Sets output message attributes on the span from the LLM response.
|
||||
Args:
|
||||
span: The OpenTelemetry span to set attributes on
|
||||
response_obj: The response object containing choices with messages
|
||||
"""
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
SpanAttributes,
|
||||
)
|
||||
|
||||
for idx, choice in enumerate(response_obj.get("choices", [])):
|
||||
response_message = choice.get("message", {})
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.OUTPUT_VALUE,
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
# This shows up under `output_messages` tab on the span page.
|
||||
prefix = f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_ROLE}",
|
||||
response_message.get("role"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
|
||||
def _set_response_attributes(span: "Span", response_obj):
|
||||
"""Helper to set response output and token usage attributes on span."""
|
||||
|
||||
if not hasattr(response_obj, "get"):
|
||||
return
|
||||
|
||||
_set_choice_outputs(span, response_obj, MessageAttributes, SpanAttributes)
|
||||
_set_image_outputs(span, response_obj, ImageAttributes, SpanAttributes)
|
||||
_set_audio_outputs(span, response_obj, AudioAttributes, SpanAttributes)
|
||||
_set_embedding_outputs(span, response_obj, EmbeddingAttributes, SpanAttributes)
|
||||
_set_structured_outputs(span, response_obj, MessageAttributes, SpanAttributes)
|
||||
_set_usage_outputs(span, response_obj, SpanAttributes)
|
||||
|
||||
|
||||
def _set_choice_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
|
||||
for idx, choice in enumerate(response_obj.get("choices", [])):
|
||||
response_message = choice.get("message", {})
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.OUTPUT_VALUE,
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{msg_attrs.MESSAGE_ROLE}",
|
||||
response_message.get("role"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{msg_attrs.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
|
||||
def _set_image_outputs(span: "Span", response_obj, image_attrs, span_attrs):
|
||||
images = response_obj.get("data", [])
|
||||
for i, image in enumerate(images):
|
||||
img_url = image.get("url")
|
||||
if img_url is None and image.get("b64_json"):
|
||||
img_url = f"data:image/png;base64,{image.get('b64_json')}"
|
||||
|
||||
if not img_url:
|
||||
continue
|
||||
|
||||
if i == 0:
|
||||
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, img_url)
|
||||
|
||||
safe_set_attribute(span, f"{image_attrs.IMAGE_URL}.{i}", img_url)
|
||||
|
||||
|
||||
def _set_audio_outputs(span: "Span", response_obj, audio_attrs, span_attrs):
|
||||
audio = response_obj.get("audio", [])
|
||||
for i, audio_item in enumerate(audio):
|
||||
audio_url = audio_item.get("url")
|
||||
if audio_url is None and audio_item.get("b64_json"):
|
||||
audio_url = f"data:audio/wav;base64,{audio_item.get('b64_json')}"
|
||||
|
||||
if audio_url:
|
||||
if i == 0:
|
||||
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, audio_url)
|
||||
safe_set_attribute(span, f"{audio_attrs.AUDIO_URL}.{i}", audio_url)
|
||||
|
||||
audio_mime = audio_item.get("mime_type")
|
||||
if audio_mime:
|
||||
safe_set_attribute(span, f"{audio_attrs.AUDIO_MIME_TYPE}.{i}", audio_mime)
|
||||
|
||||
audio_transcript = audio_item.get("transcript")
|
||||
if audio_transcript:
|
||||
safe_set_attribute(
|
||||
span, f"{audio_attrs.AUDIO_TRANSCRIPT}.{i}", audio_transcript
|
||||
)
|
||||
|
||||
|
||||
def _set_embedding_outputs(span: "Span", response_obj, embedding_attrs, span_attrs):
|
||||
embeddings = response_obj.get("data", [])
|
||||
for i, embedding_item in enumerate(embeddings):
|
||||
embedding_vector = embedding_item.get("embedding")
|
||||
if embedding_vector:
|
||||
if i == 0:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.OUTPUT_VALUE,
|
||||
str(embedding_vector),
|
||||
)
|
||||
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{embedding_attrs.EMBEDDING_VECTOR}.{i}",
|
||||
str(embedding_vector),
|
||||
)
|
||||
|
||||
embedding_text = embedding_item.get("text")
|
||||
if embedding_text:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{embedding_attrs.EMBEDDING_TEXT}.{i}",
|
||||
str(embedding_text),
|
||||
)
|
||||
|
||||
|
||||
def _set_structured_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
|
||||
output_items = response_obj.get("output", [])
|
||||
for i, item in enumerate(output_items):
|
||||
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{i}"
|
||||
if not hasattr(item, "type"):
|
||||
continue
|
||||
|
||||
item_type = item.type
|
||||
if item_type == "reasoning" and hasattr(item, "summary"):
|
||||
for summary in item.summary:
|
||||
if hasattr(summary, "text"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{msg_attrs.MESSAGE_REASONING_SUMMARY}",
|
||||
summary.text,
|
||||
)
|
||||
elif item_type == "message" and hasattr(item, "content"):
|
||||
message_content = ""
|
||||
content_list = item.content
|
||||
if content_list and len(content_list) > 0:
|
||||
first_content = content_list[0]
|
||||
message_content = getattr(first_content, "text", "")
|
||||
message_role = getattr(item, "role", "assistant")
|
||||
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, message_content)
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{msg_attrs.MESSAGE_CONTENT}", message_content
|
||||
)
|
||||
safe_set_attribute(span, f"{prefix}.{msg_attrs.MESSAGE_ROLE}", message_role)
|
||||
|
||||
|
||||
def _set_usage_outputs(span: "Span", response_obj, span_attrs):
|
||||
usage = response_obj and response_obj.get("usage")
|
||||
if not usage:
|
||||
return
|
||||
|
||||
safe_set_attribute(
|
||||
span, span_attrs.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens")
|
||||
)
|
||||
completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens")
|
||||
if completion_tokens:
|
||||
safe_set_attribute(
|
||||
span, span_attrs.LLM_TOKEN_COUNT_COMPLETION, completion_tokens
|
||||
)
|
||||
prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens")
|
||||
if prompt_tokens:
|
||||
safe_set_attribute(span, span_attrs.LLM_TOKEN_COUNT_PROMPT, prompt_tokens)
|
||||
reasoning_tokens = usage.get("output_tokens_details", {}).get("reasoning_tokens")
|
||||
if reasoning_tokens:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
|
||||
reasoning_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _infer_open_inference_span_kind(call_type: Optional[str]) -> str:
|
||||
"""
|
||||
Map LiteLLM call types to OpenInference span kinds.
|
||||
"""
|
||||
|
||||
if not call_type:
|
||||
return OpenInferenceSpanKindValues.UNKNOWN.value
|
||||
|
||||
lowered = str(call_type).lower()
|
||||
|
||||
if "embed" in lowered:
|
||||
return OpenInferenceSpanKindValues.EMBEDDING.value
|
||||
|
||||
if "rerank" in lowered:
|
||||
return OpenInferenceSpanKindValues.RERANKER.value
|
||||
|
||||
if "search" in lowered:
|
||||
return OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
|
||||
if "moderation" in lowered or "guardrail" in lowered:
|
||||
return OpenInferenceSpanKindValues.GUARDRAIL.value
|
||||
|
||||
if lowered == "call_mcp_tool" or lowered == "mcp" or lowered.endswith("tool"):
|
||||
return OpenInferenceSpanKindValues.TOOL.value
|
||||
|
||||
if "asend_message" in lowered or "a2a" in lowered or "assistant" in lowered:
|
||||
return OpenInferenceSpanKindValues.AGENT.value
|
||||
|
||||
if any(
|
||||
keyword in lowered
|
||||
for keyword in (
|
||||
"completion",
|
||||
"chat",
|
||||
"image",
|
||||
"audio",
|
||||
"speech",
|
||||
"transcription",
|
||||
"generate_content",
|
||||
"response",
|
||||
"videos",
|
||||
"realtime",
|
||||
"pass_through",
|
||||
"anthropic_messages",
|
||||
"ocr",
|
||||
)
|
||||
):
|
||||
return OpenInferenceSpanKindValues.LLM.value
|
||||
|
||||
if any(
|
||||
keyword in lowered
|
||||
for keyword in ("file", "batch", "container", "fine_tuning_job")
|
||||
):
|
||||
return OpenInferenceSpanKindValues.CHAIN.value
|
||||
|
||||
return OpenInferenceSpanKindValues.UNKNOWN.value
|
||||
|
||||
|
||||
def _set_tool_attributes(
|
||||
span: "Span", optional_tools: Optional[list], metadata_tools: Optional[list]
|
||||
):
|
||||
"""set tool attributes on span from optional_params or tool call metadata"""
|
||||
if optional_tools:
|
||||
for idx, tool in enumerate(optional_tools):
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
function = (
|
||||
tool.get("function") if isinstance(tool.get("function"), dict) else None
|
||||
)
|
||||
if not function:
|
||||
continue
|
||||
tool_name = function.get("name")
|
||||
if tool_name:
|
||||
safe_set_attribute(
|
||||
span, f"{SpanAttributes.LLM_TOOLS}.{idx}.name", tool_name
|
||||
)
|
||||
tool_description = function.get("description")
|
||||
if tool_description:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_TOOLS}.{idx}.description",
|
||||
tool_description,
|
||||
)
|
||||
params = function.get("parameters")
|
||||
if params is not None:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_TOOLS}.{idx}.parameters",
|
||||
json.dumps(params),
|
||||
)
|
||||
|
||||
if metadata_tools and isinstance(metadata_tools, list):
|
||||
for idx, tool in enumerate(metadata_tools):
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
tool_name = tool.get("name")
|
||||
if tool_name:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.name",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
tool_description = tool.get("description")
|
||||
if tool_description:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.description",
|
||||
tool_description,
|
||||
)
|
||||
|
||||
|
||||
def set_attributes(
|
||||
span: "Span", kwargs, response_obj, attributes: Type[BaseLLMObsOTELAttributes]
|
||||
):
|
||||
"""
|
||||
Populates span with OpenInference-compliant LLM attributes for Arize and Phoenix tracing.
|
||||
"""
|
||||
try:
|
||||
optional_params = _sanitize_optional_params(kwargs.get("optional_params"))
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
metadata = (
|
||||
standard_logging_payload.get("metadata")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
_set_metadata_attributes(span, metadata, SpanAttributes)
|
||||
|
||||
metadata_tools = _extract_metadata_tools(metadata)
|
||||
optional_tools = _extract_optional_tools(optional_params)
|
||||
|
||||
call_type = standard_logging_payload.get("call_type")
|
||||
_set_request_attributes(
|
||||
span=span,
|
||||
kwargs=kwargs,
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
response_obj=response_obj,
|
||||
span_attrs=SpanAttributes,
|
||||
)
|
||||
|
||||
span_kind = _infer_open_inference_span_kind(call_type=call_type)
|
||||
_set_tool_attributes(span, optional_tools, metadata_tools)
|
||||
if (
|
||||
optional_tools or metadata_tools
|
||||
) and span_kind != OpenInferenceSpanKindValues.TOOL.value:
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
|
||||
safe_set_attribute(span, SpanAttributes.OPENINFERENCE_SPAN_KIND, span_kind)
|
||||
attributes.set_messages(span, kwargs)
|
||||
|
||||
model_params = (
|
||||
standard_logging_payload.get("model_parameters")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
_set_model_params(span, model_params, SpanAttributes)
|
||||
|
||||
_set_response_attributes(span=span, response_obj=response_obj)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"[Arize/Phoenix] Failed to set OpenInference span attributes: {e}"
|
||||
)
|
||||
if hasattr(span, "record_exception"):
|
||||
span.record_exception(e)
|
||||
|
||||
|
||||
def _sanitize_optional_params(optional_params: Optional[dict]) -> dict:
|
||||
if not isinstance(optional_params, dict):
|
||||
return {}
|
||||
optional_params.pop("secret_fields", None)
|
||||
return optional_params
|
||||
|
||||
|
||||
def _set_metadata_attributes(span: "Span", metadata: Optional[Any], span_attrs) -> None:
|
||||
if metadata is not None:
|
||||
safe_set_attribute(span, span_attrs.METADATA, safe_dumps(metadata))
|
||||
|
||||
|
||||
def _extract_metadata_tools(metadata: Optional[Any]) -> Optional[list]:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
llm_obj = metadata.get("llm")
|
||||
if isinstance(llm_obj, dict):
|
||||
return llm_obj.get("tools")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_optional_tools(optional_params: dict) -> Optional[list]:
|
||||
return optional_params.get("tools") if isinstance(optional_params, dict) else None
|
||||
|
||||
|
||||
def _set_request_attributes(
|
||||
span: "Span",
|
||||
kwargs,
|
||||
standard_logging_payload: StandardLoggingPayload,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
response_obj,
|
||||
span_attrs,
|
||||
):
|
||||
if kwargs.get("model"):
|
||||
safe_set_attribute(span, span_attrs.LLM_MODEL_NAME, kwargs.get("model"))
|
||||
|
||||
safe_set_attribute(
|
||||
span, "llm.request.type", standard_logging_payload.get("call_type")
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.LLM_PROVIDER,
|
||||
litellm_params.get("custom_llm_provider", "Unknown"),
|
||||
)
|
||||
|
||||
if optional_params.get("max_tokens"):
|
||||
safe_set_attribute(
|
||||
span, "llm.request.max_tokens", optional_params.get("max_tokens")
|
||||
)
|
||||
if optional_params.get("temperature"):
|
||||
safe_set_attribute(
|
||||
span, "llm.request.temperature", optional_params.get("temperature")
|
||||
)
|
||||
if optional_params.get("top_p"):
|
||||
safe_set_attribute(span, "llm.request.top_p", optional_params.get("top_p"))
|
||||
|
||||
safe_set_attribute(
|
||||
span, "llm.is_streaming", str(optional_params.get("stream", False))
|
||||
)
|
||||
|
||||
if optional_params.get("user"):
|
||||
safe_set_attribute(span, "llm.user", optional_params.get("user"))
|
||||
|
||||
if response_obj and response_obj.get("id"):
|
||||
safe_set_attribute(span, "llm.response.id", response_obj.get("id"))
|
||||
if response_obj and response_obj.get("model"):
|
||||
safe_set_attribute(span, "llm.response.model", response_obj.get("model"))
|
||||
|
||||
|
||||
def _set_model_params(span: "Span", model_params: Optional[dict], span_attrs) -> None:
|
||||
if not model_params:
|
||||
return
|
||||
|
||||
safe_set_attribute(
|
||||
span, span_attrs.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
|
||||
)
|
||||
if model_params.get("user"):
|
||||
user_id = model_params.get("user")
|
||||
if user_id is not None:
|
||||
safe_set_attribute(span, span_attrs.USER_ID, user_id)
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
arize AI is OTEL compatible
|
||||
|
||||
this file has Arize ai specific helper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.integrations.arize._utils import ArizeOTELAttributes
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.types.integrations.arize import ArizeConfig
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
Protocol = _Protocol
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Protocol = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class ArizeLogger(OpenTelemetry):
|
||||
"""
|
||||
Arize logger that sends traces to an Arize endpoint.
|
||||
|
||||
Creates its own dedicated TracerProvider so it can coexist with the
|
||||
generic ``otel`` callback (or any other OTEL-based integration) without
|
||||
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
|
||||
"""
|
||||
|
||||
def _init_tracing(self, tracer_provider):
|
||||
"""
|
||||
Override to always create a *private* TracerProvider for Arize.
|
||||
|
||||
See ArizePhoenixLogger._init_tracing for full rationale.
|
||||
"""
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.trace import SpanKind
|
||||
|
||||
if tracer_provider is not None:
|
||||
self.tracer = tracer_provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
return
|
||||
|
||||
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
|
||||
provider.add_span_processor(self._get_span_processor())
|
||||
self.tracer = provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
|
||||
def _init_otel_logger_on_litellm_proxy(self):
|
||||
"""
|
||||
Override: Arize should NOT overwrite the proxy's
|
||||
``open_telemetry_logger``. That attribute is reserved for the
|
||||
primary ``otel`` callback which handles proxy-level parent spans.
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_arize_attributes(span: Span, kwargs, response_obj):
|
||||
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_arize_config() -> ArizeConfig:
|
||||
"""
|
||||
Helper function to get Arize configuration.
|
||||
|
||||
Returns:
|
||||
ArizeConfig: A Pydantic model containing Arize configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are not set.
|
||||
"""
|
||||
space_id = os.environ.get("ARIZE_SPACE_ID")
|
||||
space_key = os.environ.get("ARIZE_SPACE_KEY")
|
||||
api_key = os.environ.get("ARIZE_API_KEY")
|
||||
project_name = os.environ.get("ARIZE_PROJECT_NAME")
|
||||
|
||||
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
|
||||
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
|
||||
|
||||
endpoint = None
|
||||
protocol: Protocol = "otlp_grpc"
|
||||
|
||||
if grpc_endpoint:
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = grpc_endpoint
|
||||
elif http_endpoint:
|
||||
protocol = "otlp_http"
|
||||
endpoint = http_endpoint
|
||||
else:
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = "https://otlp.arize.com/v1"
|
||||
|
||||
return ArizeConfig(
|
||||
space_id=space_id,
|
||||
space_key=space_key,
|
||||
api_key=api_key,
|
||||
protocol=protocol,
|
||||
endpoint=endpoint,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
# def create_litellm_proxy_request_started_span(
|
||||
# self,
|
||||
# start_time: datetime,
|
||||
# headers: dict,
|
||||
# ):
|
||||
# """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
|
||||
# pass
|
||||
|
||||
async def async_health_check(self):
|
||||
"""
|
||||
Performs a health check for Arize integration.
|
||||
|
||||
Returns:
|
||||
dict: Health check result with status and message
|
||||
"""
|
||||
try:
|
||||
config = self.get_arize_config()
|
||||
|
||||
if not config.space_id and not config.space_key:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": "ARIZE_SPACE_ID or ARIZE_SPACE_KEY environment variable not set",
|
||||
}
|
||||
|
||||
if not config.api_key:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": "ARIZE_API_KEY environment variable not set",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"message": "Arize credentials are configured properly",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": f"Arize health check failed: {str(e)}",
|
||||
}
|
||||
|
||||
def construct_dynamic_otel_headers(
|
||||
self, standard_callback_dynamic_params: StandardCallbackDynamicParams
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Construct dynamic Arize headers from standard callback dynamic params
|
||||
|
||||
This is used for team/key based logging.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of dynamic Arize headers
|
||||
"""
|
||||
dynamic_headers = {}
|
||||
|
||||
#########################################################
|
||||
# `arize-space-id` handling
|
||||
# the suggested param is `arize_space_key`
|
||||
#########################################################
|
||||
if standard_callback_dynamic_params.get("arize_space_id"):
|
||||
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
|
||||
"arize_space_id"
|
||||
)
|
||||
if standard_callback_dynamic_params.get("arize_space_key"):
|
||||
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
|
||||
"arize_space_key"
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# `api_key` handling
|
||||
#########################################################
|
||||
if standard_callback_dynamic_params.get("arize_api_key"):
|
||||
dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
|
||||
"arize_api_key"
|
||||
)
|
||||
|
||||
return dynamic_headers
|
||||
@@ -0,0 +1,360 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.integrations.arize._utils import ArizeOTELAttributes
|
||||
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from opentelemetry.trace import SpanKind
|
||||
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry as _OpenTelemetry
|
||||
from litellm.integrations.opentelemetry import (
|
||||
OpenTelemetryConfig as _OpenTelemetryConfig,
|
||||
)
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
Protocol = _Protocol
|
||||
OpenTelemetryConfig = _OpenTelemetryConfig
|
||||
Span = Union[_Span, Any]
|
||||
OpenTelemetry = _OpenTelemetry
|
||||
else:
|
||||
Protocol = Any
|
||||
OpenTelemetryConfig = Any
|
||||
Span = Any
|
||||
TracerProvider = Any
|
||||
SpanKind = Any
|
||||
# Import OpenTelemetry at runtime
|
||||
try:
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
except ImportError:
|
||||
OpenTelemetry = None # type: ignore
|
||||
|
||||
|
||||
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://otlp.arize.com/v1/traces"
|
||||
|
||||
|
||||
class ArizePhoenixLogger(OpenTelemetry): # type: ignore
|
||||
"""
|
||||
Arize Phoenix logger that sends traces to a Phoenix endpoint.
|
||||
|
||||
Creates its own dedicated TracerProvider so it can coexist with the
|
||||
generic ``otel`` callback (or any other OTEL-based integration) without
|
||||
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
|
||||
"""
|
||||
|
||||
def _init_tracing(self, tracer_provider):
|
||||
"""
|
||||
Override to always create a *private* TracerProvider for Arize Phoenix.
|
||||
|
||||
The base ``OpenTelemetry._init_tracing`` falls back to the global
|
||||
TracerProvider when one already exists. That causes whichever
|
||||
integration initialises second to silently reuse the first one's
|
||||
exporter, so spans only reach one destination.
|
||||
|
||||
By creating our own provider we guarantee Arize Phoenix always gets
|
||||
its own exporter pipeline, regardless of initialisation order.
|
||||
"""
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.trace import SpanKind
|
||||
|
||||
if tracer_provider is not None:
|
||||
# Explicitly supplied (e.g. in tests) — honour it.
|
||||
self.tracer = tracer_provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
return
|
||||
|
||||
# Always create a dedicated provider — never touch the global one.
|
||||
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
|
||||
provider.add_span_processor(self._get_span_processor())
|
||||
self.tracer = provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
verbose_logger.debug(
|
||||
"ArizePhoenixLogger: Created dedicated TracerProvider "
|
||||
"(endpoint=%s, exporter=%s)",
|
||||
self.config.endpoint,
|
||||
self.config.exporter,
|
||||
)
|
||||
|
||||
def _init_otel_logger_on_litellm_proxy(self):
|
||||
"""
|
||||
Override: Arize Phoenix should NOT overwrite the proxy's
|
||||
``open_telemetry_logger``. That attribute is reserved for the
|
||||
primary ``otel`` callback which handles proxy-level parent spans.
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||
ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
|
||||
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
|
||||
safe_set_attribute,
|
||||
)
|
||||
|
||||
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
|
||||
|
||||
# Dynamic project name: check metadata first, then fall back to env var config
|
||||
dynamic_project_name = ArizePhoenixLogger._get_dynamic_project_name(kwargs)
|
||||
if dynamic_project_name:
|
||||
safe_set_attribute(span, "openinference.project.name", dynamic_project_name)
|
||||
else:
|
||||
# Fall back to static config from env var
|
||||
config = ArizePhoenixLogger.get_arize_phoenix_config()
|
||||
if config.project_name:
|
||||
safe_set_attribute(
|
||||
span, "openinference.project.name", config.project_name
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _get_dynamic_project_name(kwargs) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dynamic Phoenix project name from request metadata.
|
||||
|
||||
Users can set `metadata.phoenix_project_name` in their request to route
|
||||
traces to different Phoenix projects dynamically.
|
||||
"""
|
||||
standard_logging_payload = kwargs.get("standard_logging_object")
|
||||
if isinstance(standard_logging_payload, dict):
|
||||
metadata = standard_logging_payload.get("metadata")
|
||||
if isinstance(metadata, dict):
|
||||
project_name = metadata.get("phoenix_project_name")
|
||||
if project_name:
|
||||
return str(project_name)
|
||||
|
||||
# Also check litellm_params.metadata for SDK usage
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
if isinstance(litellm_params, dict):
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
else:
|
||||
metadata = {}
|
||||
if isinstance(metadata, dict):
|
||||
project_name = metadata.get("phoenix_project_name")
|
||||
if project_name:
|
||||
return str(project_name)
|
||||
|
||||
return None
|
||||
|
||||
def _get_phoenix_context(self, kwargs):
|
||||
"""
|
||||
Build a trace context for Phoenix's dedicated TracerProvider.
|
||||
|
||||
The base ``_get_span_context`` returns parent spans from the global
|
||||
TracerProvider (the ``otel`` callback). Those spans live on a
|
||||
*different* TracerProvider, so they won't appear in Phoenix — using
|
||||
them as parents just creates broken links.
|
||||
|
||||
Instead we:
|
||||
1. Honour an incoming ``traceparent`` HTTP header (distributed tracing).
|
||||
2. In proxy mode, create our *own* parent span on Phoenix's tracer
|
||||
so the hierarchy is visible end-to-end inside Phoenix.
|
||||
3. In SDK (non-proxy) mode, just return (None, None) for a root span.
|
||||
"""
|
||||
from opentelemetry import trace
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
||||
headers = proxy_server_request.get("headers", {}) or {}
|
||||
|
||||
# Propagate distributed trace context if the caller sent a traceparent
|
||||
traceparent_ctx = (
|
||||
self.get_traceparent_from_header(headers=headers)
|
||||
if headers.get("traceparent")
|
||||
else None
|
||||
)
|
||||
|
||||
is_proxy_mode = bool(proxy_server_request)
|
||||
|
||||
if is_proxy_mode:
|
||||
# Create a parent span on Phoenix's own tracer so both parent
|
||||
# and child are exported to Phoenix.
|
||||
start_time_val = kwargs.get("start_time", kwargs.get("api_call_start_time"))
|
||||
parent_span = self.tracer.start_span(
|
||||
name="litellm_proxy_request",
|
||||
start_time=self._to_ns(start_time_val)
|
||||
if start_time_val is not None
|
||||
else None,
|
||||
context=traceparent_ctx,
|
||||
kind=self.span_kind.SERVER,
|
||||
)
|
||||
ctx = trace.set_span_in_context(parent_span)
|
||||
return ctx, parent_span
|
||||
|
||||
# SDK mode — no parent span needed
|
||||
return traceparent_ctx, None
|
||||
|
||||
def _handle_success(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Override to always create spans on ArizePhoenixLogger's dedicated TracerProvider.
|
||||
|
||||
The base class's ``_get_span_context`` would find the parent span created by
|
||||
the ``otel`` callback on the *global* TracerProvider. That span is invisible
|
||||
in Phoenix (different exporter pipeline), so we ignore it and build our own
|
||||
hierarchy via ``_get_phoenix_context``.
|
||||
"""
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
verbose_logger.debug(
|
||||
"ArizePhoenixLogger: Logging kwargs: %s, OTEL config settings=%s",
|
||||
kwargs,
|
||||
self.config,
|
||||
)
|
||||
|
||||
ctx, parent_span = self._get_phoenix_context(kwargs)
|
||||
|
||||
# Create litellm_request span (child of our parent when in proxy mode)
|
||||
span = self.tracer.start_span(
|
||||
name=self._get_span_name(kwargs),
|
||||
start_time=self._to_ns(start_time),
|
||||
context=ctx,
|
||||
)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
self.set_attributes(span, kwargs, response_obj)
|
||||
|
||||
# Raw-request sub-span (if enabled) — must be created before
|
||||
# ending the parent span so the hierarchy is valid.
|
||||
self._maybe_log_raw_request(kwargs, response_obj, start_time, end_time, span)
|
||||
span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# Guardrail span
|
||||
self._create_guardrail_span(kwargs=kwargs, context=ctx)
|
||||
|
||||
# Annotate and close our proxy parent span
|
||||
if parent_span is not None:
|
||||
parent_span.set_status(Status(StatusCode.OK))
|
||||
self.set_attributes(parent_span, kwargs, response_obj)
|
||||
parent_span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# Metrics & cost recording
|
||||
self._record_metrics(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
# Semantic logs
|
||||
if self.config.enable_events:
|
||||
self._emit_semantic_logs(kwargs, response_obj, span)
|
||||
|
||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Override to always create failure spans on ArizePhoenixLogger's dedicated
|
||||
TracerProvider. Mirrors ``_handle_success`` but sets ERROR status.
|
||||
"""
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
verbose_logger.debug(
|
||||
"ArizePhoenixLogger: Failure - Logging kwargs: %s, OTEL config settings=%s",
|
||||
kwargs,
|
||||
self.config,
|
||||
)
|
||||
|
||||
ctx, parent_span = self._get_phoenix_context(kwargs)
|
||||
|
||||
# Create litellm_request span (child of our parent when in proxy mode)
|
||||
span = self.tracer.start_span(
|
||||
name=self._get_span_name(kwargs),
|
||||
start_time=self._to_ns(start_time),
|
||||
context=ctx,
|
||||
)
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
self.set_attributes(span, kwargs, response_obj)
|
||||
self._record_exception_on_span(span=span, kwargs=kwargs)
|
||||
span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# Guardrail span
|
||||
self._create_guardrail_span(kwargs=kwargs, context=ctx)
|
||||
|
||||
# Annotate and close our proxy parent span
|
||||
if parent_span is not None:
|
||||
parent_span.set_status(Status(StatusCode.ERROR))
|
||||
self.set_attributes(parent_span, kwargs, response_obj)
|
||||
self._record_exception_on_span(span=parent_span, kwargs=kwargs)
|
||||
parent_span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
@staticmethod
|
||||
def get_arize_phoenix_config() -> ArizePhoenixConfig:
|
||||
"""
|
||||
Retrieves the Arize Phoenix configuration based on environment variables.
|
||||
Returns:
|
||||
ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration.
|
||||
"""
|
||||
api_key = os.environ.get("PHOENIX_API_KEY", None)
|
||||
|
||||
collector_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
|
||||
|
||||
if not collector_endpoint:
|
||||
grpc_endpoint = os.environ.get("PHOENIX_COLLECTOR_ENDPOINT", None)
|
||||
http_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
|
||||
collector_endpoint = http_endpoint or grpc_endpoint
|
||||
|
||||
endpoint = None
|
||||
protocol: Protocol = "otlp_http"
|
||||
|
||||
if collector_endpoint:
|
||||
# Parse the endpoint to determine protocol
|
||||
if collector_endpoint.startswith("grpc://") or (
|
||||
":4317" in collector_endpoint and "/v1/traces" not in collector_endpoint
|
||||
):
|
||||
endpoint = collector_endpoint
|
||||
protocol = "otlp_grpc"
|
||||
else:
|
||||
# Phoenix Cloud endpoints (app.phoenix.arize.com) include the space in the URL
|
||||
if "app.phoenix.arize.com" in collector_endpoint:
|
||||
endpoint = collector_endpoint
|
||||
protocol = "otlp_http"
|
||||
# For other HTTP endpoints, ensure they have the correct path
|
||||
elif "/v1/traces" not in collector_endpoint:
|
||||
if collector_endpoint.endswith("/v1"):
|
||||
endpoint = collector_endpoint + "/traces"
|
||||
elif collector_endpoint.endswith("/"):
|
||||
endpoint = f"{collector_endpoint}v1/traces"
|
||||
else:
|
||||
endpoint = f"{collector_endpoint}/v1/traces"
|
||||
else:
|
||||
endpoint = collector_endpoint
|
||||
protocol = "otlp_http"
|
||||
else:
|
||||
# If no endpoint specified, self hosted phoenix
|
||||
endpoint = "http://localhost:6006/v1/traces"
|
||||
protocol = "otlp_http"
|
||||
verbose_logger.debug(
|
||||
f"No PHOENIX_COLLECTOR_ENDPOINT found, using default local Phoenix endpoint: {endpoint}"
|
||||
)
|
||||
|
||||
otlp_auth_headers = None
|
||||
if api_key is not None:
|
||||
otlp_auth_headers = f"Authorization=Bearer {api_key}"
|
||||
elif "app.phoenix.arize.com" in endpoint:
|
||||
# Phoenix Cloud requires an API key
|
||||
raise ValueError(
|
||||
"PHOENIX_API_KEY must be set when using Phoenix Cloud (app.phoenix.arize.com)."
|
||||
)
|
||||
|
||||
project_name = os.environ.get("PHOENIX_PROJECT_NAME", "default")
|
||||
|
||||
return ArizePhoenixConfig(
|
||||
otlp_auth_headers=otlp_auth_headers,
|
||||
protocol=protocol,
|
||||
endpoint=endpoint,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
## cannot suppress additional proxy server spans, removed previous methods.
|
||||
|
||||
async def async_health_check(self):
|
||||
config = self.get_arize_phoenix_config()
|
||||
|
||||
if not config.otlp_auth_headers:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": "PHOENIX_API_KEY environment variable not set",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"message": "Arize-Phoenix credentials are configured properly",
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Arize Phoenix API client for fetching prompt versions from Arize Phoenix.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
class ArizePhoenixClient:
|
||||
"""
|
||||
Client for interacting with Arize Phoenix API to fetch prompt versions.
|
||||
|
||||
Supports:
|
||||
- Authentication with Bearer tokens
|
||||
- Fetching prompt versions
|
||||
- Direct API base URL configuration
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize the Arize Phoenix client.
|
||||
|
||||
Args:
|
||||
api_key: Arize Phoenix API token
|
||||
api_base: Base URL for the Arize Phoenix API (e.g., 'https://app.phoenix.arize.com/s/workspace/v1')
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
if not self.api_base:
|
||||
raise ValueError("api_base is required")
|
||||
|
||||
# Set up authentication headers
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
# Initialize HTTPHandler
|
||||
self.http_handler = HTTPHandler(disable_default_headers=True)
|
||||
|
||||
def get_prompt_version(self, prompt_version_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch a prompt version from Arize Phoenix.
|
||||
|
||||
Args:
|
||||
prompt_version_id: The ID of the prompt version to fetch
|
||||
|
||||
Returns:
|
||||
Dictionary containing prompt version data, or None if not found
|
||||
"""
|
||||
url = f"{self.api_base}/v1/prompt_versions/{prompt_version_id}"
|
||||
|
||||
try:
|
||||
# Use the underlying httpx client directly to avoid query param extraction
|
||||
response = self.http_handler.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("data")
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's an HTTP error
|
||||
response = getattr(e, "response", None)
|
||||
if response is not None and hasattr(response, "status_code"):
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code == 403:
|
||||
raise Exception(
|
||||
f"Access denied to prompt version '{prompt_version_id}'. Check your Arize Phoenix permissions."
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
raise Exception(
|
||||
"Authentication failed. Check your Arize Phoenix API key and permissions."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to fetch prompt version '{prompt_version_id}': {e}"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error fetching prompt version '{prompt_version_id}': {e}"
|
||||
)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
Test the connection to the Arize Phoenix API.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to access the prompt_versions endpoint to test connection
|
||||
url = f"{self.api_base}/prompt_versions"
|
||||
response = self.http_handler.client.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP handler to free resources."""
|
||||
if hasattr(self, "http_handler"):
|
||||
self.http_handler.close()
|
||||
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
|
||||
Fetches prompt versions from Arize Phoenix and provides workspace-based access control.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from jinja2 import DictLoader, Environment, select_autoescape
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
from .arize_phoenix_client import ArizePhoenixClient
|
||||
|
||||
|
||||
class ArizePhoenixPromptTemplate:
|
||||
"""
|
||||
Represents a prompt template loaded from Arize Phoenix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
metadata: Dict[str, Any],
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.template_id = template_id
|
||||
self.messages = messages
|
||||
self.metadata = metadata
|
||||
self.model = model or metadata.get("model_name")
|
||||
self.model_provider = metadata.get("model_provider")
|
||||
self.temperature = metadata.get("temperature")
|
||||
self.max_tokens = metadata.get("max_tokens")
|
||||
self.invocation_parameters = metadata.get("invocation_parameters", {})
|
||||
self.description = metadata.get("description", "")
|
||||
self.template_format = metadata.get("template_format", "MUSTACHE")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ArizePhoenixPromptTemplate(id='{self.template_id}', model='{self.model}')"
|
||||
)
|
||||
|
||||
|
||||
class ArizePhoenixTemplateManager:
|
||||
"""
|
||||
Manager for loading and rendering prompt templates from Arize Phoenix.
|
||||
|
||||
Supports:
|
||||
- Fetching prompt versions from Arize Phoenix API
|
||||
- Workspace-based access control through Arize Phoenix permissions
|
||||
- Mustache/Handlebars-style templating (using Jinja2)
|
||||
- Model configuration and invocation parameters
|
||||
- Multi-message chat templates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.prompt_id = prompt_id
|
||||
self.prompts: Dict[str, ArizePhoenixPromptTemplate] = {}
|
||||
self.arize_client = ArizePhoenixClient(
|
||||
api_key=self.api_key, api_base=self.api_base
|
||||
)
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=DictLoader({}),
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
# Use Mustache/Handlebars-style delimiters
|
||||
variable_start_string="{{",
|
||||
variable_end_string="}}",
|
||||
block_start_string="{%",
|
||||
block_end_string="%}",
|
||||
comment_start_string="{#",
|
||||
comment_end_string="#}",
|
||||
)
|
||||
|
||||
# Load prompt from Arize Phoenix if prompt_id is provided
|
||||
if self.prompt_id:
|
||||
self._load_prompt_from_arize(self.prompt_id)
|
||||
|
||||
def _load_prompt_from_arize(self, prompt_version_id: str) -> None:
|
||||
"""Load a specific prompt version from Arize Phoenix."""
|
||||
try:
|
||||
# Fetch the prompt version from Arize Phoenix
|
||||
prompt_data = self.arize_client.get_prompt_version(prompt_version_id)
|
||||
|
||||
if prompt_data:
|
||||
template = self._parse_prompt_data(prompt_data, prompt_version_id)
|
||||
self.prompts[prompt_version_id] = template
|
||||
else:
|
||||
raise ValueError(f"Prompt version '{prompt_version_id}' not found")
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to load prompt version '{prompt_version_id}' from Arize Phoenix: {e}"
|
||||
)
|
||||
|
||||
def _parse_prompt_data(
|
||||
self, data: Dict[str, Any], prompt_version_id: str
|
||||
) -> ArizePhoenixPromptTemplate:
|
||||
"""Parse Arize Phoenix prompt data and extract messages and metadata."""
|
||||
template_data = data.get("template", {})
|
||||
messages = template_data.get("messages", [])
|
||||
|
||||
# Extract invocation parameters
|
||||
invocation_params = data.get("invocation_parameters", {})
|
||||
provider_params = {}
|
||||
|
||||
# Extract provider-specific parameters
|
||||
if "openai" in invocation_params:
|
||||
provider_params = invocation_params["openai"]
|
||||
elif "anthropic" in invocation_params:
|
||||
provider_params = invocation_params["anthropic"]
|
||||
else:
|
||||
# Try to find any nested provider params
|
||||
for key, value in invocation_params.items():
|
||||
if isinstance(value, dict):
|
||||
provider_params = value
|
||||
break
|
||||
|
||||
# Build metadata dictionary
|
||||
metadata = {
|
||||
"model_name": data.get("model_name"),
|
||||
"model_provider": data.get("model_provider"),
|
||||
"description": data.get("description", ""),
|
||||
"template_type": data.get("template_type"),
|
||||
"template_format": data.get("template_format", "MUSTACHE"),
|
||||
"invocation_parameters": invocation_params,
|
||||
"temperature": provider_params.get("temperature"),
|
||||
"max_tokens": provider_params.get("max_tokens"),
|
||||
}
|
||||
|
||||
return ArizePhoenixPromptTemplate(
|
||||
template_id=prompt_version_id,
|
||||
messages=messages,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def render_template(
|
||||
self, template_id: str, variables: Optional[Dict[str, Any]] = None
|
||||
) -> List[AllMessageValues]:
|
||||
"""Render a template with the given variables and return formatted messages."""
|
||||
if template_id not in self.prompts:
|
||||
raise ValueError(f"Template '{template_id}' not found")
|
||||
|
||||
template = self.prompts[template_id]
|
||||
rendered_messages: List[AllMessageValues] = []
|
||||
|
||||
for message in template.messages:
|
||||
role = message.get("role", "user")
|
||||
content_parts = message.get("content", [])
|
||||
|
||||
# Render each content part
|
||||
rendered_content_parts = []
|
||||
for part in content_parts:
|
||||
if part.get("type") == "text":
|
||||
text = part.get("text", "")
|
||||
# Render the text with Jinja2 (Mustache-style)
|
||||
jinja_template = self.jinja_env.from_string(text)
|
||||
rendered_text = jinja_template.render(**(variables or {}))
|
||||
rendered_content_parts.append(rendered_text)
|
||||
else:
|
||||
# Handle other content types if needed
|
||||
rendered_content_parts.append(part)
|
||||
|
||||
# Combine rendered content
|
||||
final_content = " ".join(rendered_content_parts)
|
||||
|
||||
rendered_messages.append(
|
||||
{"role": role, "content": final_content} # type: ignore
|
||||
)
|
||||
|
||||
return rendered_messages
|
||||
|
||||
def get_template(self, template_id: str) -> Optional[ArizePhoenixPromptTemplate]:
|
||||
"""Get a template by ID."""
|
||||
return self.prompts.get(template_id)
|
||||
|
||||
def list_templates(self) -> List[str]:
|
||||
"""List all available template IDs."""
|
||||
return list(self.prompts.keys())
|
||||
|
||||
|
||||
class ArizePhoenixPromptManager(CustomPromptManagement):
|
||||
"""
|
||||
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
|
||||
|
||||
This class enables using prompt versions from Arize Phoenix with the
|
||||
litellm completion() function by implementing the PromptManagementBase interface.
|
||||
|
||||
Usage:
|
||||
# Configure Arize Phoenix access
|
||||
arize_config = {
|
||||
"workspace": "your-workspace",
|
||||
"access_token": "your-token",
|
||||
}
|
||||
|
||||
# Use with completion
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
prompt_variables={"question": "What is AI?"},
|
||||
arize_config=arize_config,
|
||||
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.prompt_id = prompt_id
|
||||
self._prompt_manager: Optional[ArizePhoenixTemplateManager] = None
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Integration name used in model names like 'arize/gpt-4o'."""
|
||||
return "arize"
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> ArizePhoenixTemplateManager:
|
||||
"""Get or create the prompt manager instance."""
|
||||
if self._prompt_manager is None:
|
||||
self._prompt_manager = ArizePhoenixTemplateManager(
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_base,
|
||||
prompt_id=self.prompt_id,
|
||||
)
|
||||
return self._prompt_manager
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[List[AllMessageValues], Dict[str, Any]]:
|
||||
"""
|
||||
Get a prompt template and render it with variables.
|
||||
|
||||
Args:
|
||||
prompt_id: The ID of the prompt version
|
||||
prompt_variables: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Tuple of (rendered_messages, metadata)
|
||||
"""
|
||||
template = self.prompt_manager.get_template(prompt_id)
|
||||
if not template:
|
||||
raise ValueError(f"Prompt template '{prompt_id}' not found")
|
||||
|
||||
# Render the template
|
||||
rendered_messages = self.prompt_manager.render_template(
|
||||
prompt_id, prompt_variables or {}
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
metadata = {
|
||||
"model": template.model,
|
||||
"temperature": template.temperature,
|
||||
"max_tokens": template.max_tokens,
|
||||
}
|
||||
|
||||
# Add additional invocation parameters
|
||||
invocation_params = template.invocation_parameters
|
||||
provider_params = {}
|
||||
|
||||
if "openai" in invocation_params:
|
||||
provider_params = invocation_params["openai"]
|
||||
elif "anthropic" in invocation_params:
|
||||
provider_params = invocation_params["anthropic"]
|
||||
|
||||
# Add any additional parameters
|
||||
for key, value in provider_params.items():
|
||||
if key not in metadata:
|
||||
metadata[key] = value
|
||||
|
||||
return rendered_messages, metadata
|
||||
|
||||
def pre_call_hook(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
messages: List[AllMessageValues],
|
||||
function_call: Optional[Union[Dict[str, Any], str]] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Pre-call hook that processes the prompt template before making the LLM call.
|
||||
"""
|
||||
if not prompt_id:
|
||||
return messages, litellm_params
|
||||
|
||||
try:
|
||||
# Get the rendered messages and metadata
|
||||
rendered_messages, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables
|
||||
)
|
||||
|
||||
# Merge rendered messages with existing messages
|
||||
if rendered_messages:
|
||||
# Prepend rendered messages to existing messages
|
||||
final_messages = rendered_messages + messages
|
||||
else:
|
||||
final_messages = messages
|
||||
|
||||
# Update litellm_params with prompt metadata
|
||||
if litellm_params is None:
|
||||
litellm_params = {}
|
||||
|
||||
# Apply model and parameters from prompt metadata
|
||||
if prompt_metadata.get("model") and not self.ignore_prompt_manager_model:
|
||||
litellm_params["model"] = prompt_metadata["model"]
|
||||
|
||||
if not self.ignore_prompt_manager_optional_params:
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
litellm_params[param] = prompt_metadata[param]
|
||||
|
||||
return final_messages, litellm_params
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't fail the call
|
||||
import litellm
|
||||
|
||||
litellm._logging.verbose_proxy_logger.error(
|
||||
f"Error in Arize Phoenix prompt pre_call_hook: {e}"
|
||||
)
|
||||
return messages, litellm_params
|
||||
|
||||
def get_available_prompts(self) -> List[str]:
|
||||
"""Get list of available prompt IDs."""
|
||||
return self.prompt_manager.list_templates()
|
||||
|
||||
def reload_prompts(self) -> None:
|
||||
"""Reload prompts from Arize Phoenix."""
|
||||
if self.prompt_id:
|
||||
self._prompt_manager = None # Reset to force reload
|
||||
self.prompt_manager # This will trigger reload
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if prompt management should run based on the prompt_id.
|
||||
|
||||
For Arize Phoenix, we always return True and handle the prompt loading
|
||||
in the _compile_prompt_helper method.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Compile an Arize Phoenix prompt template into a PromptManagementClient structure.
|
||||
|
||||
This method:
|
||||
1. Loads the prompt version from Arize Phoenix
|
||||
2. Renders it with the provided variables
|
||||
3. Returns formatted chat messages
|
||||
4. Extracts model and optional parameters from metadata
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
|
||||
try:
|
||||
# Load the prompt from Arize Phoenix if not already loaded
|
||||
if prompt_id not in self.prompt_manager.prompts:
|
||||
self.prompt_manager._load_prompt_from_arize(prompt_id)
|
||||
|
||||
# Get the rendered messages and metadata
|
||||
rendered_messages, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables
|
||||
)
|
||||
|
||||
# Extract model from metadata (if specified)
|
||||
template_model = prompt_metadata.get("model")
|
||||
|
||||
# Extract optional parameters from metadata
|
||||
optional_params = {}
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
optional_params[param] = prompt_metadata[param]
|
||||
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=rendered_messages,
|
||||
prompt_template_model=template_model,
|
||||
prompt_template_optional_params=optional_params,
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Async version of compile prompt helper. Since Arize Phoenix operations are synchronous,
|
||||
this simply delegates to the sync version.
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
|
||||
return self._compile_prompt_helper(
|
||||
prompt_id=prompt_id,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Get chat completion prompt from Arize Phoenix and return processed model, messages, and parameters.
|
||||
"""
|
||||
return PromptManagementBase.get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id,
|
||||
prompt_variables,
|
||||
dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=ignore_prompt_manager_model,
|
||||
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
|
||||
)
|
||||
Reference in New Issue
Block a user