chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,168 @@
import logging
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
def _build_chat_completion_url(model_url: str) -> str:
# Strip trailing /
model_url = model_url.rstrip("/")
# Append /chat/completions if not already present
if model_url.endswith("/v1"):
model_url += "/chat/completions"
# Append /v1/chat/completions if not already present
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"
return model_url
class HuggingFaceChatConfig(OpenAIGPTConfig):
"""
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
"""
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers["Authorization"] = f"Bearer {api_key}"
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if model.startswith(("http://", "https://")):
base_url = model
elif base_url is None:
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
return base_url
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the API call.
For provider-specific routing through huggingface
"""
# Check if api_base is provided
if api_base is not None:
complete_url = api_base
complete_url = _build_chat_completion_url(complete_url)
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
complete_url = str(os.getenv("HF_API_BASE")) or str(
os.getenv("HUGGINGFACE_API_BASE")
)
elif model.startswith(("http://", "https://")):
complete_url = model
complete_url = _build_chat_completion_url(complete_url)
# Default construction with provider
else:
# Parse provider and model
complete_url = "https://router.huggingface.co/v1/chat/completions"
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
if provider == "hf-inference":
route = f"{provider}/models/{model}/v1/chat/completions"
elif provider == "novita":
route = f"{provider}/v3/openai/chat/completions"
elif provider == "fireworks-ai":
route = f"{provider}/inference/v1/chat/completions"
else:
route = f"{provider}/v1/chat/completions"
complete_url = f"{BASE_URL}/{route}"
# Ensure URL doesn't end with a slash
complete_url = complete_url.rstrip("/")
return complete_url
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
if litellm_params.get("api_base"):
return dict(
ChatCompletionRequest(model=model, messages=messages, **optional_params)
)
if "max_retries" in optional_params:
logger.warning("`max_retries` is not supported. It will be ignored.")
optional_params.pop("max_retries", None)
first_part, remaining = model.split("/", 1)
mapped_model = model
if "/" in remaining:
provider = first_part
model_id = remaining
provider_mapping = _fetch_inference_provider_mapping(model_id)
if provider not in provider_mapping:
raise HuggingFaceError(
message=f"Model {model_id} is not supported for provider {provider}",
status_code=404,
headers={},
)
provider_mapping = provider_mapping[provider]
if provider_mapping["status"] == "staging":
logger.warning(
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
)
mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model)
return dict(
ChatCompletionRequest(
model=mapped_model, messages=messages, **optional_params
)
)

View File

@@ -0,0 +1,102 @@
import os
from functools import lru_cache
from typing import Literal, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
HF_HUB_URL = "https://huggingface.co"
class HuggingFaceError(BaseLLMException):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
):
super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
)
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
def output_parser(generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
@lru_cache(maxsize=128)
def _fetch_inference_provider_mapping(model: str) -> dict:
"""
Fetch provider mappings for a model from the Hugging Face Hub.
Args:
model: The model identifier (e.g., 'meta-llama/Llama-2-7b')
Returns:
dict: The inference provider mapping for the model
Raises:
ValueError: If no provider mapping is found
HuggingFaceError: If the API request fails
"""
headers = {"Accept": "application/json"}
if os.getenv("HUGGINGFACE_API_KEY"):
headers["Authorization"] = f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"
path = f"{HF_HUB_URL}/api/models/{model}"
params = {"expand": ["inferenceProviderMapping"]}
try:
response = httpx.get(path, headers=headers, params=params)
response.raise_for_status()
provider_mapping = response.json().get("inferenceProviderMapping")
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping
except httpx.HTTPError as e:
if hasattr(e, "response"):
status_code = getattr(e.response, "status_code", 500)
headers = getattr(e.response, "headers", {})
else:
status_code = 500
headers = {}
raise HuggingFaceError(
message=f"Failed to fetch provider mapping: {str(e)}",
status_code=status_code,
headers=headers,
)

View File

@@ -0,0 +1,425 @@
import json
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.utils import EmbeddingResponse
from ...base import BaseLLM
from ..common_utils import HuggingFaceError
from .transformation import HuggingFaceEmbeddingConfig
config = HuggingFaceEmbeddingConfig()
HF_HUB_URL = "https://huggingface.co"
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=f"{api_base}/api/models/{model}")
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
model_info = await http_client.get(url=f"{api_base}/api/models/{model}")
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
class HuggingFaceEmbedding(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
def _transform_input_on_pipeline_tag(
self, input: List, pipeline_tag: Optional[str]
) -> dict:
if pipeline_tag is None:
return {"inputs": input}
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
if len(input) < 2:
raise HuggingFaceError(
status_code=400,
message="sentence-similarity requires 2+ sentences",
)
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
elif pipeline_tag == "rerank":
if len(input) < 2:
raise HuggingFaceError(
status_code=400,
message="reranker requires 2+ sentences",
)
return {"inputs": {"query": input[0], "texts": input[1:]}}
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = config.get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",
"top_k",
"top_p",
"temperature",
"repetition_penalty",
"max_time",
]
for k, v in optional_params.items():
if k in special_options_keys:
data.setdefault("options", {})
data["options"][k] = v
elif k in special_parameters_keys:
data.setdefault("parameters", {})
data["parameters"][k] = v
else:
data[k] = v
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingFaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input}
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if len(optional_params.keys()) > 0:
data = self._process_optional_params(
data=data, optional_params=optional_params
)
return data
def _process_embedding_response(
self,
embeddings: dict,
model_response: EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=input_tokens,
completion_tokens=input_tokens,
total_tokens=input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingFaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
) -> EmbeddingResponse:
super().embedding()
headers = config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
optional_params=optional_params,
messages=[],
litellm_params=litellm_params,
)
task_type = optional_params.get("input_type", None)
task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = (
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
)
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingFaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)

View File

@@ -0,0 +1,590 @@
import json
import os
import time
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import HuggingFaceError, hf_task_list, hf_tasks, output_parser
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
tgi_models_cache = None
conv_models_cache = None
class HuggingFaceEmbeddingConfig(BaseConfig):
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
hf_task: Optional[
hf_tasks
] = None # litellm-specific param, used to know the api spec to use when calling huggingface api
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_n_tokens: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self, model: str):
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params[
"do_sample"
] = True # Need to sample if you want best of for hf inference endpoints
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens" or param == "max_completion_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def get_hf_api_key(self) -> Optional[str]:
return get_secret_str("HUGGINGFACE_API_KEY")
def read_tgi_conv_models(self):
try:
global tgi_models_cache, conv_models_cache
# Check if the cache is already populated
# so we don't keep on reading txt file if there are 1k requests
if (tgi_models_cache is not None) and (conv_models_cache is not None):
return tgi_models_cache, conv_models_cache
# If not, read the file and populate the cache
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
script_directory = os.path.dirname(script_directory)
# Construct the file path relative to the script's directory
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
# Cache the set for future use
tgi_models_cache = tgi_models
# If not, read the file and populate the cache
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
conv_models_cache = conv_models
return tgi_models, conv_models
except Exception:
return set(), set()
def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = self.read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference", model
elif model in conversational_models:
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation", model
else:
return "text-generation-inference", model # default to tgi
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
task = litellm_params.get("task", None)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
## Load Config
config = litellm.HuggingFaceEmbeddingConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = self.get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = convert_content_list_to_str(message)
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(convert_content_list_to_str(message))
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params,
}
elif task == "text-generation-inference":
# always send "details" and "return_full_text" as params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles") or {},
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt, # type: ignore
"parameters": optional_params,
"stream": ( # type: ignore
True
if "stream" in optional_params
and isinstance(optional_params["stream"], bool)
and optional_params["stream"] is True # type: ignore
else False
),
}
else:
# Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
inference_params = deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
data = {
"inputs": prompt, # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params and optional_params["stream"] is True
else False
)
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
return data
def get_api_base(self, api_base: Optional[str], model: str) -> str:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if "https" in model:
completion_url = model
elif api_base is not None:
completion_url = api_base
elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
return completion_url
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def _convert_streamed_response_to_complete_response(
self,
response: httpx.Response,
logging_obj: LoggingClass,
model: str,
data: dict,
api_key: Optional[str] = None,
) -> List[Dict[str, Any]]:
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
## LOGGING
logging_obj.post_call(
input=data,
api_key=api_key,
original_response=completion_response,
additional_args={"complete_input_dict": data},
)
return completion_response
def convert_to_model_response_object( # noqa: PLR0915
self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: ModelResponse,
task: Optional[hf_tasks],
optional_params: dict,
encoding: Any,
messages: List[AllMessageValues],
model: str,
):
if task is None:
task = "text-generation-inference" # default to tgi
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response.choices[0].message.content = completion_response[ # type: ignore
"generated_text"
]
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]
):
raise HuggingFaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
headers=None,
)
if len(completion_response[0]["generated_text"]) > 0:
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices.extend(choices_list)
elif task == "text-classification":
model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response
)
else:
if (
isinstance(completion_response, list)
and len(completion_response[0]["generated_text"]) > 0
):
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
task = litellm_params.get("task", None)
is_streamed = False
if (
raw_response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
completion_response = self._convert_streamed_response_to_complete_response(
response=raw_response,
logging_obj=logging_obj,
model=model,
data=request_data,
api_key=api_key,
)
else:
## LOGGING
logging_obj.post_call(
input=request_data,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
if isinstance(completion_response, dict):
completion_response = [completion_response]
except Exception:
raise HuggingFaceError(
message=f"Original Response received: {raw_response.text}",
status_code=raw_response.status_code,
)
if isinstance(completion_response, dict) and "error" in completion_response:
raise HuggingFaceError(
message=completion_response["error"], # type: ignore
status_code=raw_response.status_code,
)
return self.convert_to_model_response_object(
completion_response=completion_response,
model_response=model_response,
task=task if task is not None and task in hf_task_list else None,
optional_params=optional_params,
encoding=encoding,
messages=messages,
model=model,
)

View File

@@ -0,0 +1,5 @@
"""
HuggingFace Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View File

