chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,260 @@
from __future__ import annotations
import json
import time
from typing import AsyncIterator, Iterator, Optional
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import OpenAIChatCompletionChunk
from ...custom_httpx.llm_http_handler import BaseLLMHTTPHandler
# -------------------------------
# Errors
# -------------------------------
class GenAIHubOrchestrationError(BaseLLMException):
def __init__(self, status_code: int, message: str):
super().__init__(status_code=status_code, message=message)
self.status_code = status_code
self.message = message
# -------------------------------
# Stream parsing helpers
# -------------------------------
def _now_ts() -> int:
return int(time.time())
def _is_terminal_chunk(chunk: OpenAIChatCompletionChunk) -> bool:
"""OpenAI-shaped chunk is terminal if any choice has a non-None finish_reason."""
try:
for ch in chunk.choices or []:
if ch.finish_reason is not None:
return True
except Exception:
pass
return False
class _StreamParser:
"""Normalize orchestration streaming events into OpenAI-like chunks."""
@staticmethod
def _from_orchestration_result(evt: dict) -> Optional[OpenAIChatCompletionChunk]:
"""
Accepts orchestration_result shape and maps it to an OpenAI-like *chunk*.
"""
orc = evt.get("orchestration_result") or {}
if not orc:
return None
return OpenAIChatCompletionChunk.model_validate(
{
"id": orc.get("id") or evt.get("request_id") or "stream-chunk",
"object": orc.get("object") or "chat.completion.chunk",
"created": orc.get("created") or evt.get("created") or _now_ts(),
"model": orc.get("model") or "unknown",
"choices": [
{
"index": c.get("index", 0),
"delta": c.get("delta") or {},
"finish_reason": c.get("finish_reason"),
}
for c in (orc.get("choices") or [])
],
}
)
@staticmethod
def to_openai_chunk(event_obj: dict) -> Optional[OpenAIChatCompletionChunk]:
"""
Accepts:
- {"final_result": <openai-style CHUNK>} (IMPORTANT: this is just another chunk, NOT terminal)
- {"orchestration_result": {...}} (map to chunk)
- already-openai-shaped chunks
- other events (ignored)
Raises:
- ValueError for in-stream error objects
"""
# In-stream error per spec (surface as exception)
if "code" in event_obj or "error" in event_obj:
raise ValueError(json.dumps(event_obj))
# FINAL RESULT IS *NOT* TERMINAL: treat it as the next chunk
if "final_result" in event_obj:
fr = event_obj["final_result"] or {}
# ensure it looks like an OpenAI chunk
if "object" not in fr:
fr["object"] = "chat.completion.chunk"
return OpenAIChatCompletionChunk.model_validate(fr)
# Orchestration incremental delta
if "orchestration_result" in event_obj:
return _StreamParser._from_orchestration_result(event_obj)
# Already an OpenAI-like chunk
if "choices" in event_obj and "object" in event_obj:
return OpenAIChatCompletionChunk.model_validate(event_obj)
# Unknown / heartbeat / metrics
return None
# -------------------------------
# Iterators
# -------------------------------
class SAPStreamIterator:
"""
Sync iterator over an httpx streaming response that yields OpenAIChatCompletionChunk.
Accepts both SSE `data: ...` and raw JSON lines. Closes on terminal chunk or [DONE].
"""
def __init__(
self,
response: Iterator,
event_prefix: str = "data: ",
final_msg: str = "[DONE]",
):
self._resp = response
self._iter = response
self._prefix = event_prefix
self._final = final_msg
self._done = False
def __iter__(self) -> Iterator[OpenAIChatCompletionChunk]:
return self
def __next__(self) -> OpenAIChatCompletionChunk:
if self._done:
raise StopIteration
for raw in self._iter:
line = (raw or "").strip()
if not line:
continue
payload = (
line[len(self._prefix) :] if line.startswith(self._prefix) else line
)
if payload == self._final:
self._safe_close()
raise StopIteration
try:
obj = json.loads(payload)
except Exception:
continue
try:
chunk = _StreamParser.to_openai_chunk(obj)
except ValueError as e:
self._safe_close()
raise e
if chunk is None:
continue
# Close on terminal
if _is_terminal_chunk(chunk):
self._safe_close()
return chunk
self._safe_close()
raise StopIteration
def _safe_close(self) -> None:
if self._done:
return
else:
self._done = True
class AsyncSAPStreamIterator:
sync_stream = False
def __init__(
self,
response: AsyncIterator,
event_prefix: str = "data: ",
final_msg: str = "[DONE]",
):
self._resp = response
self._prefix = event_prefix
self._final = final_msg
self._line_iter = None
self._done = False
def __aiter__(self):
return self
async def __anext__(self):
if self._done:
raise StopAsyncIteration
if self._line_iter is None:
self._line_iter = self._resp
while True:
try:
raw = await self._line_iter.__anext__()
except (StopAsyncIteration, httpx.ReadError, OSError):
await self._aclose()
raise StopAsyncIteration
line = (raw or "").strip()
if not line:
continue
# now = lambda: int(time.time() * 1000)
payload = (
line[len(self._prefix) :] if line.startswith(self._prefix) else line
)
if payload == self._final:
await self._aclose()
raise StopAsyncIteration
try:
obj = json.loads(payload)
except Exception:
continue
try:
chunk = _StreamParser.to_openai_chunk(obj)
except ValueError as e:
await self._aclose()
raise GenAIHubOrchestrationError(502, str(e))
if chunk is None:
continue
# If terminal, close BEFORE returning. Next __anext__() will stop immediately.
if any(c.finish_reason is not None for c in (chunk.choices or [])):
await self._aclose()
return chunk
async def _aclose(self):
if self._done:
return
else:
self._done = True
# -------------------------------
# LLM handler
# -------------------------------
class GenAIHubOrchestration(BaseLLMHTTPHandler):
def _add_stream_param_to_request_body(
self, data: dict, provider_config: BaseConfig, fake_stream: bool
):
if data.get("config", {}).get("stream", None) is not None:
data["config"]["stream"]["enabled"] = True
else:
data["config"]["stream"] = {"enabled": True}
return data

