Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/snowflake/embedding/transformation.py

70 lines
2.3 KiB
Python
Raw Normal View History

from typing import Optional, Union
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.utils import EmbeddingResponse
from ..utils import SnowflakeException, SnowflakeBaseConfig
class SnowflakeEmbeddingConfig(SnowflakeBaseConfig, BaseEmbeddingConfig):
"""
source: https://docs.snowflake.com/developer-guide/snowflake-rest-api/reference/cortex-embed
"""
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
api_base = self._get_api_base(api_base, optional_params)
return f"{api_base}/cortex/inference:embed"
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
return {"text": input, "model": model, **optional_params}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
response_json = raw_response.json()
# convert embeddings to 1d array
for item in response_json["data"]:
item["embedding"] = item["embedding"][0]
returned_response = EmbeddingResponse(**response_json)
returned_response.model = "snowflake/" + (returned_response.model or "")
if model is not None:
returned_response._hidden_params["model"] = model
return returned_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return SnowflakeException(
message=error_message, status_code=status_code, headers=headers
)