Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/ragflow/chat/transformation.py

270 lines
9.4 KiB
Python
Raw Normal View History

"""
RAGFlow provider configuration for OpenAI-compatible API.
RAGFlow provides OpenAI-compatible APIs with unique path structures:
- Chat endpoint: /api/v1/chats_openai/{chat_id}/chat/completions
- Agent endpoint: /api/v1/agents_openai/{agent_id}/chat/completions
Model name format:
- Chat: ragflow/chat/{chat_id}/{model_name}
- Agent: ragflow/agent/{agent_id}/{model_name}
"""
from typing import List, Optional, Tuple
import litellm
from litellm.llms.openai.openai import OpenAIConfig
from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.llms.openai import AllMessageValues
class RAGFlowConfig(OpenAIConfig):
"""
Configuration for RAGFlow OpenAI-compatible API.
Handles both chat and agent endpoints by parsing the model name format:
- ragflow/chat/{chat_id}/{model_name} for chat endpoints
- ragflow/agent/{agent_id}/{model_name} for agent endpoints
"""
def _parse_ragflow_model(self, model: str) -> Tuple[str, str, str]:
"""
Parse RAGFlow model name format: ragflow/{endpoint_type}/{id}/{model_name}
Args:
model: Model name in format ragflow/chat/{chat_id}/{model} or ragflow/agent/{agent_id}/{model}
Returns:
Tuple of (endpoint_type, id, model_name)
Raises:
ValueError: If model format is invalid
"""
parts = model.split("/")
if len(parts) < 4:
raise ValueError(
f"Invalid RAGFlow model format: {model}. "
f"Expected format: ragflow/chat/{{chat_id}}/{{model}} or ragflow/agent/{{agent_id}}/{{model}}"
)
if parts[0] != "ragflow":
raise ValueError(
f"Invalid RAGFlow model format: {model}. Must start with 'ragflow/'"
)
endpoint_type = parts[1]
if endpoint_type not in ["chat", "agent"]:
raise ValueError(
f"Invalid RAGFlow endpoint type: {endpoint_type}. Must be 'chat' or 'agent'"
)
entity_id = parts[2]
model_name = "/".join(
parts[3:]
) # Handle model names that might contain slashes
return endpoint_type, entity_id, model_name
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the RAGFlow API call.
Constructs URL based on endpoint type:
- Chat: /api/v1/chats_openai/{chat_id}/chat/completions
- Agent: /api/v1/agents_openai/{agent_id}/chat/completions
Args:
api_base: Base API URL (e.g., http://ragflow-server:port or http://ragflow-server:port/v1)
api_key: API key (not used in URL construction)
model: Model name in format ragflow/{endpoint_type}/{id}/{model}
optional_params: Optional parameters
litellm_params: LiteLLM parameters (may contain api_base)
stream: Whether streaming is enabled
Returns:
Complete URL for the API call
"""
# Get api_base from multiple sources: input param, litellm_params, environment, or global litellm setting
if (
litellm_params
and hasattr(litellm_params, "api_base")
and litellm_params.api_base
):
api_base = api_base or litellm_params.api_base
api_base = (
api_base
or litellm.api_base
or get_secret("RAGFLOW_API_BASE")
or get_secret_str("RAGFLOW_API_BASE")
)
if api_base is None:
raise ValueError(
"api_base is required for RAGFlow provider. Set it via api_base parameter, RAGFLOW_API_BASE environment variable, or litellm.api_base"
)
# Parse model name to extract endpoint type and ID
endpoint_type, entity_id, _ = self._parse_ragflow_model(model)
# Remove trailing slash from api_base if present
api_base = api_base.rstrip("/")
# Strip /v1 or /api/v1 from api_base if present, since we'll add the full path
# Check /api/v1 first because /api/v1 ends with /v1
if api_base.endswith("/api/v1"):
api_base = api_base[:-7] # Remove /api/v1
elif api_base.endswith("/v1"):
api_base = api_base[:-3] # Remove /v1
# Construct the RAGFlow-specific path
if endpoint_type == "chat":
path = f"/api/v1/chats_openai/{entity_id}/chat/completions"
else: # agent
path = f"/api/v1/agents_openai/{entity_id}/chat/completions"
# Ensure path starts with /
if not path.startswith("/"):
path = "/" + path
return f"{api_base}{path}"
def _get_openai_compatible_provider_info(
self,
model: str,
api_base: Optional[str],
api_key: Optional[str],
custom_llm_provider: str,
) -> Tuple[Optional[str], Optional[str], str]:
"""
Get OpenAI-compatible provider information for RAGFlow.
Args:
model: Model name (will be parsed to extract actual model name)
api_base: Base API URL (from input params)
api_key: API key (from input params)
custom_llm_provider: Custom LLM provider name
Returns:
Tuple of (api_base, api_key, custom_llm_provider)
"""
# Parse model to extract the actual model name
# The model name will be stored in litellm_params for use in requests
_, _, actual_model = self._parse_ragflow_model(model)
# Get api_base from multiple sources: input param, environment, or global litellm setting
dynamic_api_base = (
api_base
or litellm.api_base
or get_secret("RAGFLOW_API_BASE")
or get_secret_str("RAGFLOW_API_BASE")
)
# Get api_key from multiple sources: input param, environment, or global litellm setting
dynamic_api_key = (
api_key or litellm.api_key or get_secret_str("RAGFLOW_API_KEY")
)
return dynamic_api_base, dynamic_api_key, custom_llm_provider
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and set up headers for RAGFlow API.
Args:
headers: Request headers
model: Model name
messages: Chat messages
optional_params: Optional parameters
litellm_params: LiteLLM parameters (may contain api_key)
api_key: API key (from input params)
api_base: Base API URL
Returns:
Updated headers dictionary
"""
# Use api_key from litellm_params if available, otherwise fall back to other sources
if (
litellm_params
and hasattr(litellm_params, "api_key")
and litellm_params.api_key
):
api_key = api_key or litellm_params.api_key
# Get api_key from multiple sources: input param, litellm_params, environment, or global litellm setting
api_key = api_key or litellm.api_key or get_secret_str("RAGFLOW_API_KEY")
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Ensure Content-Type is set to application/json
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
# Parse model to extract actual model name and store it
# The actual model name should be used in the request body
try:
_, _, actual_model = self._parse_ragflow_model(model)
# Store the actual model name in litellm_params for use in transform_request
litellm_params["_ragflow_actual_model"] = actual_model
except ValueError:
# If parsing fails, use the original model name
pass
return headers
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform request for RAGFlow API.
Uses the actual model name extracted from the RAGFlow model format.
Args:
model: Model name in RAGFlow format
messages: Chat messages
optional_params: Optional parameters
litellm_params: LiteLLM parameters (may contain _ragflow_actual_model)
headers: Request headers
Returns:
Transformed request dictionary
"""
# Get the actual model name from litellm_params if available
actual_model = litellm_params.get("_ragflow_actual_model")
if actual_model is None:
# Fallback: try to parse the model name
try:
_, _, actual_model = self._parse_ragflow_model(model)
except ValueError:
# If parsing fails, use the original model name
actual_model = model
# Use parent's transform_request with the actual model name
return super().transform_request(
actual_model, messages, optional_params, litellm_params, headers
)