View File

@@ -0,0 +1,130 @@
from typing import Union, Literal
from pydantic import BaseModel, Field, field_validator
def validate_different_content(v: Union[str, dict, list]) -> str:
if v in ((), {}, []):
return ""
elif isinstance(v, dict) and "text" in v:
return v["text"]
elif isinstance(v, list):
new_v = []
for item in v:
if isinstance(item, dict) and "text" in item:
if item["text"]:
new_v.append(item["text"])
elif isinstance(item, str):
new_v.append(item)
return "\n".join(new_v)
elif isinstance(v, str):
return v
raise ValueError("Content must be a string")
return v
class TextContent(BaseModel):
type_: Literal["text"] = Field(default="text", alias="type")
text: str
class ImageURLContent(BaseModel):
url: str
detail: str = "auto"
class ImageContent(BaseModel):
type_: Literal["image_url"] = Field(default="image_url", alias="type")
image_url: ImageURLContent
class FunctionObj(BaseModel):
name: str
arguments: str
class FunctionTool(BaseModel):
description: str = ""
name: str
parameters: dict = {"type": "object", "properties": {}}
strict: bool = False
@field_validator("parameters", mode="before")
@classmethod
def ensure_object_type(cls, v: dict) -> dict:
"""Ensure parameters has type='object' as required by SAP Orchestration Service."""
if not v:
return {"type": "object", "properties": {}}
if "type" not in v:
v = {"type": "object", **v}
if "properties" not in v:
v["properties"] = {}
return v
class ChatCompletionTool(BaseModel):
type_: Literal["function"] = Field(default="function", alias="type")
function: FunctionTool
class MessageToolCall(BaseModel):
id: str
type_: Literal["function"] = Field(default="function", alias="type")
function: FunctionObj
class SAPMessage(BaseModel):
"""
Model for SystemChatMessage and DeveloperChatMessage
"""
role: Literal["system", "developer"] = "system"
content: str
_content_validator = field_validator("content", mode="before")(
validate_different_content
)
class SAPUserMessage(BaseModel):
role: Literal["user"] = "user"
content: Union[
str, TextContent, ImageContent, list[Union[TextContent, ImageContent]]
]
class SAPAssistantMessage(BaseModel):
role: Literal["assistant"] = "assistant"
content: str = ""
refusal: str = ""
tool_calls: list[MessageToolCall] = []
_content_validator = field_validator("content", mode="before")(
validate_different_content
)
class SAPToolChatMessage(BaseModel):
role: Literal["tool"] = "tool"
tool_call_id: str
content: str
_content_validator = field_validator("content", mode="before")(
validate_different_content
)
class ResponseFormat(BaseModel):
type_: Literal["text", "json_object"] = Field(default="text", alias="type")
class JSONResponseSchema(BaseModel):
description: str = ""
name: str
schema_: dict = Field(default_factory=dict, alias="schema")
strict: bool = False
class ResponseFormatJSONSchema(BaseModel):
type_: Literal["json_schema"] = Field(default="json_schema", alias="type")
json_schema: JSONResponseSchema

