Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/azure/passthrough/transformation.py

86 lines
2.6 KiB
Python
Raw Normal View History

from typing import TYPE_CHECKING, List, Optional, Tuple
import httpx
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
from httpx import URL
class AzurePassthroughConfig(BasePassthroughConfig):
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
return "stream" in request_data
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
endpoint: str,
request_query_params: Optional[dict],
litellm_params: dict,
) -> Tuple["URL", str]:
base_target_url = self.get_api_base(api_base)
if base_target_url is None:
raise Exception("Azure api base not found")
litellm_metadata = litellm_params.get("litellm_metadata") or {}
model_group = litellm_metadata.get("model_group")
if model_group and model_group in endpoint:
endpoint = endpoint.replace(model_group, model)
complete_url = BaseAzureLLM._get_base_azure_url(
api_base=base_target_url,
litellm_params=litellm_params,
route=endpoint,
default_api_version=litellm_params.get("api_version"),
)
return (
httpx.URL(complete_url),
base_target_url,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return BaseAzureLLM._base_validate_azure_environment(
headers=headers,
litellm_params=GenericLiteLLMParams(
**{**litellm_params, "api_key": api_key}
),
)
@staticmethod
def get_api_base(
api_base: Optional[str] = None,
) -> Optional[str]:
return api_base or get_secret_str("AZURE_API_BASE")
@staticmethod
def get_api_key(
api_key: Optional[str] = None,
) -> Optional[str]:
return api_key or get_secret_str("AZURE_API_KEY")
@staticmethod
def get_base_model(model: str) -> Optional[str]:
return model
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
return super().get_models(api_key, api_base)