161 lines
5.0 KiB
Python
161 lines
5.0 KiB
Python
"""
|
|
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`
|
|
|
|
In the Huggingface TGI format.
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.types.llms.openai import AllEmbeddingInputValues
|
|
|
|
from httpx._models import Headers, Response
|
|
|
|
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
from litellm.types.utils import Usage, EmbeddingResponse
|
|
from litellm.llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
|
|
|
from ..common_utils import SagemakerError
|
|
|
|
|
|
class SagemakerEmbeddingConfig(BaseEmbeddingConfig):
|
|
"""
|
|
SageMaker embedding configuration factory for supporting embedding parameters
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def get_model_config(cls, model: str) -> "BaseEmbeddingConfig":
|
|
"""
|
|
Factory method to get the appropriate embedding config based on model type
|
|
|
|
Args:
|
|
model: The model name
|
|
|
|
Returns:
|
|
Appropriate embedding config instance
|
|
"""
|
|
if "voyage" in model.lower():
|
|
return VoyageEmbeddingConfig()
|
|
else:
|
|
return cls()
|
|
|
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
|
# Check if this is an embedding model
|
|
if "voyage" in model.lower():
|
|
return VoyageEmbeddingConfig().get_supported_openai_params(model)
|
|
else:
|
|
return []
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
return optional_params
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
|
) -> BaseLLMException:
|
|
return SagemakerError(
|
|
message=error_message, status_code=status_code, headers=headers
|
|
)
|
|
|
|
def transform_embedding_request(
|
|
self,
|
|
model: str,
|
|
input: "AllEmbeddingInputValues",
|
|
optional_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
"""
|
|
Transform embedding request for Hugging Face models on SageMaker
|
|
"""
|
|
# HF models expect "inputs" field (plural)
|
|
return {"inputs": input, **optional_params}
|
|
|
|
def transform_embedding_response(
|
|
self,
|
|
model: str,
|
|
raw_response: Response,
|
|
model_response: "EmbeddingResponse",
|
|
logging_obj: Any,
|
|
api_key: Optional[str] = None,
|
|
request_data: dict = {},
|
|
optional_params: dict = {},
|
|
litellm_params: dict = {},
|
|
) -> "EmbeddingResponse":
|
|
"""
|
|
Transform embedding response for Hugging Face models on SageMaker
|
|
"""
|
|
try:
|
|
response_data = raw_response.json()
|
|
except Exception as e:
|
|
raise SagemakerError(
|
|
message=f"Failed to parse response: {str(e)}",
|
|
status_code=raw_response.status_code,
|
|
)
|
|
|
|
# Handle both raw array format (TEI) and wrapped format (standard HF)
|
|
if isinstance(response_data, list):
|
|
# TEI and some HF models return raw embedding arrays directly
|
|
embeddings = response_data
|
|
elif isinstance(response_data, dict) and "embedding" in response_data:
|
|
# Standard HF format with "embedding" key
|
|
embeddings = response_data["embedding"]
|
|
else:
|
|
raise SagemakerError(
|
|
status_code=500,
|
|
message=f"Unexpected response format. Expected list or dict with 'embedding' key, got: {type(response_data).__name__}",
|
|
)
|
|
|
|
if not isinstance(embeddings, list):
|
|
raise SagemakerError(
|
|
status_code=422,
|
|
message=f"HF response not in expected format - {embeddings}",
|
|
)
|
|
|
|
output_data = []
|
|
for idx, embedding in enumerate(embeddings):
|
|
output_data.append(
|
|
{"object": "embedding", "index": idx, "embedding": embedding}
|
|
)
|
|
|
|
model_response.object = "list"
|
|
model_response.data = output_data
|
|
model_response.model = model
|
|
|
|
# Calculate usage from request data
|
|
input_texts = request_data.get("inputs", [])
|
|
input_tokens = 0
|
|
for text in input_texts:
|
|
input_tokens += len(text.split()) # Simple word count fallback
|
|
|
|
model_response.usage = Usage(
|
|
prompt_tokens=input_tokens,
|
|
completion_tokens=0,
|
|
total_tokens=input_tokens,
|
|
)
|
|
|
|
return model_response
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[Any],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
Validate environment for SageMaker embeddings
|
|
"""
|
|
return {"Content-Type": "application/json"}
|