View File

@@ -0,0 +1,351 @@
"""
Translate from OpenAI's `/v1/chat/completions` to SAP Generative AI Hub's Orchestration Service`v2/completion`
"""
from typing import (
List,
Optional,
Union,
Dict,
Tuple,
Any,
TYPE_CHECKING,
Iterator,
AsyncIterator,
)
from functools import cached_property
import litellm
import httpx
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
from ..credentials import get_token_creator
from .models import (
SAPMessage,
SAPAssistantMessage,
SAPToolChatMessage,
ChatCompletionTool,
ResponseFormatJSONSchema,
ResponseFormat,
SAPUserMessage,
)
from .handler import (
GenAIHubOrchestrationError,
AsyncSAPStreamIterator,
SAPStreamIterator,
)
def validate_dict(data: dict, model) -> dict:
return model(**data).model_dump(by_alias=True)
class GenAIHubOrchestrationConfig(OpenAIGPTConfig):
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None #
model_version: str = "latest"
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
self.token_creator = None
self._base_url = None
self._resource_group = None
def run_env_setup(self, service_key: Optional[str] = None) -> None:
try:
self.token_creator, self._base_url, self._resource_group = get_token_creator(service_key) # type: ignore
except ValueError as err:
raise GenAIHubOrchestrationError(status_code=400, message=err.args[0])
@property
def headers(self) -> Dict[str, str]:
if self.token_creator is None:
self.run_env_setup()
access_token = self.token_creator() # type: ignore
return {
"Authorization": access_token,
"AI-Resource-Group": self.resource_group,
"Content-Type": "application/json",
"AI-Client-Type": "LiteLLM",
}
@property
def base_url(self) -> str:
if self._base_url is None:
self.run_env_setup()
return self._base_url # type: ignore
@property
def resource_group(self) -> str:
if self._resource_group is None:
self.run_env_setup()
return self._resource_group # type: ignore
@cached_property
def deployment_url(self) -> str:
# Keep a short, tight client lifecycle here to avoid fd leaks
client = litellm.module_level_client
# with httpx.Client(timeout=30) as client:
deployments = client.get(
f"{self.base_url}/lm/deployments", headers=self.headers
).json()
valid: List[Tuple[str, str]] = []
for dep in deployments.get("resources", []):
if dep.get("scenarioId") == "orchestration":
cfg = client.get(
f'{self.base_url}/lm/configurations/{dep["configurationId"]}',
headers=self.headers,
).json()
if cfg.get("executableId") == "orchestration":
valid.append((dep["deploymentUrl"], dep["createdAt"]))
# newest first
return sorted(valid, key=lambda x: x[1], reverse=True)[0][0]
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model):
params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"max_completion_tokens",
"prediction",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"extra_headers",
"parallel_tool_calls",
"response_format",
"timeout",
]
# Remove response_format for providers that don't support it on SAP GenAI Hub
if (
model.startswith("amazon")
or model.startswith("cohere")
or model.startswith("alephalpha")
or model == "gpt-4"
):
params.remove("response_format")
if model.startswith("gemini") or model.startswith("amazon"):
params.remove("tool_choice")
return params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key:
self.run_env_setup(api_key)
return self.headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
):
api_base_ = f"{self.deployment_url}/v2/completion"
return api_base_
def transform_request(
self,
model: str,
messages: List[Dict[str, str]], # type: ignore
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
# Filter out parameters that are not valid model params for SAP Orchestration API
# - tools, model_version, deployment_url: handled separately
excluded_params = {"tools", "model_version", "deployment_url"}
# Filter strict for GPT models only - SAP AI Core doesn't accept it as a model param
# LangChain agents pass strict=true at top level, which fails for GPT models
# Anthropic models accept strict, so preserve it for them
if model.startswith("gpt"):
excluded_params.add("strict")
model_params = {
k: v for k, v in optional_params.items() if k not in excluded_params
}
model_version = optional_params.pop("model_version", "latest")
template = []
for message in messages:
if message["role"] == "user":
template.append(validate_dict(message, SAPUserMessage))
elif message["role"] == "assistant":
template.append(validate_dict(message, SAPAssistantMessage))
elif message["role"] == "tool":
template.append(validate_dict(message, SAPToolChatMessage))
else:
template.append(validate_dict(message, SAPMessage))
tools_ = optional_params.pop("tools", [])
tools_ = [validate_dict(tool, ChatCompletionTool) for tool in tools_]
if tools_ != []:
tools = {"tools": tools_}
else:
tools = {}
response_format = model_params.pop("response_format", {})
resp_type = response_format.get("type", None)
if resp_type:
if resp_type == "json_schema":
response_format = validate_dict(
response_format, ResponseFormatJSONSchema
)
else:
response_format = validate_dict(response_format, ResponseFormat)
response_format = {"response_format": response_format}
model_params.pop("stream", False)
stream_config = {}
if "stream_options" in model_params:
# stream_config["enabled"] = True
stream_options = model_params.pop("stream_options", {})
stream_config["chunk_size"] = stream_options.get("chunk_size", 100)
if "delimiters" in stream_options:
stream_config["delimiters"] = stream_options.get("delimiters")
# else:
# stream_config["enabled"] = False
config = {
"config": {
"modules": {
"prompt_templating": {
"prompt": {"template": template, **tools, **response_format},
"model": {
"name": model,
"params": model_params,
"version": model_version,
},
},
},
"stream": stream_config,
}
}
return config
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
response = ModelResponse.model_validate(raw_response.json()["final_result"])
# Strip markdown code blocks if JSON response_format was used with Anthropic models
# SAP GenAI Hub with Anthropic models sometimes wraps JSON in ```json ... ```
# based on prompt phrasing. GPT/Gemini models don't exhibit this behavior,
# so we gate the stripping to avoid accidentally modifying valid responses.
response_format = optional_params.get("response_format", {})
if response_format.get("type") in ("json_object", "json_schema"):
if model.startswith("anthropic"):
response = self._strip_markdown_json(response)
return response
def _strip_markdown_json(self, response: ModelResponse) -> ModelResponse:
"""Strip markdown code block wrapper from JSON content if present.
SAP GenAI Hub with Anthropic models sometimes returns JSON wrapped in
markdown code blocks (```json ... ```) depending on prompt phrasing.
This method strips that wrapper to ensure consistent JSON output.
"""
import re
for choice in response.choices or []:
if choice.message and choice.message.content:
content = choice.message.content.strip()
# Match ```json ... ``` or ``` ... ```
match = re.match(r"^```(?:json)?\s*\n?(.*?)\n?```$", content, re.DOTALL)
if match:
choice.message.content = match.group(1).strip()
return response
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], "ModelResponse"],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
if sync_stream:
return SAPStreamIterator(response=streaming_response) # type: ignore
else:
return AsyncSAPStreamIterator(response=streaming_response) # type: ignore

