chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
GigaChat Provider for LiteLLM
|
||||
|
||||
GigaChat is Sber AI's large language model (Russia's leading LLM).
|
||||
Supports:
|
||||
- Chat completions (sync/async)
|
||||
- Streaming (sync/async)
|
||||
- Function calling / Tools
|
||||
- Structured output via JSON schema (emulated through function calls)
|
||||
- Image input (base64 and URL)
|
||||
- Embeddings
|
||||
|
||||
API Documentation: https://developers.sber.ru/docs/ru/gigachat/api/overview
|
||||
"""
|
||||
|
||||
from .chat.transformation import GigaChatConfig, GigaChatError
|
||||
from .embedding.transformation import GigaChatEmbeddingConfig
|
||||
|
||||
__all__ = [
|
||||
"GigaChatConfig",
|
||||
"GigaChatEmbeddingConfig",
|
||||
"GigaChatError",
|
||||
]
|
||||
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
GigaChat OAuth Authenticator
|
||||
|
||||
Handles OAuth 2.0 token management for GigaChat API.
|
||||
Based on official GigaChat SDK authentication flow.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import InMemoryCache
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
# GigaChat OAuth endpoint
|
||||
GIGACHAT_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
|
||||
|
||||
# Default scope for personal API access
|
||||
GIGACHAT_SCOPE = "GIGACHAT_API_PERS"
|
||||
|
||||
# Token expiry buffer in milliseconds (refresh token 60s before expiry)
|
||||
TOKEN_EXPIRY_BUFFER_MS = 60000
|
||||
|
||||
# Cache for access tokens
|
||||
_token_cache = InMemoryCache()
|
||||
|
||||
|
||||
class GigaChatAuthError(BaseLLMException):
|
||||
"""GigaChat authentication error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _get_credentials() -> Optional[str]:
|
||||
"""Get GigaChat credentials from environment."""
|
||||
return get_secret_str("GIGACHAT_CREDENTIALS") or get_secret_str("GIGACHAT_API_KEY")
|
||||
|
||||
|
||||
def _get_auth_url() -> str:
|
||||
"""Get GigaChat auth URL from environment or use default."""
|
||||
return get_secret_str("GIGACHAT_AUTH_URL") or GIGACHAT_AUTH_URL
|
||||
|
||||
|
||||
def _get_scope() -> str:
|
||||
"""Get GigaChat scope from environment or use default."""
|
||||
return get_secret_str("GIGACHAT_SCOPE") or GIGACHAT_SCOPE
|
||||
|
||||
|
||||
def _get_http_client() -> HTTPHandler:
|
||||
"""Get cached httpx client with SSL verification disabled."""
|
||||
return _get_httpx_client(params={"ssl_verify": False})
|
||||
|
||||
|
||||
def get_access_token(
|
||||
credentials: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
auth_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get valid access token, using cache if available.
|
||||
|
||||
Args:
|
||||
credentials: Base64-encoded credentials (client_id:client_secret)
|
||||
scope: API scope (GIGACHAT_API_PERS, GIGACHAT_API_CORP, etc.)
|
||||
auth_url: OAuth endpoint URL
|
||||
|
||||
Returns:
|
||||
Access token string
|
||||
|
||||
Raises:
|
||||
GigaChatAuthError: If authentication fails
|
||||
"""
|
||||
credentials = credentials or _get_credentials()
|
||||
if not credentials:
|
||||
raise GigaChatAuthError(
|
||||
status_code=401,
|
||||
message="GigaChat credentials not provided. Set GIGACHAT_CREDENTIALS or GIGACHAT_API_KEY environment variable.",
|
||||
)
|
||||
|
||||
scope = scope or _get_scope()
|
||||
auth_url = auth_url or _get_auth_url()
|
||||
|
||||
# Check cache
|
||||
cache_key = f"gigachat_token:{credentials[:16]}"
|
||||
cached = _token_cache.get_cache(cache_key)
|
||||
if cached:
|
||||
token, expires_at = cached
|
||||
# Check if token is still valid (with buffer)
|
||||
if time.time() * 1000 < expires_at - TOKEN_EXPIRY_BUFFER_MS:
|
||||
verbose_logger.debug("Using cached GigaChat access token")
|
||||
return token
|
||||
|
||||
# Request new token
|
||||
token, expires_at = _request_token_sync(credentials, scope, auth_url)
|
||||
|
||||
# Cache token
|
||||
ttl_seconds = max(
|
||||
0, (expires_at - TOKEN_EXPIRY_BUFFER_MS - time.time() * 1000) / 1000
|
||||
)
|
||||
if ttl_seconds > 0:
|
||||
_token_cache.set_cache(cache_key, (token, expires_at), ttl=ttl_seconds)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def get_access_token_async(
|
||||
credentials: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
auth_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Async version of get_access_token."""
|
||||
credentials = credentials or _get_credentials()
|
||||
if not credentials:
|
||||
raise GigaChatAuthError(
|
||||
status_code=401,
|
||||
message="GigaChat credentials not provided. Set GIGACHAT_CREDENTIALS or GIGACHAT_API_KEY environment variable.",
|
||||
)
|
||||
|
||||
scope = scope or _get_scope()
|
||||
auth_url = auth_url or _get_auth_url()
|
||||
|
||||
# Check cache
|
||||
cache_key = f"gigachat_token:{credentials[:16]}"
|
||||
cached = _token_cache.get_cache(cache_key)
|
||||
if cached:
|
||||
token, expires_at = cached
|
||||
if time.time() * 1000 < expires_at - TOKEN_EXPIRY_BUFFER_MS:
|
||||
verbose_logger.debug("Using cached GigaChat access token")
|
||||
return token
|
||||
|
||||
# Request new token
|
||||
token, expires_at = await _request_token_async(credentials, scope, auth_url)
|
||||
|
||||
# Cache token
|
||||
ttl_seconds = max(
|
||||
0, (expires_at - TOKEN_EXPIRY_BUFFER_MS - time.time() * 1000) / 1000
|
||||
)
|
||||
if ttl_seconds > 0:
|
||||
_token_cache.set_cache(cache_key, (token, expires_at), ttl=ttl_seconds)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def _request_token_sync(
|
||||
credentials: str,
|
||||
scope: str,
|
||||
auth_url: str,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Request new access token from GigaChat OAuth endpoint (sync).
|
||||
|
||||
Returns:
|
||||
Tuple of (access_token, expires_at_ms)
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Basic {credentials}",
|
||||
"RqUID": str(uuid.uuid4()),
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
data = {"scope": scope}
|
||||
|
||||
verbose_logger.debug(f"Requesting GigaChat access token from {auth_url}")
|
||||
|
||||
try:
|
||||
client = _get_http_client()
|
||||
response = client.post(auth_url, headers=headers, data=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
return _parse_token_response(response)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise GigaChatAuthError(
|
||||
status_code=e.response.status_code,
|
||||
message=f"GigaChat authentication failed: {e.response.text}",
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise GigaChatAuthError(
|
||||
status_code=500,
|
||||
message=f"GigaChat authentication request failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
async def _request_token_async(
|
||||
credentials: str,
|
||||
scope: str,
|
||||
auth_url: str,
|
||||
) -> Tuple[str, int]:
|
||||
"""Async version of _request_token_sync."""
|
||||
headers = {
|
||||
"Authorization": f"Basic {credentials}",
|
||||
"RqUID": str(uuid.uuid4()),
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
data = {"scope": scope}
|
||||
|
||||
verbose_logger.debug(f"Requesting GigaChat access token from {auth_url}")
|
||||
|
||||
try:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.GIGACHAT,
|
||||
params={"ssl_verify": False},
|
||||
)
|
||||
response = await client.post(auth_url, headers=headers, data=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
return _parse_token_response(response)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise GigaChatAuthError(
|
||||
status_code=e.response.status_code,
|
||||
message=f"GigaChat authentication failed: {e.response.text}",
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise GigaChatAuthError(
|
||||
status_code=500,
|
||||
message=f"GigaChat authentication request failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
def _parse_token_response(response: httpx.Response) -> Tuple[str, int]:
|
||||
"""Parse OAuth token response."""
|
||||
data = response.json()
|
||||
|
||||
# GigaChat returns either 'tok'/'exp' or 'access_token'/'expires_at'
|
||||
access_token = data.get("tok") or data.get("access_token")
|
||||
expires_at = data.get("exp") or data.get("expires_at")
|
||||
|
||||
if not access_token:
|
||||
raise GigaChatAuthError(
|
||||
status_code=500,
|
||||
message=f"Invalid token response: {data}",
|
||||
)
|
||||
|
||||
# expires_at is in milliseconds
|
||||
if isinstance(expires_at, str):
|
||||
expires_at = int(expires_at)
|
||||
|
||||
verbose_logger.debug("GigaChat access token obtained successfully")
|
||||
return access_token, expires_at
|
||||
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
GigaChat Chat Module
|
||||
"""
|
||||
|
||||
from .transformation import GigaChatConfig, GigaChatError
|
||||
from .streaming import GigaChatModelResponseIterator
|
||||
|
||||
__all__ = [
|
||||
"GigaChatConfig",
|
||||
"GigaChatError",
|
||||
"GigaChatModelResponseIterator",
|
||||
]
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
GigaChat Streaming Response Handler
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
|
||||
|
||||
class GigaChatModelResponseIterator:
|
||||
"""Iterator for GigaChat streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_response: Any,
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
"""Parse a single streaming chunk from GigaChat."""
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
)
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta", {})
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
# Extract text content
|
||||
text = delta.get("content", "") or ""
|
||||
|
||||
# Handle function_call in stream
|
||||
if finish_reason == "function_call" and delta.get("function_call"):
|
||||
func_call = delta["function_call"]
|
||||
args = func_call.get("arguments", {})
|
||||
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_use = ChatCompletionToolCallChunk(
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=func_call.get("name", ""),
|
||||
arguments=args,
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
if finish_reason is not None:
|
||||
is_finished = True
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason or "",
|
||||
usage=None,
|
||||
index=choice.get("index", 0),
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
if isinstance(chunk, str):
|
||||
# Parse SSE format: data: {...}
|
||||
if chunk.startswith("data: "):
|
||||
chunk = chunk[6:]
|
||||
if chunk.strip() == "[DONE]":
|
||||
raise StopIteration
|
||||
try:
|
||||
chunk = json.loads(chunk)
|
||||
except json.JSONDecodeError:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
)
|
||||
return self.chunk_parser(chunk)
|
||||
except StopIteration:
|
||||
raise
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk = await self.response_iterator.__anext__()
|
||||
if isinstance(chunk, str):
|
||||
# Parse SSE format
|
||||
if chunk.startswith("data: "):
|
||||
chunk = chunk[6:]
|
||||
if chunk.strip() == "[DONE]":
|
||||
raise StopAsyncIteration
|
||||
try:
|
||||
chunk = json.loads(chunk)
|
||||
except json.JSONDecodeError:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
)
|
||||
return self.chunk_parser(chunk)
|
||||
except StopAsyncIteration:
|
||||
raise
|
||||
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
GigaChat Chat Transformation
|
||||
|
||||
Transforms OpenAI-format requests to GigaChat format and back.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
from ..authenticator import get_access_token
|
||||
from ..file_handler import upload_file_sync
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
# GigaChat API endpoint
|
||||
GIGACHAT_BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1"
|
||||
|
||||
|
||||
def is_valid_json(value: str) -> bool:
|
||||
"""Checks whether the value passed is a valid serialized JSON string"""
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
class GigaChatError(BaseLLMException):
|
||||
"""GigaChat API error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GigaChatConfig(BaseConfig):
|
||||
"""
|
||||
Configuration class for GigaChat API.
|
||||
|
||||
GigaChat is Sber's (Russia's largest bank) LLM API.
|
||||
|
||||
Supported parameters:
|
||||
temperature: Sampling temperature (0-2, default 0.87)
|
||||
top_p: Nucleus sampling parameter
|
||||
max_tokens: Maximum tokens to generate
|
||||
repetition_penalty: Repetition penalty factor
|
||||
profanity_check: Enable content filtering
|
||||
stream: Enable streaming
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
profanity_check: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
profanity_check: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
# Instance variables for current request context
|
||||
self._current_credentials: Optional[str] = None
|
||||
self._current_api_base: Optional[str] = None
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""Get complete API URL for chat completions."""
|
||||
base = api_base or get_secret_str("GIGACHAT_API_BASE") or GIGACHAT_BASE_URL
|
||||
return f"{base}/chat/completions"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Set up headers with OAuth token.
|
||||
"""
|
||||
# Get access token
|
||||
credentials = (
|
||||
api_key
|
||||
or get_secret_str("GIGACHAT_CREDENTIALS")
|
||||
or get_secret_str("GIGACHAT_API_KEY")
|
||||
)
|
||||
access_token = get_access_token(credentials=credentials)
|
||||
|
||||
# Store credentials for image uploads
|
||||
self._current_credentials = credentials
|
||||
self._current_api_base = api_base
|
||||
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["Accept"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""Return list of supported OpenAI parameters."""
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"functions",
|
||||
"function_call",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""Map OpenAI parameters to GigaChat parameters."""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
elif param == "temperature":
|
||||
# GigaChat: temperature 0 means use top_p=0 instead
|
||||
if value == 0:
|
||||
optional_params["top_p"] = 0
|
||||
else:
|
||||
optional_params["temperature"] = value
|
||||
elif param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
elif param in ("max_tokens", "max_completion_tokens"):
|
||||
optional_params["max_tokens"] = value
|
||||
elif param == "stop":
|
||||
# GigaChat doesn't support stop sequences
|
||||
pass
|
||||
elif param == "tools":
|
||||
# Convert tools to functions format
|
||||
optional_params["functions"] = self._convert_tools_to_functions(value)
|
||||
elif param == "tool_choice":
|
||||
# Map OpenAI tool_choice to GigaChat function_call
|
||||
mapped_choice = self._map_tool_choice(value)
|
||||
if mapped_choice is not None:
|
||||
optional_params["function_call"] = mapped_choice
|
||||
elif param == "functions":
|
||||
optional_params["functions"] = value
|
||||
elif param == "function_call":
|
||||
optional_params["function_call"] = value
|
||||
elif param == "response_format":
|
||||
# Handle structured output via function calling
|
||||
if value.get("type") == "json_schema":
|
||||
json_schema = value.get("json_schema", {})
|
||||
schema_name = json_schema.get("name", "structured_output")
|
||||
schema = json_schema.get("schema", {})
|
||||
|
||||
function_def = {
|
||||
"name": schema_name,
|
||||
"description": f"Output structured response: {schema_name}",
|
||||
"parameters": schema,
|
||||
}
|
||||
|
||||
if "functions" not in optional_params:
|
||||
optional_params["functions"] = []
|
||||
optional_params["functions"].append(function_def)
|
||||
optional_params["function_call"] = {"name": schema_name}
|
||||
optional_params["_structured_output"] = True
|
||||
|
||||
return optional_params
|
||||
|
||||
def _convert_tools_to_functions(self, tools: List[dict]) -> List[dict]:
|
||||
"""Convert OpenAI tools format to GigaChat functions format."""
|
||||
functions = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
functions.append(
|
||||
{
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
return functions
|
||||
|
||||
def _map_tool_choice(
|
||||
self, tool_choice: Union[str, dict]
|
||||
) -> Optional[Union[str, dict]]:
|
||||
"""
|
||||
Map OpenAI tool_choice to GigaChat function_call format.
|
||||
|
||||
OpenAI format:
|
||||
- "auto": Call zero, one, or multiple functions (default)
|
||||
- "required": Call one or more functions
|
||||
- "none": Don't call any functions
|
||||
- {"type": "function", "function": {"name": "get_weather"}}: Force specific function
|
||||
|
||||
GigaChat format:
|
||||
- "none": Disable function calls
|
||||
- "auto": Automatic mode (default)
|
||||
- {"name": "get_weather"}: Force specific function
|
||||
|
||||
Args:
|
||||
tool_choice: OpenAI tool_choice value
|
||||
|
||||
Returns:
|
||||
GigaChat function_call value or None
|
||||
"""
|
||||
if tool_choice == "none":
|
||||
return "none"
|
||||
elif tool_choice == "auto":
|
||||
return "auto"
|
||||
elif tool_choice == "required":
|
||||
# GigaChat doesn't have a direct "required" equivalent
|
||||
# Use "auto" as the closest behavior
|
||||
return "auto"
|
||||
elif isinstance(tool_choice, dict):
|
||||
# OpenAI format: {"type": "function", "function": {"name": "func_name"}}
|
||||
# GigaChat format: {"name": "func_name"}
|
||||
if tool_choice.get("type") == "function":
|
||||
func_name = tool_choice.get("function", {}).get("name")
|
||||
if func_name:
|
||||
return {"name": func_name}
|
||||
|
||||
# Default to None (don't set function_call)
|
||||
return None
|
||||
|
||||
def _upload_image(self, image_url: str) -> Optional[str]:
|
||||
"""
|
||||
Upload image to GigaChat and return file_id.
|
||||
|
||||
Args:
|
||||
image_url: URL or base64 data URL of the image
|
||||
|
||||
Returns:
|
||||
file_id string or None if upload failed
|
||||
"""
|
||||
try:
|
||||
return upload_file_sync(
|
||||
image_url=image_url,
|
||||
credentials=self._current_credentials,
|
||||
api_base=self._current_api_base,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to upload image: {e}")
|
||||
return None
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""Transform OpenAI request to GigaChat format."""
|
||||
# Transform messages
|
||||
giga_messages = self._transform_messages(messages)
|
||||
|
||||
# Build request
|
||||
request_data = {
|
||||
"model": model.replace("gigachat/", ""),
|
||||
"messages": giga_messages,
|
||||
}
|
||||
|
||||
# Add optional params
|
||||
for key in [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"repetition_penalty",
|
||||
"profanity_check",
|
||||
]:
|
||||
if key in optional_params:
|
||||
request_data[key] = optional_params[key]
|
||||
|
||||
# Add functions if present
|
||||
if "functions" in optional_params:
|
||||
request_data["functions"] = optional_params["functions"]
|
||||
if "function_call" in optional_params:
|
||||
request_data["function_call"] = optional_params["function_call"]
|
||||
|
||||
return request_data
|
||||
|
||||
def _transform_messages(self, messages: List[AllMessageValues]) -> List[dict]:
|
||||
"""Transform OpenAI messages to GigaChat format."""
|
||||
transformed = []
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
message = dict(msg)
|
||||
|
||||
# Remove unsupported fields
|
||||
message.pop("name", None)
|
||||
|
||||
# Transform roles
|
||||
role = message.get("role", "user")
|
||||
if role == "developer":
|
||||
message["role"] = "system"
|
||||
elif role == "system" and i > 0:
|
||||
# GigaChat only allows system message as first message
|
||||
message["role"] = "user"
|
||||
elif role == "tool":
|
||||
message["role"] = "function"
|
||||
content = message.get("content", "")
|
||||
if not isinstance(content, str) or not is_valid_json(content):
|
||||
message["content"] = json.dumps(content, ensure_ascii=False)
|
||||
|
||||
# Handle None content
|
||||
if message.get("content") is None:
|
||||
message["content"] = ""
|
||||
|
||||
# Handle list content (multimodal) - extract text and images
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
texts = []
|
||||
attachments = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if part.get("type") == "text":
|
||||
texts.append(part.get("text", ""))
|
||||
elif part.get("type") == "image_url":
|
||||
# Extract image URL and upload to GigaChat
|
||||
image_url = part.get("image_url", {})
|
||||
if isinstance(image_url, str):
|
||||
url = image_url
|
||||
else:
|
||||
url = image_url.get("url", "")
|
||||
if url:
|
||||
file_id = self._upload_image(url)
|
||||
if file_id:
|
||||
attachments.append(file_id)
|
||||
message["content"] = "\n".join(texts) if texts else ""
|
||||
if attachments:
|
||||
message["attachments"] = attachments
|
||||
|
||||
# Transform tool_calls to function_call
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
tool_call = tool_calls[0]
|
||||
func = tool_call.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
message["function_call"] = {
|
||||
"name": func.get("name", ""),
|
||||
"arguments": args,
|
||||
}
|
||||
message.pop("tool_calls", None)
|
||||
|
||||
transformed.append(message)
|
||||
|
||||
return transformed
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""Transform GigaChat response to OpenAI format."""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise GigaChatError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Invalid JSON response: {raw_response.text}",
|
||||
)
|
||||
|
||||
is_structured_output = optional_params.get("_structured_output", False)
|
||||
|
||||
choices = []
|
||||
for choice in response_json.get("choices", []):
|
||||
message_data = choice.get("message", {})
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
# Transform function_call to tool_calls or content
|
||||
if finish_reason == "function_call" and message_data.get("function_call"):
|
||||
func_call = message_data["function_call"]
|
||||
args = func_call.get("arguments", {})
|
||||
|
||||
if is_structured_output:
|
||||
# Convert to content for structured output
|
||||
if isinstance(args, dict):
|
||||
content = json.dumps(args, ensure_ascii=False)
|
||||
else:
|
||||
content = str(args)
|
||||
message_data["content"] = content
|
||||
message_data.pop("function_call", None)
|
||||
message_data.pop("functions_state_id", None)
|
||||
finish_reason = "stop"
|
||||
else:
|
||||
# Convert to tool_calls format
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
message_data["tool_calls"] = [
|
||||
{
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_call.get("name", ""),
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
]
|
||||
message_data.pop("function_call", None)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
# Clean up GigaChat-specific fields
|
||||
message_data.pop("functions_state_id", None)
|
||||
|
||||
choices.append(
|
||||
Choices(
|
||||
index=choice.get("index", 0),
|
||||
message=Message(
|
||||
role=message_data.get("role", "assistant"),
|
||||
content=message_data.get("content"),
|
||||
tool_calls=message_data.get("tool_calls"),
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# Build usage
|
||||
usage_data = response_json.get("usage", {})
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
model_response.id = response_json.get("id", f"chatcmpl-{uuid.uuid4().hex[:12]}")
|
||||
model_response.created = response_json.get("created", int(time.time()))
|
||||
model_response.model = model
|
||||
model_response.choices = choices # type: ignore
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: Union[dict, httpx.Headers],
|
||||
) -> BaseLLMException:
|
||||
"""Return GigaChat error class."""
|
||||
return GigaChatError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
"""Return streaming response iterator."""
|
||||
from .streaming import GigaChatModelResponseIterator
|
||||
|
||||
return GigaChatModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
GigaChat Embedding Module
|
||||
"""
|
||||
|
||||
from .transformation import GigaChatEmbeddingConfig
|
||||
|
||||
__all__ = ["GigaChatEmbeddingConfig"]
|
||||
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
GigaChat Embedding Transformation
|
||||
|
||||
Transforms OpenAI /v1/embeddings format to GigaChat format.
|
||||
API Documentation: https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/post-embeddings
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import LlmProviders
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..authenticator import get_access_token
|
||||
|
||||
# GigaChat API endpoint
|
||||
GIGACHAT_BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1"
|
||||
|
||||
|
||||
class GigaChatEmbeddingError(BaseLLMException):
|
||||
"""GigaChat Embedding API error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GigaChatEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Configuration class for GigaChat Embeddings API.
|
||||
|
||||
GigaChat embeddings endpoint: POST /api/v1/embeddings
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""GigaChat embeddings don't support additional parameters."""
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""Map OpenAI params to GigaChat format (no special mapping needed)."""
|
||||
return optional_params
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns provider info for GigaChat.
|
||||
|
||||
Returns:
|
||||
Tuple of (custom_llm_provider, api_base, dynamic_api_key)
|
||||
"""
|
||||
api_base = api_base or GIGACHAT_BASE_URL
|
||||
return LlmProviders.GIGACHAT.value, api_base, api_key
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""Get the complete URL for embeddings endpoint."""
|
||||
base = api_base or GIGACHAT_BASE_URL
|
||||
return f"{base}/embeddings"
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI embedding request to GigaChat format.
|
||||
|
||||
GigaChat format:
|
||||
{
|
||||
"model": "Embeddings",
|
||||
"input": ["text1", "text2", ...]
|
||||
}
|
||||
"""
|
||||
# Normalize input to list
|
||||
if isinstance(input, str):
|
||||
input_list: list = [input]
|
||||
elif isinstance(input, list):
|
||||
input_list = input
|
||||
else:
|
||||
input_list = [input]
|
||||
|
||||
# Remove gigachat/ prefix from model if present
|
||||
if model.startswith("gigachat/"):
|
||||
model = model[9:]
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"input": input_list,
|
||||
}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transform GigaChat embedding response to OpenAI format.
|
||||
|
||||
GigaChat returns:
|
||||
{
|
||||
"object": "list",
|
||||
"data": [{"object": "embedding", "embedding": [...], "index": 0, "usage": {...}}],
|
||||
"model": "Embeddings"
|
||||
}
|
||||
"""
|
||||
response_json = raw_response.json()
|
||||
|
||||
# Log response
|
||||
logging_obj.post_call(
|
||||
input=request_data.get("input"),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
original_response=response_json,
|
||||
)
|
||||
|
||||
# Calculate total tokens from individual embeddings
|
||||
total_tokens = 0
|
||||
if "data" in response_json:
|
||||
for emb in response_json["data"]:
|
||||
if "usage" in emb and "prompt_tokens" in emb["usage"]:
|
||||
total_tokens += emb["usage"]["prompt_tokens"]
|
||||
# Remove usage from individual embeddings (not part of OpenAI format)
|
||||
if "usage" in emb:
|
||||
del emb["usage"]
|
||||
|
||||
# Set overall usage
|
||||
response_json["usage"] = {
|
||||
"prompt_tokens": total_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
return EmbeddingResponse(**response_json)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Set up headers with OAuth token for GigaChat.
|
||||
"""
|
||||
# Get access token via OAuth
|
||||
access_token = get_access_token(api_key)
|
||||
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
"""Return GigaChat-specific error class."""
|
||||
return GigaChatEmbeddingError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
GigaChat File Handler
|
||||
|
||||
Handles file uploads to GigaChat API for image processing.
|
||||
GigaChat requires files to be uploaded first, then referenced by file_id.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
from .authenticator import get_access_token, get_access_token_async
|
||||
|
||||
# GigaChat API endpoint
|
||||
GIGACHAT_BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1"
|
||||
|
||||
# Simple in-memory cache for file IDs
|
||||
_file_cache: Dict[str, str] = {}
|
||||
|
||||
|
||||
def _get_url_hash(url: str) -> str:
|
||||
"""Generate hash for URL to use as cache key."""
|
||||
return hashlib.sha256(url.encode()).hexdigest()
|
||||
|
||||
|
||||
def _parse_data_url(data_url: str) -> Optional[Tuple[bytes, str, str]]:
|
||||
"""
|
||||
Parse data URL (base64 image).
|
||||
|
||||
Returns:
|
||||
Tuple of (content_bytes, content_type, extension) or None
|
||||
"""
|
||||
match = re.match(r"data:([^;]+);base64,(.+)", data_url)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
content_type = match.group(1)
|
||||
base64_data = match.group(2)
|
||||
content_bytes = base64.b64decode(base64_data)
|
||||
ext = content_type.split("/")[-1].split(";")[0] or "jpg"
|
||||
|
||||
return content_bytes, content_type, ext
|
||||
|
||||
|
||||
def _download_image_sync(url: str) -> Tuple[bytes, str, str]:
|
||||
"""Download image from URL synchronously."""
|
||||
client = _get_httpx_client(params={"ssl_verify": False})
|
||||
response = client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "image/jpeg")
|
||||
ext = content_type.split("/")[-1].split(";")[0] or "jpg"
|
||||
|
||||
return response.content, content_type, ext
|
||||
|
||||
|
||||
async def _download_image_async(url: str) -> Tuple[bytes, str, str]:
|
||||
"""Download image from URL asynchronously."""
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.GIGACHAT,
|
||||
params={"ssl_verify": False},
|
||||
)
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "image/jpeg")
|
||||
ext = content_type.split("/")[-1].split(";")[0] or "jpg"
|
||||
|
||||
return response.content, content_type, ext
|
||||
|
||||
|
||||
def upload_file_sync(
|
||||
image_url: str,
|
||||
credentials: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Upload file to GigaChat and return file_id (sync).
|
||||
|
||||
Args:
|
||||
image_url: URL or base64 data URL of the image
|
||||
credentials: GigaChat credentials for auth
|
||||
api_base: Optional custom API base URL
|
||||
|
||||
Returns:
|
||||
file_id string or None if upload failed
|
||||
"""
|
||||
url_hash = _get_url_hash(image_url)
|
||||
|
||||
# Check cache
|
||||
if url_hash in _file_cache:
|
||||
verbose_logger.debug(f"Image found in cache: {url_hash[:16]}...")
|
||||
return _file_cache[url_hash]
|
||||
|
||||
try:
|
||||
# Get image data
|
||||
parsed = _parse_data_url(image_url)
|
||||
if parsed:
|
||||
content_bytes, content_type, ext = parsed
|
||||
verbose_logger.debug("Decoded base64 image")
|
||||
else:
|
||||
verbose_logger.debug(f"Downloading image from URL: {image_url[:80]}...")
|
||||
content_bytes, content_type, ext = _download_image_sync(image_url)
|
||||
|
||||
filename = f"{uuid.uuid4()}.{ext}"
|
||||
|
||||
# Get access token
|
||||
access_token = get_access_token(credentials)
|
||||
|
||||
# Upload to GigaChat
|
||||
base_url = api_base or GIGACHAT_BASE_URL
|
||||
upload_url = f"{base_url}/files"
|
||||
|
||||
client = _get_httpx_client(params={"ssl_verify": False})
|
||||
response = client.post(
|
||||
upload_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
files={"file": (filename, content_bytes, content_type)},
|
||||
data={"purpose": "general"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
file_id = result.get("id")
|
||||
if file_id:
|
||||
_file_cache[url_hash] = file_id
|
||||
verbose_logger.debug(f"File uploaded successfully, file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error uploading file to GigaChat: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def upload_file_async(
|
||||
image_url: str,
|
||||
credentials: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Upload file to GigaChat and return file_id (async).
|
||||
|
||||
Args:
|
||||
image_url: URL or base64 data URL of the image
|
||||
credentials: GigaChat credentials for auth
|
||||
api_base: Optional custom API base URL
|
||||
|
||||
Returns:
|
||||
file_id string or None if upload failed
|
||||
"""
|
||||
url_hash = _get_url_hash(image_url)
|
||||
|
||||
# Check cache
|
||||
if url_hash in _file_cache:
|
||||
verbose_logger.debug(f"Image found in cache: {url_hash[:16]}...")
|
||||
return _file_cache[url_hash]
|
||||
|
||||
try:
|
||||
# Get image data
|
||||
parsed = _parse_data_url(image_url)
|
||||
if parsed:
|
||||
content_bytes, content_type, ext = parsed
|
||||
verbose_logger.debug("Decoded base64 image")
|
||||
else:
|
||||
verbose_logger.debug(f"Downloading image from URL: {image_url[:80]}...")
|
||||
content_bytes, content_type, ext = await _download_image_async(image_url)
|
||||
|
||||
filename = f"{uuid.uuid4()}.{ext}"
|
||||
|
||||
# Get access token
|
||||
access_token = await get_access_token_async(credentials)
|
||||
|
||||
# Upload to GigaChat
|
||||
base_url = api_base or GIGACHAT_BASE_URL
|
||||
upload_url = f"{base_url}/files"
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.GIGACHAT,
|
||||
params={"ssl_verify": False},
|
||||
)
|
||||
response = await client.post(
|
||||
upload_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
files={"file": (filename, content_bytes, content_type)},
|
||||
data={"purpose": "general"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
file_id = result.get("id")
|
||||
if file_id:
|
||||
_file_cache[url_hash] = file_id
|
||||
verbose_logger.debug(f"File uploaded successfully, file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error uploading file to GigaChat: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user