chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,373 @@
import json
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class CohereError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class CohereChatConfig(BaseConfig):
"""
Configuration class for Cohere's API interface.
Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
generation_id (str, optional): Unique identifier for the generated reply.
response_id (str, optional): Unique identifier for the response.
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
max_tokens [DEPRECATED - use max_completion_tokens] (int, optional): The maximum number of tokens the model will generate as part of the response.
max_completion_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
chat_history: Optional[list] = None
generation_id: Optional[str] = None
response_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
preamble: Optional[str] = None,
chat_history: Optional[list] = None,
generation_id: Optional[str] = None,
response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt_truncation: Optional[str] = None,
connectors: Optional[list] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[list] = None,
temperature: Optional[int] = None,
max_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["num_generations"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "tools":
optional_params["tools"] = value
if param == "seed":
optional_params["seed"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
## ADD CITATIONS
if "citations" in raw_response_json:
setattr(model_response, "citations", raw_response_json["citations"])
## Tool calling response
cohere_tools_response = raw_response_json.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
generation_id = tool.get("generation_id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": f"call_{generation_id}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools
def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""
# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}
for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {})
.get("parameters", {})
.get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
return cohere_tool
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)

View File

@@ -0,0 +1,364 @@
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.cohere import CohereV2ChatResponse
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionToolCallChunk,
ChatCompletionAnnotation,
ChatCompletionAnnotationURLCitation,
)
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.utils import ModelResponse, Usage
from ..common_utils import CohereError
from ..common_utils import CohereV2ModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class CohereV2ChatConfig(OpenAIGPTConfig):
"""
Configuration class for Cohere's API interface.
Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
generation_id (str, optional): Unique identifier for the generated reply.
response_id (str, optional): Unique identifier for the response.
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
chat_history: Optional[list] = None
generation_id: Optional[str] = None
response_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
preamble: Optional[str] = None,
chat_history: Optional[list] = None,
generation_id: Optional[str] = None,
response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt_truncation: Optional[str] = None,
connectors: Optional[list] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[list] = None,
temperature: Optional[int] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["num_generations"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "tools":
optional_params["tools"] = value
if param == "seed":
optional_params["seed"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Cohere v2 chat api is in openai format, so we can use the openai transform request function to transform the request.
"""
data = super().transform_request(
model, messages, optional_params, litellm_params, headers
)
return data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
try:
cohere_v2_chat_response = CohereV2ChatResponse(**raw_response_json) # type: ignore
except Exception:
raise CohereError(message=raw_response.text, status_code=422)
cohere_content = cohere_v2_chat_response["message"].get("content", None)
if cohere_content is not None:
model_response.choices[0].message.content = "".join( # type: ignore
[
content.get("text", "")
for content in cohere_content
if content is not None
]
)
## ADD CITATIONS AS ANNOTATIONS
annotations: Optional[List[ChatCompletionAnnotation]] = None
citations = None
if (
"message" in cohere_v2_chat_response
and "citations" in cohere_v2_chat_response["message"]
):
citations = cohere_v2_chat_response["message"]["citations"]
if citations:
annotations = self._translate_citations_to_openai_annotations(citations)
## Tool calling response
cohere_tools_response = cohere_v2_chat_response["message"].get("tool_calls", [])
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls: List[ChatCompletionToolCallChunk] = []
for index, tool in enumerate(cohere_tools_response):
tool_call: ChatCompletionToolCallChunk = {
**tool, # type: ignore
"index": index,
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
annotations=annotations,
)
model_response.choices[0].message = _message # type: ignore
else:
if annotations:
current_message = model_response.choices[0].message # type: ignore
current_message.annotations = annotations
## CALCULATING USAGE - use cohere `billed_units` for returning usage
token_usage = cohere_v2_chat_response["usage"].get("tokens", {})
prompt_tokens = token_usage.get("input_tokens", 0)
completion_tokens = token_usage.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereV2ModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Cohere v2 chat completion.
The api_base should already include the full path.
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)
def _translate_citations_to_openai_annotations(
self, citations: List[dict]
) -> List[ChatCompletionAnnotation]:
"""
Transform Cohere citations to OpenAI annotations format.
Creates separate annotations for each source in a citation, allowing multiple
annotations with the same start/end index if they reference different sources.
Args:
citations: List of Cohere citation objects with format:
{
"start": int,
"end": int,
"text": str,
"sources": [
{
"type": "document",
"document": {
"title": str,
"snippet": str,
...
},
"id": str
}
]
}
Returns:
List of OpenAI ChatCompletionAnnotation objects (one per source)
"""
annotations: List[ChatCompletionAnnotation] = []
for citation in citations:
start_index = citation.get("start", 0)
end_index = citation.get("end", 0)
# Extract source information - loop through all sources
sources = citation.get("sources", [])
if not sources:
continue
# Create an annotation for each source
for source in sources:
if source.get("type") == "document" and "document" in source:
document = source["document"]
title = document.get("title", "")
url = source.get("url") or f"source:{source.get('id', 'unknown')}"
url_citation: ChatCompletionAnnotationURLCitation = {
"start_index": start_index,
"end_index": end_index,
"title": title,
"url": url,
}
annotation: ChatCompletionAnnotation = {
"type": "url_citation",
"url_citation": url_citation,
}
annotations.append(annotation)
return annotations

View File

@@ -0,0 +1,417 @@
import json
from typing import List, Optional, Literal, Tuple
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ProviderSpecificModelInfo,
)
class CohereError(BaseLLMException):
def __init__(self, status_code, message):
super().__init__(status_code=status_code, message=message)
class CohereModelInfo(BaseLLMModelInfo):
def get_provider_info(
self,
model: str,
) -> Optional[ProviderSpecificModelInfo]:
"""
Default values all models of this provider support.
"""
return None
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Returns a list of models supported by this provider.
"""
return []
@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return api_key
@staticmethod
def get_api_base(
api_base: Optional[str] = None,
) -> Optional[str]:
return api_base
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {}
@staticmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass
@staticmethod
def get_cohere_route(model: str) -> Literal["v1", "v2"]:
"""
Get the Cohere route for the given model.
Args:
model: The model name (e.g., "cohere_chat/v2/command-r-plus", "command-r-plus")
Returns:
"v2" for standard Cohere v2 API (default), "v1" for Cohere v1 API
"""
# Check for explicit v1 route
if "v1/" in model:
return "v1"
# Default to v2 for all other cases
return "v2"
def validate_environment(
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
"""
Return headers to use for cohere chat completion request
Cohere API Ref: https://docs.cohere.com/reference/chat
Expected headers:
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
"Authorization": "Bearer $CO_API_KEY"
}
"""
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
"""
Convert a string chunk to a GenericStreamingChunk
Note: This is used for Cohere pass through streaming logging
"""
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class CohereV2ModelResponseIterator:
"""V2-specific response iterator for Cohere streaming"""
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def _parse_content_delta(self, chunk: dict) -> str:
"""Parse content-delta chunks to extract text."""
delta = chunk.get("delta", {})
message = delta.get("message", {})
content = message.get("content", {})
if isinstance(content, dict) and "text" in content:
return content["text"]
elif isinstance(content, str):
return content
return ""
def _parse_tool_call_delta(
self, chunk: dict
) -> Optional[ChatCompletionToolCallChunk]:
"""Parse tool-call-delta chunks to extract tool calls."""
delta = chunk.get("delta", {})
tool_calls = delta.get("tool_calls", [])
if tool_calls:
return {
"id": tool_calls[0].get("id", ""),
"type": "function",
"function": {
"name": tool_calls[0].get("name", ""),
"arguments": tool_calls[0].get("arguments", ""),
},
} # type: ignore
return None
def _parse_tool_plan_delta(self, chunk: dict) -> Optional[dict]:
"""Parse tool-plan-delta events to extract tool plan."""
data = chunk.get("data", {})
delta = data.get("delta", {})
message = delta.get("message", {})
tool_plan = message.get("tool_plan", "")
if tool_plan:
return {"tool_plan": tool_plan}
return None
def _parse_citation_start(self, chunk: dict) -> Optional[dict]:
"""Parse citation-start events to extract citations."""
data = chunk.get("data", {})
delta = data.get("delta", {})
message = delta.get("message", {})
citations = message.get("citations", {})
if citations:
citation_data = {
"start": citations.get("start", 0),
"end": citations.get("end", 0),
"text": citations.get("text", ""),
"sources": citations.get("sources", []),
"type": citations.get("type", "TEXT_CONTENT"),
}
return {"citations": [citation_data]}
return None
def _parse_message_end(
self, chunk: dict
) -> Tuple[bool, str, Optional[ChatCompletionUsageBlock]]:
"""Parse message-end events to extract finish info and usage."""
data = chunk.get("data", {})
delta = data.get("delta", {})
is_finished = True
finish_reason = delta.get("finish_reason", "stop")
usage = None
usage_data = delta.get("usage", {})
if usage_data:
tokens_data = usage_data.get("tokens", {})
usage = ChatCompletionUsageBlock(
prompt_tokens=tokens_data.get("input_tokens", 0),
completion_tokens=tokens_data.get("output_tokens", 0),
total_tokens=tokens_data.get("input_tokens", 0)
+ tokens_data.get("output_tokens", 0),
)
return is_finished, finish_reason, usage
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
"""
Parse Cohere v2 streaming chunks.
v2 format:
- Content: chunk.type == "content-delta" -> chunk.delta.message.content.text
- Tool calls: chunk.type == "tool-call-delta" -> chunk.delta.tool_calls
- Tool plan: chunk.event == "tool-plan-delta" -> chunk.data.delta.message.tool_plan
- Citations: chunk.event == "citation-start" -> chunk.data.delta.message.citations
- Finish: chunk.event == "message-end" -> chunk.data.delta.finish_reason
"""
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
chunk_type = chunk.get("type", "")
event_type = chunk.get("event", "")
# Handle different chunk types
if chunk_type == "content-delta":
text = self._parse_content_delta(chunk)
elif chunk_type == "tool-call-delta":
tool_use = self._parse_tool_call_delta(chunk)
elif event_type == "tool-plan-delta":
provider_specific_fields = self._parse_tool_plan_delta(chunk)
elif event_type == "citation-start":
provider_specific_fields = self._parse_citation_start(chunk)
elif event_type == "message-end":
is_finished, finish_reason, usage = self._parse_message_end(chunk)
# Handle citations in any chunk type (fallback)
if "citations" in chunk:
if provider_specific_fields is None:
provider_specific_fields = {}
provider_specific_fields["citations"] = chunk["citations"]
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
except Exception as e:
raise ValueError(f"Failed to parse v2 chunk: {e}, chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
"""
Convert a string chunk to a GenericStreamingChunk for v2
Note: This is used for Cohere v2 pass through streaming logging
"""
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View File

@@ -0,0 +1,185 @@
"""
Legacy /v1/embedding handler for Bedrock Cohere.
"""
import json
from typing import Any, Callable, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import CohereEmbeddingRequest
from litellm.types.utils import EmbeddingResponse
from .v1_transformation import CohereEmbeddingConfig
def validate_environment(api_key, headers: dict):
# Create a lowercase key lookup to avoid duplicate headers with different cases
# This is important when headers come from AWS signed requests (which use Title-Case)
existing_keys_lower = {k.lower(): k for k in headers.keys()}
# Only add headers if they don't already exist (case-insensitive check)
if "request-source" not in existing_keys_lower:
headers["Request-Source"] = "unspecified:litellm"
if "accept" not in existing_keys_lower:
headers["accept"] = "application/json"
if "content-type" not in existing_keys_lower:
headers["content-type"] = "application/json"
if api_key and "authorization" not in existing_keys_lower:
headers["Authorization"] = f"Bearer {api_key}"
return headers
class CohereError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.cohere.ai/v1/generate"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
async def async_embedding(
model: str,
data: Union[dict, CohereEmbeddingRequest],
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE,
params={"timeout": timeout},
)
try:
response = await client.post(api_base, headers=headers, data=json.dumps(data))
except httpx.HTTPStatusError as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=e.response.text,
)
raise e
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
## PROCESS RESPONSE ##
return CohereEmbeddingConfig()._transform_response(
response=response,
api_key=api_key,
logging_obj=logging_obj,
data=data,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)
def embedding(
model: str,
input: list,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
headers: dict,
encoding: Any,
data: Optional[Union[dict, CohereEmbeddingRequest]] = None,
complete_api_base: Optional[str] = None,
api_key: Optional[str] = None,
aembedding: Optional[bool] = None,
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key, headers=headers)
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
model = model
data = data or CohereEmbeddingConfig()._transform_request(
model=model, input=input, inference_params=optional_params
)
## ROUTING
if aembedding is True:
return async_embedding(
model=model,
data=data,
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
optional_params=optional_params,
api_base=embed_url,
api_key=api_key,
headers=headers,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
return CohereEmbeddingConfig()._transform_response(
response=response,
api_key=api_key,
logging_obj=logging_obj,
data=data,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)

View File

@@ -0,0 +1,246 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format.
Why separate file? Make it easy to see how transformation works
Convers
- v3 embedding models
- v2 embedding models
Docs - https://docs.cohere.com/v2/reference/embed
"""
from typing import Any, List, Optional, Union, cast
import httpx
import litellm
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm import BaseEmbeddingConfig
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.bedrock import (
CohereEmbeddingRequest,
CohereEmbeddingRequestWithModel,
)
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage
from litellm.utils import is_base64_encoded
from ..common_utils import CohereError
class CohereEmbeddingConfig(BaseEmbeddingConfig):
"""
Reference: https://docs.cohere.com/v2/reference/embed
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self, model: str) -> List[str]:
return ["encoding_format", "dimensions"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool = False,
) -> dict:
for k, v in non_default_params.items():
if k == "encoding_format":
if isinstance(v, list):
optional_params["embedding_types"] = v
else:
optional_params["embedding_types"] = [v]
elif k == "dimensions":
optional_params["output_dimension"] = v
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"Content-Type": "application/json",
}
if api_key:
default_headers["Authorization"] = f"Bearer {api_key}"
headers = {**default_headers, **headers}
return headers
def _is_v3_model(self, model: str) -> bool:
return "3" in model
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
return api_base or "https://api.cohere.ai/v2/embed"
def _transform_request(
self, model: str, input: List[str], inference_params: dict
) -> CohereEmbeddingRequestWithModel:
is_encoded = False
for input_str in input:
is_encoded = is_base64_encoded(input_str)
if is_encoded: # check if string is b64 encoded image or not
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
images=input,
input_type="image",
)
else:
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
texts=input,
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
)
for k, v in inference_params.items():
transformed_request[k] = v # type: ignore
return transformed_request
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
if isinstance(input, list) and (
isinstance(input[0], list) or isinstance(input[0], int)
):
raise ValueError("Input must be a list of strings")
return cast(
dict,
self._transform_request(
model=model,
input=cast(List[str], input) if isinstance(input, List) else [input],
inference_params=optional_params,
),
)
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
input_tokens = 0
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if image_tokens is None and text_tokens is None:
for text in input:
input_tokens += len(encoding.encode(text))
else:
prompt_tokens_details = PromptTokensDetailsWrapper(
image_tokens=image_tokens,
text_tokens=text_tokens,
)
if image_tokens:
input_tokens += image_tokens
if text_tokens:
input_tokens += text_tokens
return Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
prompt_tokens_details=prompt_tokens_details,
)
def _transform_response(
self,
response: httpx.Response,
api_key: Optional[str],
logging_obj: LiteLLMLoggingObj,
data: Union[dict, CohereEmbeddingRequest],
model_response: EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> EmbeddingResponse:
response_json = response.json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response_json,
)
"""
response
{
'object': "list",
'data': [
]
'model',
'usage'
}
"""
embeddings = response_json["embeddings"]
output_data = []
for k, embedding_list in embeddings.items():
for idx, embedding in enumerate(embedding_list):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
self._calculate_usage(input, encoding, response_json.get("meta", {})),
)
return model_response
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
return self._transform_response(
response=raw_response,
api_key=api_key,
logging_obj=logging_obj,
data=request_data,
model_response=model_response,
model=model,
encoding=litellm.encoding,
input=logging_obj.model_call_details["input"],
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(
status_code=status_code,
message=error_message,
)

View File

@@ -0,0 +1,162 @@
"""
Legacy /v1/embedding transformation logic for Bedrock Cohere.
"""
from typing import Any, List, Optional, Union
import httpx
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.bedrock import (
CohereEmbeddingRequest,
CohereEmbeddingRequestWithModel,
)
from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage
from litellm.utils import is_base64_encoded
class CohereEmbeddingConfig:
"""
Reference: https://docs.cohere.com/v2/reference/embed
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self) -> List[str]:
return ["encoding_format"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for k, v in non_default_params.items():
if k == "encoding_format":
optional_params["embedding_types"] = v
return optional_params
def _is_v3_model(self, model: str) -> bool:
return "3" in model
def _transform_request(
self, model: str, input: List[str], inference_params: dict
) -> CohereEmbeddingRequestWithModel:
is_encoded = False
for input_str in input:
is_encoded = is_base64_encoded(input_str)
if is_encoded: # check if string is b64 encoded image or not
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
images=input,
input_type="image",
)
else:
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
texts=input,
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
)
for k, v in inference_params.items():
transformed_request[k] = v # type: ignore
return transformed_request
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
input_tokens = 0
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if image_tokens is None and text_tokens is None:
for text in input:
input_tokens += len(encoding.encode(text))
else:
prompt_tokens_details = PromptTokensDetailsWrapper(
image_tokens=image_tokens,
text_tokens=text_tokens,
)
if image_tokens:
input_tokens += image_tokens
if text_tokens:
input_tokens += text_tokens
return Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
prompt_tokens_details=prompt_tokens_details,
)
def _transform_response(
self,
response: httpx.Response,
api_key: Optional[str],
logging_obj: LiteLLMLoggingObj,
data: Union[dict, CohereEmbeddingRequest],
model_response: EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> EmbeddingResponse:
response_json = response.json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response_json,
)
"""
response
{
'object': "list",
'data': [
]
'model',
'usage'
}
"""
embeddings = response_json["embeddings"]
output_data = []
is_embeddings_by_type = (
response_json.get("response_type") == "embeddings_by_type"
)
if isinstance(embeddings, dict):
is_embeddings_by_type = True
if is_embeddings_by_type:
for embedding_type in embeddings:
for idx, embedding in enumerate(embeddings[embedding_type]):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding,
"type": embedding_type,
}
)
else:
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
self._calculate_usage(input, encoding, response_json.get("meta", {})),
)
return model_response

View File

@@ -0,0 +1,229 @@
# Cohere Rerank Guardrail Translation Handler
Handler for processing the rerank endpoint (`/v1/rerank`) with guardrails.
## Overview
This handler processes rerank requests by:
1. Extracting the query text from the request
2. Applying guardrails to the query
3. Updating the request with the guardrailed query
4. Returning the output unchanged (rankings are not text)
Note: Documents are not processed by guardrails as they represent the corpus
being searched, not user input. Only the query is guardrailed.
## Data Format
### Input Format
**With String Documents:**
```json
{
"model": "rerank-english-v3.0",
"query": "What is the capital of France?",
"documents": [
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
"Madrid is the capital of Spain."
],
"top_n": 2
}
```
**With Dict Documents:**
```json
{
"model": "rerank-english-v3.0",
"query": "What is the capital of France?",
"documents": [
{"text": "Paris is the capital of France.", "id": "doc1"},
{"text": "Berlin is the capital of Germany.", "id": "doc2"},
{"text": "Madrid is the capital of Spain.", "id": "doc3"}
],
"top_n": 2
}
```
### Output Format
```json
{
"id": "rerank-abc123",
"results": [
{"index": 0, "relevance_score": 0.98},
{"index": 2, "relevance_score": 0.12}
],
"meta": {
"billed_units": {"search_units": 1}
}
}
```
## Usage
The handler is automatically discovered and applied when guardrails are used with the rerank endpoint.
### Example: Using Guardrails with Rerank
```bash
curl -X POST 'http://localhost:4000/v1/rerank' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "rerank-english-v3.0",
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of AI.",
"Deep learning uses neural networks.",
"Python is a programming language."
],
"guardrails": ["content_filter"],
"top_n": 2
}'
```
The guardrail will be applied to the query only (not the documents).
### Example: PII Masking in Query
```bash
curl -X POST 'http://localhost:4000/v1/rerank' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "rerank-english-v3.0",
"query": "Find documents about John Doe from john@example.com",
"documents": [
"Document 1 content here.",
"Document 2 content here.",
"Document 3 content here."
],
"guardrails": ["mask_pii"],
"top_n": 3
}'
```
The query will be masked to: "Find documents about [NAME_REDACTED] from [EMAIL_REDACTED]"
### Example: Mixed Document Types
```bash
curl -X POST 'http://localhost:4000/v1/rerank' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "rerank-english-v3.0",
"query": "Technical documentation",
"documents": [
{"text": "This is document 1", "metadata": {"source": "wiki"}},
{"text": "This is document 2", "metadata": {"source": "docs"}},
"This is document 3 as a plain string"
],
"guardrails": ["content_moderation"]
}'
```
## Implementation Details
### Input Processing
- **Query Field**: `query` (string)
- Processing: Apply guardrail to query text
- Result: Updated query
- **Documents Field**: `documents` (list)
- Processing: Not processed (corpus being searched, not user input)
- Result: Unchanged
### Output Processing
- **Processing**: Not applicable (output contains relevance scores, not text)
- **Result**: Response returned unchanged
## Use Cases
1. **PII Protection**: Remove PII from queries before reranking
2. **Content Filtering**: Filter inappropriate content from search queries
3. **Compliance**: Ensure queries meet requirements
4. **Data Sanitization**: Clean up query text before semantic search operations
## Extension
Override these methods to customize behavior:
- `process_input_messages()`: Customize how query is processed
- `process_output_response()`: Currently a no-op, but can be overridden if needed
## Supported Call Types
- `CallTypes.rerank` - Synchronous rerank
- `CallTypes.arerank` - Asynchronous rerank
## Notes
- Only the query is processed by guardrails
- Documents are not processed (they represent the corpus, not user input)
- Output processing is a no-op since rankings don't contain text
- Both sync and async call types use the same handler
- Works with all rerank providers (Cohere, Together AI, etc.)
## Common Patterns
### PII Masking in Search
```python
import litellm
response = litellm.rerank(
model="rerank-english-v3.0",
query="Find info about john@example.com",
documents=[
"Document 1 content.",
"Document 2 content.",
"Document 3 content."
],
guardrails=["mask_pii"],
top_n=2
)
# Query will have PII masked
# query becomes: "Find info about [EMAIL_REDACTED]"
print(response.results)
```
### Content Filtering
```python
import litellm
response = litellm.rerank(
model="rerank-english-v3.0",
query="Search query here",
documents=[
{"text": "Document 1 content", "id": "doc1"},
{"text": "Document 2 content", "id": "doc2"},
],
guardrails=["content_filter"],
)
```
### Async Rerank with Guardrails
```python
import litellm
import asyncio
async def rerank_with_guardrails():
response = await litellm.arerank(
model="rerank-english-v3.0",
query="Technical query",
documents=["Doc 1", "Doc 2", "Doc 3"],
guardrails=["sanitize"],
top_n=2
)
return response
result = asyncio.run(rerank_with_guardrails())
```

View File

@@ -0,0 +1,11 @@
"""Cohere Rerank handler for Unified Guardrails."""
from litellm.llms.cohere.rerank.guardrail_translation.handler import CohereRerankHandler
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.rerank: CohereRerankHandler,
CallTypes.arerank: CohereRerankHandler,
}
__all__ = ["guardrail_translation_mappings", "CohereRerankHandler"]

View File

@@ -0,0 +1,107 @@
"""
Cohere Rerank Handler for Unified Guardrails
This module provides guardrail translation support for the rerank endpoint.
The handler processes only the 'query' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.rerank import RerankResponse
class CohereRerankHandler(BaseTranslation):
"""
Handler for processing rerank requests with guardrails.
This class provides methods to:
1. Process input query (pre-call hook)
2. Process output response (post-call hook) - not applicable for rerank
The handler specifically processes:
- The 'query' parameter (string)
Note: Documents are not processed by guardrails as they are the corpus
being searched, not user input.
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input query by applying guardrails.
Args:
data: Request data dictionary containing 'query'
guardrail_to_apply: The guardrail instance to apply
Returns:
Modified data with guardrails applied to query only
"""
# Process query only
query = data.get("query")
if query is not None and isinstance(query, str):
inputs = GenericGuardrailAPIInputs(texts=[query])
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
data["query"] = guardrailed_texts[0] if guardrailed_texts else query
verbose_proxy_logger.debug(
"Rerank: Applied guardrail to query. "
"Original length: %d, New length: %d",
len(query),
len(data["query"]),
)
else:
verbose_proxy_logger.debug(
"Rerank: No query to process or query is not a string"
)
return data
async def process_output_response(
self,
response: "RerankResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response - not applicable for rerank.
Rerank responses contain relevance scores and indices, not text,
so there's nothing to apply guardrails to. This method returns
the response unchanged.
Args:
response: Rerank response object with rankings
guardrail_to_apply: The guardrail instance (unused)
litellm_logging_obj: Optional logging object (unused)
user_api_key_dict: User API key metadata (unused)
Returns:
Unmodified response (rankings don't need text guardrails)
"""
verbose_proxy_logger.debug(
"Rerank: Output processing not applicable "
"(output contains relevance scores, not text)"
)
return response

View File

@@ -0,0 +1,5 @@
"""
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View File

@@ -0,0 +1,158 @@
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
from ..common_utils import CohereError
class CohereRerankConfig(BaseRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: Optional[dict] = None,
) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
return "https://api.cohere.ai/v1/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_chunks_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> Dict:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return dict(
OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
)
)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("COHERE_API_KEY")
or get_secret_str("CO_API_KEY")
or litellm.cohere_key
)
if api_key is None:
raise ValueError(
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
)
default_headers = {
"Authorization": f"Bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_request(
self,
model: str,
optional_rerank_params: Dict,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Cohere rerank response
No transformation required, litellm follows cohere API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
return RerankResponse(**raw_response_json)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(message=error_message, status_code=status_code)

View File

@@ -0,0 +1,88 @@
from typing import Any, Dict, List, Optional, Union
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest
class CohereRerankV2Config(CohereRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: Optional[dict] = None,
) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v2/rerank"):
api_base = f"{api_base}/v2/rerank"
return api_base
return "https://api.cohere.ai/v2/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_tokens_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> Dict:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return dict(
OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_tokens_per_doc=max_tokens_per_doc,
)
)
def transform_rerank_request(
self,
model: str,
optional_rerank_params: Dict,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)