""" Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route. """ from typing import Optional, List, Dict, Literal, Union from pydantic import BaseModel, Field from functools import cached_property import httpx from litellm.llms.base_llm.embedding.transformation import ( BaseEmbeddingConfig, LiteLLMLoggingObj, ) from litellm.types.llms.openai import AllEmbeddingInputValues from litellm.types.utils import EmbeddingResponse from ..chat.handler import GenAIHubOrchestrationError from ..credentials import get_token_creator class Usage(BaseModel): prompt_tokens: int total_tokens: int class EmbeddingItem(BaseModel): object: Literal["embedding"] embedding: List[float] = Field( ..., description="Vector of floats (length varies by model)." ) index: int class FinalResult(BaseModel): object: Literal["list"] data: List[EmbeddingItem] model: str usage: Usage class EmbeddingsResponse(BaseModel): request_id: str final_result: FinalResult class EmbeddingModel(BaseModel): name: str version: str = "latest" params: dict = Field(default_factory=dict, validation_alias="parameters") class EmbeddingsModules(BaseModel): embeddings: EmbeddingModel class EmbeddingInput(BaseModel): text: Union[str, List[str]] type: Literal["text", "document", "query"] = "text" class EmbeddingRequest(BaseModel): config: EmbeddingsModules input: EmbeddingInput def validate_dict(data: dict, model) -> dict: return model(**data).model_dump() class GenAIHubEmbeddingConfig(BaseEmbeddingConfig): def __init__(self): super().__init__() self._access_token_data = {} self.token_creator, self.base_url, self.resource_group = get_token_creator() @property def headers(self) -> Dict: access_token = self.token_creator() # headers for completions and embeddings requests headers = { "Authorization": access_token, "AI-Resource-Group": self.resource_group, "Content-Type": "application/json", "AI-Client-Type": "LiteLLM", } return headers @cached_property def deployment_url(self) -> str: with httpx.Client(timeout=30) as client: valid_deployments = [] deployments = client.get( self.base_url + "/lm/deployments", headers=self.headers ).json() for deployment in deployments.get("resources", []): if deployment["scenarioId"] == "orchestration": config_details = client.get( self.base_url + f'/lm/configurations/{deployment["configurationId"]}', headers=self.headers, ).json() if config_details["executableId"] == "orchestration": valid_deployments.append( (deployment["deploymentUrl"], deployment["createdAt"]) ) return sorted(valid_deployments, key=lambda x: x[1], reverse=True)[0][0] def get_error_class(self, error_message, status_code, headers): return GenAIHubOrchestrationError(status_code, error_message) def get_supported_openai_params(self, model: str) -> list: if "text-embedding-3" in model: return ["encoding_format", "dimensions"] else: return [ "encoding_format", ] def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: return optional_params def validate_environment(self, headers: dict, *args, **kwargs) -> dict: return self.headers def get_complete_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, stream: Optional[bool] = None, ) -> str: url = self.deployment_url.rstrip("/") + "/v2/embeddings" return url def transform_embedding_request( self, model: str, input: AllEmbeddingInputValues, optional_params: dict, headers: dict, ) -> dict: model_dict = {} model_dict["name"] = model model_dict["version"] = optional_params.get("version", "latest") model_dict["params"] = optional_params.get("parameters", {}) input_dict = {"text": input} body = { "config": { "modules": { "embeddings": {"model": validate_dict(model_dict, EmbeddingModel)} } }, "input": validate_dict(input_dict, EmbeddingInput), } return body def transform_embedding_response( self, model: str, raw_response: httpx.Response, model_response: EmbeddingResponse, logging_obj: LiteLLMLoggingObj, api_key: Optional[str], request_data: dict, optional_params: dict, litellm_params: dict, ) -> EmbeddingResponse: return EmbeddingResponse.model_validate(raw_response.json()["final_result"])