View File

@@ -0,0 +1,332 @@
from __future__ import annotations
from typing import Any, Callable, Dict, Final, List, Optional, Sequence, Tuple
from datetime import datetime, timedelta, timezone
from threading import Lock
from pathlib import Path
from dataclasses import dataclass
import json
import os
import tempfile
from litellm import sap_service_key
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
AUTH_ENDPOINT_SUFFIX = "/oauth/token"
CONFIG_FILE_ENV_VAR = "AICORE_CONFIG"
HOME_PATH_ENV_VAR = "AICORE_HOME"
PROFILE_ENV_VAR = "AICORE_PROFILE"
VCAP_SERVICES_ENV_VAR = "VCAP_SERVICES"
VCAP_AICORE_SERVICE_NAME = "aicore"
SERVICE_KEY_ENV_VAR = "AICORE_SERVICE_KEY"
DEFAULT_HOME_PATH = os.path.join(os.path.expanduser("~"), ".aicore")
def _get_home() -> str:
return os.getenv(HOME_PATH_ENV_VAR, DEFAULT_HOME_PATH)
def _get_nested(d: Dict[str, Any], path: Sequence[str]) -> Any:
cur: Any = d
for k in path:
if not isinstance(cur, dict) or k not in cur:
raise KeyError(".".join(path))
cur = cur[k]
return cur
def _load_json_env(var_name: str) -> Optional[Dict[str, Any]]:
raw = os.environ.get(var_name)
if not raw:
return None
try:
return json.loads(raw)
except json.JSONDecodeError:
return None
def _load_vcap() -> Dict[str, Any]:
return _load_json_env(VCAP_SERVICES_ENV_VAR) or {}
def _get_vcap_service(label: str) -> Optional[Dict[str, Any]]:
for services in _load_vcap().values():
for svc in services:
if svc.get("label") == label:
return svc
return None
@dataclass(frozen=True)
class CredentialsValue:
name: str
vcap_key: Optional[Tuple[str, ...]] = None
default: Optional[str] = None
transform_fn: Optional[Callable[[str], str]] = None
CREDENTIAL_VALUES: Final[List[CredentialsValue]] = [
CredentialsValue("client_id", ("clientid",)),
CredentialsValue("client_secret", ("clientsecret",)),
CredentialsValue(
"auth_url",
("url",),
transform_fn=lambda url: url.rstrip("/")
+ ("" if url.endswith(AUTH_ENDPOINT_SUFFIX) else AUTH_ENDPOINT_SUFFIX),
),
CredentialsValue(
"base_url",
("serviceurls", "AI_API_URL"),
transform_fn=lambda url: url.rstrip("/")
+ ("" if url.endswith("/v2") else "/v2"),
),
CredentialsValue("resource_group", default="default"),
CredentialsValue(
"cert_url",
("certurl",),
transform_fn=lambda url: url.rstrip("/")
+ ("" if url.endswith(AUTH_ENDPOINT_SUFFIX) else AUTH_ENDPOINT_SUFFIX),
),
# file paths (kept for config compatibility)
CredentialsValue("cert_file_path"),
CredentialsValue("key_file_path"),
# inline PEMs from VCAP
CredentialsValue(
"cert_str", ("certificate",), transform_fn=lambda s: s.replace("\\n", "\n")
),
CredentialsValue(
"key_str", ("key",), transform_fn=lambda s: s.replace("\\n", "\n")
),
]
def init_conf(profile: Optional[str] = None) -> Dict[str, Any]:
"""
Loads config JSON from:
1) $AICORE_CONFIG if set, otherwise
2) $AICORE_HOME/config.json (or config_<profile>.json when profile is given/not default)
Returns {} when nothing is found.
"""
home = Path(_get_home())
profile = profile or os.environ.get(PROFILE_ENV_VAR)
cfg_env = os.getenv(CONFIG_FILE_ENV_VAR)
cfg_path = (
Path(cfg_env)
if cfg_env
else (
home
/ (
"config.json"
if profile in (None, "", "default")
else f"config_{profile}.json"
)
)
)
if cfg_path and cfg_path.exists():
try:
with cfg_path.open(encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
raise KeyError(f"{cfg_path} is not valid JSON. Please fix or remove it!")
# If an explicit non-default profile was requested but not found, raise.
if cfg_env or (profile not in (None, "", "default")):
raise FileNotFoundError(
f"Unable to locate profile config file at '{cfg_path}' in AICORE_HOME '{home}'"
)
return {}
def _env_name(name: str) -> str:
return f"AICORE_{name.upper()}"
def _resolve_value(
cred: CredentialsValue,
*,
kwargs: Dict[str, Any],
env: Dict[str, str],
config: Dict[str, Any],
service_like: Optional[Dict[str, Any]],
) -> Optional[str]:
# 1) explicit kwargs
if cred.name in kwargs and kwargs[cred.name] is not None:
return kwargs[cred.name]
# 2) environment variables (primary name)
env_key = _env_name(cred.name)
if env_key in env and env[env_key] is not None:
return env[env_key]
# 3) config file (accept both prefixed and plain keys)
for key in (env_key, cred.name):
if key in config and config[key] is not None:
return config[key]
# 4) service-like source (AICORE_SERVICE_KEY first, else VCAP)
if service_like and cred.vcap_key:
try:
val = _get_nested(service_like, ("credentials",) + cred.vcap_key)
if val is not None:
return val
except KeyError:
pass
# 5) default
return cred.default
def fetch_credentials(
service_key: Optional[str] = None, profile: Optional[str] = None, **kwargs
) -> Dict[str, str]:
"""
Resolution order per key:
kwargs
> env (AICORE_<NAME>)
> config (AICORE_<NAME> or plain <name>)
> service-like source from JSON in $AICORE_SERVICE_KEY (same structure as a VCAP service object)
falling back to service entry in $VCAP_SERVICES with label 'aicore'
> default
"""
config = init_conf(profile)
env = os.environ # snapshot for testability
service_like = None
if not config:
# Prefer AICORE_SERVICE_KEY if present; otherwise fall back to the VCAP service.
service_like = (
service_key
or sap_service_key
or _load_json_env(SERVICE_KEY_ENV_VAR)
or _get_vcap_service(VCAP_AICORE_SERVICE_NAME)
)
out: Dict[str, str] = {}
for cred in CREDENTIAL_VALUES:
value = _resolve_value(cred, kwargs=kwargs, env=env, config=config, service_like=service_like) # type: ignore
if value is None:
continue
if cred.transform_fn:
value = cred.transform_fn(value)
out[cred.name] = value
if "cert_url" in out.keys():
out["auth_url"] = out.pop("cert_url")
return out
def get_token_creator(
service_key: Optional[str] = None,
profile: Optional[str] = None,
*,
timeout: float = 30.0,
expiry_buffer_minutes: int = 60,
**overrides,
) -> Tuple[Callable[[], str], str, str]:
"""
Creates a callable that fetches and caches an OAuth2 bearer token
using credentials from `fetch_credentials()`.
The callable:
- Automatically loads credentials via fetch_credentials(profile, **overrides)
- Fetches a new token only if expired or near expiry
- Caches token thread-safely with a configurable refresh buffer
Args:
profile: Optional AICore profile name
timeout: HTTP request timeout in seconds (default 30s)
expiry_buffer_minutes: Refresh the token this many minutes before expiry
overrides: Any explicit credential overrides (client_id, client_secret, etc.)
Returns:
Callable[[], str]: function returning a valid "Bearer <token>" string.
"""
# Resolve credentials using your helper
credentials: Dict[str, str] = fetch_credentials(
service_key=service_key, profile=profile, **overrides
)
auth_url = credentials.get("auth_url")
client_id = credentials.get("client_id")
client_secret = credentials.get("client_secret")
cert_str = credentials.get("cert_str")
key_str = credentials.get("key_str")
cert_file_path = credentials.get("cert_file_path")
key_file_path = credentials.get("key_file_path")
# Sanity check
if not auth_url or not client_id:
raise ValueError(
"fetch_credentials did not return valid 'auth_url' or 'client_id'"
)
modes = [
client_secret is not None,
(cert_str is not None and key_str is not None),
(cert_file_path is not None and key_file_path is not None),
]
if sum(bool(m) for m in modes) != 1:
raise ValueError(
"Invalid credentials: provide exactly one of client_secret, "
"(cert_str & key_str), or (cert_file_path & key_file_path)."
)
lock = Lock()
token: Optional[str] = None
token_expiry: Optional[datetime] = None
def _request_token(cert_pair=None) -> tuple[str, datetime]:
data = {"grant_type": "client_credentials", "client_id": client_id}
if client_secret:
data["client_secret"] = client_secret
client = _get_httpx_client()
# with httpx.Client(cert=cert_pair, timeout=timeout) as client:
resp = client.post(auth_url, data=data)
try:
resp.raise_for_status()
payload = resp.json()
access_token = payload["access_token"]
expires_in = int(payload.get("expires_in", 3600))
expiry_date = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
return f"Bearer {access_token}", expiry_date
except Exception as e:
msg = getattr(resp, "text", str(e))
raise RuntimeError(f"Token request failed: {msg}") from e
def _fetch_token() -> tuple[str, datetime]:
# Case 1: secret-based auth
if client_secret:
return _request_token()
# Case 2: cert/key strings
if cert_str and key_str:
cert_str_fixed = cert_str.replace("\\n", "\n")
key_str_fixed = key_str.replace("\\n", "\n")
with tempfile.TemporaryDirectory() as tmp:
cert_path = os.path.join(tmp, "cert.pem")
key_path = os.path.join(tmp, "key.pem")
with open(cert_path, "w") as f:
f.write(cert_str_fixed)
with open(key_path, "w") as f:
f.write(key_str_fixed)
return _request_token(cert_pair=(cert_path, key_path))
# Case 3: file-based cert/key
return _request_token(cert_pair=(cert_file_path, key_file_path))
def get_token() -> str:
nonlocal token, token_expiry
with lock:
now = datetime.now(timezone.utc)
if (
token is None
or token_expiry is None
or token_expiry - now < timedelta(minutes=expiry_buffer_minutes)
):
token, token_expiry = _fetch_token()
return token
return get_token, credentials["base_url"], credentials["resource_group"]

View File

@@ -0,0 +1,177 @@
"""
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
"""
from typing import Optional, List, Dict, Literal, Union
from pydantic import BaseModel, Field
from functools import cached_property
import httpx
from litellm.llms.base_llm.embedding.transformation import (
BaseEmbeddingConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.utils import EmbeddingResponse
from ..chat.handler import GenAIHubOrchestrationError
from ..credentials import get_token_creator
class Usage(BaseModel):
prompt_tokens: int
total_tokens: int
class EmbeddingItem(BaseModel):
object: Literal["embedding"]
embedding: List[float] = Field(
..., description="Vector of floats (length varies by model)."
)
index: int
class FinalResult(BaseModel):
object: Literal["list"]
data: List[EmbeddingItem]
model: str
usage: Usage
class EmbeddingsResponse(BaseModel):
request_id: str
final_result: FinalResult
class EmbeddingModel(BaseModel):
name: str
version: str = "latest"
params: dict = Field(default_factory=dict, validation_alias="parameters")
class EmbeddingsModules(BaseModel):
embeddings: EmbeddingModel
class EmbeddingInput(BaseModel):
text: Union[str, List[str]]
type: Literal["text", "document", "query"] = "text"
class EmbeddingRequest(BaseModel):
config: EmbeddingsModules
input: EmbeddingInput
def validate_dict(data: dict, model) -> dict:
return model(**data).model_dump()
class GenAIHubEmbeddingConfig(BaseEmbeddingConfig):
def __init__(self):
super().__init__()
self._access_token_data = {}
self.token_creator, self.base_url, self.resource_group = get_token_creator()
@property
def headers(self) -> Dict:
access_token = self.token_creator()
# headers for completions and embeddings requests
headers = {
"Authorization": access_token,
"AI-Resource-Group": self.resource_group,
"Content-Type": "application/json",
"AI-Client-Type": "LiteLLM",
}
return headers
@cached_property
def deployment_url(self) -> str:
with httpx.Client(timeout=30) as client:
valid_deployments = []
deployments = client.get(
self.base_url + "/lm/deployments", headers=self.headers
).json()
for deployment in deployments.get("resources", []):
if deployment["scenarioId"] == "orchestration":
config_details = client.get(
self.base_url
+ f'/lm/configurations/{deployment["configurationId"]}',
headers=self.headers,
).json()
if config_details["executableId"] == "orchestration":
valid_deployments.append(
(deployment["deploymentUrl"], deployment["createdAt"])
)
return sorted(valid_deployments, key=lambda x: x[1], reverse=True)[0][0]
def get_error_class(self, error_message, status_code, headers):
return GenAIHubOrchestrationError(status_code, error_message)
def get_supported_openai_params(self, model: str) -> list:
if "text-embedding-3" in model:
return ["encoding_format", "dimensions"]
else:
return [
"encoding_format",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def validate_environment(self, headers: dict, *args, **kwargs) -> dict:
return self.headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
url = self.deployment_url.rstrip("/") + "/v2/embeddings"
return url
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
model_dict = {}
model_dict["name"] = model
model_dict["version"] = optional_params.get("version", "latest")
model_dict["params"] = optional_params.get("parameters", {})
input_dict = {"text": input}
body = {
"config": {
"modules": {
"embeddings": {"model": validate_dict(model_dict, EmbeddingModel)}
}
},
"input": validate_dict(input_dict, EmbeddingInput),
}
return body
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
return EmbeddingResponse.model_validate(raw_response.json()["final_result"])