Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/volcengine/responses/transformation.py

570 lines
19 KiB
Python
Raw Normal View History

from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
get_args,
get_origin,
)
import httpx
from pydantic import fields as pyd_fields
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
_safe_convert_created_field,
)
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
ResponseInputParam,
ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse,
ResponsesAPIStreamingResponse,
)
from litellm.types.responses.main import DeleteResponseResult
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
from ..common_utils import (
VolcEngineError,
get_volcengine_base_url,
get_volcengine_headers,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VolcEngineResponsesAPIConfig(OpenAIResponsesAPIConfig):
_SUPPORTED_OPTIONAL_PARAMS: List[str] = [
# Doc-listed knobs
"instructions",
"max_output_tokens",
"previous_response_id",
"store",
"reasoning",
"stream",
"temperature",
"top_p",
"text",
"tools",
"tool_choice",
"max_tool_calls",
"thinking",
"caching",
"expire_at",
"context_management",
# LiteLLM-internal metadata (not sent to provider)
"metadata",
# Request plumbing helpers
"extra_headers",
"extra_query",
"extra_body",
"timeout",
]
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.VOLCENGINE
def get_supported_openai_params(self, model: str) -> list:
"""
Volcengine Responses API: only documented parameters are supported.
"""
supported = ["input", "model"] + list(self._SUPPORTED_OPTIONAL_PARAMS)
# Do not advertise internal-only metadata to callers; we still accept and drop it before send.
if "metadata" in supported:
supported.remove("metadata")
return supported
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> VolcEngineError:
typed_headers: httpx.Headers = (
headers
if isinstance(headers, httpx.Headers)
else httpx.Headers(headers or {})
)
return VolcEngineError(
status_code=status_code,
message=error_message,
headers=typed_headers,
)
def validate_environment(
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""
Build auth headers for Volcengine Responses API.
"""
if litellm_params is None:
litellm_params = GenericLiteLLMParams()
elif isinstance(litellm_params, dict):
litellm_params = GenericLiteLLMParams(**litellm_params)
api_key = (
litellm_params.api_key
or litellm.api_key
or get_secret_str("ARK_API_KEY")
or get_secret_str("VOLCENGINE_API_KEY")
)
if api_key is None:
raise ValueError(
"Volcengine API key is required. Set ARK_API_KEY / VOLCENGINE_API_KEY or pass api_key."
)
return get_volcengine_headers(api_key=api_key, extra_headers=headers)
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Construct Volcengine Responses API endpoint.
"""
base_url = (
api_base
or litellm.api_base
or get_secret_str("VOLCENGINE_API_BASE")
or get_secret_str("ARK_API_BASE")
or get_volcengine_base_url()
)
base_url = base_url.rstrip("/")
if base_url.endswith("/responses"):
return base_url
if base_url.endswith("/api/v3"):
return f"{base_url}/responses"
return f"{base_url}/api/v3/responses"
def map_openai_params(
self,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""
Volcengine Responses API aligns with OpenAI parameters.
Remove parameters not supported by the public docs.
"""
params = {
key: value
for key, value in dict(response_api_optional_params).items()
if key in self._SUPPORTED_OPTIONAL_PARAMS
}
# LiteLLM metadata is internal-only; don't send to provider
params.pop("metadata", None)
# Volcengine docs do not list parallel_tool_calls; drop it to avoid backend errors.
if "parallel_tool_calls" in params:
verbose_logger.debug(
"Volcengine Responses API: dropping unsupported 'parallel_tool_calls' param."
)
params.pop("parallel_tool_calls", None)
return params
def transform_responses_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""
Volcengine rejects any undocumented fields (including extra_body). Fail fast
with clear errors and re-filter with the documented whitelist before delegating
to the OpenAI base transformer.
"""
allowed = set(self._SUPPORTED_OPTIONAL_PARAMS)
sanitized_optional = {
k: v
for k, v in response_api_optional_request_params.items()
if k in allowed
}
# Ensure metadata never reaches provider
sanitized_optional.pop("metadata", None)
sanitized_optional.pop("parallel_tool_calls", None)
# If extra_body is provided, filter its keys against the same allowlist to avoid
# leaking unsupported params to the provider.
if isinstance(sanitized_optional.get("extra_body"), dict):
filtered_body = {
k: v
for k, v in sanitized_optional["extra_body"].items()
if k in allowed
}
if filtered_body:
sanitized_optional["extra_body"] = filtered_body
else:
sanitized_optional.pop("extra_body", None)
return super().transform_responses_api_request(
model=model,
input=input,
response_api_optional_request_params=sanitized_optional,
litellm_params=litellm_params,
headers=headers,
)
def transform_streaming_response(
self,
model: str,
parsed_chunk: dict,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIStreamingResponse:
"""
Volcengine may omit required fields; auto-fill them using event model defaults.
"""
chunk = parsed_chunk
# Patch missing response.output on response.* events
if isinstance(chunk, dict):
resp = chunk.get("response")
if isinstance(resp, dict) and "output" not in resp:
patched_chunk = dict(chunk)
patched_resp = dict(resp)
patched_resp["output"] = []
patched_chunk["response"] = patched_resp
chunk = patched_chunk
event_type = str(chunk.get("type")) if isinstance(chunk, dict) else None
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
event_type=event_type
)
patched_chunk = self._fill_missing_fields(chunk, event_pydantic_model)
return event_pydantic_model(**patched_chunk)
def transform_response_api_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
try:
logging_obj.post_call(
original_response=raw_response.text,
additional_args={"complete_input_dict": {}},
)
raw_response_json = raw_response.json()
if "created_at" in raw_response_json:
raw_response_json["created_at"] = _safe_convert_created_field(
raw_response_json["created_at"]
)
except Exception:
raise VolcEngineError(
message=raw_response.text, status_code=raw_response.status_code
)
raw_response_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_response_headers)
try:
response = ResponsesAPIResponse(**raw_response_json)
except Exception:
verbose_logger.debug(
"Volcengine Responses API: falling back to model_construct for response parsing."
)
response = ResponsesAPIResponse.model_construct(**raw_response_json)
response._hidden_params["additional_headers"] = processed_headers
response._hidden_params["headers"] = raw_response_headers
return response
#########################################################
########## DELETE RESPONSE API TRANSFORMATION ##############
#########################################################
def transform_delete_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
url = f"{api_base}/{response_id}"
data: Dict = {}
return url, data
def transform_delete_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteResponseResult:
try:
raw_response_json = raw_response.json()
except Exception:
raise VolcEngineError(
message=raw_response.text, status_code=raw_response.status_code
)
try:
return DeleteResponseResult(**raw_response_json)
except Exception:
verbose_logger.debug(
"Volcengine Responses API: falling back to model_construct for delete response parsing."
)
return DeleteResponseResult.model_construct(**raw_response_json)
#########################################################
########## GET RESPONSE API TRANSFORMATION ###############
#########################################################
def transform_get_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
url = f"{api_base}/{response_id}"
data: Dict = {}
return url, data
def transform_get_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise VolcEngineError(
message=raw_response.text, status_code=raw_response.status_code
)
raw_response_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_response_headers)
response = ResponsesAPIResponse(**raw_response_json)
response._hidden_params["additional_headers"] = processed_headers
response._hidden_params["headers"] = raw_response_headers
return response
#########################################################
########## LIST INPUT ITEMS TRANSFORMATION #############
#########################################################
def transform_list_input_items_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
before: Optional[str] = None,
include: Optional[List[str]] = None,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
) -> Tuple[str, Dict]:
url = f"{api_base}/{response_id}/input_items"
params: Dict[str, Any] = {}
if after is not None:
params["after"] = after
if before is not None:
params["before"] = before
if include:
params["include"] = ",".join(include)
if limit is not None:
params["limit"] = limit
if order is not None:
params["order"] = order
return url, params
def transform_list_input_items_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Dict:
try:
return raw_response.json()
except Exception:
raise VolcEngineError(
message=raw_response.text, status_code=raw_response.status_code
)
#########################################################
########## CANCEL RESPONSE API TRANSFORMATION ##########
#########################################################
def transform_cancel_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
url = f"{api_base}/{response_id}/cancel"
data: Dict = {}
return url, data
def transform_cancel_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise VolcEngineError(
message=raw_response.text, status_code=raw_response.status_code
)
raw_response_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_response_headers)
response = ResponsesAPIResponse(**raw_response_json)
response._hidden_params["additional_headers"] = processed_headers
response._hidden_params["headers"] = raw_response_headers
return response
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Volcengine Responses API supports native streaming; never fall back to fake stream.
"""
return False
@staticmethod
def _fill_missing_fields(chunk: Any, event_model: Any) -> Dict[str, Any]:
"""
Heuristically fill missing required fields with safe defaults based on the
event model's field annotations. This keeps parsing tolerant of providers that
omit non-essential fields.
"""
if not isinstance(chunk, dict) or event_model is None:
return chunk
patched: Dict[str, Any] = dict(chunk)
fields_map = getattr(event_model, "model_fields", {}) or {}
for name, field in fields_map.items():
if name in patched:
patched[name] = VolcEngineResponsesAPIConfig._maybe_fill_nested(
patched[name], field.annotation
)
continue
# Explicit default or factory
if (
field.default is not pyd_fields.PydanticUndefined
and field.default is not None
):
patched[name] = field.default
continue
if (
field.default_factory is not None
and field.default_factory is not pyd_fields.PydanticUndefined
):
patched[name] = field.default_factory()
continue
# Heuristic defaults for missing required fields
patched[name] = VolcEngineResponsesAPIConfig._default_for_annotation(
field.annotation
)
return patched
@staticmethod
def _default_for_annotation(annotation: Any) -> Any:
origin = get_origin(annotation)
args = get_args(annotation)
if annotation is int:
return 0
if annotation is list or origin is list:
return []
if origin is Union:
# Prefer empty list when any option is a list
if any((arg is list or get_origin(arg) is list) for arg in args):
return []
if type(None) in args:
return None
if origin is Union and type(None) in args:
return None
# Fallback to None when no safer guess exists
return None
@staticmethod
def _maybe_fill_nested(value: Any, annotation: Any) -> Any:
"""
Recursively fill nested dict/list structures based on the annotated model.
"""
model_cls = VolcEngineResponsesAPIConfig._pick_model_class(annotation, value)
args = get_args(annotation)
if isinstance(value, dict) and model_cls is not None:
return VolcEngineResponsesAPIConfig._fill_missing_fields(value, model_cls)
if isinstance(value, list):
# Attempt to fill list elements if we know the element annotation
elem_ann: Any = args[0] if args else None
if elem_ann is not None:
return [
VolcEngineResponsesAPIConfig._maybe_fill_nested(v, elem_ann)
for v in value
]
return value
@staticmethod
def _pick_model_class(annotation: Any, value: Any) -> Optional[Any]:
"""
Choose the best-matching Pydantic model class for a nested dict.
"""
candidates: List[Any] = []
origin = get_origin(annotation)
if hasattr(annotation, "model_fields"):
candidates.append(annotation)
if origin is Union:
for arg in get_args(annotation):
if hasattr(arg, "model_fields"):
candidates.append(arg)
if not candidates:
return None
# Try to match by literal "type" field when available
if isinstance(value, dict):
v_type = value.get("type")
for candidate in candidates:
try:
type_field = candidate.model_fields.get("type")
if type_field is None:
continue
literal_ann = type_field.annotation
if get_origin(literal_ann) is Literal:
literal_values = get_args(literal_ann)
if v_type in literal_values:
return candidate
except Exception:
continue
# Fall back to the first candidate
return candidates[0]
def supports_native_websocket(self) -> bool:
"""VolcEngine does not support native WebSocket for Responses API"""
return False