@@ -0,0 +1,301 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from typing_extensions import TypedDict
import litellm
from litellm._uuid import uuid
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import (
OptionalRerankParams,
RerankBilledUnits,
RerankResponse,
RerankResponseDocument,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)
from litellm.utils import token_counter
from ..common_utils import HuggingFaceError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class HuggingFaceRerankResponseItem(TypedDict):
"""Type definition for HuggingFace rerank API response items."""
index: int
score: float
text: Optional[str] # Optional, included when return_text=True
class HuggingFaceRerankResponse(TypedDict):
"""Type definition for HuggingFace rerank API complete response."""
# The response is a list of HuggingFaceRerankResponseItem
pass
# Type alias for the actual response structure
HuggingFaceRerankResponseList = List[HuggingFaceRerankResponseItem]
class HuggingFaceRerankConfig(BaseRerankConfig):
def get_api_base(self, model: str, api_base: Optional[str]) -> str:
if api_base is not None:
return api_base
elif os.getenv("HF_API_BASE") is not None:
return os.getenv("HF_API_BASE", "")
elif os.getenv("HUGGINGFACE_API_BASE") is not None:
return os.getenv("HUGGINGFACE_API_BASE", "")
else:
return "https://api-inference.huggingface.co"
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: Optional[dict] = None,
) -> str:
"""
Get the complete URL for the API call, including the /rerank suffix if necessary.
"""
# Get base URL from api_base or default
base_url = self.get_api_base(model=model, api_base=api_base)
# Remove trailing slashes and ensure we have the /rerank endpoint
base_url = base_url.rstrip("/")
if not base_url.endswith("/rerank"):
base_url = f"{base_url}/rerank"
return base_url
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> Dict:
optional_rerank_params = {}
if non_default_params is not None:
for k, v in non_default_params.items():
if k == "documents" and v is not None:
optional_rerank_params["texts"] = v
elif k == "return_documents" and v is not None and isinstance(v, bool):
optional_rerank_params["return_text"] = v
elif k == "top_n" and v is not None:
optional_rerank_params["top_n"] = v
elif k == "documents" and v is not None:
optional_rerank_params["texts"] = v
elif k == "query" and v is not None:
optional_rerank_params["query"] = v
return OptionalRerankParams(**optional_rerank_params) # type: ignore
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
optional_params: Optional[dict] = None,
api_base: Optional[str] = None,
) -> dict:
# Get API credentials
api_key, api_base = self.get_api_credentials(api_key=api_key, api_base=api_base)
default_headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
default_headers["Authorization"] = f"Bearer {api_key}"
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
return {**default_headers, **headers}
def transform_rerank_request(
self,
model: str,
optional_rerank_params: Union[OptionalRerankParams, dict],
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for HuggingFace rerank")
if "texts" not in optional_rerank_params:
raise ValueError(
"Cohere 'documents' param is required for HuggingFace rerank"
)
# Ensure return_text is a boolean value
# HuggingFace API expects return_text parameter, corresponding to our return_documents parameter
request_body = {
"raw_scores": False,
"truncate": False,
"truncation_direction": "Right",
}
request_body.update(optional_rerank_params)
return request_body
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LoggingClass,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
try:
raw_response_json: HuggingFaceRerankResponseList = raw_response.json()
except Exception:
raise HuggingFaceError(
message=getattr(raw_response, "text", str(raw_response)),
status_code=getattr(raw_response, "status_code", 500),
)
# Use standard litellm token counter for proper token estimation
input_text = request_data.get("query", "")
try:
# Calculate tokens for the raw response JSON string
response_text = str(raw_response_json)
estimated_output_tokens = token_counter(model=model, text=response_text)
# Calculate input tokens from query and documents
query = request_data.get("query", "")
documents = request_data.get("texts", [])
# Convert documents to string if they're not already
documents_text = ""
for doc in documents:
if isinstance(doc, str):
documents_text += doc + " "
elif isinstance(doc, dict) and "text" in doc:
documents_text += doc["text"] + " "
# Calculate input tokens using the same model
input_text = query + " " + documents_text
estimated_input_tokens = token_counter(model=model, text=input_text)
except Exception:
# Fallback to reasonable estimates if token counting fails
estimated_output_tokens = (
len(raw_response_json) * 10 if raw_response_json else 10
)
estimated_input_tokens = (
len(input_text) * 4 if "input_text" in locals() else 0
)
_billed_units = RerankBilledUnits(search_units=1)
_tokens = RerankTokens(
input_tokens=estimated_input_tokens, output_tokens=estimated_output_tokens
)
rerank_meta = RerankResponseMeta(
api_version={"version": "1.0"}, billed_units=_billed_units, tokens=_tokens
)
# Check if documents should be returned based on request parameters
should_return_documents = request_data.get(
"return_text", False
) or request_data.get("return_documents", False)
original_documents = request_data.get("texts", [])
results = []
for item in raw_response_json:
# Extract required fields with defaults to handle None values
index = item.get("index")
score = item.get("score")
# Skip items that don't have required fields
if index is None or score is None:
continue
# Create RerankResponseResult with required fields
result = RerankResponseResult(index=index, relevance_score=score)
# Add optional document field if needed
if should_return_documents:
text_content = item.get("text", "")
# 1. First try to use text returned directly from API if available
if text_content:
result["document"] = RerankResponseDocument(text=text_content)
# 2. If no text in API response but original documents are available, use those
elif original_documents and 0 <= item.get("index", -1) < len(
original_documents
):
doc = original_documents[item.get("index")]
if isinstance(doc, str):
result["document"] = RerankResponseDocument(text=doc)
elif isinstance(doc, dict) and "text" in doc:
result["document"] = RerankResponseDocument(text=doc["text"])
results.append(result)
return RerankResponse(
id=str(uuid.uuid4()),
results=results,
meta=rerank_meta,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(message=error_message, status_code=status_code)
def get_api_credentials(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
Get API key and base URL from multiple sources.
Returns tuple of (api_key, api_base).
Parameters:
api_key: API key provided directly to this function, takes precedence over all other sources
api_base: API base provided directly to this function, takes precedence over all other sources
"""
# Get API key from multiple sources
final_api_key = (
api_key or litellm.huggingface_key or get_secret_str("HUGGINGFACE_API_KEY")
)
# Get API base from multiple sources
final_api_base = (
api_base
or litellm.api_base
or get_secret_str("HF_API_BASE")
or get_secret_str("HUGGINGFACE_API_BASE")
)
return final_api_key, final_api_base