570 lines
19 KiB
Python
570 lines
19 KiB
Python
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
)
|
|
|
|
import httpx
|
|
from pydantic import fields as pyd_fields
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
|
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
|
_safe_convert_created_field,
|
|
)
|
|
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.llms.openai import (
|
|
ResponseInputParam,
|
|
ResponsesAPIOptionalRequestParams,
|
|
ResponsesAPIResponse,
|
|
ResponsesAPIStreamingResponse,
|
|
)
|
|
from litellm.types.responses.main import DeleteResponseResult
|
|
from litellm.types.router import GenericLiteLLMParams
|
|
from litellm.types.utils import LlmProviders
|
|
|
|
from ..common_utils import (
|
|
VolcEngineError,
|
|
get_volcengine_base_url,
|
|
get_volcengine_headers,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
|
else:
|
|
LiteLLMLoggingObj = Any
|
|
|
|
|
|
class VolcEngineResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
|
_SUPPORTED_OPTIONAL_PARAMS: List[str] = [
|
|
# Doc-listed knobs
|
|
"instructions",
|
|
"max_output_tokens",
|
|
"previous_response_id",
|
|
"store",
|
|
"reasoning",
|
|
"stream",
|
|
"temperature",
|
|
"top_p",
|
|
"text",
|
|
"tools",
|
|
"tool_choice",
|
|
"max_tool_calls",
|
|
"thinking",
|
|
"caching",
|
|
"expire_at",
|
|
"context_management",
|
|
# LiteLLM-internal metadata (not sent to provider)
|
|
"metadata",
|
|
# Request plumbing helpers
|
|
"extra_headers",
|
|
"extra_query",
|
|
"extra_body",
|
|
"timeout",
|
|
]
|
|
|
|
@property
|
|
def custom_llm_provider(self) -> LlmProviders:
|
|
return LlmProviders.VOLCENGINE
|
|
|
|
def get_supported_openai_params(self, model: str) -> list:
|
|
"""
|
|
Volcengine Responses API: only documented parameters are supported.
|
|
"""
|
|
supported = ["input", "model"] + list(self._SUPPORTED_OPTIONAL_PARAMS)
|
|
# Do not advertise internal-only metadata to callers; we still accept and drop it before send.
|
|
if "metadata" in supported:
|
|
supported.remove("metadata")
|
|
return supported
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> VolcEngineError:
|
|
typed_headers: httpx.Headers = (
|
|
headers
|
|
if isinstance(headers, httpx.Headers)
|
|
else httpx.Headers(headers or {})
|
|
)
|
|
return VolcEngineError(
|
|
status_code=status_code,
|
|
message=error_message,
|
|
headers=typed_headers,
|
|
)
|
|
|
|
def validate_environment(
|
|
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
|
) -> dict:
|
|
"""
|
|
Build auth headers for Volcengine Responses API.
|
|
"""
|
|
if litellm_params is None:
|
|
litellm_params = GenericLiteLLMParams()
|
|
elif isinstance(litellm_params, dict):
|
|
litellm_params = GenericLiteLLMParams(**litellm_params)
|
|
|
|
api_key = (
|
|
litellm_params.api_key
|
|
or litellm.api_key
|
|
or get_secret_str("ARK_API_KEY")
|
|
or get_secret_str("VOLCENGINE_API_KEY")
|
|
)
|
|
|
|
if api_key is None:
|
|
raise ValueError(
|
|
"Volcengine API key is required. Set ARK_API_KEY / VOLCENGINE_API_KEY or pass api_key."
|
|
)
|
|
|
|
return get_volcengine_headers(api_key=api_key, extra_headers=headers)
|
|
|
|
def get_complete_url(
|
|
self,
|
|
api_base: Optional[str],
|
|
litellm_params: dict,
|
|
) -> str:
|
|
"""
|
|
Construct Volcengine Responses API endpoint.
|
|
"""
|
|
base_url = (
|
|
api_base
|
|
or litellm.api_base
|
|
or get_secret_str("VOLCENGINE_API_BASE")
|
|
or get_secret_str("ARK_API_BASE")
|
|
or get_volcengine_base_url()
|
|
)
|
|
|
|
base_url = base_url.rstrip("/")
|
|
|
|
if base_url.endswith("/responses"):
|
|
return base_url
|
|
if base_url.endswith("/api/v3"):
|
|
return f"{base_url}/responses"
|
|
return f"{base_url}/api/v3/responses"
|
|
|
|
def map_openai_params(
|
|
self,
|
|
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> Dict:
|
|
"""
|
|
Volcengine Responses API aligns with OpenAI parameters.
|
|
Remove parameters not supported by the public docs.
|
|
"""
|
|
params = {
|
|
key: value
|
|
for key, value in dict(response_api_optional_params).items()
|
|
if key in self._SUPPORTED_OPTIONAL_PARAMS
|
|
}
|
|
|
|
# LiteLLM metadata is internal-only; don't send to provider
|
|
params.pop("metadata", None)
|
|
|
|
# Volcengine docs do not list parallel_tool_calls; drop it to avoid backend errors.
|
|
if "parallel_tool_calls" in params:
|
|
verbose_logger.debug(
|
|
"Volcengine Responses API: dropping unsupported 'parallel_tool_calls' param."
|
|
)
|
|
params.pop("parallel_tool_calls", None)
|
|
|
|
return params
|
|
|
|
def transform_responses_api_request(
|
|
self,
|
|
model: str,
|
|
input: Union[str, ResponseInputParam],
|
|
response_api_optional_request_params: Dict,
|
|
litellm_params: GenericLiteLLMParams,
|
|
headers: dict,
|
|
) -> Dict:
|
|
"""
|
|
Volcengine rejects any undocumented fields (including extra_body). Fail fast
|
|
with clear errors and re-filter with the documented whitelist before delegating
|
|
to the OpenAI base transformer.
|
|
"""
|
|
allowed = set(self._SUPPORTED_OPTIONAL_PARAMS)
|
|
|
|
sanitized_optional = {
|
|
k: v
|
|
for k, v in response_api_optional_request_params.items()
|
|
if k in allowed
|
|
}
|
|
# Ensure metadata never reaches provider
|
|
sanitized_optional.pop("metadata", None)
|
|
sanitized_optional.pop("parallel_tool_calls", None)
|
|
|
|
# If extra_body is provided, filter its keys against the same allowlist to avoid
|
|
# leaking unsupported params to the provider.
|
|
if isinstance(sanitized_optional.get("extra_body"), dict):
|
|
filtered_body = {
|
|
k: v
|
|
for k, v in sanitized_optional["extra_body"].items()
|
|
if k in allowed
|
|
}
|
|
if filtered_body:
|
|
sanitized_optional["extra_body"] = filtered_body
|
|
else:
|
|
sanitized_optional.pop("extra_body", None)
|
|
|
|
return super().transform_responses_api_request(
|
|
model=model,
|
|
input=input,
|
|
response_api_optional_request_params=sanitized_optional,
|
|
litellm_params=litellm_params,
|
|
headers=headers,
|
|
)
|
|
|
|
def transform_streaming_response(
|
|
self,
|
|
model: str,
|
|
parsed_chunk: dict,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
) -> ResponsesAPIStreamingResponse:
|
|
"""
|
|
Volcengine may omit required fields; auto-fill them using event model defaults.
|
|
"""
|
|
chunk = parsed_chunk
|
|
|
|
# Patch missing response.output on response.* events
|
|
if isinstance(chunk, dict):
|
|
resp = chunk.get("response")
|
|
if isinstance(resp, dict) and "output" not in resp:
|
|
patched_chunk = dict(chunk)
|
|
patched_resp = dict(resp)
|
|
patched_resp["output"] = []
|
|
patched_chunk["response"] = patched_resp
|
|
chunk = patched_chunk
|
|
|
|
event_type = str(chunk.get("type")) if isinstance(chunk, dict) else None
|
|
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
|
|
event_type=event_type
|
|
)
|
|
|
|
patched_chunk = self._fill_missing_fields(chunk, event_pydantic_model)
|
|
|
|
return event_pydantic_model(**patched_chunk)
|
|
|
|
def transform_response_api_response(
|
|
self,
|
|
model: str,
|
|
raw_response: httpx.Response,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
) -> ResponsesAPIResponse:
|
|
try:
|
|
logging_obj.post_call(
|
|
original_response=raw_response.text,
|
|
additional_args={"complete_input_dict": {}},
|
|
)
|
|
raw_response_json = raw_response.json()
|
|
if "created_at" in raw_response_json:
|
|
raw_response_json["created_at"] = _safe_convert_created_field(
|
|
raw_response_json["created_at"]
|
|
)
|
|
except Exception:
|
|
raise VolcEngineError(
|
|
message=raw_response.text, status_code=raw_response.status_code
|
|
)
|
|
|
|
raw_response_headers = dict(raw_response.headers)
|
|
processed_headers = process_response_headers(raw_response_headers)
|
|
|
|
try:
|
|
response = ResponsesAPIResponse(**raw_response_json)
|
|
except Exception:
|
|
verbose_logger.debug(
|
|
"Volcengine Responses API: falling back to model_construct for response parsing."
|
|
)
|
|
response = ResponsesAPIResponse.model_construct(**raw_response_json)
|
|
|
|
response._hidden_params["additional_headers"] = processed_headers
|
|
response._hidden_params["headers"] = raw_response_headers
|
|
return response
|
|
|
|
#########################################################
|
|
########## DELETE RESPONSE API TRANSFORMATION ##############
|
|
#########################################################
|
|
def transform_delete_response_api_request(
|
|
self,
|
|
response_id: str,
|
|
api_base: str,
|
|
litellm_params: GenericLiteLLMParams,
|
|
headers: dict,
|
|
) -> Tuple[str, Dict]:
|
|
url = f"{api_base}/{response_id}"
|
|
data: Dict = {}
|
|
return url, data
|
|
|
|
def transform_delete_response_api_response(
|
|
self,
|
|
raw_response: httpx.Response,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
) -> DeleteResponseResult:
|
|
try:
|
|
raw_response_json = raw_response.json()
|
|
except Exception:
|
|
raise VolcEngineError(
|
|
message=raw_response.text, status_code=raw_response.status_code
|
|
)
|
|
try:
|
|
return DeleteResponseResult(**raw_response_json)
|
|
except Exception:
|
|
verbose_logger.debug(
|
|
"Volcengine Responses API: falling back to model_construct for delete response parsing."
|
|
)
|
|
return DeleteResponseResult.model_construct(**raw_response_json)
|
|
|
|
#########################################################
|
|
########## GET RESPONSE API TRANSFORMATION ###############
|
|
#########################################################
|
|
def transform_get_response_api_request(
|
|
self,
|
|
response_id: str,
|
|
api_base: str,
|
|
litellm_params: GenericLiteLLMParams,
|
|
headers: dict,
|
|
) -> Tuple[str, Dict]:
|
|
url = f"{api_base}/{response_id}"
|
|
data: Dict = {}
|
|
return url, data
|
|
|
|
def transform_get_response_api_response(
|
|
self,
|
|
raw_response: httpx.Response,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
) -> ResponsesAPIResponse:
|
|
try:
|
|
raw_response_json = raw_response.json()
|
|
except Exception:
|
|
raise VolcEngineError(
|
|
message=raw_response.text, status_code=raw_response.status_code
|
|
)
|
|
|
|
raw_response_headers = dict(raw_response.headers)
|
|
processed_headers = process_response_headers(raw_response_headers)
|
|
|
|
response = ResponsesAPIResponse(**raw_response_json)
|
|
response._hidden_params["additional_headers"] = processed_headers
|
|
response._hidden_params["headers"] = raw_response_headers
|
|
return response
|
|
|
|
#########################################################
|
|
########## LIST INPUT ITEMS TRANSFORMATION #############
|
|
#########################################################
|
|
def transform_list_input_items_request(
|
|
self,
|
|
response_id: str,
|
|
api_base: str,
|
|
litellm_params: GenericLiteLLMParams,
|
|
headers: dict,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
include: Optional[List[str]] = None,
|
|
limit: int = 20,
|
|
order: Literal["asc", "desc"] = "desc",
|
|
) -> Tuple[str, Dict]:
|
|
url = f"{api_base}/{response_id}/input_items"
|
|
params: Dict[str, Any] = {}
|
|
if after is not None:
|
|
params["after"] = after
|
|
if before is not None:
|
|
params["before"] = before
|
|
if include:
|
|
params["include"] = ",".join(include)
|
|
if limit is not None:
|
|
params["limit"] = limit
|
|
if order is not None:
|
|
params["order"] = order
|
|
return url, params
|
|
|
|
def transform_list_input_items_response(
|
|
self,
|
|
raw_response: httpx.Response,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
) -> Dict:
|
|
try:
|
|
return raw_response.json()
|
|
except Exception:
|
|
raise VolcEngineError(
|
|
message=raw_response.text, status_code=raw_response.status_code
|
|
)
|
|
|
|
#########################################################
|
|
########## CANCEL RESPONSE API TRANSFORMATION ##########
|
|
#########################################################
|
|
def transform_cancel_response_api_request(
|
|
self,
|
|
response_id: str,
|
|
api_base: str,
|
|
litellm_params: GenericLiteLLMParams,
|
|
headers: dict,
|
|
) -> Tuple[str, Dict]:
|
|
url = f"{api_base}/{response_id}/cancel"
|
|
data: Dict = {}
|
|
return url, data
|
|
|
|
def transform_cancel_response_api_response(
|
|
self,
|
|
raw_response: httpx.Response,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
) -> ResponsesAPIResponse:
|
|
try:
|
|
raw_response_json = raw_response.json()
|
|
except Exception:
|
|
raise VolcEngineError(
|
|
message=raw_response.text, status_code=raw_response.status_code
|
|
)
|
|
|
|
raw_response_headers = dict(raw_response.headers)
|
|
processed_headers = process_response_headers(raw_response_headers)
|
|
|
|
response = ResponsesAPIResponse(**raw_response_json)
|
|
response._hidden_params["additional_headers"] = processed_headers
|
|
response._hidden_params["headers"] = raw_response_headers
|
|
return response
|
|
|
|
def should_fake_stream(
|
|
self,
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
custom_llm_provider: Optional[str] = None,
|
|
) -> bool:
|
|
"""
|
|
Volcengine Responses API supports native streaming; never fall back to fake stream.
|
|
"""
|
|
return False
|
|
|
|
@staticmethod
|
|
def _fill_missing_fields(chunk: Any, event_model: Any) -> Dict[str, Any]:
|
|
"""
|
|
Heuristically fill missing required fields with safe defaults based on the
|
|
event model's field annotations. This keeps parsing tolerant of providers that
|
|
omit non-essential fields.
|
|
"""
|
|
if not isinstance(chunk, dict) or event_model is None:
|
|
return chunk
|
|
|
|
patched: Dict[str, Any] = dict(chunk)
|
|
fields_map = getattr(event_model, "model_fields", {}) or {}
|
|
|
|
for name, field in fields_map.items():
|
|
if name in patched:
|
|
patched[name] = VolcEngineResponsesAPIConfig._maybe_fill_nested(
|
|
patched[name], field.annotation
|
|
)
|
|
continue
|
|
|
|
# Explicit default or factory
|
|
if (
|
|
field.default is not pyd_fields.PydanticUndefined
|
|
and field.default is not None
|
|
):
|
|
patched[name] = field.default
|
|
continue
|
|
if (
|
|
field.default_factory is not None
|
|
and field.default_factory is not pyd_fields.PydanticUndefined
|
|
):
|
|
patched[name] = field.default_factory()
|
|
continue
|
|
|
|
# Heuristic defaults for missing required fields
|
|
patched[name] = VolcEngineResponsesAPIConfig._default_for_annotation(
|
|
field.annotation
|
|
)
|
|
|
|
return patched
|
|
|
|
@staticmethod
|
|
def _default_for_annotation(annotation: Any) -> Any:
|
|
origin = get_origin(annotation)
|
|
args = get_args(annotation)
|
|
|
|
if annotation is int:
|
|
return 0
|
|
if annotation is list or origin is list:
|
|
return []
|
|
if origin is Union:
|
|
# Prefer empty list when any option is a list
|
|
if any((arg is list or get_origin(arg) is list) for arg in args):
|
|
return []
|
|
if type(None) in args:
|
|
return None
|
|
if origin is Union and type(None) in args:
|
|
return None
|
|
|
|
# Fallback to None when no safer guess exists
|
|
return None
|
|
|
|
@staticmethod
|
|
def _maybe_fill_nested(value: Any, annotation: Any) -> Any:
|
|
"""
|
|
Recursively fill nested dict/list structures based on the annotated model.
|
|
"""
|
|
model_cls = VolcEngineResponsesAPIConfig._pick_model_class(annotation, value)
|
|
args = get_args(annotation)
|
|
|
|
if isinstance(value, dict) and model_cls is not None:
|
|
return VolcEngineResponsesAPIConfig._fill_missing_fields(value, model_cls)
|
|
|
|
if isinstance(value, list):
|
|
# Attempt to fill list elements if we know the element annotation
|
|
elem_ann: Any = args[0] if args else None
|
|
if elem_ann is not None:
|
|
return [
|
|
VolcEngineResponsesAPIConfig._maybe_fill_nested(v, elem_ann)
|
|
for v in value
|
|
]
|
|
|
|
return value
|
|
|
|
@staticmethod
|
|
def _pick_model_class(annotation: Any, value: Any) -> Optional[Any]:
|
|
"""
|
|
Choose the best-matching Pydantic model class for a nested dict.
|
|
"""
|
|
candidates: List[Any] = []
|
|
origin = get_origin(annotation)
|
|
|
|
if hasattr(annotation, "model_fields"):
|
|
candidates.append(annotation)
|
|
if origin is Union:
|
|
for arg in get_args(annotation):
|
|
if hasattr(arg, "model_fields"):
|
|
candidates.append(arg)
|
|
|
|
if not candidates:
|
|
return None
|
|
|
|
# Try to match by literal "type" field when available
|
|
if isinstance(value, dict):
|
|
v_type = value.get("type")
|
|
for candidate in candidates:
|
|
try:
|
|
type_field = candidate.model_fields.get("type")
|
|
if type_field is None:
|
|
continue
|
|
literal_ann = type_field.annotation
|
|
if get_origin(literal_ann) is Literal:
|
|
literal_values = get_args(literal_ann)
|
|
if v_type in literal_values:
|
|
return candidate
|
|
except Exception:
|
|
continue
|
|
|
|
# Fall back to the first candidate
|
|
return candidates[0]
|
|
|
|
def supports_native_websocket(self) -> bool:
|
|
"""VolcEngine does not support native WebSocket for Responses API"""
|
|
return False
|