176 lines
6.1 KiB
Python
176 lines
6.1 KiB
Python
"""
|
|
Transformation logic for Voyage AI's /v1/rerank endpoint.
|
|
|
|
Docs - https://docs.voyageai.com/docs/reranker
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import httpx
|
|
|
|
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
|
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.rerank import (
|
|
RerankBilledUnits,
|
|
RerankResponse,
|
|
RerankResponseMeta,
|
|
RerankTokens,
|
|
)
|
|
from litellm.types.utils import ModelInfo
|
|
|
|
from ..embedding.transformation import VoyageError
|
|
|
|
|
|
class VoyageRerankConfig(BaseRerankConfig):
|
|
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
|
return ["query", "documents", "top_n", "return_documents"]
|
|
|
|
def map_cohere_rerank_params(
|
|
self,
|
|
non_default_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
query: str,
|
|
documents: List[Union[str, Dict[str, Any]]],
|
|
custom_llm_provider: Optional[str] = None,
|
|
top_n: Optional[int] = None,
|
|
rank_fields: Optional[List[str]] = None,
|
|
return_documents: Optional[bool] = True,
|
|
max_chunks_per_doc: Optional[int] = None,
|
|
max_tokens_per_doc: Optional[int] = None,
|
|
) -> Dict:
|
|
# Voyage AI uses 'top_k' instead of 'top_n'
|
|
optional_params: Dict[str, Any] = {"query": query, "documents": documents}
|
|
if top_n is not None:
|
|
optional_params["top_k"] = top_n
|
|
if return_documents is not None:
|
|
optional_params["return_documents"] = return_documents
|
|
# Return as dict - OptionalRerankParams is a TypedDict with total=False
|
|
# so all fields are optional and we can return the dict directly
|
|
return optional_params
|
|
|
|
def get_complete_url(
|
|
self,
|
|
api_base: Optional[str],
|
|
model: str,
|
|
optional_params: Optional[dict] = None,
|
|
) -> str:
|
|
if api_base is None:
|
|
return "https://api.voyageai.com/v1/rerank"
|
|
api_base = api_base.rstrip("/")
|
|
if not api_base.endswith("/v1/rerank"):
|
|
if api_base.endswith("/v1"):
|
|
api_base = f"{api_base}/rerank"
|
|
else:
|
|
api_base = f"{api_base}/v1/rerank"
|
|
return api_base
|
|
|
|
def transform_rerank_request(
|
|
self, model: str, optional_rerank_params: Dict, headers: Dict
|
|
) -> Dict:
|
|
return {"model": model, **optional_rerank_params}
|
|
|
|
def transform_rerank_response(
|
|
self,
|
|
model: str,
|
|
raw_response: httpx.Response,
|
|
model_response: RerankResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_key: Optional[str] = None,
|
|
request_data: Dict = {},
|
|
optional_params: Dict = {},
|
|
litellm_params: Dict = {},
|
|
) -> RerankResponse:
|
|
if raw_response.status_code != 200:
|
|
raise VoyageError(
|
|
message=raw_response.text, status_code=raw_response.status_code
|
|
)
|
|
|
|
logging_obj.post_call(original_response=raw_response.text)
|
|
|
|
try:
|
|
_json_response = raw_response.json()
|
|
except Exception:
|
|
raise VoyageError(
|
|
message=f"Failed to parse response: {raw_response.text}",
|
|
status_code=raw_response.status_code,
|
|
)
|
|
|
|
# Voyage AI returns results in "data" key, not "results"
|
|
_results: Optional[List[dict]] = _json_response.get("data")
|
|
if _results is None:
|
|
raise ValueError(f"No results found in the response={_json_response}")
|
|
|
|
# Transform to LiteLLM format
|
|
transformed_results = []
|
|
for result in _results:
|
|
transformed_result: Dict[str, Any] = {
|
|
"index": result["index"],
|
|
"relevance_score": result["relevance_score"],
|
|
}
|
|
if "document" in result:
|
|
if isinstance(result["document"], str):
|
|
transformed_result["document"] = {"text": result["document"]}
|
|
else:
|
|
transformed_result["document"] = result["document"]
|
|
transformed_results.append(transformed_result)
|
|
|
|
usage = _json_response.get("usage", {})
|
|
total_tokens = usage.get("total_tokens", 0)
|
|
_billed_units = RerankBilledUnits(total_tokens=total_tokens)
|
|
_tokens = RerankTokens(input_tokens=total_tokens, output_tokens=0)
|
|
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
|
|
|
return RerankResponse(
|
|
id=_json_response.get("id", f"voyage-rerank-{model}"),
|
|
results=transformed_results, # type: ignore
|
|
meta=rerank_meta,
|
|
)
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: Dict,
|
|
model: str,
|
|
api_key: Optional[str] = None,
|
|
optional_params: Optional[dict] = None,
|
|
) -> Dict:
|
|
if api_key is None:
|
|
api_key = get_secret_str("VOYAGE_API_KEY") or get_secret_str(
|
|
"VOYAGE_AI_API_KEY"
|
|
)
|
|
if api_key is None:
|
|
raise ValueError(
|
|
"Voyage AI API key is required. Set via `api_key` parameter or `VOYAGE_API_KEY` env var."
|
|
)
|
|
return {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"content-type": "application/json",
|
|
}
|
|
|
|
def calculate_rerank_cost(
|
|
self,
|
|
model: str,
|
|
custom_llm_provider: Optional[str] = None,
|
|
billed_units: Optional[RerankBilledUnits] = None,
|
|
model_info: Optional[ModelInfo] = None,
|
|
) -> Tuple[float, float]:
|
|
if (
|
|
model_info is None
|
|
or "input_cost_per_token" not in model_info
|
|
or model_info["input_cost_per_token"] is None
|
|
or billed_units is None
|
|
):
|
|
return 0.0, 0.0
|
|
total_tokens = billed_units.get("total_tokens")
|
|
if total_tokens is None:
|
|
return 0.0, 0.0
|
|
return model_info["input_cost_per_token"] * total_tokens, 0.0
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
):
|
|
return VoyageError(
|
|
message=error_message, status_code=status_code, headers=headers
|
|
)
|