Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/auth/auth_checks.py
2026-03-26 16:04:46 +08:00

3562 lines
118 KiB
Python

# What is this?
## Common auth checks between jwt + key based auth
"""
Got Valid Token from Cache, DB
Run checks for:
1. If user can call model
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import asyncio
import re
import time
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from fastapi import HTTPException, Request, status
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.constants import (
CLI_JWT_EXPIRATION_HOURS,
CLI_JWT_TOKEN_NAME,
DEFAULT_ACCESS_GROUP_CACHE_TTL,
DEFAULT_IN_MEMORY_TTL,
DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
DEFAULT_MAX_RECURSE_DEPTH,
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE,
)
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.proxy._types import (
RBAC_ROLES,
CallInfo,
LiteLLM_AccessGroupTable,
LiteLLM_BudgetTable,
LiteLLM_EndUserTable,
Litellm_EntityType,
LiteLLM_JWTAuth,
LiteLLM_ObjectPermissionTable,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable,
LiteLLM_ProjectTableCachedObj,
LiteLLM_TagTable,
LiteLLM_TeamMembership,
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
LiteLLM_UserTable,
LiteLLMRoutes,
LitellmUserRoles,
NewTeamRequest,
ProxyErrorTypes,
ProxyException,
RoleBasedPermissions,
SpecialModelNames,
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.guardrails.tool_name_extraction import (
TOOL_CAPABLE_CALL_TYPES,
extract_request_tool_names,
)
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.router import Router
from litellm.utils import get_utc_datetime
from .auth_checks_organization import organization_role_based_access_check
from .auth_utils import get_model_from_request
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
db_cache_expiry = DEFAULT_IN_MEMORY_TTL # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def _log_budget_lookup_failure(entity: str, error: Exception) -> None:
"""
Log a warning when budget lookup fails; cache will not be populated.
Skips logging for expected "user not found" cases (bare Exception from
get_user_object when user_id_upsert=False). Adds a schema migration hint
when the error appears schema-related.
"""
# Skip logging for expected "user not found" - not caching is correct
if str(error) == "" and type(error).__name__ == "Exception":
return
err_str = str(error).lower()
hint = ""
if any(
x in err_str
for x in ("column", "schema", "does not exist", "prisma", "migrate")
):
hint = (
" Run `prisma db push` or `prisma migrate deploy` to fix schema mismatches."
)
verbose_proxy_logger.error(
f"Budget lookup failed for {entity}; cache will not be populated. "
f"Each request will hit the database. Error: {error}.{hint}"
)
def _is_model_cost_zero(
model: Optional[Union[str, List[str]]], llm_router: Optional[Router]
) -> bool:
"""
Check if a model has zero cost (no configured pricing).
Uses the router's get_model_group_info method to get pricing information.
Args:
model: The model name or list of model names
llm_router: The LiteLLM router instance
Returns:
bool: True if all costs for the model are zero, False otherwise
"""
if model is None or llm_router is None:
return False
# Handle list of models
model_list = [model] if isinstance(model, str) else model
for model_name in model_list:
try:
# Use router's get_model_group_info method directly for better reliability
model_group_info = llm_router.get_model_group_info(model_group=model_name)
if model_group_info is None:
# Model not found or no pricing info available
# Conservative approach: assume it has cost
verbose_proxy_logger.debug(
f"No model group info found for {model_name}, assuming it has cost"
)
return False
# Check costs for this model
# Only allow bypass if BOTH costs are explicitly set to 0 (not None)
input_cost = model_group_info.input_cost_per_token
output_cost = model_group_info.output_cost_per_token
# If costs are not explicitly configured (None), assume it has cost
if input_cost is None or output_cost is None:
verbose_proxy_logger.debug(
f"Model {model_name} has undefined cost (input: {input_cost}, output: {output_cost}), assuming it has cost"
)
return False
# If either cost is non-zero, return False
if input_cost > 0 or output_cost > 0:
verbose_proxy_logger.debug(
f"Model {model_name} has non-zero cost (input: {input_cost}, output: {output_cost})"
)
return False
# This model has zero cost explicitly configured
verbose_proxy_logger.debug(
f"Model {model_name} has zero cost explicitly configured (input: {input_cost}, output: {output_cost})"
)
except Exception as e:
# If we can't determine the cost, assume it has cost (conservative approach)
verbose_proxy_logger.debug(
f"Error checking cost for model {model_name}: {str(e)}, assuming it has cost"
)
return False
# All models checked have zero cost
return True
async def _run_project_checks(
project_object: Optional[LiteLLM_ProjectTableCachedObj],
_model: Optional[Union[str, List[str]]],
llm_router: Optional[Router],
skip_budget_checks: bool,
valid_token: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
) -> None:
"""
Run all project-level checks: blocked, model access, budget, soft budget.
Extracted from common_checks() to keep statement count manageable.
"""
if project_object is None:
return
# 1.1. If project is blocked
if project_object.blocked is True:
raise Exception(
f"Project={project_object.project_id} is blocked. Update via `/project/update` if you're an admin."
)
# 2.2 If project can call model
if _model and len(project_object.models) > 0:
can_project_access_model(
model=_model,
project_object=project_object,
llm_router=llm_router,
)
if not skip_budget_checks:
# 3.0.2. If project is in budget
await _project_max_budget_check(
project_object=project_object,
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
# 3.0.3. If project is over soft budget (alert only, doesn't block)
await _project_soft_budget_check(
project_object=project_object,
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
def _enforce_user_param_check(
general_settings: dict, request: Request, request_body: dict, route: str
) -> None:
if not general_settings.get("enforce_user_param", False):
return
http_method = request.method if hasattr(request, "method") else None
is_post_method = http_method and http_method.upper() == "POST"
is_openai_route = RouteChecks.is_llm_api_route(route=route)
is_mcp_route = (
route in LiteLLMRoutes.mcp_routes.value
or RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
)
)
if (
is_post_method
and is_openai_route
and not is_mcp_route
and "user" not in request_body
):
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
def _reject_clientside_metadata_tags_check(
general_settings: dict, request_body: dict, route: str
) -> None:
if not general_settings.get("reject_clientside_metadata_tags", False):
return
if (
RouteChecks.is_llm_api_route(route=route)
and "metadata" in request_body
and isinstance(request_body["metadata"], dict)
and "tags" in request_body["metadata"]
):
raise ProxyException(
message=f"Client-side 'metadata.tags' not allowed in request. 'reject_clientside_metadata_tags'={general_settings['reject_clientside_metadata_tags']}. Tags can only be set via API key metadata.",
type=ProxyErrorTypes.bad_request_error,
param="metadata.tags",
code=status.HTTP_400_BAD_REQUEST,
)
def _global_proxy_budget_check(
global_proxy_spend: Optional[float], skip_budget_checks: bool, route: str
) -> None:
if (
litellm.max_budget > 0
and not skip_budget_checks
and global_proxy_spend is not None
and RouteChecks.is_llm_api_route(route=route)
and route != "/v1/models"
and route != "/models"
):
if global_proxy_spend > litellm.max_budget:
raise litellm.BudgetExceededError(
current_cost=global_proxy_spend, max_budget=litellm.max_budget
)
def _guardrail_modification_check(
request_body: dict, team_object: Optional[LiteLLM_TeamTable]
) -> None:
_request_metadata: dict = request_body.get("metadata", {}) or {}
if not _request_metadata.get("guardrails"):
return
from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails
if not can_modify_guardrails(team_object):
raise HTTPException(
status_code=403,
detail={
"error": "Your team does not have permission to modify guardrails."
},
)
async def check_tools_allowlist(
request_body: dict,
valid_token: Optional[UserAPIKeyAuth],
team_object: Optional[LiteLLM_TeamTable],
route: str,
) -> None:
"""
Enforce key/team tool allowlist (metadata.allowed_tools). No DB in hot path —
effective allowlist is read from valid_token.metadata and valid_token.team_metadata.
Raises ProxyException with tool_access_denied if a tool is not allowed.
"""
from litellm.litellm_core_utils.api_route_to_call_types import (
get_call_types_for_route,
)
if valid_token is None:
return
call_types = get_call_types_for_route(route)
if not call_types or not any(
ct.value in TOOL_CAPABLE_CALL_TYPES for ct in call_types
):
return
tool_names = extract_request_tool_names(route, request_body)
if not tool_names:
return
key_meta = (
(valid_token.metadata or {}) if isinstance(valid_token.metadata, dict) else {}
)
team_meta = (
(valid_token.team_metadata or {})
if isinstance(valid_token.team_metadata, dict)
else {}
)
key_allowed = key_meta.get("allowed_tools")
team_allowed = team_meta.get("allowed_tools")
effective = (
key_allowed
if (isinstance(key_allowed, list) and len(key_allowed) > 0)
else team_allowed
)
if not isinstance(effective, list) or len(effective) == 0:
return
allowed_set = {str(t) for t in effective}
disallowed = [n for n in tool_names if n not in allowed_set]
if disallowed:
raise ProxyException(
message=f"Tool(s) {disallowed} are not in the allowed tools list for this key/team.",
type=ProxyErrorTypes.tool_access_denied,
param="tools",
code=status.HTTP_403_FORBIDDEN,
)
async def common_checks( # noqa: PLR0915
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable],
end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float],
general_settings: dict,
route: str,
llm_router: Optional[Router],
proxy_logging_obj: ProxyLogging,
valid_token: Optional[UserAPIKeyAuth],
request: Request,
skip_budget_checks: bool = False,
project_object: Optional[LiteLLM_ProjectTableCachedObj] = None,
) -> bool:
"""
Common checks across jwt + key-based auth.
1. If team is blocked
1.1. If project is blocked
2. If team can call model
2.2 If project can call model
3. If team is in budget
3.0.2. If project is in budget
3.0.3. If project is over soft budget (alert only)
4. If user passed in (JWT or key.user_id) - is in budget
5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
8. [OPTIONAL] If guardrails modified - is request allowed to change this
9. Check if request body is safe
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store
"""
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
_model: Optional[Union[str, List[str]]] = get_model_from_request(
request_body, route
)
# 1. If team is blocked
if team_object is not None and team_object.blocked is True:
raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
)
# 2. If team can call model
if _model and team_object:
if not await can_team_access_model(
model=_model,
team_object=team_object,
llm_router=llm_router,
team_model_aliases=valid_token.team_model_aliases if valid_token else None,
):
raise ProxyException(
message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}",
type=ProxyErrorTypes.team_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
# Require trace id for agent keys when agent has require_trace_id_on_calls_by_agent
if valid_token is not None and valid_token.agent_id:
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers
agent = global_agent_registry.get_agent_by_id(agent_id=valid_token.agent_id)
if agent is not None:
require_trace_id = (agent.litellm_params or {}).get(
"require_trace_id_on_calls_by_agent"
)
if require_trace_id:
headers_dict = dict(request.headers)
trace_id = get_chain_id_from_headers(headers_dict)
if not trace_id:
raise ProxyException(
message="Requests made with this agent's key must include the x-litellm-trace-id header.",
type=ProxyErrorTypes.bad_request_error,
param=None,
code=status.HTTP_400_BAD_REQUEST,
)
## 2.1 If user can call model (if personal key)
if _model and team_object is None and user_object is not None:
await can_user_call_model(
model=_model,
llm_router=llm_router,
user_object=user_object,
)
# 1.1 - 2.2 - 3.0.2 - 3.0.3: Project checks (blocked, model access, budget)
await _run_project_checks(
project_object=project_object,
_model=_model,
llm_router=llm_router,
skip_budget_checks=skip_budget_checks,
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
# If this is a free model, skip all budget checks
if not skip_budget_checks:
# 3. If team is in budget
await _team_max_budget_check(
team_object=team_object,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# 3.0.5. If team is over soft budget (alert only, doesn't block)
await _team_soft_budget_check(
team_object=team_object,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# 3.1. If organization is in budget
await _organization_max_budget_check(
valid_token=valid_token,
team_object=team_object,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
await _tag_max_budget_check(
request_body=request_body,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# 4. If user is in budget
## 4.1 check personal budget, if personal key
if (
(team_object is None or team_object.team_id is None)
and user_object is not None
and user_object.max_budget is not None
):
user_budget = user_object.max_budget
if user_budget < user_object.spend:
raise litellm.BudgetExceededError(
current_cost=user_object.spend,
max_budget=user_budget,
message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
)
## 4.2 check team member budget, if team key
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if (
end_user_object is not None
and end_user_object.litellm_budget_table is not None
):
end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise litellm.BudgetExceededError(
current_cost=end_user_object.spend,
max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
)
_enforce_user_param_check(general_settings, request, request_body, route)
_reject_clientside_metadata_tags_check(general_settings, request_body, route)
_global_proxy_budget_check(global_proxy_spend, skip_budget_checks, route)
_guardrail_modification_check(request_body, team_object)
# 10 [OPTIONAL] Organization RBAC checks
organization_role_based_access_check(
user_object=user_object, route=route, request_body=request_body
)
token_team = getattr(valid_token, "team_id", None)
token_type: Literal["ui", "api"] = (
"ui" if token_team is not None and token_team == "litellm-dashboard" else "api"
)
_is_route_allowed = _is_allowed_route(
route=route,
token_type=token_type,
user_obj=user_object,
request=request,
request_data=request_body,
valid_token=valid_token,
)
# 11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store
await vector_store_access_check(
request_body=request_body,
team_object=team_object,
valid_token=valid_token,
)
# 12. [OPTIONAL] Tool allowlist - key/team allowed_tools (no DB in hot path)
await check_tools_allowlist(
request_body=request_body,
valid_token=valid_token,
team_object=team_object,
route=route,
)
return True
def _is_ui_route(
route: str,
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Check if the route is a UI used route
"""
# this token is only used for managing the ui
allowed_routes = LiteLLMRoutes.ui_routes.value
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
):
# Do something if the current route starts with any of the allowed routes
return True
elif any(
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
for allowed_route in allowed_routes
):
return True
return False
def _get_user_role(
user_obj: Optional[LiteLLM_UserTable],
) -> Optional[LitellmUserRoles]:
if user_obj is None:
return None
_user = user_obj
_user_role = _user.user_role
try:
role = LitellmUserRoles(_user_role)
except ValueError:
return LitellmUserRoles.INTERNAL_USER
return role
def _is_api_route_allowed(
route: str,
request: Request,
request_data: dict,
valid_token: Optional[UserAPIKeyAuth],
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Route b/w api token check and normal token check
"""
_user_role = _get_user_role(user_obj=user_obj)
if valid_token is None:
raise Exception("Invalid proxy server token passed. valid_token=None.")
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
RouteChecks.non_proxy_admin_allowed_routes_check(
user_obj=user_obj,
_user_role=_user_role,
route=route,
request=request,
request_data=request_data,
valid_token=valid_token,
)
return True
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
if user_obj is None:
return False
if (
user_obj.user_role is not None
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return True
if (
user_obj.user_role is not None
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return True
return False
def _is_allowed_route(
route: str,
token_type: Literal["ui", "api"],
request: Request,
request_data: dict,
valid_token: Optional[UserAPIKeyAuth],
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Route b/w ui token check and normal token check
"""
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
return True
else:
return _is_api_route_allowed(
route=route,
request=request,
request_data=request_data,
valid_token=valid_token,
user_obj=user_obj,
)
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
"""
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
Parameters:
- user_route: str - the route the user is trying to call
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
"""
from starlette.routing import compile_path
for allowed_route in allowed_routes:
if allowed_route in LiteLLMRoutes.__members__:
for template in LiteLLMRoutes[allowed_route].value:
regex, _, _ = compile_path(template)
if regex.match(user_route):
return True
elif allowed_route == user_route:
return True
return False
def allowed_routes_check(
user_role: LitellmUserRoles,
user_route: str,
litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool:
"""
Check if user -> not admin - allowed to access these routes
"""
if user_role == LitellmUserRoles.PROXY_ADMIN:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == LitellmUserRoles.TEAM:
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes
"""
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
)
return is_allowed
elif litellm_proxy_roles.team_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.team_allowed_routes,
)
return is_allowed
return False
def allowed_route_check_inside_route(
user_api_key_dict: UserAPIKeyAuth,
requested_user_id: Optional[str],
) -> bool:
ret_val = True
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
):
ret_val = False
if requested_user_id is not None and user_api_key_dict.user_id is not None:
if user_api_key_dict.user_id == requested_user_id:
ret_val = True
return ret_val
def get_actual_routes(allowed_routes: list) -> list:
actual_routes: list = []
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
if isinstance(route_value, set):
actual_routes.extend(list(route_value))
else:
actual_routes.extend(route_value)
except KeyError:
actual_routes.append(route_name)
return actual_routes
async def get_default_end_user_budget(
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
) -> Optional[LiteLLM_BudgetTable]:
"""
Fetches the default end user budget from the database if litellm.max_end_user_budget_id is configured.
This budget is applied to end users who don't have an explicit budget_id set.
Results are cached for performance.
Args:
prisma_client: Database client instance
user_api_key_cache: Cache for storing/retrieving budget data
parent_otel_span: Optional OpenTelemetry span for tracing
Returns:
LiteLLM_BudgetTable if configured and found, None otherwise
"""
if prisma_client is None or litellm.max_end_user_budget_id is None:
return None
cache_key = f"default_end_user_budget:{litellm.max_end_user_budget_id}"
# Check cache first
cached_budget = await user_api_key_cache.async_get_cache(key=cache_key)
if cached_budget is not None:
return LiteLLM_BudgetTable(**cached_budget)
# Fetch from database
try:
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
where={"budget_id": litellm.max_end_user_budget_id}
)
if budget_record is None:
verbose_proxy_logger.warning(
f"Default end user budget not found in database: {litellm.max_end_user_budget_id}"
)
return None
# Cache the budget for 60 seconds
await user_api_key_cache.async_set_cache(
key=cache_key,
value=budget_record.dict(),
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
return LiteLLM_BudgetTable(**budget_record.dict())
except Exception as e:
verbose_proxy_logger.error(f"Error fetching default end user budget: {str(e)}")
return None
async def _apply_default_budget_to_end_user(
end_user_obj: LiteLLM_EndUserTable,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
) -> LiteLLM_EndUserTable:
"""
Helper function to apply default budget to end user if they don't have a budget assigned.
Args:
end_user_obj: The end user object to potentially apply default budget to
prisma_client: Database client instance
user_api_key_cache: Cache for storing/retrieving data
parent_otel_span: Optional OpenTelemetry span for tracing
Returns:
Updated end user object with default budget applied if applicable
"""
# If end user already has a budget assigned, no need to apply default
if end_user_obj.litellm_budget_table is not None:
return end_user_obj
# If no default budget configured, return as-is
if litellm.max_end_user_budget_id is None:
return end_user_obj
# Fetch and apply default budget
default_budget = await get_default_end_user_budget(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
if default_budget is not None:
# Apply default budget to end user object
end_user_obj.litellm_budget_table = default_budget
verbose_proxy_logger.debug(
f"Applied default budget {litellm.max_end_user_budget_id} to end user {end_user_obj.user_id}"
)
return end_user_obj
def _check_end_user_budget(
end_user_obj: LiteLLM_EndUserTable,
route: str,
) -> None:
"""
Check if end user is within their budget limit.
Args:
end_user_obj: The end user object to check
route: The request route
Raises:
litellm.BudgetExceededError: If end user has exceeded their budget
"""
if RouteChecks.is_info_route(route):
return
if end_user_obj.litellm_budget_table is None:
return
end_user_budget = end_user_obj.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_obj.spend > end_user_budget:
raise litellm.BudgetExceededError(
current_cost=end_user_obj.spend,
max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_obj.user_id} over budget. Spend={end_user_obj.spend}, Budget={end_user_budget}",
)
@log_db_metrics
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
route: str,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_EndUserTable]:
"""
Returns end user object from database or cache.
If end user exists but has no budget_id, applies the default budget
(if configured via litellm.max_end_user_budget_id).
Args:
end_user_id: The ID of the end user
prisma_client: Database client instance
user_api_key_cache: Cache for storing/retrieving data
route: The request route
parent_otel_span: Optional OpenTelemetry span for tracing
proxy_logging_obj: Optional proxy logging object
Returns:
LiteLLM_EndUserTable if found, None otherwise
"""
if prisma_client is None:
raise Exception("No db connected")
if end_user_id is None:
return None
_key = "end_user_id:{}".format(end_user_id)
# Check cache first
cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
if cached_user_obj is not None:
return_obj = LiteLLM_EndUserTable(**cached_user_obj)
# Apply default budget if needed
return_obj = await _apply_default_budget_to_end_user(
end_user_obj=return_obj,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
# Check budget limits
_check_end_user_budget(end_user_obj=return_obj, route=route)
return return_obj
# Fetch from database
try:
response = await prisma_client.db.litellm_endusertable.find_unique(
where={"user_id": end_user_id},
include={"litellm_budget_table": True, "object_permission": True},
)
if response is None:
raise Exception
# Convert to LiteLLM_EndUserTable object
_response = LiteLLM_EndUserTable(**response.dict())
# Apply default budget if needed
_response = await _apply_default_budget_to_end_user(
end_user_obj=_response,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
# Save to cache (always store as dict for consistency)
await user_api_key_cache.async_set_cache(
key="end_user_id:{}".format(end_user_id), value=_response.dict()
)
# Check budget limits
_check_end_user_budget(end_user_obj=_response, route=route)
return _response
except Exception as e:
if isinstance(e, litellm.BudgetExceededError):
raise e
return None
@log_db_metrics
async def get_tag_objects_batch(
tag_names: List[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Dict[str, LiteLLM_TagTable]:
"""
Batch fetch multiple tag objects from cache and db.
Optimizes for latency by:
1. Fetching all cached tags in parallel
2. Batch fetching uncached tags in one DB query
Args:
tag_names: List of tag names to fetch
prisma_client: Prisma database client
user_api_key_cache: Cache for storing tag objects
parent_otel_span: Optional OpenTelemetry span for tracing
proxy_logging_obj: Optional proxy logging object
Returns:
Dictionary mapping tag_name to LiteLLM_TagTable object
"""
if prisma_client is None:
return {}
if not tag_names:
return {}
tag_objects = {}
uncached_tags = []
# Try to get all tags from cache first
for tag_name in tag_names:
cache_key = f"tag:{tag_name}"
cached_tag = await user_api_key_cache.async_get_cache(key=cache_key)
if cached_tag is not None:
if isinstance(cached_tag, dict):
tag_objects[tag_name] = LiteLLM_TagTable(**cached_tag)
else:
tag_objects[tag_name] = cached_tag
else:
uncached_tags.append(tag_name)
# Batch fetch uncached tags from DB in one query
if uncached_tags:
try:
db_tags = await prisma_client.db.litellm_tagtable.find_many(
where={"tag_name": {"in": uncached_tags}},
include={"litellm_budget_table": True},
)
# Cache and add to tag_objects
for db_tag in db_tags:
tag_name = db_tag.tag_name
cache_key = f"tag:{tag_name}"
# Cache with default TTL (same as end_user objects)
await user_api_key_cache.async_set_cache(
key=cache_key, value=db_tag.dict()
)
tag_objects[tag_name] = LiteLLM_TagTable(**db_tag.dict())
except Exception as e:
verbose_proxy_logger.debug(f"Error batch fetching tags from database: {e}")
return tag_objects
@log_db_metrics
async def get_tag_object(
tag_name: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_TagTable]:
"""
Returns tag object from cache or db.
Uses default cache TTL (same as end_user objects) to avoid drift.
Args:
tag_name: Name of the tag to fetch
prisma_client: Prisma database client
user_api_key_cache: Cache for storing tag objects
parent_otel_span: Optional OpenTelemetry span for tracing
proxy_logging_obj: Optional proxy logging object
Returns:
LiteLLM_TagTable object if found, None otherwise
"""
if prisma_client is None or tag_name is None:
return None
# Use batch helper for consistency
tag_objects = await get_tag_objects_batch(
tag_names=[tag_name],
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
return tag_objects.get(tag_name)
@log_db_metrics
async def get_team_membership(
user_id: str,
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional["LiteLLM_TeamMembership"]:
"""
Returns team membership object if user is member of team.
Do a isolated check for team membership vs. doing a combined key + team + user + team-membership check, as key might come in frequently for different users/teams. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (team membership).
"""
from litellm.proxy._types import LiteLLM_TeamMembership
if prisma_client is None:
raise Exception("No db connected")
if user_id is None or team_id is None:
return None
_key = "team_membership:{}:{}".format(user_id, team_id)
# check if in cache
cached_membership_obj = await user_api_key_cache.async_get_cache(key=_key)
if cached_membership_obj is not None:
return LiteLLM_TeamMembership(**cached_membership_obj)
# else, check db
try:
response = await prisma_client.db.litellm_teammembership.find_unique(
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}},
include={"litellm_budget_table": True},
)
if response is None:
return None
# save the team membership object to cache (store as dict)
await user_api_key_cache.async_set_cache(key=_key, value=response.dict())
_response = LiteLLM_TeamMembership(**response.dict())
return _response
except Exception:
verbose_proxy_logger.exception(
"Error getting team membership for user_id: %s, team_id: %s",
user_id,
team_id,
)
return None
def model_in_access_group(
model: str, team_models: Optional[List[str]], llm_router: Optional[Router]
) -> bool:
from collections import defaultdict
if team_models is None:
return True
if model in team_models:
return True
access_groups: dict[str, list[str]] = defaultdict(list)
if llm_router:
access_groups = llm_router.get_model_access_groups(model_name=model)
if len(access_groups) > 0: # check if token contains any model access groups
for idx, m in enumerate(
team_models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
return True
# Filter out models that are access_groups
filtered_models = [m for m in team_models if m not in access_groups]
if model in filtered_models:
return True
return False
def _should_check_db(
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
) -> bool:
"""
Prevent calling db repeatedly for items that don't exist in the db.
"""
current_time = time.time()
# if key doesn't exist in last_db_access_time -> check db
if key not in last_db_access_time:
return True
elif (
last_db_access_time[key][0] is not None
): # check db for non-null values (for refresh operations)
return True
elif last_db_access_time[key][0] is None:
if current_time - last_db_access_time[key] >= db_cache_expiry:
return True
return False
def _update_last_db_access_time(
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
):
last_db_access_time[key] = (value, time.time())
def _get_role_based_permissions(
rbac_role: RBAC_ROLES,
general_settings: dict,
key: Literal["models", "routes"],
) -> Optional[List[str]]:
"""
Get the role based permissions from the general settings.
"""
role_based_permissions = cast(
Optional[List[RoleBasedPermissions]],
general_settings.get("role_permissions", []),
)
if role_based_permissions is None:
return None
for role_based_permission in role_based_permissions:
if role_based_permission.role == rbac_role:
return getattr(role_based_permission, key)
return None
def get_role_based_models(
rbac_role: RBAC_ROLES,
general_settings: dict,
) -> Optional[List[str]]:
"""
Get the models allowed for a user role.
Used by JWT Auth.
"""
return _get_role_based_permissions(
rbac_role=rbac_role,
general_settings=general_settings,
key="models",
)
def get_role_based_routes(
rbac_role: RBAC_ROLES,
general_settings: dict,
) -> Optional[List[str]]:
"""
Get the routes allowed for a user role.
"""
return _get_role_based_permissions(
rbac_role=rbac_role,
general_settings=general_settings,
key="routes",
)
async def _get_fuzzy_user_object(
prisma_client: PrismaClient,
sso_user_id: Optional[str] = None,
user_email: Optional[str] = None,
) -> Optional[LiteLLM_UserTable]:
"""
Checks if sso user is in db.
Called when user id match is not found in db.
- Check if sso_user_id is user_id in db
- Check if sso_user_id is sso_user_id in db
- Check if user_email is user_email in db
- If not, create new user with user_email and sso_user_id and user_id = sso_user_id
"""
response = None
if sso_user_id is not None:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"sso_user_id": sso_user_id},
include={"organization_memberships": True},
)
if response is None and user_email is not None:
# Use case-insensitive query to handle emails with different casing
# This matches the pattern used in _check_duplicate_user_email
response = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": {"equals": user_email, "mode": "insensitive"}},
include={"organization_memberships": True},
)
if response is not None and sso_user_id is not None: # update sso_user_id
asyncio.create_task( # background task to update user with sso id
prisma_client.db.litellm_usertable.update(
where={"user_id": response.user_id},
data={"sso_user_id": sso_user_id},
)
)
return response
@log_db_metrics
async def get_user_object(
user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
user_id_upsert: bool,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
sso_user_id: Optional[str] = None,
user_email: Optional[str] = None,
check_db_only: Optional[bool] = None,
) -> Optional[LiteLLM_UserTable]:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if user_id is None:
return None
# check if in cache
if not check_db_only:
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
if prisma_client is None:
raise Exception("No db connected")
try:
db_access_time_key = "user_id:{}".format(user_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
)
if response is None:
response = await _get_fuzzy_user_object(
prisma_client=prisma_client,
sso_user_id=sso_user_id,
user_email=user_email,
)
else:
response = None
if response is None:
if user_id_upsert:
new_user_params: Dict[str, Any] = {
"user_id": user_id,
}
if user_email is not None:
new_user_params["user_email"] = user_email
if litellm.default_internal_user_params is not None:
new_user_params.update(litellm.default_internal_user_params)
response = await prisma_client.db.litellm_usertable.create(
data=new_user_params,
include={"organization_memberships": True},
)
else:
raise Exception
if (
response.organization_memberships is not None
and len(response.organization_memberships) > 0
):
# dump each organization membership to type LiteLLM_OrganizationMembershipTable
_dumped_memberships = [
LiteLLM_OrganizationMembershipTable(**membership.model_dump())
for membership in response.organization_memberships
if membership is not None
]
response.organization_memberships = _dumped_memberships
_response = LiteLLM_UserTable(**dict(response))
response_dict = _response.model_dump()
# save the user object to cache
await user_api_key_cache.async_set_cache(
key=user_id,
value=response_dict,
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=response_dict,
last_db_access_time=last_db_access_time,
)
return _response
except Exception as e: # if user not in db
_log_budget_lookup_failure("user", e)
raise ValueError(
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call. Got error - {e}"
)
async def _cache_management_object(
key: str,
value: BaseModel,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
):
await user_api_key_cache.async_set_cache(
key=key,
value=value,
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
async def _cache_team_object(
team_id: str,
team_table: LiteLLM_TeamTableCachedObj,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
):
key = "team_id:{}".format(team_id)
## CACHE REFRESH TIME!
team_table.last_refreshed_at = time.time()
await _cache_management_object(
key=key,
value=team_table,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _cache_key_object(
hashed_token: str,
user_api_key_obj: UserAPIKeyAuth,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
):
key = hashed_token
## CACHE REFRESH TIME
user_api_key_obj.last_refreshed_at = time.time()
await _cache_management_object(
key=key,
value=user_api_key_obj,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _delete_cache_key_object(
hashed_token: str,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
):
key = hashed_token
user_api_key_cache.delete_cache(key=key)
## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache(
key=key
)
@log_db_metrics
async def _get_team_db_check(
team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
):
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if response is None and team_id_upsert:
from litellm.proxy.management_endpoints.team_endpoints import new_team
new_team_data = NewTeamRequest(team_id=team_id)
mock_request = Request(scope={"type": "http"})
system_admin_user = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN)
created_team_dict = await new_team(
data=new_team_data,
http_request=mock_request,
user_api_key_dict=system_admin_user,
)
response = LiteLLM_TeamTable(**created_team_dict)
return response
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
async def _get_team_object_from_user_api_key_cache(
team_id: str,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
last_db_access_time: LimitedSizeOrderedDict,
db_cache_expiry: int,
proxy_logging_obj: Optional[ProxyLogging],
key: str,
team_id_upsert: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
db_access_time_key = key
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client, team_id_upsert=team_id_upsert
)
else:
response = None
if response is None:
raise Exception
_response = LiteLLM_TeamTableCachedObj(**response.dict())
# Load object_permission if object_permission_id exists but object_permission is not loaded
if _response.object_permission_id and not _response.object_permission:
try:
_response.object_permission = await get_object_permission(
object_permission_id=_response.object_permission_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
)
except Exception as e:
verbose_proxy_logger.debug(
f"Failed to load object_permission for team {team_id} with object_permission_id={_response.object_permission_id}: {e}"
)
# save the team object to cache
await _cache_team_object(
team_id=team_id,
team_table=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
)
return _response
async def _get_team_object_from_cache(
key: str,
proxy_logging_obj: Optional[ProxyLogging],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
) -> Optional[LiteLLM_TeamTableCachedObj]:
cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
## CHECK REDIS CACHE ##
if (
proxy_logging_obj is not None
and proxy_logging_obj.internal_usage_cache.dual_cache
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
key=key, parent_otel_span=parent_otel_span
)
)
if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTableCachedObj(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
return cached_team_obj
return None
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
check_db_only: Optional[bool] = None,
team_id_upsert: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
Raises:
- HTTPException: If team doesn't exist in db or cache (status_code=404)
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
key = "team_id:{}".format(team_id)
if not check_db_only:
cached_team_obj = await _get_team_object_from_cache(
key=key,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
if cached_team_obj is not None:
return cached_team_obj
if check_cache_only:
raise HTTPException(
status_code=404,
detail={
"error": f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
},
)
# else, check db
try:
return await _get_team_object_from_user_api_key_cache(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
key=key,
team_id_upsert=team_id_upsert,
)
except Exception:
raise HTTPException(
status_code=404,
detail={
"error": f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
},
)
async def _cache_access_object(
access_group_id: str,
access_group_table: LiteLLM_AccessGroupTable,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
key = "access_group_id:{}".format(access_group_id)
await user_api_key_cache.async_set_cache(
key=key,
value=access_group_table,
ttl=DEFAULT_ACCESS_GROUP_CACHE_TTL,
)
async def _delete_cache_access_object(
access_group_id: str,
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
key = "access_group_id:{}".format(access_group_id)
user_api_key_cache.delete_cache(key=key)
## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache(
key=key
)
@log_db_metrics
async def get_access_object(
access_group_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> LiteLLM_AccessGroupTable:
"""
- Check if access_group_id in proxy AccessGroupTable
- Always checks cache first, then DB only when not found in cache
- if valid, return LiteLLM_AccessGroupTable object
- if not, then raise an error
Unlike get_team_object, this has no check_cache_only or check_db_only flags;
it always follows cache-first-then-db semantics.
Raises:
- HTTPException: If access group doesn't exist in db or cache (status_code=404)
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
key = "access_group_id:{}".format(access_group_id)
# Always check cache first
cached_access_obj = await user_api_key_cache.async_get_cache(key=key)
if cached_access_obj is not None:
if isinstance(cached_access_obj, dict):
return LiteLLM_AccessGroupTable(**cached_access_obj)
elif isinstance(cached_access_obj, LiteLLM_AccessGroupTable):
return cached_access_obj
# Not in cache - fetch from DB
try:
response = await prisma_client.db.litellm_accessgrouptable.find_unique(
where={"access_group_id": access_group_id}
)
if response is None:
raise HTTPException(
status_code=404,
detail={
"error": f"Access group doesn't exist in db. Access group={access_group_id}."
},
)
_response = LiteLLM_AccessGroupTable(**response.dict())
# Save to cache
await _cache_access_object(
access_group_id=access_group_id,
access_group_table=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return _response
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
"Error getting access group for access_group_id: %s",
access_group_id,
)
raise HTTPException(
status_code=404,
detail={
"error": f"Access group doesn't exist in db. Access group={access_group_id}. Error: {e}"
},
)
@log_db_metrics
async def get_team_object_by_alias(
team_alias: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional["Span"] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> LiteLLM_TeamTableCachedObj:
"""
Look up a team by its team_alias (name) in the database.
Args:
team_alias: The team name/alias to look up
prisma_client: Database client
user_api_key_cache: Cache for storing results
parent_otel_span: Optional OpenTelemetry span
proxy_logging_obj: Optional proxy logging object
Returns:
LiteLLM_TeamTableCachedObj: The team object if found
Raises:
HTTPException: If team doesn't exist or multiple teams have the same alias
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# Check cache first (keyed by alias)
cache_key = "team_alias:{}".format(team_alias)
cached_team_obj = await _get_team_object_from_cache(
key=cache_key,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
if cached_team_obj is not None:
return cached_team_obj
# Query database by team_alias
try:
teams = await prisma_client.db.litellm_teamtable.find_many(
where={"team_alias": team_alias}
)
if not teams:
raise HTTPException(
status_code=404,
detail={
"error": f"Team with alias '{team_alias}' doesn't exist in db. Create team via `/team/new` call."
},
)
if len(teams) > 1:
raise HTTPException(
status_code=400,
detail={
"error": f"Multiple teams found with alias '{team_alias}'. Please use team_id_jwt_field instead or ensure team aliases are unique."
},
)
team = teams[0]
team_obj = LiteLLM_TeamTableCachedObj(**team.model_dump())
# Load object_permission if object_permission_id exists but object_permission is not loaded
if team_obj.object_permission_id and not team_obj.object_permission:
try:
team_obj.object_permission = await get_object_permission(
object_permission_id=team_obj.object_permission_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except Exception as e:
verbose_proxy_logger.debug(
f"Failed to load object_permission for team {team_obj.team_id} with object_permission_id={team_obj.object_permission_id}: {e}"
)
# Cache the result by both alias and team_id
await user_api_key_cache.async_set_cache(
key=cache_key,
value=team_obj,
ttl=DEFAULT_IN_MEMORY_TTL,
)
# Also cache by team_id for consistency
team_id_cache_key = "team_id:{}".format(team_obj.team_id)
await user_api_key_cache.async_set_cache(
key=team_id_cache_key,
value=team_obj,
ttl=DEFAULT_IN_MEMORY_TTL,
)
return team_obj
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error looking up team by alias: %s", team_alias)
raise HTTPException(
status_code=500,
detail={
"error": f"Error looking up team by alias '{team_alias}': {str(e)}"
},
)
@log_db_metrics
async def get_org_object_by_alias(
org_alias: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional["Span"] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_OrganizationTable]:
"""
Look up an organization by its organization_alias in the database.
Args:
org_alias: The organization name/alias to look up
prisma_client: Database client
user_api_key_cache: Cache for storing results
parent_otel_span: Optional OpenTelemetry span
proxy_logging_obj: Optional proxy logging object
Returns:
LiteLLM_OrganizationTable if found, None otherwise
Raises:
HTTPException: If organization not found or multiple orgs have the same alias
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# Check cache first (keyed by alias)
cache_key = "org_alias:{}".format(org_alias)
cached_org_obj = await user_api_key_cache.async_get_cache(key=cache_key)
if cached_org_obj is not None:
if isinstance(cached_org_obj, dict):
return LiteLLM_OrganizationTable(**cached_org_obj)
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
return cached_org_obj
# Query database by organization_alias
try:
orgs = await prisma_client.db.litellm_organizationtable.find_many(
where={"organization_alias": org_alias}
)
if not orgs:
raise HTTPException(
status_code=404,
detail={
"error": f"Organization with alias '{org_alias}' doesn't exist in db. Create organization via `/organization/new` call."
},
)
if len(orgs) > 1:
raise HTTPException(
status_code=400,
detail={
"error": f"Multiple organizations found with alias '{org_alias}'. Please use org_id_jwt_field instead or ensure organization aliases are unique."
},
)
org = orgs[0]
org_obj = LiteLLM_OrganizationTable(**org.model_dump())
# Cache the result
await user_api_key_cache.async_set_cache(
key=cache_key,
value=org_obj.model_dump(),
ttl=DEFAULT_IN_MEMORY_TTL,
)
# Also cache by org_id for consistency
await user_api_key_cache.async_set_cache(
key="org_id:{}".format(org_obj.organization_id),
value=org_obj.model_dump(),
ttl=DEFAULT_IN_MEMORY_TTL,
)
return org_obj
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
"Error looking up organization by alias: %s", org_alias
)
raise HTTPException(
status_code=500,
detail={
"error": f"Error looking up organization by alias '{org_alias}': {str(e)}"
},
)
class ExperimentalUIJWTToken:
@staticmethod
def get_experimental_ui_login_jwt_auth_token(user_info: LiteLLM_UserTable) -> str:
from datetime import timedelta
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
encrypt_value_helper,
)
if user_info.user_role is None:
raise Exception("User role is required for experimental UI login")
# Experimental UI flow uses fixed 10-min expiry for security (does not use LITELLM_UI_SESSION_DURATION)
expiration_time = get_utc_datetime() + timedelta(minutes=10)
# Format the expiration time as ISO 8601 string
expires = expiration_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "+00:00"
valid_token = UserAPIKeyAuth(
token="ui-token",
key_name="ui-token",
key_alias="ui-token",
max_budget=litellm.max_ui_session_budget,
rpm_limit=100, # allow user to have a conversation on test key pane of UI
expires=expires,
user_id=user_info.user_id,
team_id="litellm-dashboard",
models=user_info.models,
max_parallel_requests=None,
user_role=LitellmUserRoles(user_info.user_role),
)
return encrypt_value_helper(valid_token.model_dump_json(exclude_none=True))
@staticmethod
def get_cli_jwt_auth_token(
user_info: LiteLLM_UserTable, team_id: Optional[str] = None
) -> str:
"""
Generate a JWT token for CLI authentication with configurable expiration.
The expiration time can be controlled via the LITELLM_CLI_JWT_EXPIRATION_HOURS
environment variable (defaults to 24 hours).
Args:
user_info: User information from the database
team_id: Team ID for the user (optional, uses user's team if available)
Returns:
Encrypted JWT token string
"""
from datetime import timedelta
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
encrypt_value_helper,
)
if user_info.user_role is None:
raise Exception("User role is required for CLI JWT login")
# Calculate expiration time (configurable via LITELLM_CLI_JWT_EXPIRATION_HOURS env var)
expiration_time = get_utc_datetime() + timedelta(hours=CLI_JWT_EXPIRATION_HOURS)
# Format the expiration time as ISO 8601 string
expires = expiration_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "+00:00"
# Use provided team_id, or fall back to user's teams if available
_team_id = team_id
if _team_id is None and hasattr(user_info, "teams") and user_info.teams:
# Use first team if user has teams
_team_id = user_info.teams[0] if len(user_info.teams) > 0 else None
valid_token = UserAPIKeyAuth(
token=CLI_JWT_TOKEN_NAME,
key_name=CLI_JWT_TOKEN_NAME,
key_alias=CLI_JWT_TOKEN_NAME,
max_budget=litellm.max_ui_session_budget,
expires=expires,
user_id=user_info.user_id,
team_id=_team_id,
models=user_info.models,
max_parallel_requests=None,
user_role=LitellmUserRoles(user_info.user_role),
)
return encrypt_value_helper(valid_token.model_dump_json(exclude_none=True))
@staticmethod
def get_key_object_from_ui_hash_key(
hashed_token: str,
) -> Optional[UserAPIKeyAuth]:
import json
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
)
decrypted_token = decrypt_value_helper(
hashed_token, key="ui_hash_key", exception_type="debug"
)
if decrypted_token is None:
return None
try:
return UserAPIKeyAuth(**json.loads(decrypted_token))
except Exception as e:
raise Exception(
f"Invalid hash key. Hash key={hashed_token}. Decrypted token={decrypted_token}. Error: {e}"
)
async def _fetch_key_object_from_db_with_reconnect(
hashed_token: str,
prisma_client: PrismaClient,
parent_otel_span: Optional[Span],
proxy_logging_obj: Optional[ProxyLogging],
) -> Optional[BaseModel]:
"""
Fetch key object from DB and retry once if a DB connection error can be healed.
"""
try:
return await prisma_client.get_data(
token=hashed_token,
table_name="combined_view",
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except Exception as e:
if PrismaDBExceptionHandler.is_database_transport_error(e):
did_reconnect = False
if hasattr(prisma_client, "attempt_db_reconnect"):
auth_reconnect_timeout = getattr(
prisma_client, "_db_auth_reconnect_timeout_seconds", 2.0
)
if not isinstance(auth_reconnect_timeout, (int, float)):
auth_reconnect_timeout = 2.0
auth_reconnect_lock_timeout = getattr(
prisma_client, "_db_auth_reconnect_lock_timeout_seconds", 0.1
)
if not isinstance(auth_reconnect_lock_timeout, (int, float)):
auth_reconnect_lock_timeout = 0.1
did_reconnect = await prisma_client.attempt_db_reconnect(
reason="auth_get_key_object_lookup_failure",
timeout_seconds=auth_reconnect_timeout,
lock_timeout_seconds=auth_reconnect_lock_timeout,
)
if did_reconnect:
return await prisma_client.get_data(
token=hashed_token,
table_name="combined_view",
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
raise
@log_db_metrics
async def get_jwt_key_mapping_object(
jwt_claim_name: str,
jwt_claim_value: str,
prisma_client: PrismaClient,
) -> Optional[str]:
"""
Lookup a JWT-to-virtual-key mapping from the database.
Returns the hashed token (str) if a matching active mapping is found, else None.
"""
mapping = await prisma_client.db.litellm_jwtkeymapping.find_first(
where={
"jwt_claim_name": jwt_claim_name,
"jwt_claim_value": jwt_claim_value,
"is_active": True,
}
)
if mapping is not None:
return mapping.token
return None
@log_db_metrics
async def get_key_object(
hashed_token: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
) -> UserAPIKeyAuth:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
key = hashed_token
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
key=key
)
if cached_key_obj is not None:
if isinstance(cached_key_obj, dict):
return UserAPIKeyAuth(**cached_key_obj)
elif isinstance(cached_key_obj, UserAPIKeyAuth):
return cached_key_obj
if check_cache_only:
raise Exception(
f"Key doesn't exist in cache + check_cache_only=True. key={key}."
)
# else, check db
_valid_token: Optional[BaseModel] = await _fetch_key_object_from_db_with_reconnect(
hashed_token=hashed_token,
prisma_client=prisma_client,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if _valid_token is None:
raise ProxyException(
message="Authentication Error, Invalid proxy server token passed. key={}, not found in db. Create key via `/key/generate` call.".format(
hashed_token
),
type=ProxyErrorTypes.token_not_found_in_db,
param="key",
code=status.HTTP_401_UNAUTHORIZED,
)
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
# Load object_permission if object_permission_id exists but object_permission is not loaded
if _response.object_permission_id and not _response.object_permission:
try:
_response.object_permission = await get_object_permission(
object_permission_id=_response.object_permission_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except Exception as e:
verbose_proxy_logger.debug(
f"Failed to load object_permission for key with object_permission_id={_response.object_permission_id}: {e}"
)
# save the key object to cache
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return _response
@log_db_metrics
async def get_object_permission(
object_permission_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_ObjectPermissionTable]:
"""
- Check if object permission id in proxy ObjectPermissionTable
- if valid, return LiteLLM_ObjectPermissionTable object
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
key = "object_permission_id:{}".format(object_permission_id)
cached_obj_permission = await user_api_key_cache.async_get_cache(key=key)
if cached_obj_permission is not None:
if isinstance(cached_obj_permission, dict):
return LiteLLM_ObjectPermissionTable(**cached_obj_permission)
elif isinstance(cached_obj_permission, LiteLLM_ObjectPermissionTable):
return cached_obj_permission
# else, check db
try:
response = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id}
)
if response is None:
return None
# save the object permission to cache
await user_api_key_cache.async_set_cache(
key=key,
value=response.model_dump(),
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
return LiteLLM_ObjectPermissionTable(**response.dict())
except Exception:
return None
@log_db_metrics
async def get_org_object(
org_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
include_budget_table: bool = False,
) -> Optional[LiteLLM_OrganizationTable]:
"""
- Check if org id in proxy Org Table
- if valid, return LiteLLM_OrganizationTable object
- if not, then raise an error
Args:
org_id: Organization ID to look up
prisma_client: Database client
user_api_key_cache: Cache for storing results
parent_otel_span: Optional OpenTelemetry span
proxy_logging_obj: Optional proxy logging object
include_budget_table: If True, includes litellm_budget_table in the query
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
if not isinstance(org_id, str):
return None
# Use different cache key if budget table is included
cache_key = "org_id:{}".format(org_id)
if include_budget_table:
cache_key = "org_id:{}:with_budget".format(org_id)
# check if in cache
cached_org_obj = user_api_key_cache.async_get_cache(key=cache_key)
if cached_org_obj is not None:
if isinstance(cached_org_obj, dict):
return LiteLLM_OrganizationTable(**cached_org_obj)
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
return cached_org_obj
# else, check db
try:
query_kwargs: Dict[str, Any] = {"where": {"organization_id": org_id}}
if include_budget_table:
query_kwargs["include"] = {"litellm_budget_table": True}
response = await prisma_client.db.litellm_organizationtable.find_unique(
**query_kwargs
)
if response is None:
raise Exception
# Cache the result
await user_api_key_cache.async_set_cache(
key=cache_key,
value=(
response.model_dump() if hasattr(response, "model_dump") else response
),
ttl=DEFAULT_IN_MEMORY_TTL,
)
return response
except Exception:
raise Exception(
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
)
async def _get_resources_from_access_groups(
access_group_ids: List[str],
resource_field: Literal[
"access_model_names", "access_mcp_server_ids", "access_agent_ids"
],
prisma_client: Optional[PrismaClient] = None,
user_api_key_cache: Optional[DualCache] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> List[str]:
"""
Fetch access groups by their IDs (from cache or DB) and collect
the specified resource field across all of them.
Args:
access_group_ids: List of access group IDs to fetch
resource_field: Which resource list to extract from each access group
- "access_model_names": model names (for model access checks)
- "access_mcp_server_ids": MCP server IDs (for MCP access checks)
- "access_agent_ids": agent IDs (for agent access checks)
prisma_client: Optional PrismaClient (lazy-imported from proxy_server if None)
user_api_key_cache: Optional DualCache (lazy-imported from proxy_server if None)
proxy_logging_obj: Optional ProxyLogging (lazy-imported from proxy_server if None)
Returns:
Deduplicated list of resource identifiers from all resolved access groups.
"""
if not access_group_ids:
return []
# Lazy import to avoid circular imports
if prisma_client is None or user_api_key_cache is None:
from litellm.proxy.proxy_server import prisma_client as _prisma_client
from litellm.proxy.proxy_server import proxy_logging_obj as _proxy_logging_obj
from litellm.proxy.proxy_server import user_api_key_cache as _user_api_key_cache
prisma_client = prisma_client or _prisma_client
user_api_key_cache = user_api_key_cache or _user_api_key_cache
proxy_logging_obj = proxy_logging_obj or _proxy_logging_obj
if user_api_key_cache is None:
return []
resources: List[str] = []
for ag_id in access_group_ids:
try:
ag = await get_access_object(
access_group_id=ag_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
resources.extend(getattr(ag, resource_field, []))
except Exception:
verbose_proxy_logger.debug(
"Could not fetch access group %s for resource field %s",
ag_id,
resource_field,
)
return list(set(resources))
async def _get_models_from_access_groups(
access_group_ids: List[str],
prisma_client: Optional[PrismaClient] = None,
user_api_key_cache: Optional[DualCache] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> List[str]:
"""
Collect model names from unified access groups.
Models are matched by model name for backwards compatibility.
"""
return await _get_resources_from_access_groups(
access_group_ids=access_group_ids,
resource_field="access_model_names",
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _get_mcp_server_ids_from_access_groups(
access_group_ids: List[str],
prisma_client: Optional[PrismaClient] = None,
user_api_key_cache: Optional[DualCache] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> List[str]:
"""
Collect MCP server IDs from unified access groups.
MCPs are matched by server ID.
"""
return await _get_resources_from_access_groups(
access_group_ids=access_group_ids,
resource_field="access_mcp_server_ids",
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _get_agent_ids_from_access_groups(
access_group_ids: List[str],
prisma_client: Optional[PrismaClient] = None,
user_api_key_cache: Optional[DualCache] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> List[str]:
"""
Collect agent IDs from unified access groups.
Agents are matched by agent ID.
"""
return await _get_resources_from_access_groups(
access_group_ids=access_group_ids,
resource_field="access_agent_ids",
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
def _check_model_access_helper(
model: str,
llm_router: Optional[Router],
models: List[str],
team_model_aliases: Optional[Dict[str, str]] = None,
team_id: Optional[str] = None,
) -> bool:
## check if model in allowed model names
from collections import defaultdict
access_groups: Dict[str, List[str]] = defaultdict(list)
if llm_router:
access_groups = llm_router.get_model_access_groups(
model_name=model, team_id=team_id
)
if (
len(access_groups) > 0 and llm_router is not None
): # check if token contains any model access groups
for idx, m in enumerate(
models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
return True
# Filter out models that are access_groups
filtered_models = [m for m in models if m not in access_groups]
if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
return True
if _model_matches_any_wildcard_pattern_in_list(
model=model, allowed_model_list=filtered_models
):
return True
all_model_access: bool = False
if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models:
all_model_access = True
if SpecialModelNames.all_proxy_models.value in filtered_models:
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False:
return False
return True
def _can_object_call_model(
model: Union[str, List[str]],
llm_router: Optional[Router],
models: List[str],
team_model_aliases: Optional[Dict[str, str]] = None,
team_id: Optional[str] = None,
object_type: Literal["user", "team", "key", "org", "project"] = "user",
fallback_depth: int = 0,
) -> Literal[True]:
"""
Checks if token can call a given model
Args:
- model: str
- llm_router: Optional[Router]
- models: List[str]
- team_model_aliases: Optional[Dict[str, str]]
- object_type: Literal["user", "team", "key", "org"]. We use the object type to raise the correct exception type
Returns:
- True: if token allowed to call model
Raises:
- Exception: If token not allowed to call model
"""
if fallback_depth >= DEFAULT_MAX_RECURSE_DEPTH:
raise Exception(
"Unable to parse model, max fallback depth exceeded - received model: {}".format(
model
)
)
if isinstance(model, list):
for m in model:
_can_object_call_model(
model=m,
llm_router=llm_router,
models=models,
team_model_aliases=team_model_aliases,
team_id=team_id,
object_type=object_type,
fallback_depth=fallback_depth + 1,
)
return True
potential_models = [model]
if model in litellm.model_alias_map:
potential_models.append(litellm.model_alias_map[model])
elif llm_router and model in llm_router.model_group_alias:
_model = llm_router._get_model_from_alias(model)
if _model:
potential_models.append(_model)
## check model access for alias + underlying model - allow if either is in allowed models
for m in potential_models:
if _check_model_access_helper(
model=m,
llm_router=llm_router,
models=models,
team_model_aliases=team_model_aliases,
team_id=team_id,
):
return True
raise ProxyException(
message=f"{object_type} not allowed to access model. This {object_type} can only access models={models}. Tried to access {model}",
type=ProxyErrorTypes.get_model_access_error_type_for_object(
object_type=object_type
),
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
def _model_in_team_aliases(
model: str, team_model_aliases: Optional[Dict[str, str]] = None
) -> bool:
"""
Returns True if `model` being accessed is an alias of a team model
- `model=gpt-4o`
- `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}`
- returns True
- `model=gp-4o`
- `team_model_aliases={"o-3": "o3-preview"}`
- returns False
"""
if team_model_aliases:
if model in team_model_aliases:
return True
return False
async def can_key_call_model(
model: Union[str, List[str]],
llm_model_list: Optional[list],
valid_token: UserAPIKeyAuth,
llm_router: Optional[litellm.Router],
) -> Literal[True]:
"""
Checks if token can call a given model
1. First checks native key-level model permissions (current implementation)
2. If not allowed natively, falls back to access_group_ids on the key
Returns:
- True: if token allowed to call model
Raises:
- Exception: If token not allowed to call model
"""
try:
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=valid_token.models,
team_model_aliases=valid_token.team_model_aliases,
team_id=valid_token.team_id,
object_type="key",
)
except ProxyException:
# Fallback: check key's access_group_ids
key_access_group_ids = valid_token.access_group_ids or []
if key_access_group_ids:
models_from_groups = await _get_models_from_access_groups(
access_group_ids=key_access_group_ids,
)
if models_from_groups:
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=models_from_groups,
team_model_aliases=valid_token.team_model_aliases,
team_id=valid_token.team_id,
object_type="key",
)
raise
def can_org_access_model(
model: str,
org_object: Optional[LiteLLM_OrganizationTable],
llm_router: Optional[Router],
team_model_aliases: Optional[Dict[str, str]] = None,
) -> Literal[True]:
"""
Returns True if the team can access a specific model.
"""
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=org_object.models if org_object else [],
team_model_aliases=team_model_aliases,
object_type="org",
)
async def can_team_access_model(
model: Union[str, List[str]],
team_object: Optional[LiteLLM_TeamTable],
llm_router: Optional[Router],
team_model_aliases: Optional[Dict[str, str]] = None,
) -> Literal[True]:
"""
Returns True if the team can access a specific model.
1. First checks native team-level model permissions (current implementation)
2. If not allowed natively, falls back to access_group_ids on the team
"""
try:
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=team_object.models if team_object else [],
team_model_aliases=team_model_aliases,
team_id=team_object.team_id if team_object else None,
object_type="team",
)
except ProxyException:
# Fallback: check team's access_group_ids
team_access_group_ids = (
(team_object.access_group_ids or []) if team_object else []
)
if team_access_group_ids:
models_from_groups = await _get_models_from_access_groups(
access_group_ids=team_access_group_ids,
)
if models_from_groups:
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=models_from_groups,
team_model_aliases=team_model_aliases,
team_id=team_object.team_id if team_object else None,
object_type="team",
)
raise
def can_project_access_model(
model: Union[str, List[str]],
project_object: LiteLLM_ProjectTableCachedObj,
llm_router: Optional[Router],
) -> Literal[True]:
"""
Returns True if the project can access a specific model.
Raises ProxyException if access is denied.
"""
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=project_object.models if project_object else [],
object_type="project",
)
async def can_user_call_model(
model: Union[str, List[str]],
llm_router: Optional[Router],
user_object: Optional[LiteLLM_UserTable],
) -> Literal[True]:
if user_object is None:
return True
if SpecialModelNames.no_default_models.value in user_object.models:
raise ProxyException(
message=f"User not allowed to access model. No default model access, only team models allowed. Tried to access {model}",
type=ProxyErrorTypes.key_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=user_object.models,
object_type="user",
)
async def is_valid_fallback_model(
model: str,
llm_router: Optional[Router],
user_model: Optional[str],
) -> Literal[True]:
"""
Try to route the fallback model request.
Validate if it can't be routed.
Help catch invalid fallback models.
"""
await route_request(
data={
"model": model,
"messages": [{"role": "user", "content": "Who was Alexander?"}],
},
llm_router=llm_router,
user_model=user_model,
route_type="acompletion", # route type shouldn't affect the fallback model check
)
return True
async def _virtual_key_max_budget_check(
valid_token: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
user_obj: Optional[LiteLLM_UserTable] = None,
):
"""
Raises:
BudgetExceededError if the token is over it's max budget.
Triggers a budget alert if the token is over it's max budget.
"""
if valid_token.spend is not None and valid_token.max_budget is not None:
####################################
# collect information for alerting #
####################################
user_email = None
# Check if the token has any user id information
if user_obj is not None:
user_email = user_obj.user_email
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
soft_budget=valid_token.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
organization_id=valid_token.org_id,
user_email=user_email,
key_alias=valid_token.key_alias,
event_group=Litellm_EntityType.KEY,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="token_budget",
user_info=call_info,
)
)
####################################
# collect information for alerting #
####################################
if valid_token.spend >= valid_token.max_budget:
raise litellm.BudgetExceededError(
current_cost=valid_token.spend,
max_budget=valid_token.max_budget,
)
async def _virtual_key_soft_budget_check(
valid_token: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
user_obj: Optional[LiteLLM_UserTable] = None,
):
"""
Triggers a budget alert if the token is over it's soft budget.
"""
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
verbose_proxy_logger.debug(
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
valid_token.token,
valid_token.spend,
valid_token.soft_budget,
)
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
soft_budget=valid_token.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
organization_id=valid_token.org_id,
user_email=user_obj.user_email if user_obj else None,
key_alias=valid_token.key_alias,
event_group=Litellm_EntityType.KEY,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="soft_budget",
user_info=call_info,
)
)
async def _virtual_key_max_budget_alert_check(
valid_token: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
user_obj: Optional[LiteLLM_UserTable] = None,
):
"""
Triggers a budget alert if the token has reached EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE
(default 80%) of its max budget.
This is a warning alert before the token actually exceeds the max budget.
"""
if (
valid_token.max_budget is not None
and valid_token.spend is not None
and valid_token.spend > 0
):
alert_threshold = (
valid_token.max_budget * EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE
)
# Only alert if we've crossed the threshold but haven't exceeded max_budget yet
if (
valid_token.spend >= alert_threshold
and valid_token.spend < valid_token.max_budget
):
verbose_proxy_logger.debug(
"Reached Max Budget Alert Threshold for token %s, spend %s, max_budget %s, alert_threshold %s",
valid_token.token,
valid_token.spend,
valid_token.max_budget,
alert_threshold,
)
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
soft_budget=valid_token.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
organization_id=valid_token.org_id,
user_email=user_obj.user_email if user_obj else None,
key_alias=valid_token.key_alias,
event_group=Litellm_EntityType.KEY,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="max_budget_alert",
user_info=call_info,
)
)
async def _check_team_member_budget(
team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable],
valid_token: Optional[UserAPIKeyAuth],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
proxy_logging_obj: ProxyLogging,
):
"""Check if team member is over their max budget within the team."""
if (
team_object is not None
and team_object.team_id is not None
and user_object is not None
and valid_token is not None
and valid_token.user_id is not None
):
team_membership = await get_team_membership(
user_id=valid_token.user_id,
team_id=team_object.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
if (
team_membership is not None
and team_membership.litellm_budget_table is not None
and team_membership.litellm_budget_table.max_budget is not None
):
team_member_budget = team_membership.litellm_budget_table.max_budget
team_member_spend = team_membership.spend or 0.0
if team_member_spend >= team_member_budget:
raise litellm.BudgetExceededError(
current_cost=team_member_spend,
max_budget=team_member_budget,
message=f"Budget has been exceeded! User={valid_token.user_id} in Team={team_object.team_id} Current cost: {team_member_spend}, Max budget: {team_member_budget}",
)
async def _team_max_budget_check(
team_object: Optional[LiteLLM_TeamTable],
valid_token: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
):
"""
Check if the team is over it's max budget.
Raises:
BudgetExceededError if the team is over it's max budget.
Triggers a budget alert if the team is over it's max budget.
"""
if (
team_object is not None
and team_object.max_budget is not None
and team_object.spend is not None
and team_object.spend > team_object.max_budget
):
if valid_token:
call_info = CallInfo(
token=valid_token.token,
spend=team_object.spend,
max_budget=team_object.max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
organization_id=valid_token.org_id,
event_group=Litellm_EntityType.TEAM,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="team_budget",
user_info=call_info,
)
)
raise litellm.BudgetExceededError(
current_cost=team_object.spend,
max_budget=team_object.max_budget,
message=f"Budget has been exceeded! Team={team_object.team_id} Current cost: {team_object.spend}, Max budget: {team_object.max_budget}",
)
async def _team_soft_budget_check(
team_object: Optional[LiteLLM_TeamTable],
valid_token: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
):
"""
Triggers a budget alert if the team is over it's soft budget.
"""
if (
team_object is not None
and team_object.soft_budget is not None
and team_object.spend is not None
and team_object.spend >= team_object.soft_budget
):
verbose_proxy_logger.debug(
"Crossed Soft Budget for team %s, spend %s, soft_budget %s",
team_object.team_id,
team_object.spend,
team_object.soft_budget,
)
if valid_token:
# Extract alert emails from team metadata
alert_emails: Optional[List[str]] = None
if team_object.metadata is not None and isinstance(
team_object.metadata, dict
):
soft_budget_alert_emails = team_object.metadata.get(
"soft_budget_alerting_emails"
)
if soft_budget_alert_emails is not None:
if isinstance(soft_budget_alert_emails, list):
alert_emails = [
email
for email in soft_budget_alert_emails
if isinstance(email, str) and email.strip()
]
elif isinstance(soft_budget_alert_emails, str):
# Handle comma-separated string
alert_emails = [
email.strip()
for email in soft_budget_alert_emails.split(",")
if email.strip()
]
# Filter out empty strings
if alert_emails:
alert_emails = [email for email in alert_emails if email]
else:
alert_emails = None
# Only send team soft budget alerts if alert_emails are configured
# Team soft budget alerts are sent via metadata.soft_budget_alerting_emails, not global alerting
if alert_emails is None or len(alert_emails) == 0:
verbose_proxy_logger.debug(
"Skipping team soft budget alert for team %s: no alert_emails configured in metadata.soft_budget_alerting_emails",
team_object.team_id,
)
return
call_info = CallInfo(
token=valid_token.token,
spend=team_object.spend,
max_budget=team_object.max_budget,
soft_budget=team_object.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
organization_id=valid_token.org_id,
user_email=None, # Team-level alert, no specific user email
key_alias=valid_token.key_alias,
event_group=Litellm_EntityType.TEAM,
alert_emails=alert_emails,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="soft_budget",
user_info=call_info,
)
)
async def _project_max_budget_check(
project_object: Optional[LiteLLM_ProjectTableCachedObj],
valid_token: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
):
"""
Check if the project is over its max budget.
Raises:
BudgetExceededError if the project is over its max budget.
Triggers a budget alert if the project is over its max budget.
"""
if project_object is None:
return
max_budget = None
if project_object.litellm_budget_table is not None:
max_budget = project_object.litellm_budget_table.max_budget
if (
max_budget is not None
and project_object.spend is not None
and project_object.spend > max_budget
):
if valid_token:
call_info = CallInfo(
token=valid_token.token,
spend=project_object.spend,
max_budget=max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
organization_id=valid_token.org_id,
event_group=Litellm_EntityType.PROJECT,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="project_budget",
user_info=call_info,
)
)
raise litellm.BudgetExceededError(
current_cost=project_object.spend,
max_budget=max_budget,
message=f"Budget has been exceeded! Project={project_object.project_id} Current cost: {project_object.spend}, Max budget: {max_budget}",
)
async def _project_soft_budget_check(
project_object: Optional[LiteLLM_ProjectTableCachedObj],
valid_token: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
):
"""
Triggers a budget alert if the project is over its soft budget.
Mirrors _team_soft_budget_check() pattern.
"""
if project_object is None:
return
soft_budget = None
if project_object.litellm_budget_table is not None:
soft_budget = project_object.litellm_budget_table.soft_budget
if (
soft_budget is not None
and project_object.spend is not None
and project_object.spend >= soft_budget
):
verbose_proxy_logger.debug(
"Crossed Soft Budget for project %s, spend %s, soft_budget %s",
project_object.project_id,
project_object.spend,
soft_budget,
)
if valid_token:
call_info = CallInfo(
token=valid_token.token,
spend=project_object.spend,
max_budget=None,
soft_budget=soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
organization_id=valid_token.org_id,
event_group=Litellm_EntityType.PROJECT,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="soft_budget",
user_info=call_info,
)
)
async def get_project_object(
project_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_ProjectTableCachedObj]:
"""
Fetch project object from cache or DB.
Follows get_team_object() caching pattern with TTL and last_refreshed_at.
Returns LiteLLM_ProjectTableCachedObj or None if not found.
"""
if prisma_client is None:
return None
# Check cache first
cache_key = "project_id:{}".format(project_id)
cached_obj = await user_api_key_cache.async_get_cache(key=cache_key)
if cached_obj is not None:
if isinstance(cached_obj, dict):
return LiteLLM_ProjectTableCachedObj(**cached_obj)
elif isinstance(cached_obj, LiteLLM_ProjectTableCachedObj):
return cached_obj
# Fetch from DB
project_row = await prisma_client.db.litellm_projecttable.find_unique(
where={"project_id": project_id},
include={"litellm_budget_table": True},
)
if project_row is None:
return None
project_obj = LiteLLM_ProjectTableCachedObj(**project_row.model_dump())
# Cache with TTL following _cache_management_object pattern
project_obj.last_refreshed_at = time.time()
await _cache_management_object(
key=cache_key,
value=project_obj,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return project_obj
async def _organization_max_budget_check(
valid_token: Optional[UserAPIKeyAuth],
team_object: Optional[LiteLLM_TeamTable],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
proxy_logging_obj: ProxyLogging,
):
"""
Check if the organization is over its max budget.
This function checks the organization budget using:
1. First, tries to use valid_token.org_id (if key has organization_id set)
2. Falls back to team_object.organization_id (if key doesn't have org_id but team does)
This ensures organization budget checks work even when keys don't have organization_id
set directly, as long as their team belongs to an organization.
Raises:
BudgetExceededError if the organization is over its max budget.
Triggers a budget alert if the organization is over its max budget.
"""
if valid_token is None or prisma_client is None:
return
# Determine organization_id: first try from token, then fallback to team
org_id: Optional[str] = None
if valid_token.org_id is not None:
org_id = valid_token.org_id
elif team_object is not None and team_object.organization_id is not None:
org_id = team_object.organization_id
# If no organization_id found, skip the check
if org_id is None:
return
# Get organization object with budget table - use get_org_object so it can be mocked in tests
try:
org_table = await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
include_budget_table=True,
)
except Exception:
# If organization lookup fails, skip the check
return
if org_table is None:
return
# Get max_budget from organization's budget table
org_max_budget: Optional[float] = None
if org_table.litellm_budget_table is not None:
org_max_budget = org_table.litellm_budget_table.max_budget
# Only check if organization has a valid max_budget set
if org_max_budget is None or org_max_budget <= 0:
return
# Check if organization spend exceeds max budget
if org_table.spend >= org_max_budget:
# Trigger budget alert
call_info = CallInfo(
token=valid_token.token,
spend=org_table.spend,
max_budget=org_max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
organization_id=org_id,
event_group=Litellm_EntityType.ORGANIZATION,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="organization_budget",
user_info=call_info,
)
)
raise litellm.BudgetExceededError(
current_cost=org_table.spend,
max_budget=org_max_budget,
message=f"Budget has been exceeded! Organization={org_id} Current cost: {org_table.spend}, Max budget: {org_max_budget}",
)
async def _tag_max_budget_check(
request_body: dict,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
proxy_logging_obj: ProxyLogging,
valid_token: Optional[UserAPIKeyAuth],
):
"""
Check if any tags in the request are over their max budget.
Raises:
BudgetExceededError if any tag is over its max budget.
Triggers a budget alert if any tag is over its max budget.
"""
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
if prisma_client is None:
return
# Get tags from request metadata
tags = get_tags_from_request_body(request_body=request_body)
if not tags:
return
# Batch fetch all tags in one go
tag_objects = await get_tag_objects_batch(
tag_names=tags,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# Check budget for each tag
for tag_name in tags:
tag_object = tag_objects.get(tag_name)
if tag_object is None:
continue
# Check if tag has budget limits
if (
tag_object.litellm_budget_table is not None
and tag_object.litellm_budget_table.max_budget is not None
and tag_object.spend is not None
and tag_object.spend > tag_object.litellm_budget_table.max_budget
):
raise litellm.BudgetExceededError(
current_cost=tag_object.spend,
max_budget=tag_object.litellm_budget_table.max_budget,
message=f"Budget has been exceeded! Tag={tag_name} Current cost: {tag_object.spend}, Max budget: {tag_object.litellm_budget_table.max_budget}",
)
def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
"""
Check if a model matches an allowed pattern.
Handles exact matches and wildcard patterns.
Args:
model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620")
allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*")
Returns:
bool: True if model matches the pattern, False otherwise
"""
if "*" in allowed_model_pattern:
pattern = f"^{allowed_model_pattern.replace('*', '.*')}$"
return bool(re.match(pattern, model))
return False
def _model_matches_any_wildcard_pattern_in_list(
model: str, allowed_model_list: list
) -> bool:
"""
Returns True if a model matches any wildcard pattern in a list.
eg.
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
"""
if any(
_is_wildcard_pattern(allowed_model_pattern)
and is_model_allowed_by_pattern(
model=model, allowed_model_pattern=allowed_model_pattern
)
for allowed_model_pattern in allowed_model_list
):
return True
if any(
_is_wildcard_pattern(allowed_model_pattern)
and _model_custom_llm_provider_matches_wildcard_pattern(
model=model, allowed_model_pattern=allowed_model_pattern
)
for allowed_model_pattern in allowed_model_list
):
return True
return False
def _model_custom_llm_provider_matches_wildcard_pattern(
model: str, allowed_model_pattern: str
) -> bool:
"""
Returns True for this scenario:
- `model=gpt-4o`
- `allowed_model_pattern=openai/*`
or
- `model=claude-3-5-sonnet-20240620`
- `allowed_model_pattern=anthropic/*`
"""
try:
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception:
return False
return is_model_allowed_by_pattern(
model=f"{custom_llm_provider}/{model}",
allowed_model_pattern=allowed_model_pattern,
)
def _is_wildcard_pattern(allowed_model_pattern: str) -> bool:
"""
Returns True if the pattern is a wildcard pattern.
Checks if `*` is in the pattern.
"""
return "*" in allowed_model_pattern
async def vector_store_access_check(
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
valid_token: Optional[UserAPIKeyAuth],
):
"""
Checks if the object (key, team, org) has access to the vector store.
Raises ProxyException if the object (key, team, org) cannot access the specific vector store.
"""
from litellm.proxy.proxy_server import prisma_client
#########################################################
# Get the vector store the user is trying to access
#########################################################
if prisma_client is None:
verbose_proxy_logger.debug(
"Prisma client not found, skipping vector store access check"
)
return True
if litellm.vector_store_registry is None:
verbose_proxy_logger.debug(
"Vector store registry not found, skipping vector store access check"
)
return True
vector_store_ids_to_run = litellm.vector_store_registry.get_vector_store_ids_to_run(
non_default_params=request_body, tools=request_body.get("tools", None)
)
if vector_store_ids_to_run is None:
verbose_proxy_logger.debug(
"Vector store to run not found, skipping vector store access check"
)
return True
#########################################################
# Check if the object (key, team, org) has access to the vector store
#########################################################
# Check if the key can access the vector store
if valid_token is not None and valid_token.object_permission_id is not None:
key_object_permission = (
await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": valid_token.object_permission_id},
)
)
if key_object_permission is not None:
_can_object_call_vector_stores(
object_type="key",
vector_store_ids_to_run=vector_store_ids_to_run,
object_permissions=key_object_permission,
)
# Check if the team can access the vector store
if team_object is not None and team_object.object_permission_id is not None:
team_object_permission = (
await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": team_object.object_permission_id},
)
)
if team_object_permission is not None:
_can_object_call_vector_stores(
object_type="team",
vector_store_ids_to_run=vector_store_ids_to_run,
object_permissions=team_object_permission,
)
return True
def _can_object_call_vector_stores(
object_type: Literal["key", "team", "org"],
vector_store_ids_to_run: List[str],
object_permissions: Optional[LiteLLM_ObjectPermissionTable],
):
"""
Raises ProxyException if the object (key, team, org) cannot access the specific vector store.
"""
if object_permissions is None:
return True
if object_permissions.vector_stores is None:
return True
# If length is 0, then the object has access to all vector stores.
if len(object_permissions.vector_stores) == 0:
return True
for vector_store_id in vector_store_ids_to_run:
if vector_store_id not in object_permissions.vector_stores:
raise ProxyException(
message=f"User not allowed to access vector store. Tried to access {vector_store_id}. Only allowed to access {object_permissions.vector_stores}",
type=ProxyErrorTypes.get_vector_store_access_error_type_for_object(
object_type
),
param="vector_store",
code=status.HTTP_401_UNAUTHORIZED,
)
return True