chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,680 @@
|
||||
from typing import List, Set
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_AccessGroupTable,
|
||||
LitellmUserRoles,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_access_object,
|
||||
_cache_key_object,
|
||||
_cache_team_object,
|
||||
_delete_cache_access_object,
|
||||
_get_team_object_from_cache,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.proxy.utils import get_prisma_client_or_throw
|
||||
from litellm.types.access_group import (
|
||||
AccessGroupCreateRequest,
|
||||
AccessGroupResponse,
|
||||
AccessGroupUpdateRequest,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["access group management"],
|
||||
)
|
||||
|
||||
|
||||
def _require_proxy_admin(user_api_key_dict: UserAPIKeyAuth) -> None:
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
|
||||
def _record_to_response(record) -> AccessGroupResponse:
|
||||
return AccessGroupResponse(
|
||||
access_group_id=record.access_group_id,
|
||||
access_group_name=record.access_group_name,
|
||||
description=record.description,
|
||||
access_model_names=record.access_model_names,
|
||||
access_mcp_server_ids=record.access_mcp_server_ids,
|
||||
access_agent_ids=record.access_agent_ids,
|
||||
assigned_team_ids=record.assigned_team_ids,
|
||||
assigned_key_ids=record.assigned_key_ids,
|
||||
created_at=record.created_at,
|
||||
created_by=record.created_by,
|
||||
updated_at=record.updated_at,
|
||||
updated_by=record.updated_by,
|
||||
)
|
||||
|
||||
|
||||
def _record_to_access_group_table(record) -> LiteLLM_AccessGroupTable:
|
||||
"""Convert a Prisma record to a LiteLLM_AccessGroupTable pydantic object for caching."""
|
||||
return LiteLLM_AccessGroupTable(**record.dict())
|
||||
|
||||
|
||||
async def _cache_access_group_record(record) -> None:
|
||||
"""
|
||||
Cache an access group Prisma record in the user_api_key_cache.
|
||||
|
||||
Uses a lazy import of user_api_key_cache and proxy_logging_obj from proxy_server
|
||||
to avoid circular imports, following the same pattern as key_management_endpoints.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
|
||||
|
||||
access_group_table = _record_to_access_group_table(record)
|
||||
await _cache_access_object(
|
||||
access_group_id=record.access_group_id,
|
||||
access_group_table=access_group_table,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
async def _invalidate_cache_access_group(access_group_id: str) -> None:
|
||||
"""
|
||||
Invalidate (delete) an access group entry from both in-memory and Redis caches.
|
||||
|
||||
Uses a lazy import of user_api_key_cache and proxy_logging_obj from proxy_server
|
||||
to avoid circular imports, following the same pattern as key_management_endpoints.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
|
||||
|
||||
await _delete_cache_access_object(
|
||||
access_group_id=access_group_id,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB sync helpers (called inside a Prisma transaction)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _sync_add_access_group_to_teams(
|
||||
tx, team_ids: List[str], access_group_id: str
|
||||
) -> None:
|
||||
"""Add access_group_id to each team's access_group_ids (idempotent)."""
|
||||
for team_id in team_ids:
|
||||
team = await tx.litellm_teamtable.find_unique(where={"team_id": team_id})
|
||||
if team is not None and access_group_id not in (team.access_group_ids or []):
|
||||
await tx.litellm_teamtable.update(
|
||||
where={"team_id": team_id},
|
||||
data={
|
||||
"access_group_ids": list(team.access_group_ids or [])
|
||||
+ [access_group_id]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _sync_remove_access_group_from_teams(
|
||||
tx, team_ids: List[str], access_group_id: str
|
||||
) -> None:
|
||||
"""Remove access_group_id from each team's access_group_ids (idempotent)."""
|
||||
for team_id in team_ids:
|
||||
team = await tx.litellm_teamtable.find_unique(where={"team_id": team_id})
|
||||
if team is not None and access_group_id in (team.access_group_ids or []):
|
||||
await tx.litellm_teamtable.update(
|
||||
where={"team_id": team_id},
|
||||
data={
|
||||
"access_group_ids": [
|
||||
ag for ag in team.access_group_ids if ag != access_group_id
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _sync_add_access_group_to_keys(
|
||||
tx, key_tokens: List[str], access_group_id: str
|
||||
) -> None:
|
||||
"""Add access_group_id to each key's access_group_ids (idempotent)."""
|
||||
for token in key_tokens:
|
||||
key = await tx.litellm_verificationtoken.find_unique(where={"token": token})
|
||||
if key is not None and access_group_id not in (key.access_group_ids or []):
|
||||
await tx.litellm_verificationtoken.update(
|
||||
where={"token": token},
|
||||
data={
|
||||
"access_group_ids": list(key.access_group_ids or [])
|
||||
+ [access_group_id]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _sync_remove_access_group_from_keys(
|
||||
tx, key_tokens: List[str], access_group_id: str
|
||||
) -> None:
|
||||
"""Remove access_group_id from each key's access_group_ids (idempotent)."""
|
||||
for token in key_tokens:
|
||||
key = await tx.litellm_verificationtoken.find_unique(where={"token": token})
|
||||
if key is not None and access_group_id in (key.access_group_ids or []):
|
||||
await tx.litellm_verificationtoken.update(
|
||||
where={"token": token},
|
||||
data={
|
||||
"access_group_ids": [
|
||||
ag for ag in key.access_group_ids if ag != access_group_id
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache patch helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _patch_team_caches_add_access_group(
|
||||
team_ids: List[str],
|
||||
access_group_id: str,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
) -> None:
|
||||
"""Patch cached team objects to include access_group_id."""
|
||||
for team_id in team_ids:
|
||||
cached_team = await _get_team_object_from_cache(
|
||||
key="team_id:{}".format(team_id),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
)
|
||||
if cached_team is None:
|
||||
continue
|
||||
if cached_team.access_group_ids is None:
|
||||
cached_team.access_group_ids = [access_group_id]
|
||||
elif access_group_id not in cached_team.access_group_ids:
|
||||
cached_team.access_group_ids = list(cached_team.access_group_ids) + [
|
||||
access_group_id
|
||||
]
|
||||
else:
|
||||
continue
|
||||
await _cache_team_object(
|
||||
team_id=team_id,
|
||||
team_table=cached_team,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
async def _patch_team_caches_remove_access_group(
|
||||
team_ids: List[str],
|
||||
access_group_id: str,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
) -> None:
|
||||
"""Patch cached team objects to remove access_group_id."""
|
||||
for team_id in team_ids:
|
||||
cached_team = await _get_team_object_from_cache(
|
||||
key="team_id:{}".format(team_id),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
)
|
||||
if cached_team is not None and cached_team.access_group_ids:
|
||||
cached_team.access_group_ids = [
|
||||
ag for ag in cached_team.access_group_ids if ag != access_group_id
|
||||
]
|
||||
await _cache_team_object(
|
||||
team_id=team_id,
|
||||
team_table=cached_team,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
async def _patch_key_caches_add_access_group(
|
||||
key_tokens: List[str],
|
||||
access_group_id: str,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
) -> None:
|
||||
"""Patch cached key objects to include access_group_id."""
|
||||
for token in key_tokens:
|
||||
cached_key = await user_api_key_cache.async_get_cache(key=token)
|
||||
if cached_key is None:
|
||||
continue
|
||||
if isinstance(cached_key, dict):
|
||||
cached_key = UserAPIKeyAuth(**cached_key)
|
||||
if not isinstance(cached_key, UserAPIKeyAuth):
|
||||
continue
|
||||
if cached_key.access_group_ids is None:
|
||||
cached_key.access_group_ids = [access_group_id]
|
||||
elif access_group_id not in cached_key.access_group_ids:
|
||||
cached_key.access_group_ids = list(cached_key.access_group_ids) + [
|
||||
access_group_id
|
||||
]
|
||||
else:
|
||||
continue
|
||||
await _cache_key_object(
|
||||
hashed_token=token,
|
||||
user_api_key_obj=cached_key,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
async def _patch_key_caches_remove_access_group(
|
||||
key_tokens: List[str],
|
||||
access_group_id: str,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
) -> None:
|
||||
"""Patch cached key objects to remove access_group_id."""
|
||||
for token in key_tokens:
|
||||
cached_key = await user_api_key_cache.async_get_cache(key=token)
|
||||
if cached_key is None:
|
||||
continue
|
||||
if isinstance(cached_key, dict):
|
||||
cached_key = UserAPIKeyAuth(**cached_key)
|
||||
if isinstance(cached_key, UserAPIKeyAuth) and cached_key.access_group_ids:
|
||||
cached_key.access_group_ids = [
|
||||
ag for ag in cached_key.access_group_ids if ag != access_group_id
|
||||
]
|
||||
await _cache_key_object(
|
||||
hashed_token=token,
|
||||
user_api_key_obj=cached_key,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CRUD endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/access_group",
|
||||
response_model=AccessGroupResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_access_group(
|
||||
data: AccessGroupCreateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> AccessGroupResponse:
|
||||
_require_proxy_admin(user_api_key_dict)
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
async with prisma_client.db.tx() as tx:
|
||||
existing = await tx.litellm_accessgrouptable.find_unique(
|
||||
where={"access_group_name": data.access_group_name}
|
||||
)
|
||||
if existing is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Access group '{data.access_group_name}' already exists",
|
||||
)
|
||||
|
||||
record = await tx.litellm_accessgrouptable.create(
|
||||
data={
|
||||
"access_group_name": data.access_group_name,
|
||||
"description": data.description,
|
||||
"access_model_names": data.access_model_names or [],
|
||||
"access_mcp_server_ids": data.access_mcp_server_ids or [],
|
||||
"access_agent_ids": data.access_agent_ids or [],
|
||||
"assigned_team_ids": data.assigned_team_ids or [],
|
||||
"assigned_key_ids": data.assigned_key_ids or [],
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Sync team and key tables to reference the new access group
|
||||
await _sync_add_access_group_to_teams(
|
||||
tx, data.assigned_team_ids or [], record.access_group_id
|
||||
)
|
||||
await _sync_add_access_group_to_keys(
|
||||
tx, data.assigned_key_ids or [], record.access_group_id
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Race condition: another request created the same name between find_unique and create.
|
||||
if "unique constraint" in str(e).lower() or "P2002" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Access group '{data.access_group_name}' already exists",
|
||||
)
|
||||
raise
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
|
||||
|
||||
await _cache_access_group_record(record)
|
||||
await _patch_team_caches_add_access_group(
|
||||
data.assigned_team_ids or [],
|
||||
record.access_group_id,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
await _patch_key_caches_add_access_group(
|
||||
data.assigned_key_ids or [],
|
||||
record.access_group_id,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/access_group",
|
||||
response_model=List[AccessGroupResponse],
|
||||
)
|
||||
async def list_access_groups(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> List[AccessGroupResponse]:
|
||||
_require_proxy_admin(user_api_key_dict)
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
records = await prisma_client.db.litellm_accessgrouptable.find_many(
|
||||
order={"created_at": "desc"}
|
||||
)
|
||||
return [_record_to_response(r) for r in records]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/access_group/{access_group_id}",
|
||||
response_model=AccessGroupResponse,
|
||||
)
|
||||
async def get_access_group(
|
||||
access_group_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> AccessGroupResponse:
|
||||
_require_proxy_admin(user_api_key_dict)
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_accessgrouptable.find_unique(
|
||||
where={"access_group_id": access_group_id}
|
||||
)
|
||||
if record is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Access group '{access_group_id}' not found",
|
||||
)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/v1/access_group/{access_group_id}",
|
||||
response_model=AccessGroupResponse,
|
||||
)
|
||||
async def update_access_group(
|
||||
access_group_id: str,
|
||||
data: AccessGroupUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> AccessGroupResponse:
|
||||
_require_proxy_admin(user_api_key_dict)
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
update_fields = data.model_dump(exclude_unset=True)
|
||||
update_data: dict = {"updated_by": user_api_key_dict.user_id}
|
||||
for field, value in update_fields.items():
|
||||
if (
|
||||
field
|
||||
in (
|
||||
"assigned_team_ids",
|
||||
"assigned_key_ids",
|
||||
"access_model_names",
|
||||
"access_mcp_server_ids",
|
||||
"access_agent_ids",
|
||||
)
|
||||
and value is None
|
||||
):
|
||||
value = []
|
||||
update_data[field] = value
|
||||
|
||||
# Initialize delta lists before the try block so they remain accessible
|
||||
# for cache updates after the transaction, even if an error path is added later.
|
||||
teams_to_add: List[str] = []
|
||||
teams_to_remove: List[str] = []
|
||||
keys_to_add: List[str] = []
|
||||
keys_to_remove: List[str] = []
|
||||
|
||||
try:
|
||||
async with prisma_client.db.tx() as tx:
|
||||
# Read inside the transaction so delta computation is consistent with the write,
|
||||
# avoiding a TOCTOU race where a concurrent update could make deltas stale.
|
||||
existing = await tx.litellm_accessgrouptable.find_unique(
|
||||
where={"access_group_id": access_group_id}
|
||||
)
|
||||
if existing is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Access group '{access_group_id}' not found",
|
||||
)
|
||||
|
||||
old_team_ids: Set[str] = set(existing.assigned_team_ids or [])
|
||||
old_key_ids: Set[str] = set(existing.assigned_key_ids or [])
|
||||
new_team_ids: Set[str] = (
|
||||
set(update_fields["assigned_team_ids"] or [])
|
||||
if "assigned_team_ids" in update_fields
|
||||
else old_team_ids
|
||||
)
|
||||
new_key_ids: Set[str] = (
|
||||
set(update_fields["assigned_key_ids"] or [])
|
||||
if "assigned_key_ids" in update_fields
|
||||
else old_key_ids
|
||||
)
|
||||
|
||||
teams_to_add = list(new_team_ids - old_team_ids)
|
||||
teams_to_remove = list(old_team_ids - new_team_ids)
|
||||
keys_to_add = list(new_key_ids - old_key_ids)
|
||||
keys_to_remove = list(old_key_ids - new_key_ids)
|
||||
|
||||
record = await tx.litellm_accessgrouptable.update(
|
||||
where={"access_group_id": access_group_id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
await _sync_add_access_group_to_teams(tx, teams_to_add, access_group_id)
|
||||
await _sync_remove_access_group_from_teams(
|
||||
tx, teams_to_remove, access_group_id
|
||||
)
|
||||
await _sync_add_access_group_to_keys(tx, keys_to_add, access_group_id)
|
||||
await _sync_remove_access_group_from_keys(
|
||||
tx, keys_to_remove, access_group_id
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Unique constraint violation (e.g. access_group_name already exists).
|
||||
if "unique constraint" in str(e).lower() or "P2002" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Access group '{update_data.get('access_group_name', '')}' already exists",
|
||||
)
|
||||
raise
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
|
||||
|
||||
await _cache_access_group_record(record)
|
||||
await _patch_team_caches_add_access_group(
|
||||
teams_to_add, access_group_id, user_api_key_cache, proxy_logging_obj
|
||||
)
|
||||
await _patch_team_caches_remove_access_group(
|
||||
teams_to_remove, access_group_id, user_api_key_cache, proxy_logging_obj
|
||||
)
|
||||
await _patch_key_caches_add_access_group(
|
||||
keys_to_add, access_group_id, user_api_key_cache, proxy_logging_obj
|
||||
)
|
||||
await _patch_key_caches_remove_access_group(
|
||||
keys_to_remove, access_group_id, user_api_key_cache, proxy_logging_obj
|
||||
)
|
||||
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/v1/access_group/{access_group_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_access_group(
|
||||
access_group_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> None:
|
||||
_require_proxy_admin(user_api_key_dict)
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
affected_team_ids: List[str] = []
|
||||
affected_key_tokens: List[str] = []
|
||||
|
||||
async with prisma_client.db.tx() as tx:
|
||||
existing = await tx.litellm_accessgrouptable.find_unique(
|
||||
where={"access_group_id": access_group_id}
|
||||
)
|
||||
if existing is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Access group '{access_group_id}' not found",
|
||||
)
|
||||
|
||||
# Union of: teams that have this access_group_id in their own access_group_ids
|
||||
# AND teams listed in assigned_team_ids (handles out-of-sync data from before this sync was added)
|
||||
teams_with_group = await tx.litellm_teamtable.find_many(
|
||||
where={"access_group_ids": {"hasSome": [access_group_id]}}
|
||||
)
|
||||
all_affected_team_ids: Set[str] = {
|
||||
team.team_id for team in teams_with_group
|
||||
} | set(existing.assigned_team_ids or [])
|
||||
affected_team_ids = list(all_affected_team_ids)
|
||||
|
||||
# Union of: keys that have this access_group_id in their own access_group_ids
|
||||
# AND keys listed in assigned_key_ids (handles out-of-sync data)
|
||||
keys_with_group = await tx.litellm_verificationtoken.find_many(
|
||||
where={"access_group_ids": {"hasSome": [access_group_id]}}
|
||||
)
|
||||
all_affected_key_tokens: Set[str] = {
|
||||
key.token for key in keys_with_group
|
||||
} | set(existing.assigned_key_ids or [])
|
||||
affected_key_tokens = list(all_affected_key_tokens)
|
||||
|
||||
# Update teams returned by find_many directly — we already have their data.
|
||||
for team in teams_with_group:
|
||||
await tx.litellm_teamtable.update(
|
||||
where={"team_id": team.team_id},
|
||||
data={
|
||||
"access_group_ids": [
|
||||
ag
|
||||
for ag in (team.access_group_ids or [])
|
||||
if ag != access_group_id
|
||||
]
|
||||
},
|
||||
)
|
||||
# Use _sync_remove only for out-of-sync teams not found by the hasSome query.
|
||||
out_of_sync_team_ids = set(existing.assigned_team_ids or []) - {
|
||||
t.team_id for t in teams_with_group
|
||||
}
|
||||
await _sync_remove_access_group_from_teams(
|
||||
tx, list(out_of_sync_team_ids), access_group_id
|
||||
)
|
||||
|
||||
# Update keys returned by find_many directly — we already have their data.
|
||||
for key in keys_with_group:
|
||||
await tx.litellm_verificationtoken.update(
|
||||
where={"token": key.token},
|
||||
data={
|
||||
"access_group_ids": [
|
||||
ag
|
||||
for ag in (key.access_group_ids or [])
|
||||
if ag != access_group_id
|
||||
]
|
||||
},
|
||||
)
|
||||
# Use _sync_remove only for out-of-sync keys not found by the hasSome query.
|
||||
out_of_sync_key_tokens = set(existing.assigned_key_ids or []) - {
|
||||
k.token for k in keys_with_group
|
||||
}
|
||||
await _sync_remove_access_group_from_keys(
|
||||
tx, list(out_of_sync_key_tokens), access_group_id
|
||||
)
|
||||
|
||||
await tx.litellm_accessgrouptable.delete(
|
||||
where={"access_group_id": access_group_id}
|
||||
)
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
|
||||
|
||||
await _invalidate_cache_access_group(access_group_id)
|
||||
await _patch_team_caches_remove_access_group(
|
||||
affected_team_ids, access_group_id, user_api_key_cache, proxy_logging_obj
|
||||
)
|
||||
await _patch_key_caches_remove_access_group(
|
||||
affected_key_tokens, access_group_id, user_api_key_cache, proxy_logging_obj
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"delete_access_group failed: access_group_id=%s error=%s",
|
||||
access_group_id,
|
||||
e,
|
||||
)
|
||||
if PrismaDBExceptionHandler.is_database_connection_error(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=CommonProxyErrors.db_not_connected_error.value,
|
||||
)
|
||||
if "P2025" in str(e) or (
|
||||
"record" in str(e).lower() and "not found" in str(e).lower()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Access group '{access_group_id}' not found",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete access group. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
# Alias routes for /v1/unified_access_group
|
||||
router.add_api_route(
|
||||
"/v1/unified_access_group",
|
||||
create_access_group,
|
||||
methods=["POST"],
|
||||
response_model=AccessGroupResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
router.add_api_route(
|
||||
"/v1/unified_access_group",
|
||||
list_access_groups,
|
||||
methods=["GET"],
|
||||
response_model=List[AccessGroupResponse],
|
||||
)
|
||||
router.add_api_route(
|
||||
"/v1/unified_access_group/{access_group_id}",
|
||||
get_access_group,
|
||||
methods=["GET"],
|
||||
response_model=AccessGroupResponse,
|
||||
)
|
||||
router.add_api_route(
|
||||
"/v1/unified_access_group/{access_group_id}",
|
||||
update_access_group,
|
||||
methods=["PUT"],
|
||||
response_model=AccessGroupResponse,
|
||||
)
|
||||
router.add_api_route(
|
||||
"/v1/unified_access_group/{access_group_id}",
|
||||
delete_access_group,
|
||||
methods=["DELETE"],
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
BUDGET MANAGEMENT
|
||||
|
||||
All /budget management endpoints
|
||||
|
||||
/budget/new
|
||||
/budget/info
|
||||
/budget/update
|
||||
/budget/delete
|
||||
/budget/settings
|
||||
/budget/list
|
||||
"""
|
||||
|
||||
#### BUDGET TABLE MANAGEMENT ####
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.utils import jsonify_object
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/new",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_budget(
|
||||
budget_obj: BudgetNewRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
|
||||
|
||||
Parameters:
|
||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
|
||||
- max_budget: Optional[float] - The max budget for the budget.
|
||||
- soft_budget: Optional[float] - The soft budget for the budget.
|
||||
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
|
||||
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
|
||||
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
|
||||
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
|
||||
- budget_reset_at: Optional[datetime] - Datetime when the initial budget is reset. Default is now.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Validate budget values are not negative
|
||||
if budget_obj.max_budget is not None and budget_obj.max_budget < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"max_budget cannot be negative. Received: {budget_obj.max_budget}"
|
||||
},
|
||||
)
|
||||
if budget_obj.soft_budget is not None and budget_obj.soft_budget < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"soft_budget cannot be negative. Received: {budget_obj.soft_budget}"
|
||||
},
|
||||
)
|
||||
|
||||
# Validate model_max_budget if present
|
||||
if budget_obj.model_max_budget is not None and len(budget_obj.model_max_budget) > 0:
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
validate_model_max_budget,
|
||||
)
|
||||
|
||||
try:
|
||||
validate_model_max_budget(budget_obj.model_max_budget)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail={"error": str(e)})
|
||||
|
||||
# if no budget_reset_at date is set, but a budget_duration is given, then set budget_reset_at initially to the first completed duration interval in future
|
||||
if budget_obj.budget_reset_at is None and budget_obj.budget_duration is not None:
|
||||
budget_obj.budget_reset_at = datetime.utcnow() + timedelta(
|
||||
seconds=duration_in_seconds(duration=budget_obj.budget_duration)
|
||||
)
|
||||
|
||||
budget_obj_json = budget_obj.model_dump(exclude_none=True)
|
||||
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
|
||||
response = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**budget_obj_jsonified, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
} # type: ignore
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/update",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_budget(
|
||||
budget_obj: BudgetNewRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing budget object.
|
||||
|
||||
Parameters:
|
||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
|
||||
- max_budget: Optional[float] - The max budget for the budget.
|
||||
- soft_budget: Optional[float] - The soft budget for the budget.
|
||||
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
|
||||
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
|
||||
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
|
||||
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
|
||||
- budget_reset_at: Optional[datetime] - Update the Datetime when the budget was last reset.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
if budget_obj.budget_id is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
|
||||
|
||||
# Validate budget values are not negative
|
||||
if budget_obj.max_budget is not None and budget_obj.max_budget < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"max_budget cannot be negative. Received: {budget_obj.max_budget}"
|
||||
},
|
||||
)
|
||||
if budget_obj.soft_budget is not None and budget_obj.soft_budget < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"soft_budget cannot be negative. Received: {budget_obj.soft_budget}"
|
||||
},
|
||||
)
|
||||
|
||||
# Validate model_max_budget if present in update
|
||||
if budget_obj.model_max_budget is not None and len(budget_obj.model_max_budget) > 0:
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
validate_model_max_budget,
|
||||
)
|
||||
|
||||
try:
|
||||
validate_model_max_budget(budget_obj.model_max_budget)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail={"error": str(e)})
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.update(
|
||||
where={"budget_id": budget_obj.budget_id},
|
||||
data={
|
||||
**budget_obj.model_dump(exclude_unset=True), # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}, # type: ignore
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/info",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def info_budget(data: BudgetRequest):
|
||||
"""
|
||||
Get the budget id specific information
|
||||
|
||||
Parameters:
|
||||
- budgets: List[str] - The list of budget ids to get information for
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if len(data.budgets) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
|
||||
},
|
||||
)
|
||||
response = await prisma_client.db.litellm_budgettable.find_many(
|
||||
where={"budget_id": {"in": data.budgets}},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/settings",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def budget_settings(
|
||||
budget_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get list of configurable params + current value for a budget item + description of each field
|
||||
|
||||
Used on Admin UI.
|
||||
|
||||
Query Parameters:
|
||||
- budget_id: str - The budget id to get information for
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
## get budget item from db
|
||||
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
|
||||
where={"budget_id": budget_id}
|
||||
)
|
||||
|
||||
if db_budget_row is not None:
|
||||
db_budget_row_dict = db_budget_row.model_dump(exclude_none=True)
|
||||
else:
|
||||
db_budget_row_dict = {}
|
||||
|
||||
allowed_args = {
|
||||
"max_parallel_requests": {"type": "Integer"},
|
||||
"tpm_limit": {"type": "Integer"},
|
||||
"rpm_limit": {"type": "Integer"},
|
||||
"budget_duration": {"type": "String"},
|
||||
"max_budget": {"type": "Float"},
|
||||
"soft_budget": {"type": "Float"},
|
||||
"model_max_budget": {"type": "Object"},
|
||||
}
|
||||
|
||||
return_val = []
|
||||
|
||||
for field_name, field_info in BudgetNewRequest.model_fields.items():
|
||||
if field_name in allowed_args:
|
||||
_stored_in_db = True
|
||||
|
||||
_response_obj = ConfigList(
|
||||
field_name=field_name,
|
||||
field_type=allowed_args[field_name]["type"],
|
||||
field_description=field_info.description or "",
|
||||
field_value=db_budget_row_dict.get(field_name, None),
|
||||
stored_in_db=_stored_in_db,
|
||||
field_default_value=field_info.default,
|
||||
)
|
||||
return_val.append(_response_obj)
|
||||
|
||||
return return_val
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/list",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def list_budget(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""List all the created budgets in proxy db. Used on Admin UI."""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.find_many()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/delete",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_budget(
|
||||
data: BudgetDeleteRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete budget
|
||||
|
||||
Parameters:
|
||||
- id: str - The budget id to delete
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.delete(
|
||||
where={"budget_id": data.id}
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
CACHE SETTINGS MANAGEMENT
|
||||
|
||||
Endpoints for managing cache configuration
|
||||
|
||||
GET /cache/settings - Get cache configuration including available settings
|
||||
POST /cache/settings/test - Test cache connection with provided credentials
|
||||
POST /cache/settings - Save cache settings to database
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.management_endpoints import (
|
||||
CACHE_SETTINGS_FIELDS,
|
||||
REDIS_TYPE_DESCRIPTIONS,
|
||||
CacheSettingsField,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CacheSettingsManager:
|
||||
"""
|
||||
Manages cache settings initialization and updates.
|
||||
Tracks last cache params to avoid unnecessary reinitialization.
|
||||
"""
|
||||
|
||||
_last_cache_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
@staticmethod
|
||||
def _cache_params_equal(params1: Dict[str, Any], params2: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Compare two cache parameter dictionaries for equality.
|
||||
Normalizes values and filters out UI-only fields.
|
||||
"""
|
||||
|
||||
# Normalize by removing None values and UI-only fields
|
||||
def normalize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
normalized = {}
|
||||
for k, v in params.items():
|
||||
if k == "redis_type": # Skip UI-only field
|
||||
continue
|
||||
if v is not None:
|
||||
# Convert to string for comparison to handle different types
|
||||
normalized[k] = str(v) if not isinstance(v, (list, dict)) else v
|
||||
return normalized
|
||||
|
||||
normalized1 = normalize(params1)
|
||||
normalized2 = normalize(params2)
|
||||
|
||||
return normalized1 == normalized2
|
||||
|
||||
@staticmethod
|
||||
async def init_cache_settings_in_db(prisma_client, proxy_config):
|
||||
"""
|
||||
Initialize cache settings from database into the router on startup.
|
||||
Only reinitializes if cache params have changed.
|
||||
"""
|
||||
import json
|
||||
|
||||
try:
|
||||
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
|
||||
where={"id": "cache_config"}
|
||||
)
|
||||
if cache_config is not None and cache_config.cache_settings:
|
||||
# Parse cache settings JSON
|
||||
cache_settings_json = cache_config.cache_settings
|
||||
if isinstance(cache_settings_json, str):
|
||||
cache_settings_dict = json.loads(cache_settings_json)
|
||||
else:
|
||||
cache_settings_dict = cache_settings_json
|
||||
|
||||
# Decrypt cache settings
|
||||
decrypted_settings = proxy_config._decrypt_db_variables(
|
||||
variables_dict=cache_settings_dict
|
||||
)
|
||||
|
||||
# Remove redis_type if present (UI-only field, not a Cache parameter)
|
||||
# We derive it for UI in get_cache_settings endpoint
|
||||
cache_params = {
|
||||
k: v for k, v in decrypted_settings.items() if k != "redis_type"
|
||||
}
|
||||
|
||||
# Check if cache params have changed
|
||||
if (
|
||||
CacheSettingsManager._last_cache_params is not None
|
||||
and CacheSettingsManager._cache_params_equal(
|
||||
CacheSettingsManager._last_cache_params, cache_params
|
||||
)
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"Cache settings unchanged, skipping reinitialization"
|
||||
)
|
||||
return
|
||||
|
||||
# Initialize cache only if params changed or cache not initialized
|
||||
proxy_config._init_cache(cache_params=cache_params)
|
||||
|
||||
# Store the params we just initialized
|
||||
CacheSettingsManager._last_cache_params = cache_params.copy()
|
||||
|
||||
# Switch on LLM response caching
|
||||
proxy_config.switch_on_llm_response_caching()
|
||||
|
||||
verbose_proxy_logger.info("Cache settings initialized from database")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.cache_settings_endpoints.py::CacheSettingsManager::init_cache_settings_in_db - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_cache_params(cache_params: Dict[str, Any]):
|
||||
"""
|
||||
Update the last cache params after initialization.
|
||||
Called after cache settings are updated via the API.
|
||||
"""
|
||||
CacheSettingsManager._last_cache_params = cache_params.copy()
|
||||
|
||||
|
||||
class CacheSettingsResponse(BaseModel):
|
||||
fields: List[CacheSettingsField] = Field(
|
||||
description="List of all configurable cache settings with metadata"
|
||||
)
|
||||
current_values: Dict[str, Any] = Field(
|
||||
description="Current values of cache settings"
|
||||
)
|
||||
redis_type_descriptions: Dict[str, str] = Field(
|
||||
description="Descriptions for each Redis type option"
|
||||
)
|
||||
|
||||
|
||||
class CacheTestRequest(BaseModel):
|
||||
cache_settings: Dict[str, Any] = Field(
|
||||
description="Cache settings to test connection with"
|
||||
)
|
||||
|
||||
|
||||
class CacheTestResponse(BaseModel):
|
||||
status: str = Field(description="Connection status: 'success' or 'failed'")
|
||||
message: str = Field(description="Connection result message")
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="Error message if connection failed"
|
||||
)
|
||||
|
||||
|
||||
class CacheSettingsUpdateRequest(BaseModel):
|
||||
cache_settings: Dict[str, Any] = Field(description="Cache settings to save")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/cache/settings",
|
||||
tags=["Cache Settings"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CacheSettingsResponse,
|
||||
)
|
||||
async def get_cache_settings(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get cache configuration and available settings.
|
||||
|
||||
Returns:
|
||||
- fields: List of all configurable cache settings with their metadata (type, description, default, options)
|
||||
- current_values: Current values of cache settings from database
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client, proxy_config
|
||||
|
||||
try:
|
||||
# Get cache settings fields from types file
|
||||
cache_fields = [field.model_copy(deep=True) for field in CACHE_SETTINGS_FIELDS]
|
||||
|
||||
# Try to get cache settings from database
|
||||
current_values = {}
|
||||
if prisma_client is not None:
|
||||
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
|
||||
where={"id": "cache_config"}
|
||||
)
|
||||
if cache_config is not None and cache_config.cache_settings:
|
||||
# Decrypt cache settings
|
||||
cache_settings_json = cache_config.cache_settings
|
||||
if isinstance(cache_settings_json, str):
|
||||
cache_settings_dict = json.loads(cache_settings_json)
|
||||
else:
|
||||
cache_settings_dict = cache_settings_json
|
||||
|
||||
# Decrypt environment variables
|
||||
decrypted_settings = proxy_config._decrypt_db_variables(
|
||||
variables_dict=cache_settings_dict
|
||||
)
|
||||
|
||||
# Derive redis_type for UI based on settings
|
||||
# UI uses redis_type to show/hide fields, backend only stores 'type'
|
||||
if decrypted_settings.get("type") == "redis":
|
||||
if decrypted_settings.get("redis_startup_nodes"):
|
||||
decrypted_settings["redis_type"] = "cluster"
|
||||
elif decrypted_settings.get("sentinel_nodes"):
|
||||
decrypted_settings["redis_type"] = "sentinel"
|
||||
else:
|
||||
decrypted_settings["redis_type"] = "node"
|
||||
|
||||
current_values = decrypted_settings
|
||||
|
||||
# Update field values with current values
|
||||
for field in cache_fields:
|
||||
if field.field_name in current_values:
|
||||
field.field_value = current_values[field.field_name]
|
||||
|
||||
return CacheSettingsResponse(
|
||||
fields=cache_fields,
|
||||
current_values=current_values,
|
||||
redis_type_descriptions=REDIS_TYPE_DESCRIPTIONS,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error fetching cache settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error fetching cache settings: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/settings/test",
|
||||
tags=["Cache Settings"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CacheTestResponse,
|
||||
)
|
||||
async def test_cache_connection(
|
||||
request: CacheTestRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Test cache connection with provided credentials.
|
||||
|
||||
Creates a temporary cache instance and uses its test_connection method
|
||||
to verify the credentials work without affecting global state.
|
||||
"""
|
||||
from litellm import Cache
|
||||
|
||||
try:
|
||||
cache_settings = request.cache_settings.copy()
|
||||
verbose_proxy_logger.debug(
|
||||
"Testing cache connection with settings: %s", cache_settings
|
||||
)
|
||||
|
||||
# Only support Redis for now
|
||||
if cache_settings.get("type") != "redis":
|
||||
return CacheTestResponse(
|
||||
status="failed",
|
||||
message="Only Redis cache type is currently supported for testing",
|
||||
)
|
||||
|
||||
# Create temporary cache instance
|
||||
temp_cache = Cache(**cache_settings)
|
||||
|
||||
# Use the cache's test_connection method
|
||||
result = await temp_cache.cache.test_connection()
|
||||
|
||||
return CacheTestResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error testing cache connection: {str(e)}")
|
||||
return CacheTestResponse(
|
||||
status="failed",
|
||||
message=f"Cache connection test failed: {str(e)}",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/settings",
|
||||
tags=["Cache Settings"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_cache_settings(
|
||||
request: CacheSettingsUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Save cache settings to database and initialize cache.
|
||||
|
||||
This endpoint:
|
||||
1. Encrypts sensitive fields (passwords, etc.)
|
||||
2. Saves to LiteLLM_CacheConfig table
|
||||
3. Reinitializes cache with new settings
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Database not connected. Please connect a database."},
|
||||
)
|
||||
|
||||
if store_model_in_db is not True:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
cache_settings = request.cache_settings.copy()
|
||||
|
||||
# Encrypt sensitive fields (keep redis_type for storage)
|
||||
encrypted_settings = proxy_config._encrypt_env_variables(
|
||||
environment_variables=cache_settings
|
||||
)
|
||||
|
||||
# Save to database
|
||||
await prisma_client.db.litellm_cacheconfig.upsert(
|
||||
where={"id": "cache_config"},
|
||||
data={
|
||||
"create": {
|
||||
"id": "cache_config",
|
||||
"cache_settings": json.dumps(encrypted_settings),
|
||||
},
|
||||
"update": {
|
||||
"cache_settings": json.dumps(encrypted_settings),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Reinitialize cache with new settings
|
||||
# Decrypt for initialization
|
||||
decrypted_settings = proxy_config._decrypt_db_variables(
|
||||
variables_dict=encrypted_settings
|
||||
)
|
||||
|
||||
# Remove redis_type if present (UI-only field, not a Cache parameter)
|
||||
cache_params = {
|
||||
k: v for k, v in decrypted_settings.items() if k != "redis_type"
|
||||
}
|
||||
|
||||
# Initialize cache (frontend sends type="redis", not redis_type)
|
||||
proxy_config._init_cache(cache_params=cache_params)
|
||||
|
||||
# Update the last cache params to avoid reinitializing unnecessarily
|
||||
CacheSettingsManager.update_cache_params(cache_params)
|
||||
|
||||
# Switch on LLM response caching
|
||||
proxy_config.switch_on_llm_response_caching()
|
||||
|
||||
return {
|
||||
"message": "Cache settings updated successfully",
|
||||
"status": "success",
|
||||
"settings": cache_settings,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error updating cache settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error updating cache settings: {str(e)}"
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Endpoints for managing callbacks
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from litellm.litellm_core_utils.logging_callback_manager import CallbacksByType
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/callbacks/list",
|
||||
tags=["Logging Callbacks"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CallbacksByType,
|
||||
)
|
||||
async def list_callbacks():
|
||||
"""
|
||||
View List of Active Logging Callbacks
|
||||
"""
|
||||
from litellm import logging_callback_manager
|
||||
|
||||
# Get callbacks organized by type using the callback manager utility
|
||||
callbacks_by_type = logging_callback_manager.get_callbacks_by_type()
|
||||
|
||||
return callbacks_by_type
|
||||
|
||||
|
||||
@router.get(
|
||||
"/callbacks/configs",
|
||||
tags=["Logging Callbacks"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_callback_configs():
|
||||
"""
|
||||
Get Available Callback Configurations
|
||||
|
||||
Returns the configuration details for all available logging callbacks,
|
||||
including supported parameters, field types, and descriptions.
|
||||
"""
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"integrations",
|
||||
"callback_configs.json",
|
||||
)
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
configs = json.load(f)
|
||||
|
||||
return configs
|
||||
@@ -0,0 +1,814 @@
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
BreakdownMetrics,
|
||||
DailySpendData,
|
||||
DailySpendMetadata,
|
||||
KeyMetadata,
|
||||
KeyMetricWithMetadata,
|
||||
MetricWithMetadata,
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
SpendMetrics,
|
||||
)
|
||||
|
||||
# Mapping from Prisma accessor names to actual PostgreSQL table names.
|
||||
_PRISMA_TO_PG_TABLE: Dict[str, str] = {
|
||||
"litellm_dailyuserspend": "LiteLLM_DailyUserSpend",
|
||||
"litellm_dailyteamspend": "LiteLLM_DailyTeamSpend",
|
||||
"litellm_dailyorganizationspend": "LiteLLM_DailyOrganizationSpend",
|
||||
"litellm_dailyenduserspend": "LiteLLM_DailyEndUserSpend",
|
||||
"litellm_dailyagentspend": "LiteLLM_DailyAgentSpend",
|
||||
"litellm_dailytagspend": "LiteLLM_DailyTagSpend",
|
||||
}
|
||||
|
||||
|
||||
def update_metrics(existing_metrics: SpendMetrics, record: Any) -> SpendMetrics:
|
||||
"""Update metrics with new record data."""
|
||||
existing_metrics.spend += record.spend
|
||||
existing_metrics.prompt_tokens += record.prompt_tokens
|
||||
existing_metrics.completion_tokens += record.completion_tokens
|
||||
existing_metrics.total_tokens += record.prompt_tokens + record.completion_tokens
|
||||
existing_metrics.cache_read_input_tokens += record.cache_read_input_tokens
|
||||
existing_metrics.cache_creation_input_tokens += record.cache_creation_input_tokens
|
||||
existing_metrics.api_requests += record.api_requests
|
||||
existing_metrics.successful_requests += record.successful_requests
|
||||
existing_metrics.failed_requests += record.failed_requests
|
||||
return existing_metrics
|
||||
|
||||
|
||||
def _is_user_agent_tag(tag: Optional[str]) -> bool:
|
||||
"""Determine whether a tag should be treated as a User-Agent tag."""
|
||||
if not tag:
|
||||
return False
|
||||
normalized_tag = tag.strip().lower()
|
||||
return normalized_tag.startswith("user-agent:") or normalized_tag.startswith(
|
||||
"user agent:"
|
||||
)
|
||||
|
||||
|
||||
def compute_tag_metadata_totals(records: List[Any]) -> SpendMetrics:
|
||||
"""
|
||||
Deduplicate spend metrics for tags using request_id, ignoring User-Agent prefixed tags.
|
||||
|
||||
Each unique request_id contributes at most one record (the tag with max spend) to metadata.
|
||||
"""
|
||||
deduped_records: Dict[str, Any] = {}
|
||||
for record in records:
|
||||
request_id = getattr(record, "request_id", None)
|
||||
if not request_id:
|
||||
continue
|
||||
|
||||
tag_value = getattr(record, "tag", None)
|
||||
if _is_user_agent_tag(tag_value):
|
||||
continue
|
||||
|
||||
current_best = deduped_records.get(request_id)
|
||||
if current_best is None or record.spend > current_best.spend:
|
||||
deduped_records[request_id] = record
|
||||
|
||||
metadata_metrics = SpendMetrics()
|
||||
for record in deduped_records.values():
|
||||
update_metrics(metadata_metrics, record)
|
||||
return metadata_metrics
|
||||
|
||||
|
||||
def update_breakdown_metrics(
|
||||
breakdown: BreakdownMetrics,
|
||||
record: Any,
|
||||
model_metadata: Dict[str, Dict[str, Any]],
|
||||
provider_metadata: Dict[str, Dict[str, Any]],
|
||||
api_key_metadata: Dict[str, Dict[str, Any]],
|
||||
entity_id_field: Optional[str] = None,
|
||||
entity_metadata_field: Optional[Dict[str, dict]] = None,
|
||||
) -> BreakdownMetrics:
|
||||
"""Updates breakdown metrics for a single record using the existing update_metrics function"""
|
||||
|
||||
# Update model breakdown
|
||||
if record.model and record.model not in breakdown.models:
|
||||
breakdown.models[record.model] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=model_metadata.get(
|
||||
record.model, {}
|
||||
), # Add any model-specific metadata here
|
||||
)
|
||||
if record.model:
|
||||
breakdown.models[record.model].metrics = update_metrics(
|
||||
breakdown.models[record.model].metrics, record
|
||||
)
|
||||
|
||||
# Update API key breakdown for this model
|
||||
if record.api_key not in breakdown.models[record.model].api_key_breakdown:
|
||||
breakdown.models[record.model].api_key_breakdown[
|
||||
record.api_key
|
||||
] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get(
|
||||
"team_id", None
|
||||
),
|
||||
),
|
||||
)
|
||||
breakdown.models[record.model].api_key_breakdown[
|
||||
record.api_key
|
||||
].metrics = update_metrics(
|
||||
breakdown.models[record.model].api_key_breakdown[record.api_key].metrics,
|
||||
record,
|
||||
)
|
||||
|
||||
# Update model group breakdown
|
||||
if record.model_group and record.model_group not in breakdown.model_groups:
|
||||
breakdown.model_groups[record.model_group] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=model_metadata.get(record.model_group, {}),
|
||||
)
|
||||
if record.model_group:
|
||||
breakdown.model_groups[record.model_group].metrics = update_metrics(
|
||||
breakdown.model_groups[record.model_group].metrics, record
|
||||
)
|
||||
|
||||
# Update API key breakdown for this model
|
||||
if (
|
||||
record.api_key
|
||||
not in breakdown.model_groups[record.model_group].api_key_breakdown
|
||||
):
|
||||
breakdown.model_groups[record.model_group].api_key_breakdown[
|
||||
record.api_key
|
||||
] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get(
|
||||
"team_id", None
|
||||
),
|
||||
),
|
||||
)
|
||||
breakdown.model_groups[record.model_group].api_key_breakdown[
|
||||
record.api_key
|
||||
].metrics = update_metrics(
|
||||
breakdown.model_groups[record.model_group]
|
||||
.api_key_breakdown[record.api_key]
|
||||
.metrics,
|
||||
record,
|
||||
)
|
||||
|
||||
if record.mcp_namespaced_tool_name:
|
||||
if record.mcp_namespaced_tool_name not in breakdown.mcp_servers:
|
||||
breakdown.mcp_servers[record.mcp_namespaced_tool_name] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata={},
|
||||
)
|
||||
breakdown.mcp_servers[record.mcp_namespaced_tool_name].metrics = update_metrics(
|
||||
breakdown.mcp_servers[record.mcp_namespaced_tool_name].metrics, record
|
||||
)
|
||||
|
||||
# Update API key breakdown for this MCP server
|
||||
if (
|
||||
record.api_key
|
||||
not in breakdown.mcp_servers[
|
||||
record.mcp_namespaced_tool_name
|
||||
].api_key_breakdown
|
||||
):
|
||||
breakdown.mcp_servers[record.mcp_namespaced_tool_name].api_key_breakdown[
|
||||
record.api_key
|
||||
] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get(
|
||||
"team_id", None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
breakdown.mcp_servers[record.mcp_namespaced_tool_name].api_key_breakdown[
|
||||
record.api_key
|
||||
].metrics = update_metrics(
|
||||
breakdown.mcp_servers[record.mcp_namespaced_tool_name]
|
||||
.api_key_breakdown[record.api_key]
|
||||
.metrics,
|
||||
record,
|
||||
)
|
||||
|
||||
# Update provider breakdown
|
||||
provider = record.custom_llm_provider or "unknown"
|
||||
if provider not in breakdown.providers:
|
||||
breakdown.providers[provider] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=provider_metadata.get(
|
||||
provider, {}
|
||||
), # Add any provider-specific metadata here
|
||||
)
|
||||
breakdown.providers[provider].metrics = update_metrics(
|
||||
breakdown.providers[provider].metrics, record
|
||||
)
|
||||
|
||||
# Update API key breakdown for this provider
|
||||
if record.api_key not in breakdown.providers[provider].api_key_breakdown:
|
||||
breakdown.providers[provider].api_key_breakdown[
|
||||
record.api_key
|
||||
] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get("team_id", None),
|
||||
),
|
||||
)
|
||||
breakdown.providers[provider].api_key_breakdown[
|
||||
record.api_key
|
||||
].metrics = update_metrics(
|
||||
breakdown.providers[provider].api_key_breakdown[record.api_key].metrics,
|
||||
record,
|
||||
)
|
||||
|
||||
# Update endpoint breakdown
|
||||
if record.endpoint:
|
||||
if record.endpoint not in breakdown.endpoints:
|
||||
breakdown.endpoints[record.endpoint] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata={},
|
||||
)
|
||||
breakdown.endpoints[record.endpoint].metrics = update_metrics(
|
||||
breakdown.endpoints[record.endpoint].metrics, record
|
||||
)
|
||||
|
||||
# Update API key breakdown for this endpoint
|
||||
if record.api_key not in breakdown.endpoints[record.endpoint].api_key_breakdown:
|
||||
breakdown.endpoints[record.endpoint].api_key_breakdown[
|
||||
record.api_key
|
||||
] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get(
|
||||
"team_id", None
|
||||
),
|
||||
),
|
||||
)
|
||||
breakdown.endpoints[record.endpoint].api_key_breakdown[
|
||||
record.api_key
|
||||
].metrics = update_metrics(
|
||||
breakdown.endpoints[record.endpoint]
|
||||
.api_key_breakdown[record.api_key]
|
||||
.metrics,
|
||||
record,
|
||||
)
|
||||
|
||||
# Update api key breakdown
|
||||
if record.api_key not in breakdown.api_keys:
|
||||
breakdown.api_keys[record.api_key] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get("team_id", None),
|
||||
), # Add any api_key-specific metadata here
|
||||
)
|
||||
breakdown.api_keys[record.api_key].metrics = update_metrics(
|
||||
breakdown.api_keys[record.api_key].metrics, record
|
||||
)
|
||||
|
||||
# Update entity-specific metrics if entity_id_field is provided
|
||||
if entity_id_field:
|
||||
entity_value = getattr(record, entity_id_field, None)
|
||||
entity_value = (
|
||||
entity_value if entity_value else "Unassigned"
|
||||
) # allow for null entity_id_field
|
||||
if entity_value not in breakdown.entities:
|
||||
breakdown.entities[entity_value] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=(
|
||||
entity_metadata_field.get(entity_value, {})
|
||||
if entity_metadata_field
|
||||
else {}
|
||||
),
|
||||
)
|
||||
breakdown.entities[entity_value].metrics = update_metrics(
|
||||
breakdown.entities[entity_value].metrics, record
|
||||
)
|
||||
|
||||
# Update API key breakdown for this entity
|
||||
if record.api_key not in breakdown.entities[entity_value].api_key_breakdown:
|
||||
breakdown.entities[entity_value].api_key_breakdown[
|
||||
record.api_key
|
||||
] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get(
|
||||
"team_id", None
|
||||
),
|
||||
),
|
||||
)
|
||||
breakdown.entities[entity_value].api_key_breakdown[
|
||||
record.api_key
|
||||
].metrics = update_metrics(
|
||||
breakdown.entities[entity_value].api_key_breakdown[record.api_key].metrics,
|
||||
record,
|
||||
)
|
||||
|
||||
return breakdown
|
||||
|
||||
|
||||
async def get_api_key_metadata(
|
||||
prisma_client: PrismaClient,
|
||||
api_keys: Set[str],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get api key metadata, falling back to deleted keys table for keys not found in active table.
|
||||
|
||||
This ensures that key_alias and team_id are preserved in historical activity logs
|
||||
even after a key is deleted or regenerated.
|
||||
"""
|
||||
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": list(api_keys)}}
|
||||
)
|
||||
result = {
|
||||
k.token: {"key_alias": k.key_alias, "team_id": k.team_id} for k in key_records
|
||||
}
|
||||
|
||||
# For any keys not found in the active table, check the deleted keys table
|
||||
missing_keys = api_keys - set(result.keys())
|
||||
if missing_keys:
|
||||
try:
|
||||
deleted_key_records = (
|
||||
await prisma_client.db.litellm_deletedverificationtoken.find_many(
|
||||
where={"token": {"in": list(missing_keys)}},
|
||||
order={"deleted_at": "desc"},
|
||||
)
|
||||
)
|
||||
# Use the most recent deleted record for each token (ordered by deleted_at desc)
|
||||
for k in deleted_key_records:
|
||||
if k.token not in result:
|
||||
result[k.token] = {
|
||||
"key_alias": k.key_alias,
|
||||
"team_id": k.team_id,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"Failed to fetch deleted key metadata for %d missing keys: %s",
|
||||
len(missing_keys),
|
||||
e,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _adjust_dates_for_timezone(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
timezone_offset_minutes: Optional[int],
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Adjust date range to account for timezone differences.
|
||||
|
||||
The database stores dates in UTC. When a user in a different timezone
|
||||
selects a local date range, we need to expand the UTC query range to
|
||||
capture all records that fall within their local date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format (user's local date)
|
||||
end_date: End date in YYYY-MM-DD format (user's local date)
|
||||
timezone_offset_minutes: Minutes behind UTC (positive = west of UTC)
|
||||
This matches JavaScript's Date.getTimezoneOffset() convention.
|
||||
For example: PST = +480 (8 hours * 60 = 480 minutes behind UTC)
|
||||
|
||||
Returns:
|
||||
Tuple of (adjusted_start_date, adjusted_end_date) in YYYY-MM-DD format
|
||||
"""
|
||||
if timezone_offset_minutes is None or timezone_offset_minutes == 0:
|
||||
return start_date, end_date
|
||||
|
||||
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
if timezone_offset_minutes > 0:
|
||||
# West of UTC (Americas): local evening extends into next UTC day
|
||||
# e.g., Feb 4 23:59 PST = Feb 5 07:59 UTC
|
||||
end = end + timedelta(days=1)
|
||||
else:
|
||||
# East of UTC (Asia/Europe): local morning starts in previous UTC day
|
||||
# e.g., Feb 4 00:00 IST = Feb 3 18:30 UTC
|
||||
start = start - timedelta(days=1)
|
||||
|
||||
return start.strftime("%Y-%m-%d"), end.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def _build_where_conditions(
|
||||
*,
|
||||
entity_id_field: str,
|
||||
entity_id: Optional[Union[str, List[str]]],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
model: Optional[str],
|
||||
api_key: Optional[Union[str, List[str]]],
|
||||
exclude_entity_ids: Optional[List[str]] = None,
|
||||
timezone_offset_minutes: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build prisma where clause for daily activity queries."""
|
||||
# Adjust dates for timezone if provided
|
||||
adjusted_start, adjusted_end = _adjust_dates_for_timezone(
|
||||
start_date, end_date, timezone_offset_minutes
|
||||
)
|
||||
|
||||
where_conditions: Dict[str, Any] = {
|
||||
"date": {
|
||||
"gte": adjusted_start,
|
||||
"lte": adjusted_end,
|
||||
}
|
||||
}
|
||||
|
||||
if model:
|
||||
where_conditions["model"] = model
|
||||
if api_key:
|
||||
if isinstance(api_key, list):
|
||||
where_conditions["api_key"] = {"in": api_key}
|
||||
else:
|
||||
where_conditions["api_key"] = api_key
|
||||
|
||||
if entity_id is not None:
|
||||
if isinstance(entity_id, list):
|
||||
where_conditions[entity_id_field] = {"in": entity_id}
|
||||
else:
|
||||
where_conditions[entity_id_field] = {"equals": entity_id}
|
||||
|
||||
if exclude_entity_ids:
|
||||
current = where_conditions.get(entity_id_field, {})
|
||||
if isinstance(current, str):
|
||||
current = {"equals": current}
|
||||
current["not"] = {"in": exclude_entity_ids}
|
||||
where_conditions[entity_id_field] = current
|
||||
|
||||
return where_conditions
|
||||
|
||||
|
||||
def _build_aggregated_sql_query(
|
||||
*,
|
||||
table_name: str,
|
||||
entity_id_field: str,
|
||||
entity_id: Optional[Union[str, List[str]]],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
exclude_entity_ids: Optional[List[str]] = None,
|
||||
timezone_offset_minutes: Optional[int] = None,
|
||||
) -> Tuple[str, List[Any]]:
|
||||
"""Build a parameterized SQL GROUP BY query for aggregated daily activity.
|
||||
|
||||
Groups by (date, api_key, model, model_group, custom_llm_provider,
|
||||
mcp_namespaced_tool_name, endpoint) with SUMs on all metric columns.
|
||||
The entity_id column is intentionally omitted from GROUP BY to collapse
|
||||
rows across entities — this is where the biggest row reduction comes from.
|
||||
|
||||
Returns:
|
||||
Tuple of (sql_query, params_list) ready for prisma_client.db.query_raw().
|
||||
"""
|
||||
pg_table = _PRISMA_TO_PG_TABLE.get(table_name)
|
||||
if pg_table is None:
|
||||
raise ValueError(f"Unknown table name: {table_name}")
|
||||
|
||||
adjusted_start, adjusted_end = _adjust_dates_for_timezone(
|
||||
start_date, end_date, timezone_offset_minutes
|
||||
)
|
||||
|
||||
sql_conditions: List[str] = []
|
||||
sql_params: List[Any] = []
|
||||
p = 1 # parameter index (1-based for PostgreSQL $N placeholders)
|
||||
|
||||
# Date range (always present)
|
||||
sql_conditions.append(f"date >= ${p}")
|
||||
sql_params.append(adjusted_start)
|
||||
p += 1
|
||||
|
||||
sql_conditions.append(f"date <= ${p}")
|
||||
sql_params.append(adjusted_end)
|
||||
p += 1
|
||||
|
||||
# Optional entity filter
|
||||
if entity_id is not None:
|
||||
if isinstance(entity_id, list):
|
||||
placeholders = ", ".join(f"${p + i}" for i in range(len(entity_id)))
|
||||
sql_conditions.append(f'"{entity_id_field}" IN ({placeholders})')
|
||||
sql_params.extend(entity_id)
|
||||
p += len(entity_id)
|
||||
else:
|
||||
sql_conditions.append(f'"{entity_id_field}" = ${p}')
|
||||
sql_params.append(entity_id)
|
||||
p += 1
|
||||
|
||||
# Exclude specific entities
|
||||
if exclude_entity_ids:
|
||||
placeholders = ", ".join(f"${p + i}" for i in range(len(exclude_entity_ids)))
|
||||
sql_conditions.append(f'"{entity_id_field}" NOT IN ({placeholders})')
|
||||
sql_params.extend(exclude_entity_ids)
|
||||
p += len(exclude_entity_ids)
|
||||
|
||||
# Optional model filter
|
||||
if model:
|
||||
sql_conditions.append(f"model = ${p}")
|
||||
sql_params.append(model)
|
||||
p += 1
|
||||
|
||||
# Optional api_key filter
|
||||
if api_key:
|
||||
sql_conditions.append(f"api_key = ${p}")
|
||||
sql_params.append(api_key)
|
||||
p += 1
|
||||
|
||||
where_clause = " AND ".join(sql_conditions)
|
||||
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
date,
|
||||
api_key,
|
||||
model,
|
||||
model_group,
|
||||
custom_llm_provider,
|
||||
mcp_namespaced_tool_name,
|
||||
endpoint,
|
||||
SUM(spend)::float AS spend,
|
||||
SUM(prompt_tokens)::bigint AS prompt_tokens,
|
||||
SUM(completion_tokens)::bigint AS completion_tokens,
|
||||
SUM(cache_read_input_tokens)::bigint AS cache_read_input_tokens,
|
||||
SUM(cache_creation_input_tokens)::bigint AS cache_creation_input_tokens,
|
||||
SUM(api_requests)::bigint AS api_requests,
|
||||
SUM(successful_requests)::bigint AS successful_requests,
|
||||
SUM(failed_requests)::bigint AS failed_requests
|
||||
FROM "{pg_table}"
|
||||
WHERE {where_clause}
|
||||
GROUP BY date, api_key, model, model_group, custom_llm_provider,
|
||||
mcp_namespaced_tool_name, endpoint
|
||||
ORDER BY date DESC
|
||||
"""
|
||||
|
||||
return sql_query, sql_params
|
||||
|
||||
|
||||
async def _aggregate_spend_records(
|
||||
*,
|
||||
prisma_client: PrismaClient,
|
||||
records: List[Any],
|
||||
entity_id_field: Optional[str],
|
||||
entity_metadata_field: Optional[Dict[str, dict]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate rows into DailySpendData list and total metrics."""
|
||||
api_keys: Set[str] = set()
|
||||
for record in records:
|
||||
if record.api_key:
|
||||
api_keys.add(record.api_key)
|
||||
|
||||
api_key_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
model_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
provider_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
if api_keys:
|
||||
api_key_metadata = await get_api_key_metadata(prisma_client, api_keys)
|
||||
|
||||
results: List[DailySpendData] = []
|
||||
total_metrics = SpendMetrics()
|
||||
grouped_data: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for record in records:
|
||||
date_str = record.date
|
||||
if date_str not in grouped_data:
|
||||
grouped_data[date_str] = {
|
||||
"metrics": SpendMetrics(),
|
||||
"breakdown": BreakdownMetrics(),
|
||||
}
|
||||
|
||||
grouped_data[date_str]["metrics"] = update_metrics(
|
||||
grouped_data[date_str]["metrics"], record
|
||||
)
|
||||
|
||||
grouped_data[date_str]["breakdown"] = update_breakdown_metrics(
|
||||
grouped_data[date_str]["breakdown"],
|
||||
record,
|
||||
model_metadata,
|
||||
provider_metadata,
|
||||
api_key_metadata,
|
||||
entity_id_field=entity_id_field,
|
||||
entity_metadata_field=entity_metadata_field,
|
||||
)
|
||||
|
||||
total_metrics = update_metrics(total_metrics, record)
|
||||
|
||||
for date_str, data in grouped_data.items():
|
||||
results.append(
|
||||
DailySpendData(
|
||||
date=datetime.strptime(date_str, "%Y-%m-%d").date(),
|
||||
metrics=data["metrics"],
|
||||
breakdown=data["breakdown"],
|
||||
)
|
||||
)
|
||||
|
||||
results.sort(key=lambda x: x.date, reverse=True)
|
||||
|
||||
return {"results": results, "totals": total_metrics}
|
||||
|
||||
|
||||
async def get_daily_activity(
|
||||
prisma_client: Optional[PrismaClient],
|
||||
table_name: str,
|
||||
entity_id_field: str,
|
||||
entity_id: Optional[Union[str, List[str]]],
|
||||
entity_metadata_field: Optional[Dict[str, dict]],
|
||||
start_date: Optional[str],
|
||||
end_date: Optional[str],
|
||||
model: Optional[str],
|
||||
api_key: Optional[Union[str, List[str]]],
|
||||
page: int,
|
||||
page_size: int,
|
||||
exclude_entity_ids: Optional[List[str]] = None,
|
||||
metadata_metrics_func: Optional[Callable[[List[Any]], SpendMetrics]] = None,
|
||||
timezone_offset_minutes: Optional[int] = None,
|
||||
) -> SpendAnalyticsPaginatedResponse:
|
||||
"""Common function to get daily activity for any entity type."""
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if start_date is None or end_date is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Please provide start_date and end_date"},
|
||||
)
|
||||
|
||||
try:
|
||||
where_conditions = _build_where_conditions(
|
||||
entity_id_field=entity_id_field,
|
||||
entity_id=entity_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
exclude_entity_ids=exclude_entity_ids,
|
||||
timezone_offset_minutes=timezone_offset_minutes,
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await getattr(prisma_client.db, table_name).count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
# Fetch paginated results
|
||||
daily_spend_data = await getattr(prisma_client.db, table_name).find_many(
|
||||
where=where_conditions,
|
||||
order=[
|
||||
{"date": "desc"},
|
||||
],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
aggregated = await _aggregate_spend_records(
|
||||
prisma_client=prisma_client,
|
||||
records=daily_spend_data,
|
||||
entity_id_field=entity_id_field,
|
||||
entity_metadata_field=entity_metadata_field,
|
||||
)
|
||||
|
||||
metadata_metrics = aggregated["totals"]
|
||||
if metadata_metrics_func:
|
||||
metadata_metrics = metadata_metrics_func(daily_spend_data)
|
||||
|
||||
return SpendAnalyticsPaginatedResponse(
|
||||
results=aggregated["results"],
|
||||
metadata=DailySpendMetadata(
|
||||
total_spend=metadata_metrics.spend,
|
||||
total_prompt_tokens=metadata_metrics.prompt_tokens,
|
||||
total_completion_tokens=metadata_metrics.completion_tokens,
|
||||
total_tokens=metadata_metrics.total_tokens,
|
||||
total_api_requests=metadata_metrics.api_requests,
|
||||
total_successful_requests=metadata_metrics.successful_requests,
|
||||
total_failed_requests=metadata_metrics.failed_requests,
|
||||
total_cache_read_input_tokens=metadata_metrics.cache_read_input_tokens,
|
||||
total_cache_creation_input_tokens=metadata_metrics.cache_creation_input_tokens,
|
||||
page=page,
|
||||
total_pages=-(-total_count // page_size), # Ceiling division
|
||||
has_more=(page * page_size) < total_count,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error fetching daily activity: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Failed to fetch analytics: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
async def get_daily_activity_aggregated(
|
||||
prisma_client: Optional[PrismaClient],
|
||||
table_name: str,
|
||||
entity_id_field: str,
|
||||
entity_id: Optional[Union[str, List[str]]],
|
||||
entity_metadata_field: Optional[Dict[str, dict]],
|
||||
start_date: Optional[str],
|
||||
end_date: Optional[str],
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
exclude_entity_ids: Optional[List[str]] = None,
|
||||
timezone_offset_minutes: Optional[int] = None,
|
||||
) -> SpendAnalyticsPaginatedResponse:
|
||||
"""Aggregated variant that returns the full result set (no pagination).
|
||||
|
||||
Uses SQL GROUP BY to aggregate rows in the database rather than fetching
|
||||
all individual rows into Python. This collapses rows across entities
|
||||
(users/teams/orgs), reducing ~150k rows to ~2-3k grouped rows.
|
||||
|
||||
Matches the response model of the paginated endpoint so the UI does not need to transform.
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if start_date is None or end_date is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Please provide start_date and end_date"},
|
||||
)
|
||||
|
||||
try:
|
||||
sql_query, sql_params = _build_aggregated_sql_query(
|
||||
table_name=table_name,
|
||||
entity_id_field=entity_id_field,
|
||||
entity_id=entity_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
exclude_entity_ids=exclude_entity_ids,
|
||||
timezone_offset_minutes=timezone_offset_minutes,
|
||||
)
|
||||
|
||||
# Execute GROUP BY query — returns pre-aggregated dicts
|
||||
rows = await prisma_client.db.query_raw(sql_query, *sql_params)
|
||||
if rows is None:
|
||||
rows = []
|
||||
|
||||
# Convert dicts to objects for compatibility with _aggregate_spend_records
|
||||
records = [SimpleNamespace(**row) for row in rows]
|
||||
|
||||
# entity_id_field=None skips entity breakdown (entity dimension was
|
||||
# collapsed by the GROUP BY, so per-entity data is not available)
|
||||
aggregated = await _aggregate_spend_records(
|
||||
prisma_client=prisma_client,
|
||||
records=records,
|
||||
entity_id_field=None,
|
||||
entity_metadata_field=None,
|
||||
)
|
||||
|
||||
return SpendAnalyticsPaginatedResponse(
|
||||
results=aggregated["results"],
|
||||
metadata=DailySpendMetadata(
|
||||
total_spend=aggregated["totals"].spend,
|
||||
total_prompt_tokens=aggregated["totals"].prompt_tokens,
|
||||
total_completion_tokens=aggregated["totals"].completion_tokens,
|
||||
total_tokens=aggregated["totals"].total_tokens,
|
||||
total_api_requests=aggregated["totals"].api_requests,
|
||||
total_successful_requests=aggregated["totals"].successful_requests,
|
||||
total_failed_requests=aggregated["totals"].failed_requests,
|
||||
total_cache_read_input_tokens=aggregated[
|
||||
"totals"
|
||||
].cache_read_input_tokens,
|
||||
total_cache_creation_input_tokens=aggregated[
|
||||
"totals"
|
||||
].cache_creation_input_tokens,
|
||||
page=1,
|
||||
total_pages=1,
|
||||
has_more=False,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error fetching aggregated daily activity: {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Failed to fetch analytics: {str(e)}"},
|
||||
)
|
||||
@@ -0,0 +1,473 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import (
|
||||
KeyRequestBase,
|
||||
LiteLLM_ManagementEndpoint_MetadataFields,
|
||||
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
|
||||
LiteLLM_OrganizationTable,
|
||||
LiteLLM_ProjectTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
NewProjectRequest,
|
||||
UpdateProjectRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import NewProjectRequest, UpdateProjectRequest
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
||||
|
||||
def _user_has_admin_view(user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
return (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
|
||||
)
|
||||
|
||||
|
||||
def _is_user_team_admin(
|
||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||
) -> bool:
|
||||
for member in team_obj.members_with_roles:
|
||||
if (
|
||||
member.user_id is not None and member.user_id == user_api_key_dict.user_id
|
||||
) and member.role == "admin":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _is_user_org_admin_for_team(
|
||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user is an org admin for the team's organization.
|
||||
|
||||
Returns True if:
|
||||
- The team belongs to an organization, AND
|
||||
- The user has org_admin role in that organization
|
||||
"""
|
||||
if not team_obj.organization_id or not user_api_key_dict.user_id:
|
||||
return False
|
||||
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
caller_user = await get_user_object(
|
||||
user_id=user_api_key_dict.user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=False,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if caller_user is None:
|
||||
return False
|
||||
|
||||
for m in caller_user.organization_memberships or []:
|
||||
if (
|
||||
m.organization_id == team_obj.organization_id
|
||||
and m.user_role == LitellmUserRoles.ORG_ADMIN.value
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _team_member_has_permission(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
team_obj: LiteLLM_TeamTable,
|
||||
permission: str,
|
||||
) -> bool:
|
||||
"""Check if a non-admin team member has a specific permission on a team."""
|
||||
if not team_obj.team_member_permissions:
|
||||
return False
|
||||
if permission not in team_obj.team_member_permissions:
|
||||
return False
|
||||
for member in team_obj.members_with_roles:
|
||||
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _user_has_admin_privileges(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: Optional["PrismaClient"] = None,
|
||||
user_api_key_cache: Optional["DualCache"] = None,
|
||||
proxy_logging_obj: Optional["ProxyLogging"] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has admin privileges (proxy admin, team admin, or org admin).
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User API key authentication object
|
||||
prisma_client: Prisma client for database operations
|
||||
user_api_key_cache: Cache for user API keys
|
||||
proxy_logging_obj: Proxy logging object
|
||||
|
||||
Returns:
|
||||
True if user is proxy admin, team admin for any team, or org admin for any organization
|
||||
"""
|
||||
# Check if user is proxy admin
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN:
|
||||
return True
|
||||
|
||||
# If no database connection, can't check team/org admin status
|
||||
if prisma_client is None or user_api_key_dict.user_id is None:
|
||||
return False
|
||||
|
||||
# Get user object to check team and org admin status
|
||||
from litellm.caching import DualCache as DualCacheImport
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
|
||||
try:
|
||||
user_obj = await get_user_object(
|
||||
user_id=user_api_key_dict.user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache or DualCacheImport(),
|
||||
user_id_upsert=False,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if user_obj is None:
|
||||
return False
|
||||
|
||||
# Check if user is org admin for any organization
|
||||
if user_obj.organization_memberships is not None:
|
||||
for membership in user_obj.organization_memberships:
|
||||
if membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
|
||||
return True
|
||||
|
||||
# Check if user is team admin for any team
|
||||
if user_obj.teams is not None and len(user_obj.teams) > 0:
|
||||
# Get all teams user is in
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": user_obj.teams}}
|
||||
)
|
||||
|
||||
for team in teams:
|
||||
team_obj = LiteLLM_TeamTable(**team.model_dump())
|
||||
if _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||
):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# If there's an error checking, default to False for security
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error checking admin privileges for user {user_api_key_dict.user_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _org_admin_can_invite_user(
|
||||
admin_user_obj: LiteLLM_UserTable,
|
||||
target_user_obj: LiteLLM_UserTable,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an org admin can invite the target user.
|
||||
Target user must be in at least one org where the admin has org admin role.
|
||||
|
||||
Args:
|
||||
admin_user_obj: The admin user's full object (from get_user_object)
|
||||
target_user_obj: The target user's full object (from get_user_object)
|
||||
|
||||
Returns:
|
||||
True if target user is in an org where admin has org admin role
|
||||
"""
|
||||
if admin_user_obj.organization_memberships is None:
|
||||
return False
|
||||
admin_org_ids = {
|
||||
m.organization_id
|
||||
for m in admin_user_obj.organization_memberships
|
||||
if m.user_role == LitellmUserRoles.ORG_ADMIN.value
|
||||
}
|
||||
if not admin_org_ids:
|
||||
return False
|
||||
if target_user_obj.organization_memberships is None:
|
||||
return False
|
||||
target_org_ids = {
|
||||
m.organization_id for m in target_user_obj.organization_memberships
|
||||
}
|
||||
return bool(admin_org_ids & target_org_ids)
|
||||
|
||||
|
||||
async def _team_admin_can_invite_user(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
admin_user_obj: LiteLLM_UserTable,
|
||||
target_user_obj: LiteLLM_UserTable,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a team admin can invite the target user.
|
||||
Target user must be in at least one team where the admin has team admin role.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: The admin user's API key auth object
|
||||
admin_user_obj: The admin user's full object (from get_user_object)
|
||||
target_user_obj: The target user's full object (from get_user_object)
|
||||
prisma_client: Prisma client for database operations
|
||||
|
||||
Returns:
|
||||
True if target user is in a team where admin has team admin role
|
||||
"""
|
||||
if not admin_user_obj.teams or len(admin_user_obj.teams) == 0:
|
||||
return False
|
||||
if not target_user_obj.teams or len(target_user_obj.teams) == 0:
|
||||
return False
|
||||
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": admin_user_obj.teams}}
|
||||
)
|
||||
admin_team_ids = [
|
||||
team.team_id
|
||||
for team in teams
|
||||
if _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_obj=LiteLLM_TeamTable(**team.model_dump()),
|
||||
)
|
||||
]
|
||||
if not admin_team_ids:
|
||||
return False
|
||||
target_team_ids = set(target_user_obj.teams)
|
||||
return bool(set(admin_team_ids) & target_team_ids)
|
||||
|
||||
|
||||
async def admin_can_invite_user(
|
||||
target_user_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: Optional["PrismaClient"] = None,
|
||||
user_api_key_cache: Optional["DualCache"] = None,
|
||||
proxy_logging_obj: Optional["ProxyLogging"] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the admin can create an invitation for the target user.
|
||||
- Proxy admins: can invite any user
|
||||
- Org admins: can only invite users in their org(s)
|
||||
- Team admins: can only invite users in their team(s)
|
||||
|
||||
Uses get_user_object for caching of both admin and target user objects.
|
||||
|
||||
Args:
|
||||
target_user_id: The user_id of the user to invite
|
||||
user_api_key_dict: The admin user's API key auth object
|
||||
prisma_client: Prisma client for database operations
|
||||
user_api_key_cache: Cache for user API keys
|
||||
proxy_logging_obj: Proxy logging object
|
||||
|
||||
Returns:
|
||||
True if user can invite the target user
|
||||
"""
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN:
|
||||
return True
|
||||
|
||||
if prisma_client is None or user_api_key_dict.user_id is None:
|
||||
return False
|
||||
|
||||
from litellm.caching import DualCache as DualCacheImport
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
|
||||
try:
|
||||
cache = user_api_key_cache or DualCacheImport()
|
||||
admin_user_obj = await get_user_object(
|
||||
user_id=user_api_key_dict.user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=cache,
|
||||
user_id_upsert=False,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if admin_user_obj is None:
|
||||
return False
|
||||
|
||||
target_user_obj = await get_user_object(
|
||||
user_id=target_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=cache,
|
||||
user_id_upsert=False,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if target_user_obj is None:
|
||||
return False
|
||||
|
||||
if _org_admin_can_invite_user(admin_user_obj, target_user_obj):
|
||||
return True
|
||||
|
||||
if await _team_admin_can_invite_user(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
admin_user_obj=admin_user_obj,
|
||||
target_user_obj=target_user_obj,
|
||||
prisma_client=prisma_client,
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error checking invite permission for user {user_api_key_dict.user_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _set_object_metadata_field(
|
||||
object_data: Union[
|
||||
LiteLLM_TeamTable,
|
||||
KeyRequestBase,
|
||||
LiteLLM_OrganizationTable,
|
||||
LiteLLM_ProjectTable,
|
||||
"NewProjectRequest",
|
||||
"UpdateProjectRequest",
|
||||
],
|
||||
field_name: str,
|
||||
value: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to set metadata fields that require premium user checks
|
||||
|
||||
Args:
|
||||
object_data: The team/key/organization/project data object to modify
|
||||
field_name: Name of the metadata field to set
|
||||
value: Value to set for the field
|
||||
"""
|
||||
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
_premium_user_check(field_name)
|
||||
|
||||
object_data.metadata = object_data.metadata or {}
|
||||
object_data.metadata[field_name] = value
|
||||
|
||||
|
||||
async def _upsert_budget_and_membership(
|
||||
tx,
|
||||
*,
|
||||
team_id: str,
|
||||
user_id: str,
|
||||
max_budget: Optional[float],
|
||||
existing_budget_id: Optional[str],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
tpm_limit: Optional[int] = None,
|
||||
rpm_limit: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Helper function to Create/Update or Delete the budget within the team membership
|
||||
Args:
|
||||
tx: The transaction object
|
||||
team_id: The ID of the team
|
||||
user_id: The ID of the user
|
||||
max_budget: The maximum budget for the team
|
||||
existing_budget_id: The ID of the existing budget, if any
|
||||
user_api_key_dict: User API Key dictionary containing user information
|
||||
tpm_limit: Tokens per minute limit for the team member
|
||||
rpm_limit: Requests per minute limit for the team member
|
||||
|
||||
If max_budget, tpm_limit, and rpm_limit are all None, the user's budget is removed from the team membership.
|
||||
If any of these values exist, a budget is updated or created and linked to the team membership.
|
||||
"""
|
||||
if max_budget is None and tpm_limit is None and rpm_limit is None:
|
||||
# disconnect the budget since all limits are None
|
||||
await tx.litellm_teammembership.update(
|
||||
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}},
|
||||
data={"litellm_budget_table": {"disconnect": True}},
|
||||
)
|
||||
return
|
||||
|
||||
# create a new budget
|
||||
create_data: Dict[str, Any] = {
|
||||
"created_by": user_api_key_dict.user_id or "",
|
||||
"updated_by": user_api_key_dict.user_id or "",
|
||||
}
|
||||
if max_budget is not None:
|
||||
create_data["max_budget"] = max_budget
|
||||
if tpm_limit is not None:
|
||||
create_data["tpm_limit"] = tpm_limit
|
||||
if rpm_limit is not None:
|
||||
create_data["rpm_limit"] = rpm_limit
|
||||
|
||||
new_budget = await tx.litellm_budgettable.create(
|
||||
data=create_data,
|
||||
include={"team_membership": True},
|
||||
)
|
||||
# upsert the team membership with the new/updated budget
|
||||
await tx.litellm_teammembership.upsert(
|
||||
where={
|
||||
"user_id_team_id": {
|
||||
"user_id": user_id,
|
||||
"team_id": team_id,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": user_id,
|
||||
"team_id": team_id,
|
||||
"litellm_budget_table": {
|
||||
"connect": {"budget_id": new_budget.budget_id},
|
||||
},
|
||||
},
|
||||
"update": {
|
||||
"litellm_budget_table": {
|
||||
"connect": {"budget_id": new_budget.budget_id},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _update_metadata_field(updated_kv: dict, field_name: str) -> None:
|
||||
"""
|
||||
Helper function to update metadata fields that require premium user checks in the update endpoint
|
||||
|
||||
Args:
|
||||
updated_kv: The key-value dict being used for the update
|
||||
field_name: Name of the metadata field being updated
|
||||
"""
|
||||
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
value = updated_kv.get(field_name)
|
||||
# Skip the premium check for empty collections ([] or {}).
|
||||
# The UI sends these as defaults even when the user hasn't configured
|
||||
# any enterprise features (see issue #20304). However, we still
|
||||
# proceed with the update so that users can intentionally clear a
|
||||
# previously-set field by sending an empty list/dict.
|
||||
if value is not None and value != [] and value != {}:
|
||||
_premium_user_check()
|
||||
|
||||
if field_name in updated_kv and updated_kv[field_name] is not None:
|
||||
# remove field from updated_kv
|
||||
_value = updated_kv.pop(field_name)
|
||||
if "metadata" in updated_kv and updated_kv["metadata"] is not None:
|
||||
updated_kv["metadata"][field_name] = _value
|
||||
else:
|
||||
updated_kv["metadata"] = {field_name: _value}
|
||||
|
||||
|
||||
def _has_non_empty_value(value: Any) -> bool:
|
||||
"""Check if a value has real content (not None, not empty list, not blank string)."""
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, list) and len(value) == 0:
|
||||
return False
|
||||
if isinstance(value, str) and value.strip() == "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _update_metadata_fields(updated_kv: dict) -> None:
|
||||
"""
|
||||
Helper function to update all metadata fields (both premium and standard).
|
||||
|
||||
Args:
|
||||
updated_kv: The key-value dict being used for the update
|
||||
"""
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
if field in updated_kv and updated_kv[field] is not None:
|
||||
_update_metadata_field(updated_kv=updated_kv, field_name=field)
|
||||
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields:
|
||||
if field in updated_kv and updated_kv[field] is not None:
|
||||
_update_metadata_field(updated_kv=updated_kv, field_name=field)
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
COMPLIANCE CHECK ENDPOINTS
|
||||
|
||||
Endpoints for checking regulatory compliance of LLM request logs.
|
||||
|
||||
/compliance/eu-ai-act - Check EU AI Act compliance
|
||||
/compliance/gdpr - Check GDPR compliance
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.compliance_checks import ComplianceChecker
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.types.proxy.compliance_endpoints import (
|
||||
ComplianceCheckRequest,
|
||||
ComplianceResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/compliance/eu-ai-act",
|
||||
tags=["compliance"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ComplianceResponse,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def check_eu_ai_act_compliance(
|
||||
data: ComplianceCheckRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> ComplianceResponse:
|
||||
"""
|
||||
Check EU AI Act compliance for a spend log entry.
|
||||
|
||||
Checks:
|
||||
- Art. 9: Guardrails applied (any guardrail)
|
||||
- Art. 5: Content screened before LLM (pre-call guardrails)
|
||||
- Art. 12: Audit record complete (user_id, model, timestamp, guardrail_results)
|
||||
"""
|
||||
checker = ComplianceChecker(data)
|
||||
checks = checker.check_eu_ai_act()
|
||||
return ComplianceResponse(
|
||||
compliant=all(c.passed for c in checks),
|
||||
regulation="EU AI Act",
|
||||
checks=checks,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/compliance/gdpr",
|
||||
tags=["compliance"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ComplianceResponse,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def check_gdpr_compliance(
|
||||
data: ComplianceCheckRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> ComplianceResponse:
|
||||
"""
|
||||
Check GDPR compliance for a spend log entry.
|
||||
|
||||
Checks:
|
||||
- Art. 32: Data protection applied (pre-call guardrails)
|
||||
- Art. 5(1)(c): Sensitive data protected (masked/blocked or no issues)
|
||||
- Art. 30: Audit record complete (user_id, model, timestamp, guardrail_results)
|
||||
"""
|
||||
checker = ComplianceChecker(data)
|
||||
checks = checker.check_gdpr()
|
||||
return ComplianceResponse(
|
||||
compliant=all(c.passed for c in checks),
|
||||
regulation="GDPR",
|
||||
checks=checks,
|
||||
)
|
||||
@@ -0,0 +1,413 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
|
||||
try:
|
||||
from prisma.errors import RecordNotFoundError
|
||||
except ImportError:
|
||||
RecordNotFoundError = Exception # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
KeyManagementSystem,
|
||||
LitellmUserRoles,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.proxy.management_endpoints.config_overrides import (
|
||||
ConfigOverrideSettingsResponse,
|
||||
HashicorpVaultConfig,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- Hashicorp Vault constants ---
|
||||
|
||||
HASHICORP_ENV_VAR_MAPPING: Dict[str, str] = {
|
||||
"vault_addr": "HCP_VAULT_ADDR",
|
||||
"vault_token": "HCP_VAULT_TOKEN",
|
||||
"approle_role_id": "HCP_VAULT_APPROLE_ROLE_ID",
|
||||
"approle_secret_id": "HCP_VAULT_APPROLE_SECRET_ID",
|
||||
"approle_mount_path": "HCP_VAULT_APPROLE_MOUNT_PATH",
|
||||
"client_cert": "HCP_VAULT_CLIENT_CERT",
|
||||
"client_key": "HCP_VAULT_CLIENT_KEY",
|
||||
"vault_cert_role": "HCP_VAULT_CERT_ROLE",
|
||||
"vault_namespace": "HCP_VAULT_NAMESPACE",
|
||||
"vault_mount_name": "HCP_VAULT_MOUNT_NAME",
|
||||
"vault_path_prefix": "HCP_VAULT_PATH_PREFIX",
|
||||
}
|
||||
|
||||
HASHICORP_SENSITIVE_FIELDS: Set[str] = {
|
||||
"vault_token",
|
||||
"approle_secret_id",
|
||||
"client_key",
|
||||
}
|
||||
|
||||
_sensitive_masker = SensitiveDataMasker()
|
||||
|
||||
|
||||
# --- Shared helpers ---
|
||||
|
||||
|
||||
def _mask_sensitive_fields(
|
||||
data: Dict[str, Any], sensitive_fields: Set[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Mask sensitive fields for API responses. Non-sensitive fields are left as-is."""
|
||||
masked = {}
|
||||
for key, value in data.items():
|
||||
if value is not None and key in sensitive_fields and isinstance(value, str):
|
||||
masked[key] = _sensitive_masker._mask_value(value)
|
||||
else:
|
||||
masked[key] = value
|
||||
return masked
|
||||
|
||||
|
||||
def _get_current_env_values(env_var_mapping: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Read current env var values as fallback when no DB record exists."""
|
||||
values = {}
|
||||
for field_name, env_var_name in env_var_mapping.items():
|
||||
env_value = os.environ.get(env_var_name)
|
||||
values[field_name] = env_value
|
||||
return values
|
||||
|
||||
|
||||
def _extract_field_type(field_info: Dict[str, Any]) -> str:
|
||||
"""Extract the non-null type from a Pydantic v2 JSON schema field."""
|
||||
if "type" in field_info:
|
||||
return field_info["type"]
|
||||
for option in field_info.get("anyOf", []):
|
||||
if option.get("type") != "null":
|
||||
return option.get("type", "string")
|
||||
return "string"
|
||||
|
||||
|
||||
def _build_field_schema(model_class: type) -> Dict[str, Any]:
|
||||
"""Build field_schema dict from a Pydantic model for UI rendering."""
|
||||
schema = TypeAdapter(model_class).json_schema(by_alias=True)
|
||||
properties = {}
|
||||
for field_name, field_info in schema.get("properties", {}).items():
|
||||
properties[field_name] = {
|
||||
"description": field_info.get("description", ""),
|
||||
"type": _extract_field_type(field_info),
|
||||
}
|
||||
return {
|
||||
"description": schema.get("description", ""),
|
||||
"properties": properties,
|
||||
}
|
||||
|
||||
|
||||
def _parse_config_value(raw: Any) -> Dict[str, Any]:
|
||||
"""Parse a config_value from DB (may be JSON string or dict)."""
|
||||
if isinstance(raw, str):
|
||||
return safe_json_loads(raw, default={})
|
||||
return dict(raw)
|
||||
|
||||
|
||||
def _set_env_vars(config_data: Dict[str, Any]) -> None:
|
||||
"""Set HCP_VAULT_* env vars from config data. Unsets vars for missing/None/empty fields."""
|
||||
for field_name, env_var_name in HASHICORP_ENV_VAR_MAPPING.items():
|
||||
value = config_data.get(field_name)
|
||||
if value is not None and value != "":
|
||||
os.environ[env_var_name] = str(value)
|
||||
else:
|
||||
os.environ.pop(env_var_name, None)
|
||||
|
||||
|
||||
def _clear_hashicorp_vault_state(proxy_config: Any) -> None:
|
||||
"""Clear all Hashicorp Vault state: env vars, secret manager, and change-detection cache."""
|
||||
_set_env_vars({})
|
||||
if litellm._key_management_system == KeyManagementSystem.HASHICORP_VAULT:
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_system = None
|
||||
proxy_config._last_hashicorp_vault_config = None
|
||||
|
||||
|
||||
# --- Hashicorp Vault endpoints ---
|
||||
|
||||
|
||||
@router.post(
|
||||
"/config_overrides/hashicorp_vault",
|
||||
tags=["Config Overrides"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_hashicorp_vault_config(
|
||||
config: HashicorpVaultConfig,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update Hashicorp Vault secret manager configuration.
|
||||
Sets environment variables, encrypts sensitive fields, and stores in DB.
|
||||
Reinitializes the secret manager on this pod.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client, proxy_config
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only admin users can update config overrides",
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=CommonProxyErrors.db_not_connected_error.value,
|
||||
)
|
||||
|
||||
config_data = config.model_dump(exclude_none=True)
|
||||
|
||||
# Merge ALL fields the user didn't send: try DB first, fall back to env vars.
|
||||
# Omitted field = keep existing; empty string = clear/remove the field.
|
||||
existing_record = await prisma_client.db.litellm_configoverrides.find_unique(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
if existing_record is not None and existing_record.config_value is not None:
|
||||
existing_data = _parse_config_value(existing_record.config_value)
|
||||
existing_decrypted = proxy_config._decrypt_db_variables(existing_data)
|
||||
for field in HASHICORP_ENV_VAR_MAPPING:
|
||||
if field not in config_data and existing_decrypted.get(field):
|
||||
config_data[field] = existing_decrypted[field]
|
||||
else:
|
||||
# No DB record yet — merge from current env vars
|
||||
env_values = _get_current_env_values(HASHICORP_ENV_VAR_MAPPING)
|
||||
for field in HASHICORP_ENV_VAR_MAPPING:
|
||||
if field not in config_data and env_values.get(field):
|
||||
config_data[field] = env_values[field]
|
||||
|
||||
# Strip empty strings — they signal "clear this field"
|
||||
config_data = {k: v for k, v in config_data.items() if v != ""}
|
||||
|
||||
# Validate that the config has enough fields to initialize
|
||||
has_vault_addr = bool(config_data.get("vault_addr"))
|
||||
has_token_auth = bool(config_data.get("vault_token"))
|
||||
has_approle_auth = bool(
|
||||
config_data.get("approle_role_id") and config_data.get("approle_secret_id")
|
||||
)
|
||||
has_tls_cert_auth = bool(
|
||||
config_data.get("client_cert") and config_data.get("client_key")
|
||||
)
|
||||
|
||||
if not has_vault_addr:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Vault Address is required",
|
||||
)
|
||||
|
||||
if not has_token_auth and not has_approle_auth and not has_tls_cert_auth:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one authentication method is required: "
|
||||
"provide a Token, both AppRole Role ID and Secret ID, "
|
||||
"or both Client Certificate and Client Key",
|
||||
)
|
||||
|
||||
# Snapshot current env vars so we can restore on failure
|
||||
previous_env = _get_current_env_values(HASHICORP_ENV_VAR_MAPPING)
|
||||
|
||||
# Set env vars and verify the secret manager can initialize before persisting
|
||||
_set_env_vars(config_data)
|
||||
|
||||
try:
|
||||
proxy_config.initialize_secret_manager(key_management_system="hashicorp_vault")
|
||||
except Exception as e:
|
||||
_set_env_vars(previous_env)
|
||||
verbose_proxy_logger.exception(
|
||||
"Error reinitializing Hashicorp Vault secret manager: %s", str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to initialize secret manager: {e}",
|
||||
)
|
||||
|
||||
# Only persist to DB after successful init
|
||||
encrypted_data = proxy_config._encrypt_env_variables(config_data)
|
||||
config_value = safe_dumps(encrypted_data)
|
||||
await prisma_client.db.litellm_configoverrides.upsert(
|
||||
where={"config_type": "hashicorp_vault"},
|
||||
data={
|
||||
"create": {
|
||||
"config_type": "hashicorp_vault",
|
||||
"config_value": config_value,
|
||||
},
|
||||
"update": {
|
||||
"config_value": config_value,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Update change-detection cache so the background reload doesn't redundantly re-init
|
||||
proxy_config._last_hashicorp_vault_config = safe_json_loads(config_value)
|
||||
|
||||
return {
|
||||
"message": "Hashicorp Vault configuration updated successfully",
|
||||
"status": "success",
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/config_overrides/hashicorp_vault",
|
||||
tags=["Config Overrides"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ConfigOverrideSettingsResponse,
|
||||
)
|
||||
async def get_hashicorp_vault_config(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get current Hashicorp Vault configuration.
|
||||
Returns decrypted values from DB, or falls back to current env vars.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client, proxy_config
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only admin users can view config overrides",
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=CommonProxyErrors.db_not_connected_error.value,
|
||||
)
|
||||
|
||||
field_schema = _build_field_schema(HashicorpVaultConfig)
|
||||
|
||||
# Try to load from DB
|
||||
db_record = await prisma_client.db.litellm_configoverrides.find_unique(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
|
||||
if db_record is not None and db_record.config_value is not None:
|
||||
config_data = _parse_config_value(db_record.config_value)
|
||||
|
||||
# Decrypt then mask sensitive fields so plaintext secrets are never sent to the UI
|
||||
decrypted_data = proxy_config._decrypt_db_variables(config_data)
|
||||
masked_data = _mask_sensitive_fields(decrypted_data, HASHICORP_SENSITIVE_FIELDS)
|
||||
|
||||
return ConfigOverrideSettingsResponse(
|
||||
config_type="hashicorp_vault",
|
||||
values=masked_data,
|
||||
field_schema=field_schema,
|
||||
)
|
||||
|
||||
# Fallback to env vars — also mask sensitive values
|
||||
env_values = _get_current_env_values(HASHICORP_ENV_VAR_MAPPING)
|
||||
masked_env_values = _mask_sensitive_fields(env_values, HASHICORP_SENSITIVE_FIELDS)
|
||||
|
||||
return ConfigOverrideSettingsResponse(
|
||||
config_type="hashicorp_vault",
|
||||
values=masked_env_values,
|
||||
field_schema=field_schema,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/config_overrides/hashicorp_vault",
|
||||
tags=["Config Overrides"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_hashicorp_vault_config(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""Delete Hashicorp Vault configuration. Idempotent."""
|
||||
from litellm.proxy.proxy_server import prisma_client, proxy_config
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only admin users can delete config overrides",
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=CommonProxyErrors.db_not_connected_error.value,
|
||||
)
|
||||
|
||||
# Delete DB record if it exists — ignore if not found
|
||||
try:
|
||||
await prisma_client.db.litellm_configoverrides.delete(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
except RecordNotFoundError:
|
||||
verbose_proxy_logger.debug(
|
||||
"No existing Hashicorp Vault config record to delete"
|
||||
)
|
||||
|
||||
_clear_hashicorp_vault_state(proxy_config)
|
||||
|
||||
return {
|
||||
"message": "Hashicorp Vault configuration deleted successfully",
|
||||
"status": "success",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/config_overrides/hashicorp_vault/test_connection",
|
||||
tags=["Config Overrides"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def test_hashicorp_vault_connection(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Test the connection to the currently configured Hashicorp Vault.
|
||||
Uses the already-initialized secret manager client. Does not modify any state.
|
||||
"""
|
||||
from litellm.secret_managers.hashicorp_secret_manager import (
|
||||
HashicorpSecretManager,
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only admin users can test Vault connection",
|
||||
)
|
||||
|
||||
client = litellm.secret_manager_client
|
||||
if not isinstance(client, HashicorpSecretManager):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Hashicorp Vault is not configured. Save a configuration first.",
|
||||
)
|
||||
|
||||
# Step 1: Authenticate (exercises AppRole login, TLS cert login, or direct token)
|
||||
try:
|
||||
headers = await asyncio.to_thread(client._get_request_headers)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Vault authentication failed: {e}",
|
||||
)
|
||||
|
||||
# Step 2: Verify the token is valid via token/lookup-self
|
||||
try:
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager
|
||||
)
|
||||
lookup_url = f"{client.vault_addr}/v1/auth/token/lookup-self"
|
||||
if client.vault_namespace:
|
||||
headers["X-Vault-Namespace"] = client.vault_namespace
|
||||
response = await async_client.get(lookup_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Vault token validation failed: {e}",
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Successfully connected to Vault at {client.vault_addr}",
|
||||
}
|
||||
@@ -0,0 +1,588 @@
|
||||
"""
|
||||
COST TRACKING SETTINGS MANAGEMENT
|
||||
|
||||
Endpoints for managing cost discount and margin configuration
|
||||
|
||||
GET /config/cost_discount_config - Get current cost discount configuration
|
||||
PATCH /config/cost_discount_config - Update cost discount configuration
|
||||
GET /config/cost_margin_config - Get current cost margin configuration
|
||||
PATCH /config/cost_margin_config - Update cost margin configuration
|
||||
POST /cost/estimate - Estimate cost for a given model and token counts
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.cost_calculator import completion_cost
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
CostEstimateRequest,
|
||||
CostEstimateResponse,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.utils import LlmProvidersSet
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_model_for_cost_lookup(model: str) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
Resolve a model name (which may be a router alias/model_group) to the
|
||||
underlying litellm model name for cost lookup.
|
||||
|
||||
Args:
|
||||
model: The model name from the request (could be a router alias like 'e-model-router'
|
||||
or an actual model name like 'azure_ai/gpt-4')
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_model_name, custom_llm_provider)
|
||||
- resolved_model_name: The actual model name to use for cost lookup
|
||||
- custom_llm_provider: The provider if resolved from router, None otherwise
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
custom_llm_provider: Optional[str] = None
|
||||
|
||||
# Try to resolve from router if available
|
||||
if llm_router is not None:
|
||||
try:
|
||||
# Get deployments for this model name (handles aliases, wildcards, etc.)
|
||||
deployments = llm_router.get_model_list(model_name=model)
|
||||
|
||||
if deployments and len(deployments) > 0:
|
||||
first_deployment = deployments[0]
|
||||
litellm_params = first_deployment.get("litellm_params", {})
|
||||
model_info = first_deployment.get("model_info", {})
|
||||
|
||||
# Check base_model first (needed for Azure custom deployment names)
|
||||
base_model = model_info.get("base_model") or litellm_params.get(
|
||||
"base_model"
|
||||
)
|
||||
if base_model:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Resolved model '{model}' to base_model '{base_model}' from router"
|
||||
)
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
return (
|
||||
str(base_model),
|
||||
str(custom_llm_provider)
|
||||
if custom_llm_provider is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
resolved_model = litellm_params.get("model")
|
||||
|
||||
if resolved_model:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Resolved model '{model}' to '{resolved_model}' from router"
|
||||
)
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
return (
|
||||
str(resolved_model),
|
||||
str(custom_llm_provider)
|
||||
if custom_llm_provider is not None
|
||||
else None,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Could not resolve model '{model}' from router: {e}"
|
||||
)
|
||||
|
||||
# Return original model if not resolved
|
||||
return model, custom_llm_provider
|
||||
|
||||
|
||||
def _calculate_period_costs(
|
||||
num_requests, cost_per_request, input_cost, output_cost, margin_cost
|
||||
):
|
||||
"""
|
||||
Calculate costs for a given number of requests.
|
||||
|
||||
Returns tuple of (total_cost, input_cost, output_cost, margin_cost) or all None if num_requests is None/0.
|
||||
"""
|
||||
if not num_requests:
|
||||
return None, None, None, None
|
||||
return (
|
||||
cost_per_request * num_requests,
|
||||
input_cost * num_requests,
|
||||
output_cost * num_requests,
|
||||
margin_cost * num_requests,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/config/cost_discount_config",
|
||||
tags=["Cost Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_cost_discount_config(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get current cost discount configuration.
|
||||
|
||||
Returns the cost_discount_config from litellm_settings.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client, proxy_config
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config from DB
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# Get cost_discount_config from litellm_settings
|
||||
litellm_settings = config.get("litellm_settings", {})
|
||||
cost_discount_config = litellm_settings.get("cost_discount_config", {})
|
||||
|
||||
return {"values": cost_discount_config}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error fetching cost discount config: {str(e)}")
|
||||
return {"values": {}}
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/config/cost_discount_config",
|
||||
tags=["Cost Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_cost_discount_config(
|
||||
cost_discount_config: Dict[str, float],
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update cost discount configuration.
|
||||
|
||||
Updates the cost_discount_config in litellm_settings.
|
||||
Discounts should be between 0 and 1 (e.g., 0.05 = 5% discount).
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"vertex_ai": 0.05,
|
||||
"gemini": 0.05,
|
||||
"openai": 0.01
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if store_model_in_db is not True:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
||||
},
|
||||
)
|
||||
|
||||
# Validate that all providers are valid LiteLLM providers
|
||||
invalid_providers = []
|
||||
for provider in cost_discount_config.keys():
|
||||
if provider not in LlmProvidersSet:
|
||||
invalid_providers.append(provider)
|
||||
|
||||
if invalid_providers:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Invalid provider(s): {', '.join(invalid_providers)}. Must be valid LiteLLM providers. See https://docs.litellm.ai/docs/providers for the full list."
|
||||
},
|
||||
)
|
||||
|
||||
# Validate discount values are between 0 and 1
|
||||
for provider, discount in cost_discount_config.items():
|
||||
if not isinstance(discount, (int, float)):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Discount for {provider} must be a number"
|
||||
)
|
||||
if not (0 <= discount <= 1):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Discount for {provider} must be between 0 and 1 (0% to 100%)",
|
||||
)
|
||||
|
||||
try:
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# Ensure litellm_settings exists
|
||||
if "litellm_settings" not in config:
|
||||
config["litellm_settings"] = {}
|
||||
|
||||
# Update cost_discount_config
|
||||
config["litellm_settings"]["cost_discount_config"] = cost_discount_config
|
||||
|
||||
# Save the updated config to DB
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
# Update in-memory litellm.cost_discount_config
|
||||
litellm.cost_discount_config = cost_discount_config
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Updated cost_discount_config: {cost_discount_config}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Cost discount configuration updated successfully",
|
||||
"status": "success",
|
||||
"values": cost_discount_config,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error updating cost discount config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to update cost discount config: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/config/cost_margin_config",
|
||||
tags=["Cost Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_cost_margin_config(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get current cost margin configuration.
|
||||
|
||||
Returns the cost_margin_config from litellm_settings.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client, proxy_config
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config from DB
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# Get cost_margin_config from litellm_settings
|
||||
litellm_settings = config.get("litellm_settings", {})
|
||||
cost_margin_config = litellm_settings.get("cost_margin_config", {})
|
||||
|
||||
return {"values": cost_margin_config}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error fetching cost margin config: {str(e)}")
|
||||
return {"values": {}}
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/config/cost_margin_config",
|
||||
tags=["Cost Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_cost_margin_config(
|
||||
cost_margin_config: Dict[str, Union[float, Dict[str, float]]],
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update cost margin configuration.
|
||||
|
||||
Updates the cost_margin_config in litellm_settings.
|
||||
Margins can be:
|
||||
- Percentage: {"openai": 0.10} = 10% margin
|
||||
- Fixed amount: {"openai": {"fixed_amount": 0.001}} = $0.001 per request
|
||||
- Combined: {"vertex_ai": {"percentage": 0.08, "fixed_amount": 0.0005}}
|
||||
- Global: {"global": 0.05} = 5% global margin on all providers
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"global": 0.05,
|
||||
"openai": 0.10,
|
||||
"anthropic": {"fixed_amount": 0.001},
|
||||
"vertex_ai": {"percentage": 0.08, "fixed_amount": 0.0005}
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if store_model_in_db is not True:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
||||
},
|
||||
)
|
||||
|
||||
# Validate that all providers are valid LiteLLM providers (except "global")
|
||||
invalid_providers = []
|
||||
for provider in cost_margin_config.keys():
|
||||
if provider != "global" and provider not in LlmProvidersSet:
|
||||
invalid_providers.append(provider)
|
||||
|
||||
if invalid_providers:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Invalid provider(s): {', '.join(invalid_providers)}. Must be valid LiteLLM providers or 'global'. See https://docs.litellm.ai/docs/providers for the full list."
|
||||
},
|
||||
)
|
||||
|
||||
# Validate margin values
|
||||
for provider, margin_value in cost_margin_config.items():
|
||||
if isinstance(margin_value, (int, float)):
|
||||
# Simple percentage format: {"openai": 0.10}
|
||||
if not (0 <= margin_value <= 10): # Allow up to 1000% margin
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Margin percentage for {provider} must be between 0 and 10 (0% to 1000%)",
|
||||
)
|
||||
elif isinstance(margin_value, dict):
|
||||
# Complex format: {"percentage": 0.08, "fixed_amount": 0.0005}
|
||||
if "percentage" in margin_value:
|
||||
percentage = margin_value["percentage"]
|
||||
if not isinstance(percentage, (int, float)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Margin percentage for {provider} must be a number",
|
||||
)
|
||||
if not (0 <= percentage <= 10):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Margin percentage for {provider} must be between 0 and 10 (0% to 1000%)",
|
||||
)
|
||||
if "fixed_amount" in margin_value:
|
||||
fixed_amount = margin_value["fixed_amount"]
|
||||
if not isinstance(fixed_amount, (int, float)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Fixed margin amount for {provider} must be a number",
|
||||
)
|
||||
if fixed_amount < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Fixed margin amount for {provider} must be non-negative",
|
||||
)
|
||||
if not margin_value: # Empty dict
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Margin config for {provider} cannot be empty. Must include 'percentage' and/or 'fixed_amount'",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Margin for {provider} must be a number (percentage) or dict with 'percentage' and/or 'fixed_amount'",
|
||||
)
|
||||
|
||||
try:
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# Ensure litellm_settings exists
|
||||
if "litellm_settings" not in config:
|
||||
config["litellm_settings"] = {}
|
||||
|
||||
# Update cost_margin_config
|
||||
config["litellm_settings"]["cost_margin_config"] = cost_margin_config
|
||||
|
||||
# Save the updated config to DB
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
# Update in-memory litellm.cost_margin_config
|
||||
litellm.cost_margin_config = cost_margin_config
|
||||
|
||||
verbose_proxy_logger.info(f"Updated cost_margin_config: {cost_margin_config}")
|
||||
|
||||
return {
|
||||
"message": "Cost margin configuration updated successfully",
|
||||
"status": "success",
|
||||
"values": cost_margin_config,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error updating cost margin config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to update cost margin config: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cost/estimate",
|
||||
tags=["Cost Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CostEstimateResponse,
|
||||
)
|
||||
async def estimate_cost(
|
||||
request: CostEstimateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> CostEstimateResponse:
|
||||
"""
|
||||
Estimate cost for a given model and token counts.
|
||||
|
||||
This endpoint uses the same cost calculation logic as actual requests,
|
||||
including any configured margins and discounts.
|
||||
|
||||
Parameters:
|
||||
- model: Model name (e.g., "gpt-4", "claude-3-opus")
|
||||
- input_tokens: Expected input tokens per request
|
||||
- output_tokens: Expected output tokens per request
|
||||
- num_requests_per_day: Number of requests per day (optional)
|
||||
- num_requests_per_month: Number of requests per month (optional)
|
||||
|
||||
Returns cost breakdown including:
|
||||
- Per-request costs (input, output, margin)
|
||||
- Daily costs (if num_requests_per_day provided)
|
||||
- Monthly costs (if num_requests_per_month provided)
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 500,
|
||||
"num_requests_per_day": 100,
|
||||
"num_requests_per_month": 3000
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
# Resolve model name (handles router aliases like 'e-model-router' -> 'azure_ai/gpt-4')
|
||||
resolved_model, resolved_provider = _resolve_model_for_cost_lookup(request.model)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Cost estimate: request.model='{request.model}' resolved to '{resolved_model}'"
|
||||
)
|
||||
|
||||
# Create a mock response with usage for completion_cost
|
||||
mock_response = ModelResponse(
|
||||
model=resolved_model,
|
||||
usage=Usage(
|
||||
prompt_tokens=request.input_tokens,
|
||||
completion_tokens=request.output_tokens,
|
||||
total_tokens=request.input_tokens + request.output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
# Create a logging object to capture cost breakdown
|
||||
litellm_logging_obj = LiteLLMLoggingObj(
|
||||
model=resolved_model,
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="completion",
|
||||
start_time=None,
|
||||
litellm_call_id="cost-estimate",
|
||||
function_id="cost-estimate",
|
||||
)
|
||||
|
||||
# Use completion_cost which handles all the logic including margins/discounts
|
||||
try:
|
||||
cost_per_request = completion_cost(
|
||||
completion_response=mock_response,
|
||||
model=resolved_model,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Could not calculate cost for model '{request.model}' (resolved to '{resolved_model}'): {str(e)}"
|
||||
},
|
||||
)
|
||||
|
||||
# Get cost breakdown from the logging object
|
||||
cost_breakdown = litellm_logging_obj.cost_breakdown
|
||||
|
||||
input_cost = cost_breakdown.get("input_cost", 0.0) if cost_breakdown else 0.0
|
||||
output_cost = cost_breakdown.get("output_cost", 0.0) if cost_breakdown else 0.0
|
||||
margin_cost = (
|
||||
cost_breakdown.get("margin_total_amount", 0.0) if cost_breakdown else 0.0
|
||||
)
|
||||
|
||||
# Get model info for per-token pricing display
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=resolved_model)
|
||||
input_cost_per_token = model_info.get("input_cost_per_token")
|
||||
output_cost_per_token = model_info.get("output_cost_per_token")
|
||||
custom_llm_provider = model_info.get("litellm_provider")
|
||||
except Exception:
|
||||
input_cost_per_token = None
|
||||
output_cost_per_token = None
|
||||
custom_llm_provider = None
|
||||
|
||||
# Use provider from router resolution if not found in model_info
|
||||
if custom_llm_provider is None and resolved_provider is not None:
|
||||
custom_llm_provider = resolved_provider
|
||||
|
||||
# Calculate daily and monthly costs
|
||||
(
|
||||
daily_cost,
|
||||
daily_input_cost,
|
||||
daily_output_cost,
|
||||
daily_margin_cost,
|
||||
) = _calculate_period_costs(
|
||||
num_requests=request.num_requests_per_day,
|
||||
cost_per_request=cost_per_request,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
margin_cost=margin_cost,
|
||||
)
|
||||
(
|
||||
monthly_cost,
|
||||
monthly_input_cost,
|
||||
monthly_output_cost,
|
||||
monthly_margin_cost,
|
||||
) = _calculate_period_costs(
|
||||
num_requests=request.num_requests_per_month,
|
||||
cost_per_request=cost_per_request,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
margin_cost=margin_cost,
|
||||
)
|
||||
|
||||
return CostEstimateResponse(
|
||||
model=request.model,
|
||||
input_tokens=request.input_tokens,
|
||||
output_tokens=request.output_tokens,
|
||||
num_requests_per_day=request.num_requests_per_day,
|
||||
num_requests_per_month=request.num_requests_per_month,
|
||||
cost_per_request=cost_per_request,
|
||||
input_cost_per_request=input_cost,
|
||||
output_cost_per_request=output_cost,
|
||||
margin_cost_per_request=margin_cost,
|
||||
daily_cost=daily_cost,
|
||||
daily_input_cost=daily_input_cost,
|
||||
daily_output_cost=daily_output_cost,
|
||||
daily_margin_cost=daily_margin_cost,
|
||||
monthly_cost=monthly_cost,
|
||||
monthly_input_cost=monthly_input_cost,
|
||||
monthly_output_cost=monthly_output_cost,
|
||||
monthly_margin_cost=monthly_margin_cost,
|
||||
input_cost_per_token=input_cost_per_token,
|
||||
output_cost_per_token=output_cost_per_token,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
@@ -0,0 +1,925 @@
|
||||
"""
|
||||
CUSTOMER MANAGEMENT
|
||||
|
||||
All /customer management endpoints
|
||||
|
||||
/customer/new
|
||||
/customer/info
|
||||
/customer/update
|
||||
/customer/delete
|
||||
"""
|
||||
|
||||
#### END-USER/CUSTOMER MANAGEMENT ####
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||
from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
_set_object_permission,
|
||||
handle_update_object_permission_common,
|
||||
)
|
||||
from litellm.proxy.utils import handle_exception_on_proxy
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/end_user/block",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.post(
|
||||
"/customer/block",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def block_user(data: BlockUsers):
|
||||
"""
|
||||
[BETA] Reject calls with this end-user id
|
||||
|
||||
Parameters:
|
||||
- user_ids (List[str], required): The unique `user_id`s for the users to block
|
||||
|
||||
(any /chat/completion call with this user={end-user-id} param, will be rejected.)
|
||||
|
||||
```
|
||||
curl -X POST "http://0.0.0.0:8000/user/block"
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
-d '{
|
||||
"user_ids": [<user_id>, ...]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
records = []
|
||||
if prisma_client is not None:
|
||||
for id in data.user_ids:
|
||||
record = await prisma_client.db.litellm_endusertable.upsert(
|
||||
where={"user_id": id}, # type: ignore
|
||||
data={
|
||||
"create": {"user_id": id, "blocked": True}, # type: ignore
|
||||
"update": {"blocked": True},
|
||||
},
|
||||
)
|
||||
records.append(record)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Postgres DB Not connected"},
|
||||
)
|
||||
|
||||
return {"blocked_users": records}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"An error occurred - {str(e)}")
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/end_user/unblock",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.post(
|
||||
"/customer/unblock",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def unblock_user(data: BlockUsers):
|
||||
"""
|
||||
[BETA] Unblock calls with this user id
|
||||
|
||||
Example
|
||||
```
|
||||
curl -X POST "http://0.0.0.0:8000/user/unblock"
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
-d '{
|
||||
"user_ids": [<user_id>, ...]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
try:
|
||||
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||
_ENTERPRISE_BlockedUserList,
|
||||
)
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Blocked user check was never set. This call has no effect."
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
},
|
||||
)
|
||||
|
||||
if (
|
||||
not any(isinstance(x, _ENTERPRISE_BlockedUserList) for x in litellm.callbacks)
|
||||
or litellm.blocked_user_list is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Blocked user check was never set. This call has no effect."
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(litellm.blocked_user_list, list):
|
||||
for id in data.user_ids:
|
||||
litellm.blocked_user_list.remove(id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "`blocked_user_list` must be set as a list. Filepaths can't be updated."
|
||||
},
|
||||
)
|
||||
|
||||
return {"blocked_users": litellm.blocked_user_list}
|
||||
|
||||
|
||||
def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]:
|
||||
"""
|
||||
Return a new budget object if new budget params are passed.
|
||||
"""
|
||||
budget_params = BudgetNewRequest.model_fields.keys()
|
||||
budget_kv_pairs = {}
|
||||
|
||||
# Get the actual values from the data object using getattr
|
||||
for field_name in budget_params:
|
||||
if field_name == "budget_id":
|
||||
continue
|
||||
value = getattr(data, field_name, None)
|
||||
if value is not None:
|
||||
budget_kv_pairs[field_name] = value
|
||||
|
||||
if budget_kv_pairs:
|
||||
budget_request = BudgetNewRequest(**budget_kv_pairs)
|
||||
if (
|
||||
budget_request.budget_reset_at is None
|
||||
and budget_request.budget_duration is not None
|
||||
):
|
||||
budget_request.budget_reset_at = datetime.utcnow() + timedelta(
|
||||
seconds=duration_in_seconds(duration=budget_request.budget_duration)
|
||||
)
|
||||
return budget_request
|
||||
return None
|
||||
|
||||
|
||||
async def _handle_customer_object_permission_update(
|
||||
non_default_values: dict,
|
||||
end_user_table_data_typed: Optional[LiteLLM_EndUserTable],
|
||||
update_end_user_table_data: dict,
|
||||
prisma_client,
|
||||
) -> None:
|
||||
"""
|
||||
Handle object permission updates for customer endpoints.
|
||||
|
||||
Updates the update_end_user_table_data dict in place with the new object_permission_id.
|
||||
|
||||
Args:
|
||||
non_default_values: Dictionary containing the update values including object_permission
|
||||
end_user_table_data_typed: Existing end user table data
|
||||
update_end_user_table_data: Dictionary to update with new object_permission_id
|
||||
prisma_client: Prisma database client
|
||||
"""
|
||||
if "object_permission" in non_default_values:
|
||||
existing_object_permission_id = (
|
||||
end_user_table_data_typed.object_permission_id
|
||||
if end_user_table_data_typed is not None
|
||||
else None
|
||||
)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
data_json=non_default_values,
|
||||
existing_object_permission_id=existing_object_permission_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_end_user_table_data["object_permission_id"] = object_permission_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/end_user/new",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/customer/new",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_end_user(
|
||||
data: NewCustomerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Allow creating a new Customer
|
||||
|
||||
|
||||
Parameters:
|
||||
- user_id: str - The unique identifier for the user.
|
||||
- alias: Optional[str] - A human-friendly alias for the user.
|
||||
- blocked: bool - Flag to allow or disallow requests for this end-user. Default is False.
|
||||
- max_budget: Optional[float] - The maximum budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
|
||||
- budget_id: Optional[str] - The identifier for an existing budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
|
||||
- allowed_model_region: Optional[Union[Literal["eu"], Literal["us"]]] - Require all user requests to use models in this specific region.
|
||||
- default_model: Optional[str] - If no equivalent model in the allowed region, default all requests to this model.
|
||||
- metadata: Optional[dict] = Metadata for customer, store information for customer. Example metadata = {"data_training_opt_out": True}
|
||||
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
- tpm_limit: Optional[int] - [Not Implemented Yet] Specify tpm limit for a given customer (Tokens per minute)
|
||||
- rpm_limit: Optional[int] - [Not Implemented Yet] Specify rpm limit for a given customer (Requests per minute)
|
||||
- model_max_budget: Optional[dict] - [Not Implemented Yet] Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d"}}
|
||||
- max_parallel_requests: Optional[int] - [Not Implemented Yet] Specify max parallel requests for a given customer.
|
||||
- soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests.
|
||||
- spend: Optional[float] - Specify initial spend for a given customer.
|
||||
- budget_reset_at: Optional[str] - Specify the date and time when the budget should be reset.
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
|
||||
Supported fields:
|
||||
* mcp_servers: List[str] - List of allowed MCP server IDs
|
||||
* mcp_access_groups: List[str] - List of MCP access group names
|
||||
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names (e.g., {"server_1": ["tool_a", "tool_b"]})
|
||||
* vector_stores: List[str] - List of allowed vector store IDs
|
||||
* agents: List[str] - List of allowed agent IDs
|
||||
* agent_access_groups: List[str] - List of agent access group names
|
||||
Example: {"mcp_servers": ["server_1", "server_2"], "vector_stores": ["vector_store_1"], "agents": ["agent_1"]}
|
||||
IF null or {} then no object-level restrictions apply.
|
||||
|
||||
|
||||
- Allow specifying allowed regions
|
||||
- Allow specifying default model
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/customer/new' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id" : "ishaan-jaff-3",
|
||||
"allowed_region": "eu",
|
||||
"budget_id": "free_tier",
|
||||
"default_model": "azure/gpt-3.5-turbo-eu"
|
||||
}'
|
||||
|
||||
# With object permissions
|
||||
curl -L -X POST 'http://localhost:4000/customer/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "user_1",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_1"],
|
||||
"mcp_access_groups": ["public_group"],
|
||||
"vector_stores": ["vector_store_1"]
|
||||
}
|
||||
}'
|
||||
|
||||
# return end-user object
|
||||
```
|
||||
|
||||
NOTE: This used to be called `/end_user/new`, we will still be maintaining compatibility for /end_user/XXX for these endpoints
|
||||
"""
|
||||
"""
|
||||
Validation:
|
||||
- check if default model exists
|
||||
- create budget object if not already created
|
||||
|
||||
- Add user to end user table
|
||||
|
||||
Return
|
||||
- end-user object
|
||||
- currently allowed models
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
try:
|
||||
## VALIDATION ##
|
||||
if data.default_model is not None:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail={"error": CommonProxyErrors.no_llm_router.value},
|
||||
)
|
||||
elif data.default_model not in llm_router.get_model_names():
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail={
|
||||
"error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format(
|
||||
data.default_model, set(llm_router.get_model_names())
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
new_end_user_obj: Dict = {}
|
||||
|
||||
## CREATE BUDGET ## if set
|
||||
_new_budget = new_budget_request(data)
|
||||
if _new_budget is not None:
|
||||
try:
|
||||
budget_record = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**_new_budget.model_dump(exclude_unset=True),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail={"error": str(e)})
|
||||
|
||||
new_end_user_obj["budget_id"] = budget_record.budget_id
|
||||
elif data.budget_id is not None:
|
||||
new_end_user_obj["budget_id"] = data.budget_id
|
||||
|
||||
_user_data = data.dict(exclude_none=True)
|
||||
|
||||
for k, v in _user_data.items():
|
||||
if k not in BudgetNewRequest.model_fields.keys():
|
||||
new_end_user_obj[k] = v
|
||||
|
||||
## Handle Object Permission - MCP Servers, Vector Stores etc.
|
||||
new_end_user_obj = await _set_object_permission(
|
||||
data_json=new_end_user_obj,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
# Ensure object_permission is not in the data being sent to create
|
||||
# It should have been converted to object_permission_id by _set_object_permission
|
||||
if "object_permission" in new_end_user_obj:
|
||||
verbose_proxy_logger.warning(
|
||||
f"object_permission still in new_end_user_obj after _set_object_permission: {new_end_user_obj.get('object_permission')}"
|
||||
)
|
||||
new_end_user_obj.pop("object_permission", None)
|
||||
|
||||
## WRITE TO DB ##
|
||||
end_user_record = await prisma_client.db.litellm_endusertable.create(
|
||||
data=new_end_user_obj, # type: ignore
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
|
||||
# Convert to dict and clean up recursive fields
|
||||
response_dict = end_user_record.model_dump()
|
||||
if response_dict.get("object_permission"):
|
||||
# Remove reverse relations from object_permission
|
||||
for field in [
|
||||
"teams",
|
||||
"verification_tokens",
|
||||
"organizations",
|
||||
"users",
|
||||
"end_users",
|
||||
]:
|
||||
response_dict["object_permission"].pop(field, None)
|
||||
|
||||
return response_dict
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if "Unique constraint failed on the fields: (`user_id`)" in str(e):
|
||||
raise ProxyException(
|
||||
message=f"Customer already exists, passed user_id={data.user_id}. Please pass a new user_id.",
|
||||
type="bad_request",
|
||||
code=400,
|
||||
param="user_id",
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/customer/info",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_EndUserTable,
|
||||
)
|
||||
@router.get(
|
||||
"/end_user/info",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def end_user_info(
|
||||
end_user_id: str = fastapi.Query(
|
||||
description="End User ID in the request parameters"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Get information about an end-user. An `end_user` is a customer (external user) of the proxy.
|
||||
|
||||
Parameters:
|
||||
- end_user_id (str, required): The unique identifier for the end-user
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X GET 'http://localhost:4000/customer/info?end_user_id=test-litellm-user-4' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
user_info = await prisma_client.db.litellm_endusertable.find_first(
|
||||
where={"user_id": end_user_id},
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
|
||||
if user_info is None:
|
||||
raise ProxyException(
|
||||
message="End User Id={} does not exist in db".format(end_user_id),
|
||||
type="not_found",
|
||||
code=404,
|
||||
param="end_user_id",
|
||||
)
|
||||
|
||||
# Convert to dict and clean up recursive fields
|
||||
response_dict = user_info.model_dump(exclude_none=True)
|
||||
if response_dict.get("object_permission"):
|
||||
# Remove reverse relations from object_permission
|
||||
for field in [
|
||||
"teams",
|
||||
"verification_tokens",
|
||||
"organizations",
|
||||
"users",
|
||||
"end_users",
|
||||
]:
|
||||
response_dict["object_permission"].pop(field, None)
|
||||
|
||||
return response_dict
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.customer_endpoints.end_user_info(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/customer/update",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/end_user/update",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_end_user(
|
||||
data: UpdateCustomerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Example curl
|
||||
|
||||
Parameters:
|
||||
- user_id: str
|
||||
- alias: Optional[str] = None # human-friendly alias
|
||||
- blocked: bool = False # allow/disallow requests for this end-user
|
||||
- max_budget: Optional[float] = None
|
||||
- budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
- allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
- default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
|
||||
Supported fields:
|
||||
* mcp_servers: List[str] - List of allowed MCP server IDs
|
||||
* mcp_access_groups: List[str] - List of MCP access group names
|
||||
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names
|
||||
* vector_stores: List[str] - List of allowed vector store IDs
|
||||
* agents: List[str] - List of allowed agent IDs
|
||||
* agent_access_groups: List[str] - List of agent access group names
|
||||
Example: {"mcp_servers": ["server_1"], "vector_stores": ["vector_store_1"]}
|
||||
IF null or {} then no object-level restrictions apply.
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/customer/update' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id": "test-litellm-user-4",
|
||||
"budget_id": "paid_tier"
|
||||
}'
|
||||
|
||||
# Updating object permissions
|
||||
curl -L -X POST 'http://localhost:4000/customer/update' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id": "user_1",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_3"],
|
||||
"vector_stores": ["vector_store_2", "vector_store_3"]
|
||||
}
|
||||
}'
|
||||
|
||||
See below for all params
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
try:
|
||||
data_json: dict = data.json()
|
||||
# get the row from db
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
# get non default values for key
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if v is not None and v not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
## Get end user table data ##
|
||||
end_user_table_data = await prisma_client.db.litellm_endusertable.find_first(
|
||||
where={"user_id": data.user_id}, include={"litellm_budget_table": True}
|
||||
)
|
||||
|
||||
if end_user_table_data is None:
|
||||
raise ProxyException(
|
||||
message="End User Id={} does not exist in db".format(data.user_id),
|
||||
type="not_found",
|
||||
code=404,
|
||||
param="user_id",
|
||||
)
|
||||
|
||||
end_user_table_data_typed = LiteLLM_EndUserTable(
|
||||
**end_user_table_data.model_dump()
|
||||
)
|
||||
|
||||
## Get budget table data ##
|
||||
end_user_budget_table = end_user_table_data_typed.litellm_budget_table
|
||||
|
||||
## Get all params for budget table ##
|
||||
budget_table_data = {}
|
||||
update_end_user_table_data = {}
|
||||
for k, v in non_default_values.items():
|
||||
# budget_id is for linking to existing budget, not for creating new budget
|
||||
if k == "budget_id":
|
||||
update_end_user_table_data[k] = v
|
||||
elif k in LiteLLM_BudgetTable.model_fields.keys():
|
||||
budget_table_data[k] = v
|
||||
|
||||
elif k in LiteLLM_EndUserTable.model_fields.keys():
|
||||
update_end_user_table_data[k] = v
|
||||
|
||||
## Handle object permission updates (MCP servers, vector stores, etc.)
|
||||
await _handle_customer_object_permission_update(
|
||||
non_default_values=non_default_values,
|
||||
end_user_table_data_typed=end_user_table_data_typed,
|
||||
update_end_user_table_data=update_end_user_table_data,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
## Check if we need to create a new budget (only if budget fields are provided, not just budget_id) ##
|
||||
if budget_table_data:
|
||||
if end_user_budget_table is None:
|
||||
## Create new budget ##
|
||||
budget_table_data_record = (
|
||||
await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**budget_table_data,
|
||||
"created_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
},
|
||||
include={"end_users": True},
|
||||
)
|
||||
)
|
||||
|
||||
update_end_user_table_data[
|
||||
"budget_id"
|
||||
] = budget_table_data_record.budget_id
|
||||
else:
|
||||
## Update existing budget ##
|
||||
budget_table_data_record = (
|
||||
await prisma_client.db.litellm_budgettable.update(
|
||||
where={"budget_id": end_user_budget_table.budget_id},
|
||||
data=budget_table_data,
|
||||
)
|
||||
)
|
||||
|
||||
## Update user table, with update params + new budget id (if set) ##
|
||||
verbose_proxy_logger.debug("/customer/update: Received data = %s", data)
|
||||
|
||||
# Ensure object_permission is not in the update data
|
||||
# It should have been converted to object_permission_id by handle_update_object_permission_common
|
||||
if "object_permission" in update_end_user_table_data:
|
||||
verbose_proxy_logger.warning(
|
||||
f"object_permission still in update_end_user_table_data: {update_end_user_table_data.get('object_permission')}"
|
||||
)
|
||||
update_end_user_table_data.pop("object_permission", None)
|
||||
|
||||
if data.user_id is not None and len(data.user_id) > 0:
|
||||
update_end_user_table_data["user_id"] = data.user_id # type: ignore
|
||||
verbose_proxy_logger.debug("In update customer, user_id condition block.")
|
||||
response = await prisma_client.db.litellm_endusertable.update(
|
||||
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore
|
||||
)
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
f"Failed updating customer data. User ID does not exist passed user_id={data.user_id}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"received response from updating prisma client. response={response}"
|
||||
)
|
||||
|
||||
# Convert to dict and clean up recursive fields
|
||||
response_dict = response.model_dump()
|
||||
if response_dict.get("object_permission"):
|
||||
# Remove reverse relations from object_permission
|
||||
for field in [
|
||||
"teams",
|
||||
"verification_tokens",
|
||||
"organizations",
|
||||
"users",
|
||||
"end_users",
|
||||
]:
|
||||
response_dict["object_permission"].pop(field, None)
|
||||
|
||||
return response_dict
|
||||
else:
|
||||
raise ValueError(f"user_id is required, passed user_id = {data.user_id}")
|
||||
|
||||
# update based on remaining passed in values
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/customer/delete",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/end_user/delete",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_end_user(
|
||||
data: DeleteCustomerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete multiple end-users.
|
||||
|
||||
Parameters:
|
||||
- user_ids (List[str], required): The unique `user_id`s for the users to delete
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/customer/delete' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_ids" :["ishaan-jaff-5"]
|
||||
}'
|
||||
|
||||
See below for all params
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
verbose_proxy_logger.debug("/customer/delete: Received data = %s", data)
|
||||
if (
|
||||
data.user_ids is not None
|
||||
and isinstance(data.user_ids, list)
|
||||
and len(data.user_ids) > 0
|
||||
):
|
||||
# First check if all users exist
|
||||
existing_users = await prisma_client.db.litellm_endusertable.find_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
existing_user_ids = {user.user_id for user in existing_users}
|
||||
missing_user_ids = [
|
||||
user_id for user_id in data.user_ids if user_id not in existing_user_ids
|
||||
]
|
||||
|
||||
if missing_user_ids:
|
||||
raise ProxyException(
|
||||
message="End User Id(s)={} do not exist in db".format(
|
||||
", ".join(missing_user_ids)
|
||||
),
|
||||
type="not_found",
|
||||
code=404,
|
||||
param="user_ids",
|
||||
)
|
||||
|
||||
# All users exist, proceed with deletion
|
||||
response = await prisma_client.db.litellm_endusertable.delete_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"received response from updating prisma client. response={response}"
|
||||
)
|
||||
return {
|
||||
"deleted_customers": response,
|
||||
"message": "Successfully deleted customers with ids: "
|
||||
+ str(data.user_ids),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"user_id is required, passed user_id = {data.user_ids}")
|
||||
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/customer/list",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[LiteLLM_EndUserTable],
|
||||
)
|
||||
@router.get(
|
||||
"/end_user/list",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def list_end_user(
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Admin-only] List all available customers
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location --request GET 'http://0.0.0.0:4000/customer/list' \
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
||||
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "Admin-only endpoint. Your user role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_endusertable.find_many(
|
||||
include={"litellm_budget_table": True, "object_permission": True}
|
||||
)
|
||||
|
||||
returned_response: List[LiteLLM_EndUserTable] = []
|
||||
for item in response:
|
||||
item_dict = item.model_dump()
|
||||
# Remove reverse relations from object_permission
|
||||
if item_dict.get("object_permission"):
|
||||
for field in [
|
||||
"teams",
|
||||
"verification_tokens",
|
||||
"organizations",
|
||||
"users",
|
||||
"end_users",
|
||||
]:
|
||||
item_dict["object_permission"].pop(field, None)
|
||||
returned_response.append(LiteLLM_EndUserTable(**item_dict))
|
||||
return returned_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.customer_endpoints.list_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/customer/daily/activity",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
@router.get(
|
||||
"/end_user/daily/activity",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_customer_daily_activity(
|
||||
end_user_ids: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
exclude_end_user_ids: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get daily activity for specific organizations or all accessible organizations.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Parse comma-separated ids
|
||||
end_user_ids_list = end_user_ids.split(",") if end_user_ids else None
|
||||
exclude_end_user_ids_list: Optional[List[str]] = None
|
||||
if exclude_end_user_ids:
|
||||
exclude_end_user_ids_list = (
|
||||
exclude_end_user_ids.split(",") if exclude_end_user_ids else None
|
||||
)
|
||||
|
||||
# Fetch organization aliases for metadata
|
||||
where_condition = {}
|
||||
if end_user_ids_list:
|
||||
where_condition["user_id"] = {"in": list(end_user_ids_list)}
|
||||
end_user_aliases = await prisma_client.db.litellm_endusertable.find_many(
|
||||
where=where_condition
|
||||
)
|
||||
end_user_alias_metadata = {e.user_id: {"alias": e.alias} for e in end_user_aliases}
|
||||
|
||||
# Query daily activity for organizations
|
||||
return await get_daily_activity(
|
||||
prisma_client=prisma_client,
|
||||
table_name="litellm_dailyenduserspend",
|
||||
entity_id_field="end_user_id",
|
||||
entity_id=end_user_ids_list,
|
||||
entity_metadata_field=end_user_alias_metadata,
|
||||
exclude_entity_ids=exclude_end_user_ids_list,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
FALLBACK MANAGEMENT ENDPOINTS
|
||||
|
||||
Dedicated endpoints for managing model fallbacks separately from general config.
|
||||
|
||||
POST /fallback - Create or update fallbacks for a specific model
|
||||
GET /fallback/{model} - Get fallbacks for a specific model
|
||||
DELETE /fallback/{model} - Delete fallbacks for a specific model
|
||||
"""
|
||||
# pyright: reportMissingImports=false
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
else:
|
||||
try:
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
except ImportError:
|
||||
# fastapi is only required for proxy, not for SDK usage
|
||||
pass
|
||||
|
||||
from litellm.types.management_endpoints.router_settings_endpoints import (
|
||||
FallbackCreateRequest,
|
||||
FallbackDeleteResponse,
|
||||
FallbackGetResponse,
|
||||
FallbackResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/fallback",
|
||||
tags=["Fallback Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=FallbackResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def create_fallback(
|
||||
data: FallbackCreateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create or update fallbacks for a specific model.
|
||||
|
||||
This endpoint allows you to configure fallback models separately from the general config.
|
||||
Fallbacks are triggered when a model call fails after retries.
|
||||
|
||||
**Example Request:**
|
||||
```json
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"fallback_models": ["gpt-4", "claude-3-haiku"],
|
||||
"fallback_type": "general"
|
||||
}
|
||||
```
|
||||
|
||||
**Fallback Types:**
|
||||
- `general`: Standard fallbacks for any error (default)
|
||||
- `context_window`: Fallbacks specifically for context window exceeded errors
|
||||
- `content_policy`: Fallbacks specifically for content policy violations
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate that we have a router
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "Router not initialized"},
|
||||
)
|
||||
|
||||
# Validate that the model exists in the router
|
||||
model_names = llm_router.model_names
|
||||
if data.model not in model_names:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"error": f"Model '{data.model}' not found in router",
|
||||
"available_models": list(model_names),
|
||||
},
|
||||
)
|
||||
|
||||
# Validate that all fallback models exist in the router
|
||||
invalid_fallback_models = [
|
||||
m for m in data.fallback_models if m not in model_names
|
||||
]
|
||||
if invalid_fallback_models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": f"Invalid fallback models: {invalid_fallback_models}",
|
||||
"available_models": list(model_names),
|
||||
},
|
||||
)
|
||||
|
||||
# Check if fallback model is the same as the primary model
|
||||
if data.model in data.fallback_models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": f"Model '{data.model}' cannot be its own fallback"},
|
||||
)
|
||||
|
||||
# Check if we need to store in DB
|
||||
if store_model_in_db is not True or prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "Database storage not enabled. Set 'STORE_MODEL_IN_DB=True' in your environment to use this feature."
|
||||
},
|
||||
)
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
router_settings = config.get("router_settings", {})
|
||||
|
||||
# Get the appropriate fallback list based on type
|
||||
fallback_key = "fallbacks"
|
||||
if data.fallback_type == "context_window":
|
||||
fallback_key = "context_window_fallbacks"
|
||||
elif data.fallback_type == "content_policy":
|
||||
fallback_key = "content_policy_fallbacks"
|
||||
|
||||
# Get existing fallbacks
|
||||
existing_fallbacks: List[Dict[str, List[str]]] = router_settings.get(
|
||||
fallback_key, []
|
||||
)
|
||||
|
||||
# Update or add the fallback configuration
|
||||
fallback_updated = False
|
||||
for i, fallback_dict in enumerate(existing_fallbacks):
|
||||
if data.model in fallback_dict:
|
||||
# Update existing fallback
|
||||
existing_fallbacks[i] = {data.model: data.fallback_models}
|
||||
fallback_updated = True
|
||||
break
|
||||
|
||||
if not fallback_updated:
|
||||
# Add new fallback
|
||||
existing_fallbacks.append({data.model: data.fallback_models})
|
||||
|
||||
# Update router settings
|
||||
router_settings[fallback_key] = existing_fallbacks
|
||||
|
||||
# Save to database - convert router_settings to JSON string
|
||||
router_settings_json = json.dumps(router_settings)
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "router_settings"},
|
||||
data={
|
||||
"create": {
|
||||
"param_name": "router_settings",
|
||||
"param_value": router_settings_json,
|
||||
},
|
||||
"update": {"param_value": router_settings_json},
|
||||
},
|
||||
)
|
||||
|
||||
# Update the in-memory router configuration
|
||||
setattr(llm_router, fallback_key, existing_fallbacks)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Fallback configured: {data.model} -> {data.fallback_models} (type: {data.fallback_type})"
|
||||
)
|
||||
|
||||
return FallbackResponse(
|
||||
model=data.model,
|
||||
fallback_models=data.fallback_models,
|
||||
fallback_type=data.fallback_type,
|
||||
message=f"Fallback configuration {'updated' if fallback_updated else 'created'} successfully",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error creating fallback: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Failed to create fallback: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/fallback/{model}",
|
||||
tags=["Fallback Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=FallbackGetResponse,
|
||||
)
|
||||
async def get_fallback(
|
||||
model: str,
|
||||
fallback_type: Literal["general", "context_window", "content_policy"] = "general",
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get fallback configuration for a specific model.
|
||||
|
||||
**Parameters:**
|
||||
- `model`: The model name to get fallbacks for
|
||||
- `fallback_type`: Type of fallback to retrieve (query parameter)
|
||||
|
||||
**Example:**
|
||||
```
|
||||
GET /fallback/gpt-3.5-turbo?fallback_type=general
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
try:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "Router not initialized"},
|
||||
)
|
||||
|
||||
# Get fallbacks using the existing utility function
|
||||
fallback_models = get_all_fallbacks(
|
||||
model=model, llm_router=llm_router, fallback_type=fallback_type
|
||||
)
|
||||
|
||||
if not fallback_models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"error": f"No {fallback_type} fallbacks configured for model '{model}'"
|
||||
},
|
||||
)
|
||||
|
||||
return FallbackGetResponse(
|
||||
model=model,
|
||||
fallback_models=fallback_models,
|
||||
fallback_type=fallback_type,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error getting fallback: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Failed to get fallback: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/fallback/{model}",
|
||||
tags=["Fallback Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=FallbackDeleteResponse,
|
||||
)
|
||||
async def delete_fallback(
|
||||
model: str,
|
||||
fallback_type: Literal["general", "context_window", "content_policy"] = "general",
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete fallback configuration for a specific model.
|
||||
|
||||
**Parameters:**
|
||||
- `model`: The model name to delete fallbacks for
|
||||
- `fallback_type`: Type of fallback to delete (query parameter)
|
||||
|
||||
**Example:**
|
||||
```
|
||||
DELETE /fallback/gpt-3.5-turbo?fallback_type=general
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "Router not initialized"},
|
||||
)
|
||||
|
||||
if store_model_in_db is not True or prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "Database storage not enabled. Set 'STORE_MODEL_IN_DB=True' in your environment to use this feature."
|
||||
},
|
||||
)
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
router_settings = config.get("router_settings", {})
|
||||
|
||||
# Get the appropriate fallback list based on type
|
||||
fallback_key = "fallbacks"
|
||||
if fallback_type == "context_window":
|
||||
fallback_key = "context_window_fallbacks"
|
||||
elif fallback_type == "content_policy":
|
||||
fallback_key = "content_policy_fallbacks"
|
||||
|
||||
# Get existing fallbacks
|
||||
existing_fallbacks: List[Dict[str, List[str]]] = router_settings.get(
|
||||
fallback_key, []
|
||||
)
|
||||
|
||||
# Find and remove the fallback configuration
|
||||
fallback_found = False
|
||||
updated_fallbacks = []
|
||||
for fallback_dict in existing_fallbacks:
|
||||
if model not in fallback_dict:
|
||||
updated_fallbacks.append(fallback_dict)
|
||||
else:
|
||||
fallback_found = True
|
||||
|
||||
if not fallback_found:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"error": f"No {fallback_type} fallbacks configured for model '{model}'"
|
||||
},
|
||||
)
|
||||
|
||||
# Update router settings
|
||||
router_settings[fallback_key] = updated_fallbacks
|
||||
|
||||
# Save to database - convert router_settings to JSON string
|
||||
router_settings_json = json.dumps(router_settings)
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "router_settings"},
|
||||
data={
|
||||
"create": {
|
||||
"param_name": "router_settings",
|
||||
"param_value": router_settings_json,
|
||||
},
|
||||
"update": {"param_value": router_settings_json},
|
||||
},
|
||||
)
|
||||
|
||||
# Update the in-memory router configuration
|
||||
setattr(llm_router, fallback_key, updated_fallbacks)
|
||||
|
||||
verbose_proxy_logger.info(f"Fallback deleted: {model} (type: {fallback_type})")
|
||||
|
||||
return FallbackDeleteResponse(
|
||||
model=model,
|
||||
fallback_type=fallback_type,
|
||||
message="Fallback configuration deleted successfully",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error deleting fallback: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Failed to delete fallback: {str(e)}"},
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,256 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from litellm.proxy._types import (
|
||||
CreateJWTKeyMappingRequest,
|
||||
DeleteJWTKeyMappingRequest,
|
||||
JWTKeyMappingResponse,
|
||||
LitellmUserRoles,
|
||||
UpdateJWTKeyMappingRequest,
|
||||
UserAPIKeyAuth,
|
||||
hash_token,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _to_response(mapping) -> JWTKeyMappingResponse:
|
||||
"""Convert a Prisma mapping object to a safe response (no hashed token)."""
|
||||
return JWTKeyMappingResponse(
|
||||
id=mapping.id,
|
||||
jwt_claim_name=mapping.jwt_claim_name,
|
||||
jwt_claim_value=mapping.jwt_claim_value,
|
||||
description=mapping.description,
|
||||
is_active=mapping.is_active,
|
||||
created_at=mapping.created_at,
|
||||
updated_at=mapping.updated_at,
|
||||
created_by=mapping.created_by,
|
||||
updated_by=mapping.updated_by,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/jwt/key/mapping/new",
|
||||
tags=["JWT Key Mapping"],
|
||||
response_model=JWTKeyMappingResponse,
|
||||
)
|
||||
async def create_jwt_key_mapping(
|
||||
data: CreateJWTKeyMappingRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only proxy admins can create JWT key mappings"
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
hashed_key = hash_token(data.key)
|
||||
create_data = {
|
||||
"jwt_claim_name": data.jwt_claim_name,
|
||||
"jwt_claim_value": data.jwt_claim_value,
|
||||
"token": hashed_key,
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
if data.description is not None:
|
||||
create_data["description"] = data.description
|
||||
|
||||
new_mapping = await prisma_client.db.litellm_jwtkeymapping.create(
|
||||
data=create_data
|
||||
)
|
||||
|
||||
# Invalidate cache
|
||||
cache_key = f"jwt_key_mapping:{data.jwt_claim_name}:{data.jwt_claim_value}"
|
||||
await user_api_key_cache.async_delete_cache(cache_key)
|
||||
|
||||
return _to_response(new_mapping)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "unique" in error_str or "p2002" in error_str:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A mapping for claim '{data.jwt_claim_name}' = '{data.jwt_claim_value}' already exists.",
|
||||
)
|
||||
if "foreign" in error_str or "p2003" in error_str:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The provided key does not match an existing virtual key.",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Failed to create JWT key mapping.")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/jwt/key/mapping/update",
|
||||
tags=["JWT Key Mapping"],
|
||||
response_model=JWTKeyMappingResponse,
|
||||
)
|
||||
async def update_jwt_key_mapping(
|
||||
data: UpdateJWTKeyMappingRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only proxy admins can update JWT key mappings"
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True, exclude={"id", "key"})
|
||||
if data.key is not None:
|
||||
update_data["token"] = hash_token(data.key)
|
||||
update_data["updated_by"] = user_api_key_dict.user_id
|
||||
|
||||
try:
|
||||
# Get old mapping for cache invalidation
|
||||
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
|
||||
where={"id": data.id}
|
||||
)
|
||||
|
||||
if old_mapping is None:
|
||||
raise HTTPException(status_code=404, detail="Mapping not found")
|
||||
|
||||
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
|
||||
await user_api_key_cache.async_delete_cache(cache_key)
|
||||
|
||||
updated_mapping = await prisma_client.db.litellm_jwtkeymapping.update(
|
||||
where={"id": data.id}, data=update_data
|
||||
)
|
||||
|
||||
# Invalidate new cache key if claim fields changed
|
||||
cache_key = f"jwt_key_mapping:{updated_mapping.jwt_claim_name}:{updated_mapping.jwt_claim_value}"
|
||||
await user_api_key_cache.async_delete_cache(cache_key)
|
||||
|
||||
return _to_response(updated_mapping)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "unique" in error_str or "p2002" in error_str:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="A mapping with those claim values already exists.",
|
||||
)
|
||||
if "foreign" in error_str or "p2003" in error_str:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The provided key does not match an existing virtual key.",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Failed to update JWT key mapping.")
|
||||
|
||||
|
||||
@router.post("/jwt/key/mapping/delete", tags=["JWT Key Mapping"])
|
||||
async def delete_jwt_key_mapping(
|
||||
data: DeleteJWTKeyMappingRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only proxy admins can delete JWT key mappings"
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get old mapping for cache invalidation
|
||||
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
|
||||
where={"id": data.id}
|
||||
)
|
||||
|
||||
if old_mapping is None:
|
||||
raise HTTPException(status_code=404, detail="Mapping not found")
|
||||
|
||||
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
|
||||
await user_api_key_cache.async_delete_cache(cache_key)
|
||||
|
||||
await prisma_client.db.litellm_jwtkeymapping.delete(where={"id": data.id})
|
||||
return {"status": "success"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete JWT key mapping.")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/jwt/key/mapping/list",
|
||||
tags=["JWT Key Mapping"],
|
||||
)
|
||||
async def list_jwt_key_mappings(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
page: int = Query(1, description="Page number", ge=1),
|
||||
size: int = Query(50, description="Page size", ge=1, le=100),
|
||||
):
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only proxy admins can list JWT key mappings"
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
skip = (page - 1) * size
|
||||
mappings = await prisma_client.db.litellm_jwtkeymapping.find_many(
|
||||
skip=skip,
|
||||
take=size,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
total_count = await prisma_client.db.litellm_jwtkeymapping.count()
|
||||
return {
|
||||
"mappings": [_to_response(m) for m in mappings],
|
||||
"total_count": total_count,
|
||||
"current_page": page,
|
||||
"total_pages": -(-total_count // size), # ceiling division
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to list JWT key mappings.")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/jwt/key/mapping/info",
|
||||
tags=["JWT Key Mapping"],
|
||||
response_model=JWTKeyMappingResponse,
|
||||
)
|
||||
async def info_jwt_key_mapping(
|
||||
id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only proxy admins can get JWT key mapping info"
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
|
||||
where={"id": id}
|
||||
)
|
||||
if mapping is None:
|
||||
raise HTTPException(status_code=404, detail="Mapping not found")
|
||||
return _to_response(mapping)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to get JWT key mapping info."
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,765 @@
|
||||
"""
|
||||
Allow proxy admin to manage model access groups
|
||||
|
||||
Endpoints here:
|
||||
- POST /model_group/new - Create a new access group with multiple model names
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
# Clear cache and reload models to pick up the access group changes
|
||||
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
clear_cache,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
||||
AccessGroupInfo,
|
||||
DeleteModelGroupResponse,
|
||||
ListAccessGroupsResponse,
|
||||
NewModelGroupRequest,
|
||||
NewModelGroupResponse,
|
||||
UpdateModelGroupRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def validate_models_exist(model_names: List[str], llm_router) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
Validate that all requested model names exist in the router.
|
||||
Checks only exact model name matches.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[str]]: (all_valid, missing_models)
|
||||
"""
|
||||
if llm_router is None:
|
||||
return False, model_names
|
||||
|
||||
router_model_names = set(llm_router.get_model_names())
|
||||
missing = [m for m in model_names if m not in router_model_names]
|
||||
return (len(missing) == 0, missing)
|
||||
|
||||
|
||||
def add_access_group_to_deployment(
|
||||
model_info: Dict[str, Any], access_group: str
|
||||
) -> Tuple[Dict[str, Any], bool]:
|
||||
"""
|
||||
Add an access group to a deployment's model_info.
|
||||
|
||||
Args:
|
||||
model_info: The model_info dictionary from the deployment
|
||||
access_group: The access group name to add
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], bool]: (updated_model_info, was_modified)
|
||||
"""
|
||||
access_groups = model_info.get("access_groups", [])
|
||||
|
||||
# Check if access group already exists
|
||||
if access_group in access_groups:
|
||||
return model_info, False
|
||||
|
||||
# Add the access group
|
||||
access_groups.append(access_group)
|
||||
model_info["access_groups"] = access_groups
|
||||
|
||||
return model_info, True
|
||||
|
||||
|
||||
async def update_deployments_with_access_group(
|
||||
model_names: List[str],
|
||||
access_group: str,
|
||||
prisma_client: PrismaClient,
|
||||
) -> int:
|
||||
"""
|
||||
Update all deployments for the given model names to include the access group.
|
||||
|
||||
Args:
|
||||
model_names: List of model names whose deployments should be updated
|
||||
access_group: The access group name to add
|
||||
prisma_client: Database client
|
||||
|
||||
Returns:
|
||||
int: Number of deployments updated
|
||||
"""
|
||||
models_updated = 0
|
||||
|
||||
for model_name in model_names:
|
||||
verbose_proxy_logger.debug(f"Updating deployments for model_name: {model_name}")
|
||||
|
||||
# Get all deployments with this model_name
|
||||
deployments = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
where={"model_name": model_name}
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Found {len(deployments)} deployments for model_name: {model_name}"
|
||||
)
|
||||
|
||||
# If no deployments found, this is a config model (not in DB)
|
||||
if len(deployments) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Can't find model '{model_name}' in Database. Access group management is only supported for database models."
|
||||
},
|
||||
)
|
||||
|
||||
# Update each deployment
|
||||
for deployment in deployments:
|
||||
model_info = deployment.model_info or {}
|
||||
|
||||
# Add access group using helper
|
||||
updated_model_info, was_modified = add_access_group_to_deployment(
|
||||
model_info=model_info,
|
||||
access_group=access_group,
|
||||
)
|
||||
|
||||
# Only update in DB if modified
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": deployment.model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
|
||||
models_updated += 1
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updated deployment {deployment.model_id} with access group: {access_group}"
|
||||
)
|
||||
|
||||
return models_updated
|
||||
|
||||
|
||||
async def update_specific_deployments_with_access_group(
|
||||
model_ids: List[str],
|
||||
access_group: str,
|
||||
prisma_client: PrismaClient,
|
||||
) -> int:
|
||||
"""
|
||||
Update specific deployments (by model_id) to include the access group.
|
||||
|
||||
Unlike update_deployments_with_access_group which tags ALL deployments sharing
|
||||
a model_name, this function only tags the specific deployments identified by
|
||||
their unique model_id.
|
||||
"""
|
||||
models_updated = 0
|
||||
for model_id in model_ids:
|
||||
verbose_proxy_logger.debug(f"Updating specific deployment model_id: {model_id}")
|
||||
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_id}
|
||||
)
|
||||
if deployment is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Deployment with model_id '{model_id}' not found in Database."
|
||||
},
|
||||
)
|
||||
model_info = deployment.model_info or {}
|
||||
updated_model_info, was_modified = add_access_group_to_deployment(
|
||||
model_info=model_info,
|
||||
access_group=access_group,
|
||||
)
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
models_updated += 1
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updated deployment {model_id} with access group: {access_group}"
|
||||
)
|
||||
return models_updated
|
||||
|
||||
|
||||
def remove_access_group_from_deployment(
|
||||
model_info: Dict[str, Any], access_group: str
|
||||
) -> Tuple[Dict[str, Any], bool]:
|
||||
"""
|
||||
Remove an access group from a deployment's model_info.
|
||||
|
||||
Args:
|
||||
model_info: The model_info dictionary from the deployment
|
||||
access_group: The access group name to remove
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], bool]: (updated_model_info, was_modified)
|
||||
"""
|
||||
access_groups = model_info.get("access_groups", [])
|
||||
|
||||
# Check if access group exists
|
||||
if access_group not in access_groups:
|
||||
return model_info, False
|
||||
|
||||
# Remove the access group
|
||||
access_groups.remove(access_group)
|
||||
model_info["access_groups"] = access_groups
|
||||
|
||||
return model_info, True
|
||||
|
||||
|
||||
async def get_all_access_groups_from_db(
|
||||
prisma_client: PrismaClient,
|
||||
) -> Dict[str, AccessGroupInfo]:
|
||||
"""
|
||||
Get all access groups from the database.
|
||||
|
||||
Returns:
|
||||
Dict[str, AccessGroupInfo]: Dictionary mapping access_group name to info
|
||||
"""
|
||||
# Get all deployments
|
||||
deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
|
||||
# Build access group map
|
||||
access_group_map: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for deployment in deployments:
|
||||
model_info = deployment.model_info or {}
|
||||
access_groups = model_info.get("access_groups", [])
|
||||
model_name = deployment.model_name
|
||||
|
||||
for access_group in access_groups:
|
||||
if access_group not in access_group_map:
|
||||
access_group_map[access_group] = {
|
||||
"model_names": set(),
|
||||
"deployment_count": 0,
|
||||
}
|
||||
|
||||
access_group_map[access_group]["model_names"].add(model_name)
|
||||
access_group_map[access_group]["deployment_count"] += 1
|
||||
|
||||
# Convert to AccessGroupInfo objects
|
||||
result = {}
|
||||
for access_group, data in access_group_map.items():
|
||||
result[access_group] = AccessGroupInfo(
|
||||
access_group=access_group,
|
||||
model_names=sorted(list(data["model_names"])),
|
||||
deployment_count=data["deployment_count"],
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/access_group/new",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=NewModelGroupResponse,
|
||||
)
|
||||
async def create_model_group(
|
||||
data: NewModelGroupRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new access group containing multiple model names.
|
||||
|
||||
An access group is a named collection of model groups that can be referenced
|
||||
by teams/keys for simplified access control.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/access_group/new' \\
|
||||
-H 'Authorization: Bearer sk-1234' \\
|
||||
-H 'Content-Type: application/json' \\
|
||||
-d '{
|
||||
"access_group": "production-models",
|
||||
"model_names": ["gpt-4", "claude-3-opus", "gemini-pro"]
|
||||
}'
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- access_group: str - The access group name (e.g., "production-models")
|
||||
- model_names: List[str] - List of existing model groups to include
|
||||
|
||||
Returns:
|
||||
- NewModelGroupResponse with the created access group details
|
||||
|
||||
Raises:
|
||||
- HTTPException 400: If any model names don't exist
|
||||
- HTTPException 500: If database operations fail
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Creating access group: {data.access_group} with models: {data.model_names}"
|
||||
)
|
||||
|
||||
# Validation: Check if access_group is provided
|
||||
if not data.access_group or not data.access_group.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "access_group is required and cannot be empty"},
|
||||
)
|
||||
|
||||
# Validation: Check that at least one of model_names or model_ids is provided
|
||||
has_model_names = data.model_names and len(data.model_names) > 0
|
||||
has_model_ids = data.model_ids and len(data.model_ids) > 0
|
||||
|
||||
if not has_model_names and not has_model_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Either model_names or model_ids must be provided and non-empty"
|
||||
},
|
||||
)
|
||||
|
||||
# If model_ids is provided, use it (more precise targeting)
|
||||
use_model_ids = has_model_ids
|
||||
|
||||
# Validate model_names exist in router (only if using model_names path)
|
||||
if not use_model_ids and has_model_names:
|
||||
assert data.model_names is not None
|
||||
all_valid, missing_models = validate_models_exist(
|
||||
model_names=data.model_names,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if not all_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},
|
||||
)
|
||||
|
||||
# Check if database is connected
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Database not connected. Cannot create access group."},
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if access group already exists
|
||||
existing_access_groups = await get_all_access_groups_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
|
||||
if data.access_group in existing_access_groups:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": f"Access group '{data.access_group}' already exists. Use PUT /access_group/{data.access_group}/update to modify it."
|
||||
},
|
||||
)
|
||||
|
||||
# Update deployments using the appropriate method
|
||||
if use_model_ids:
|
||||
assert data.model_ids is not None
|
||||
models_updated = await update_specific_deployments_with_access_group(
|
||||
model_ids=data.model_ids,
|
||||
access_group=data.access_group,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
assert data.model_names is not None
|
||||
models_updated = await update_deployments_with_access_group(
|
||||
model_names=data.model_names,
|
||||
access_group=data.access_group,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
await clear_cache()
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully created access group '{data.access_group}' with {models_updated} models updated"
|
||||
)
|
||||
|
||||
return NewModelGroupResponse(
|
||||
access_group=data.access_group,
|
||||
model_names=data.model_names,
|
||||
model_ids=data.model_ids,
|
||||
models_updated=models_updated,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error creating access group '{data.access_group}': {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to create access group: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/access_group/list",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ListAccessGroupsResponse,
|
||||
)
|
||||
async def list_access_groups(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all access groups.
|
||||
|
||||
Returns a list of all access groups with their model names and deployment counts.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl -X GET 'http://localhost:4000/access_group/list' \\
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
Returns:
|
||||
- ListAccessGroupsResponse with all access groups
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Database not connected."},
|
||||
)
|
||||
|
||||
try:
|
||||
access_groups_map = await get_all_access_groups_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
|
||||
# Sort by access group name
|
||||
access_groups_list = sorted(
|
||||
access_groups_map.values(),
|
||||
key=lambda x: x.access_group,
|
||||
)
|
||||
|
||||
return ListAccessGroupsResponse(access_groups=access_groups_list)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error listing access groups: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to list access groups: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/access_group/{access_group}/info",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AccessGroupInfo,
|
||||
)
|
||||
async def get_access_group_info(
|
||||
access_group: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get information about a specific access group.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl -X GET 'http://localhost:4000/access_group/production-models/info' \\
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- access_group: str - The access group name (URL path parameter)
|
||||
|
||||
Returns:
|
||||
- AccessGroupInfo with the access group details
|
||||
|
||||
Raises:
|
||||
- HTTPException 404: If access group not found
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Database not connected."},
|
||||
)
|
||||
|
||||
try:
|
||||
access_groups_map = await get_all_access_groups_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
|
||||
if access_group not in access_groups_map:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Access group '{access_group}' not found"},
|
||||
)
|
||||
|
||||
return access_groups_map[access_group]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error getting access group info for '{access_group}': {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to get access group info: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/access_group/{access_group}/update",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=NewModelGroupResponse,
|
||||
)
|
||||
async def update_access_group(
|
||||
access_group: str,
|
||||
data: UpdateModelGroupRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an access group's model names.
|
||||
|
||||
This will:
|
||||
1. Remove the access group from all current deployments
|
||||
2. Add the access group to all deployments for the new model_names list
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl -X PUT 'http://localhost:4000/access_group/production-models/update' \\
|
||||
-H 'Authorization: Bearer sk-1234' \\
|
||||
-H 'Content-Type: application/json' \\
|
||||
-d '{
|
||||
"model_names": ["gpt-4", "claude-3-sonnet"]
|
||||
}'
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- access_group: str - The access group name (URL path parameter)
|
||||
- model_names: List[str] - New list of model groups to include
|
||||
|
||||
Returns:
|
||||
- NewModelGroupResponse with the updated access group details
|
||||
|
||||
Raises:
|
||||
- HTTPException 400: If any model names don't exist
|
||||
- HTTPException 404: If access group not found
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_router, prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Database not connected."},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updating access group: {access_group} with models: {data.model_names}"
|
||||
)
|
||||
|
||||
# Validation: Check that at least one of model_names or model_ids is provided
|
||||
has_model_names = data.model_names and len(data.model_names) > 0
|
||||
has_model_ids = data.model_ids and len(data.model_ids) > 0
|
||||
|
||||
if not has_model_names and not has_model_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Either model_names or model_ids must be provided and non-empty"
|
||||
},
|
||||
)
|
||||
|
||||
use_model_ids = has_model_ids
|
||||
|
||||
# Validation: Check if access group exists
|
||||
try:
|
||||
access_groups_map = await get_all_access_groups_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
if access_group not in access_groups_map:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Access group '{access_group}' not found"},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to check access group existence: {str(e)}"},
|
||||
)
|
||||
|
||||
# Validation: Check if all new models exist (only if using model_names path)
|
||||
if not use_model_ids and has_model_names:
|
||||
assert data.model_names is not None
|
||||
all_valid, missing_models = validate_models_exist(
|
||||
model_names=data.model_names,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if not all_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Remove access group from ALL DB deployments (skip config models)
|
||||
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
|
||||
for deployment in all_deployments:
|
||||
model_info = deployment.model_info or {}
|
||||
|
||||
updated_model_info, was_modified = remove_access_group_from_deployment(
|
||||
model_info=model_info,
|
||||
access_group=access_group,
|
||||
)
|
||||
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": deployment.model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
|
||||
# Step 2: Add access group using the appropriate method
|
||||
if use_model_ids:
|
||||
assert data.model_ids is not None
|
||||
models_updated = await update_specific_deployments_with_access_group(
|
||||
model_ids=data.model_ids,
|
||||
access_group=access_group,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
assert data.model_names is not None
|
||||
models_updated = await update_deployments_with_access_group(
|
||||
model_names=data.model_names,
|
||||
access_group=access_group,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
# Clear cache and reload models to pick up the access group changes
|
||||
await clear_cache()
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully updated access group '{access_group}' with {models_updated} models updated"
|
||||
)
|
||||
|
||||
return NewModelGroupResponse(
|
||||
access_group=access_group,
|
||||
model_names=data.model_names,
|
||||
model_ids=data.model_ids,
|
||||
models_updated=models_updated,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error updating access group '{access_group}': {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to update access group: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/access_group/{access_group}/delete",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=DeleteModelGroupResponse,
|
||||
)
|
||||
async def delete_access_group(
|
||||
access_group: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete an access group.
|
||||
|
||||
Removes the access group from all deployments that have it.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl -X DELETE 'http://localhost:4000/access_group/production-models/delete' \\
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- access_group: str - The access group name (URL path parameter)
|
||||
|
||||
Returns:
|
||||
- DeleteModelGroupResponse with deletion details
|
||||
|
||||
Raises:
|
||||
- HTTPException 404: If access group not found
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Database not connected."},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(f"Deleting access group: {access_group}")
|
||||
|
||||
# Validation: Check if access group exists
|
||||
try:
|
||||
access_groups_map = await get_all_access_groups_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
if access_group not in access_groups_map:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Access group '{access_group}' not found"},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to check access group existence: {str(e)}"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Remove access group from all DB deployments (skip config models)
|
||||
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
models_updated = 0
|
||||
|
||||
for deployment in all_deployments:
|
||||
model_info = deployment.model_info or {}
|
||||
|
||||
updated_model_info, was_modified = remove_access_group_from_deployment(
|
||||
model_info=model_info,
|
||||
access_group=access_group,
|
||||
)
|
||||
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": deployment.model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
models_updated += 1
|
||||
|
||||
# Clear cache and reload models to pick up the access group changes
|
||||
await clear_cache()
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully deleted access group '{access_group}' from {models_updated} deployments"
|
||||
)
|
||||
|
||||
return DeleteModelGroupResponse(
|
||||
access_group=access_group,
|
||||
models_updated=models_updated,
|
||||
message=f"Access group '{access_group}' deleted successfully",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error deleting access group '{access_group}': {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to delete access group: {str(e)}"},
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Policy endpoints package.
|
||||
|
||||
Re-exports everything from endpoints module so existing imports
|
||||
like `from litellm.proxy.management_endpoints.policy_endpoints import router`
|
||||
continue to work. Patch targets also resolve correctly since names
|
||||
are imported directly into this namespace.
|
||||
"""
|
||||
|
||||
from litellm.proxy.management_endpoints.policy_endpoints.endpoints import * # noqa: F401, F403
|
||||
from litellm.proxy.management_endpoints.policy_endpoints.endpoints import ( # noqa: F401
|
||||
_build_all_names_per_competitor,
|
||||
_build_comparison_blocked_words,
|
||||
_build_competitor_guardrail_definitions,
|
||||
_build_name_blocked_words,
|
||||
_build_recommendation_blocked_words,
|
||||
_build_refinement_prompt,
|
||||
_clean_competitor_line,
|
||||
_parse_variations_response,
|
||||
)
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
AI Policy Suggester - uses LLM tool calling to suggest policy templates
|
||||
based on user-provided attack examples and descriptions.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import DEFAULT_COMPETITOR_DISCOVERY_MODEL
|
||||
|
||||
SUGGEST_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "select_policy_templates",
|
||||
"description": "Select one or more policy templates that best match the user's security requirements",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"selected_templates": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"template_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the selected template",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Brief reason why this template matches",
|
||||
},
|
||||
},
|
||||
"required": ["template_id", "reason"],
|
||||
},
|
||||
"description": "List of templates that match the user's requirements",
|
||||
},
|
||||
"explanation": {
|
||||
"type": "string",
|
||||
"description": "Overall explanation of why these templates were suggested",
|
||||
},
|
||||
},
|
||||
"required": ["selected_templates", "explanation"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AiPolicySuggester:
|
||||
"""Suggests policy templates using LLM tool calling."""
|
||||
|
||||
async def suggest(
|
||||
self,
|
||||
templates: list,
|
||||
attack_examples: List[str],
|
||||
description: str,
|
||||
model: Optional[str] = None,
|
||||
) -> dict:
|
||||
system_prompt = self._build_system_prompt(templates)
|
||||
user_prompt = self._build_user_prompt(attack_examples, description)
|
||||
model = model or DEFAULT_COMPETITOR_DISCOVERY_MODEL
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
tools=[SUGGEST_TOOL],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": "select_policy_templates"},
|
||||
},
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls # type: ignore
|
||||
if not tool_calls:
|
||||
return {
|
||||
"selected_templates": [],
|
||||
"explanation": "No templates could be matched to your requirements.",
|
||||
}
|
||||
|
||||
result = json.loads(tool_calls[0].function.arguments)
|
||||
|
||||
valid_ids = {t["id"] for t in templates}
|
||||
result["selected_templates"] = [
|
||||
s
|
||||
for s in result.get("selected_templates", [])
|
||||
if s.get("template_id") in valid_ids
|
||||
]
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("AI policy suggestion failed: %s", e)
|
||||
raise
|
||||
|
||||
def _build_system_prompt(self, templates: list) -> str:
|
||||
template_descriptions = []
|
||||
for t in templates:
|
||||
examples = t.get("example_sentences", [])
|
||||
examples_str = ", ".join(f'"{e}"' for e in examples) if examples else "none"
|
||||
entry = (
|
||||
f"- ID: {t['id']}\n"
|
||||
f" Title: {t['title']}\n"
|
||||
f" Description: {t['description']}\n"
|
||||
f" Example attacks it protects against: {examples_str}"
|
||||
)
|
||||
template_descriptions.append(entry)
|
||||
|
||||
return (
|
||||
"You are a security policy advisor. The user will describe attacks or content "
|
||||
"they want to block. Your job is to select the most relevant policy templates "
|
||||
"from the available set. Use the select_policy_templates tool to return your "
|
||||
"selections. Only select templates that are clearly relevant to what the user "
|
||||
"wants to block.\n\n"
|
||||
"Available templates:\n\n" + "\n\n".join(template_descriptions)
|
||||
)
|
||||
|
||||
def _build_user_prompt(self, attack_examples: List[str], description: str) -> str:
|
||||
parts = []
|
||||
filtered_examples = [e for e in attack_examples if e.strip()]
|
||||
if filtered_examples:
|
||||
parts.append("Example attack prompts I want to block:")
|
||||
for i, ex in enumerate(filtered_examples, 1):
|
||||
parts.append(f" {i}. {ex}")
|
||||
if description.strip():
|
||||
parts.append(f"\nDescription of what I want to block: {description}")
|
||||
return "\n".join(parts)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,936 @@
|
||||
"""
|
||||
Endpoints for /project operations
|
||||
|
||||
/project/new
|
||||
/project/update
|
||||
/project/delete
|
||||
/project/info
|
||||
/project/list
|
||||
"""
|
||||
|
||||
#### PROJECT MANAGEMENT ####
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _set_object_metadata_field
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
management_endpoint_wrapper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _check_user_permission_for_project(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
team_id: Optional[str],
|
||||
prisma_client: PrismaClient,
|
||||
require_admin: bool = False,
|
||||
team_object: Optional[LiteLLM_TeamTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has permission to manage a project.
|
||||
|
||||
Returns True if user is proxy admin or team admin (when team_id provided).
|
||||
If require_admin=True, only proxy admins are allowed.
|
||||
|
||||
If team_object is provided, it will be used instead of fetching from DB
|
||||
(avoids duplicate DB queries when team was already fetched for validation).
|
||||
"""
|
||||
is_proxy_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
|
||||
if require_admin:
|
||||
return is_proxy_admin
|
||||
|
||||
if is_proxy_admin:
|
||||
return True
|
||||
|
||||
if not team_id or not user_api_key_dict.user_id:
|
||||
return False
|
||||
|
||||
team = team_object
|
||||
if team is None:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
if team and team.admins:
|
||||
return user_api_key_dict.user_id in team.admins
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _validate_team_exists(
|
||||
team_id: str,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
"""Validate that a team exists. Returns the team row."""
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id},
|
||||
)
|
||||
|
||||
if team is None:
|
||||
raise ProxyException(
|
||||
message=f"Team not found, team_id={team_id}",
|
||||
type="not_found",
|
||||
code=404,
|
||||
param="team_id",
|
||||
)
|
||||
|
||||
return team
|
||||
|
||||
|
||||
def _check_team_project_limits(
|
||||
team_object: LiteLLM_TeamTable,
|
||||
data: Union[NewProjectRequest, UpdateProjectRequest],
|
||||
) -> None:
|
||||
"""
|
||||
Check that project limits respect its parent Team's limits.
|
||||
|
||||
Mirrors _check_org_team_limits() from team_endpoints.py.
|
||||
|
||||
Validates:
|
||||
- Project models are a subset of Team models
|
||||
- Project max_budget <= Team max_budget
|
||||
- Project tpm_limit <= Team tpm_limit
|
||||
- Project rpm_limit <= Team rpm_limit
|
||||
- Budget values are non-negative
|
||||
- soft_budget < max_budget
|
||||
"""
|
||||
# --- Budget non-negativity checks ---
|
||||
if data.max_budget is not None and data.max_budget < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"max_budget cannot be negative. Received: {data.max_budget}"
|
||||
},
|
||||
)
|
||||
if data.soft_budget is not None and data.soft_budget < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"soft_budget cannot be negative. Received: {data.soft_budget}"
|
||||
},
|
||||
)
|
||||
|
||||
# --- soft_budget < max_budget ---
|
||||
if data.soft_budget is not None and data.max_budget is not None:
|
||||
if data.soft_budget >= data.max_budget:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"soft_budget ({data.soft_budget}) must be strictly lower than max_budget ({data.max_budget})"
|
||||
},
|
||||
)
|
||||
|
||||
# --- Validate project models are a subset of team models ---
|
||||
project_models = getattr(data, "models", None)
|
||||
team_models = team_object.models or []
|
||||
if project_models and len(team_models) > 0:
|
||||
# If team has 'all-proxy-models', skip validation as it allows all models
|
||||
if SpecialModelNames.all_proxy_models.value not in team_models:
|
||||
for m in project_models:
|
||||
if m not in team_models:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Model '{m}' not in team's allowed models. Team allowed models={team_models}. Team: {team_object.team_id}"
|
||||
},
|
||||
)
|
||||
|
||||
# --- Validate project max_budget <= team max_budget ---
|
||||
# Team stores budget fields directly (max_budget, tpm_limit, rpm_limit)
|
||||
# unlike Project which uses a separate LiteLLM_BudgetTable relation
|
||||
if (
|
||||
data.max_budget is not None
|
||||
and team_object.max_budget is not None
|
||||
and data.max_budget > team_object.max_budget
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Project max_budget ({data.max_budget}) exceeds team's max_budget ({team_object.max_budget}). Team: {team_object.team_id}"
|
||||
},
|
||||
)
|
||||
|
||||
# --- Validate project tpm_limit <= team tpm_limit ---
|
||||
if (
|
||||
data.tpm_limit is not None
|
||||
and team_object.tpm_limit is not None
|
||||
and data.tpm_limit > team_object.tpm_limit
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Project tpm_limit ({data.tpm_limit}) exceeds team's tpm_limit ({team_object.tpm_limit}). Team: {team_object.team_id}"
|
||||
},
|
||||
)
|
||||
|
||||
# --- Validate project rpm_limit <= team rpm_limit ---
|
||||
if (
|
||||
data.rpm_limit is not None
|
||||
and team_object.rpm_limit is not None
|
||||
and data.rpm_limit > team_object.rpm_limit
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Project rpm_limit ({data.rpm_limit}) exceeds team's rpm_limit ({team_object.rpm_limit}). Team: {team_object.team_id}"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _create_budget_for_project(
|
||||
data: NewProjectRequest,
|
||||
user_id: Optional[str],
|
||||
litellm_proxy_admin_name: str,
|
||||
prisma_client: PrismaClient,
|
||||
) -> str:
|
||||
"""Create a budget for the project and return budget_id."""
|
||||
budget_params = LiteLLM_BudgetTable.model_fields.keys()
|
||||
_json_data = data.json(exclude_none=True)
|
||||
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
|
||||
budget_row = LiteLLM_BudgetTable(**_budget_data)
|
||||
|
||||
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**new_budget,
|
||||
"created_by": user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
)
|
||||
|
||||
return _budget.budget_id
|
||||
|
||||
|
||||
async def _set_project_object_permission(
|
||||
data: NewProjectRequest,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Creates the LiteLLM_ObjectPermissionTable record for the project.
|
||||
Returns the object_permission_id if created, otherwise None.
|
||||
"""
|
||||
if prisma_client is None:
|
||||
return None
|
||||
|
||||
if data.object_permission is not None:
|
||||
created_object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data=data.object_permission.model_dump(exclude_none=True),
|
||||
)
|
||||
)
|
||||
del data.object_permission
|
||||
return created_object_permission.object_permission_id
|
||||
return None
|
||||
|
||||
|
||||
def _remove_budget_fields_from_project_data(project_data: dict) -> dict:
|
||||
"""
|
||||
Remove budget fields from project data.
|
||||
Budget fields belong to LiteLLM_BudgetTable, not LiteLLM_ProjectTable.
|
||||
Keep budget_id as it's a foreign key.
|
||||
|
||||
Following the pattern from organization_endpoints.py
|
||||
"""
|
||||
budget_fields = LiteLLM_BudgetTable.model_fields.keys()
|
||||
for field in list(budget_fields):
|
||||
if field != "budget_id": # Keep the foreign key
|
||||
project_data.pop(field, None)
|
||||
return project_data
|
||||
|
||||
|
||||
@router.post(
|
||||
"/project/new",
|
||||
tags=["project management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=NewProjectResponse,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def new_project(
|
||||
data: NewProjectRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new project. Projects sit between teams and keys in the hierarchy.
|
||||
|
||||
Only admins or team admins can create projects.
|
||||
|
||||
# Parameters
|
||||
|
||||
- project_alias: *Optional[str]* - The name of the project.
|
||||
- description: *Optional[str]* - Description of the project's purpose and use case.
|
||||
- team_id: *str* - The team id that this project belongs to. Required.
|
||||
- models: *List* - The models the project has access to.
|
||||
- budget_id: *Optional[str]* - The id for a budget (tpm/rpm/max budget) for the project.
|
||||
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
|
||||
- max_budget: *Optional[float]* - Max budget for project
|
||||
- tpm_limit: *Optional[int]* - Max tpm limit for project
|
||||
- rpm_limit: *Optional[int]* - Max rpm limit for project
|
||||
- max_parallel_requests: *Optional[int]* - Max parallel requests for project
|
||||
- soft_budget: *Optional[float]* - Get a slack alert when this soft budget is reached. Don't block requests.
|
||||
- model_max_budget: *Optional[dict]* - Max budget for a specific model. Example: {"gpt-4": 100.0, "gpt-3.5-turbo": 50.0}
|
||||
- model_rpm_limit: *Optional[dict]* - RPM limits per model. Example: {"gpt-4": 1000, "gpt-3.5-turbo": 5000}
|
||||
- model_tpm_limit: *Optional[dict]* - TPM limits per model. Example: {"gpt-4": 50000, "gpt-3.5-turbo": 100000}
|
||||
- budget_duration: *Optional[str]* - Frequency of reseting project budget
|
||||
- metadata: *Optional[dict]* - Metadata for project, store information for project. Example metadata - {"use_case_id": "SNOW-12345", "responsible_ai_id": "RAI-67890"}
|
||||
- tags: *Optional[list]* - Tags for the project. Example: ["production", "api"]
|
||||
- blocked: *bool* - Flag indicating if the project is blocked or not - will stop all calls from keys with this project_id.
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - project-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission.
|
||||
|
||||
Example 1: Create new project **without** a budget_id, with model-specific limits
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/project/new' \\
|
||||
--header 'Authorization: Bearer sk-1234' \\
|
||||
--header 'Content-Type: application/json' \\
|
||||
--data '{
|
||||
"project_alias": "flight-search-assistant",
|
||||
"description": "AI-powered flight search and booking assistant",
|
||||
"team_id": "team-123",
|
||||
"models": ["gpt-4", "gpt-3.5-turbo"],
|
||||
"max_budget": 100,
|
||||
"model_rpm_limit": {
|
||||
"gpt-4": 1000,
|
||||
"gpt-3.5-turbo": 5000
|
||||
},
|
||||
"model_tpm_limit": {
|
||||
"gpt-4": 50000,
|
||||
"gpt-3.5-turbo": 100000
|
||||
},
|
||||
"metadata": {
|
||||
"use_case_id": "SNOW-12345",
|
||||
"responsible_ai_id": "RAI-67890"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
Example 2: Create new project **with** a budget_id
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/project/new' \\
|
||||
--header 'Authorization: Bearer sk-1234' \\
|
||||
--header 'Content-Type: application/json' \\
|
||||
--data '{
|
||||
"project_alias": "hotel-recommendations",
|
||||
"description": "Personalized hotel recommendation engine",
|
||||
"team_id": "team-123",
|
||||
"models": ["claude-3-sonnet"],
|
||||
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689",
|
||||
"metadata": {
|
||||
"use_case_id": "SNOW-54321"
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
try:
|
||||
if getattr(data, "tags", None) is not None and not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only premium users can add tags to projects. "
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
},
|
||||
)
|
||||
|
||||
if not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Project management is an enterprise feature. "
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
},
|
||||
)
|
||||
|
||||
# ADD METADATA FIELDS
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
if getattr(data, field, None) is not None:
|
||||
_set_object_metadata_field(
|
||||
object_data=data,
|
||||
field_name=field,
|
||||
value=getattr(data, field),
|
||||
)
|
||||
delattr(data, field)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Validate team exists and get team object with budget
|
||||
team_object = await _validate_team_exists(
|
||||
team_id=data.team_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
# Validate project limits against team limits
|
||||
_check_team_project_limits(
|
||||
team_object=LiteLLM_TeamTable(**team_object.model_dump()),
|
||||
data=data,
|
||||
)
|
||||
|
||||
# Check if user has permission to create projects for this team
|
||||
# only team admins can create projects for their team
|
||||
has_permission = await _check_user_permission_for_project(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_id=data.team_id,
|
||||
prisma_client=prisma_client,
|
||||
team_object=LiteLLM_TeamTable(**team_object.model_dump()),
|
||||
)
|
||||
|
||||
if not has_permission:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": f"Only admins or team admins can create projects. Your role is {user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
# Generate project_id if not provided
|
||||
if data.project_id is None:
|
||||
data.project_id = str(uuid.uuid4())
|
||||
else:
|
||||
# Check if project_id already exists
|
||||
existing_project = await prisma_client.db.litellm_projecttable.find_unique(
|
||||
where={"project_id": data.project_id}
|
||||
)
|
||||
if existing_project is not None:
|
||||
raise ProxyException(
|
||||
message=f"Project id = {data.project_id} already exists. Please use a different project id.",
|
||||
type="bad_request",
|
||||
code=400,
|
||||
param="project_id",
|
||||
)
|
||||
|
||||
# Create budget if not provided
|
||||
if data.budget_id is None:
|
||||
data.budget_id = await _create_budget_for_project(
|
||||
data=data,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
## Handle Object Permission - MCP, Vector Stores etc.
|
||||
object_permission_id = await _set_project_object_permission(
|
||||
data=data,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
# Create project row (following organization_endpoints.py pattern)
|
||||
project_row = LiteLLM_ProjectTable(
|
||||
**data.json(exclude_none=True),
|
||||
object_permission_id=object_permission_id,
|
||||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields:
|
||||
if getattr(data, field, None) is not None:
|
||||
_set_object_metadata_field(
|
||||
object_data=project_row,
|
||||
field_name=field,
|
||||
value=getattr(data, field),
|
||||
)
|
||||
|
||||
new_project_row = prisma_client.jsonify_object(
|
||||
project_row.json(exclude_none=True)
|
||||
)
|
||||
|
||||
# Remove budget fields (following organization_endpoints.py pattern)
|
||||
new_project_row = _remove_budget_fields_from_project_data(new_project_row)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"new_project_row: {json.dumps(new_project_row, indent=2)}"
|
||||
)
|
||||
response = await prisma_client.db.litellm_projecttable.create(
|
||||
data={
|
||||
**new_project_row, # type: ignore
|
||||
},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.project_endpoints.new_project(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/project/update",
|
||||
tags=["project management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_ProjectTable,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def update_project( # noqa: PLR0915
|
||||
data: UpdateProjectRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update a project
|
||||
|
||||
Parameters:
|
||||
- project_id: *str* - The project id to update. Required.
|
||||
- project_alias: *Optional[str]* - Updated name for the project
|
||||
- description: *Optional[str]* - Updated description for the project
|
||||
- team_id: *Optional[str]* - Updated team_id for the project
|
||||
- metadata: *Optional[dict]* - Updated metadata for project
|
||||
- models: *Optional[list]* - Updated list of models for the project
|
||||
- blocked: *Optional[bool]* - Updated blocked status
|
||||
- max_budget: *Optional[float]* - Updated max budget
|
||||
- tpm_limit: *Optional[int]* - Updated tpm limit
|
||||
- rpm_limit: *Optional[int]* - Updated rpm limit
|
||||
- model_rpm_limit: *Optional[dict]* - Updated RPM limits per model
|
||||
- model_tpm_limit: *Optional[dict]* - Updated TPM limits per model
|
||||
- budget_duration: *Optional[str]* - Updated budget duration
|
||||
- tags: *Optional[list]* - Updated list of tags for the project
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Updated object permission
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/project/update' \\
|
||||
--header 'Authorization: Bearer sk-1234' \\
|
||||
--header 'Content-Type: application/json' \\
|
||||
--data '{
|
||||
"project_id": "project-123",
|
||||
"description": "Updated flight search system with enhanced capabilities",
|
||||
"max_budget": 200,
|
||||
"model_rpm_limit": {
|
||||
"gpt-4": 2000,
|
||||
"gpt-3.5-turbo": 10000
|
||||
},
|
||||
"metadata": {
|
||||
"use_case_id": "SNOW-12345",
|
||||
"status": "active"
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
try:
|
||||
if getattr(data, "tags", None) is not None and not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only premium users can add tags to projects. "
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
},
|
||||
)
|
||||
|
||||
if not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Project management is an enterprise feature. "
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
},
|
||||
)
|
||||
|
||||
# ADD METADATA FIELDS
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
if getattr(data, field, None) is not None:
|
||||
_set_object_metadata_field(
|
||||
object_data=data,
|
||||
field_name=field,
|
||||
value=getattr(data, field),
|
||||
)
|
||||
delattr(data, field)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if data.project_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "project_id is required"},
|
||||
)
|
||||
|
||||
# Fetch existing project
|
||||
existing_project = await prisma_client.db.litellm_projecttable.find_unique(
|
||||
where={"project_id": data.project_id}
|
||||
)
|
||||
|
||||
if existing_project is None:
|
||||
raise ProxyException(
|
||||
message=f"Project not found, project_id={data.project_id}",
|
||||
type="not_found",
|
||||
code=404,
|
||||
param="project_id",
|
||||
)
|
||||
|
||||
# Validate team exists and get team object for limit + permission checks
|
||||
team_id_to_check = data.team_id or existing_project.team_id
|
||||
team_obj_for_checks = None
|
||||
if team_id_to_check is not None:
|
||||
team_obj_for_checks = await _validate_team_exists(
|
||||
team_id=team_id_to_check, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
# Check if user has permission to update this project
|
||||
has_permission = await _check_user_permission_for_project(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_id=existing_project.team_id,
|
||||
prisma_client=prisma_client,
|
||||
team_object=LiteLLM_TeamTable(**team_obj_for_checks.model_dump())
|
||||
if team_obj_for_checks
|
||||
else None,
|
||||
)
|
||||
|
||||
if not has_permission:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Only admins or team admins can update projects"},
|
||||
)
|
||||
|
||||
# Validate project limits against team limits
|
||||
if team_obj_for_checks is not None:
|
||||
_check_team_project_limits(
|
||||
team_object=LiteLLM_TeamTable(**team_obj_for_checks.model_dump()),
|
||||
data=data,
|
||||
)
|
||||
|
||||
# Prepare update data
|
||||
update_data = data.json(exclude_none=True, exclude={"project_id"})
|
||||
update_data = prisma_client.jsonify_object(update_data)
|
||||
update_data["updated_by"] = (
|
||||
user_api_key_dict.user_id or litellm_proxy_admin_name
|
||||
)
|
||||
|
||||
# Handle budget updates
|
||||
budget_fields = LiteLLM_BudgetTable.model_fields.keys()
|
||||
budget_updates = {k: v for k, v in update_data.items() if k in budget_fields}
|
||||
|
||||
if budget_updates and existing_project.budget_id:
|
||||
# Update existing budget
|
||||
await prisma_client.db.litellm_budgettable.update(
|
||||
where={"budget_id": existing_project.budget_id},
|
||||
data={
|
||||
**budget_updates,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
},
|
||||
)
|
||||
# Remove budget fields from project update
|
||||
for field in budget_updates.keys():
|
||||
update_data.pop(field, None)
|
||||
|
||||
# Handle object permissions
|
||||
if "object_permission" in update_data:
|
||||
object_permission_data = update_data.pop("object_permission")
|
||||
if object_permission_data:
|
||||
if existing_project.object_permission_id:
|
||||
# Update existing permission
|
||||
await prisma_client.db.litellm_objectpermissiontable.update(
|
||||
where={
|
||||
"object_permission_id": existing_project.object_permission_id
|
||||
},
|
||||
data=object_permission_data,
|
||||
)
|
||||
else:
|
||||
# Create new permission
|
||||
created_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data=object_permission_data,
|
||||
)
|
||||
)
|
||||
update_data[
|
||||
"object_permission_id"
|
||||
] = created_permission.object_permission_id
|
||||
|
||||
# Handle metadata fields
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields:
|
||||
if field in update_data:
|
||||
if update_data.get("metadata") is None:
|
||||
update_data["metadata"] = {}
|
||||
update_data["metadata"][field] = update_data.pop(field)
|
||||
|
||||
# Remove budget fields (following organization_endpoints.py pattern)
|
||||
update_data = _remove_budget_fields_from_project_data(update_data)
|
||||
|
||||
# Update project
|
||||
updated_project = await prisma_client.db.litellm_projecttable.update(
|
||||
where={"project_id": data.project_id},
|
||||
data=update_data,
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
|
||||
return updated_project
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.project_endpoints.update_project(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/project/delete",
|
||||
tags=["project management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[LiteLLM_ProjectTable],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def delete_project(
|
||||
data: DeleteProjectRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete projects
|
||||
|
||||
Parameters:
|
||||
- project_ids: *List[str]* - List of project ids to delete
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl --location --request DELETE 'http://0.0.0.0:4000/project/delete' \\
|
||||
--header 'Authorization: Bearer sk-1234' \\
|
||||
--header 'Content-Type: application/json' \\
|
||||
--data '{
|
||||
"project_ids": ["project-123", "project-456"]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user, prisma_client
|
||||
|
||||
try:
|
||||
if not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Project management is an enterprise feature. "
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
},
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Check if user is admin (only admins can delete projects)
|
||||
has_permission = await _check_user_permission_for_project(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_id=None,
|
||||
prisma_client=prisma_client,
|
||||
require_admin=True,
|
||||
)
|
||||
|
||||
if not has_permission:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Only admins can delete projects"},
|
||||
)
|
||||
|
||||
deleted_projects = []
|
||||
|
||||
for project_id in data.project_ids:
|
||||
# Check if project exists
|
||||
existing_project = await prisma_client.db.litellm_projecttable.find_unique(
|
||||
where={"project_id": project_id}
|
||||
)
|
||||
|
||||
if existing_project is None:
|
||||
raise ProxyException(
|
||||
message=f"Project not found, project_id={project_id}",
|
||||
type="not_found",
|
||||
code=404,
|
||||
param="project_ids",
|
||||
)
|
||||
|
||||
# Check if there are any keys associated with this project
|
||||
associated_keys = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"project_id": project_id}
|
||||
)
|
||||
)
|
||||
|
||||
if len(associated_keys) > 0:
|
||||
raise ProxyException(
|
||||
message=f"Cannot delete project {project_id}. {len(associated_keys)} key(s) are associated with it. Please delete or reassign the keys first.",
|
||||
type="bad_request",
|
||||
code=400,
|
||||
param="project_ids",
|
||||
)
|
||||
|
||||
# Delete the project
|
||||
deleted_project = await prisma_client.db.litellm_projecttable.delete(
|
||||
where={"project_id": project_id}
|
||||
)
|
||||
|
||||
deleted_projects.append(deleted_project)
|
||||
|
||||
return deleted_projects
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.project_endpoints.delete_project(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/project/info",
|
||||
tags=["project management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_ProjectTable,
|
||||
)
|
||||
async def project_info(
|
||||
project_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get information about a specific project
|
||||
|
||||
Parameters:
|
||||
- project_id: *str* - The project id to fetch info for
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/project/info?project_id=project-123' \\
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Fetch project
|
||||
project = await prisma_client.db.litellm_projecttable.find_unique(
|
||||
where={"project_id": project_id},
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
|
||||
if project is None:
|
||||
raise ProxyException(
|
||||
message=f"Project not found, project_id={project_id}",
|
||||
type="not_found",
|
||||
code=404,
|
||||
param="project_id",
|
||||
)
|
||||
|
||||
# Check if user has access to this project (admin or team member)
|
||||
is_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
is_team_member = False
|
||||
|
||||
if project.team_id and user_api_key_dict.user_id:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": project.team_id}
|
||||
)
|
||||
if team:
|
||||
is_team_member = (
|
||||
user_api_key_dict.user_id in team.admins
|
||||
or user_api_key_dict.user_id in team.members
|
||||
)
|
||||
|
||||
if not (is_admin or is_team_member):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "You don't have access to this project"},
|
||||
)
|
||||
|
||||
return project
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.project_endpoints.project_info(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/project/list",
|
||||
tags=["project management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[LiteLLM_ProjectTable],
|
||||
)
|
||||
async def list_projects(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all projects that the user has access to
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/project/list' \\
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# If proxy admin, get all projects
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN:
|
||||
projects = await prisma_client.db.litellm_projecttable.find_many(
|
||||
include={"litellm_budget_table": True, "object_permission": True}
|
||||
)
|
||||
else:
|
||||
# Get projects for teams the user belongs to
|
||||
user_teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"members": {"has": user_api_key_dict.user_id}},
|
||||
{"admins": {"has": user_api_key_dict.user_id}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
team_ids = [team.team_id for team in user_teams]
|
||||
|
||||
projects = await prisma_client.db.litellm_projecttable.find_many(
|
||||
where={"team_id": {"in": team_ids}},
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
|
||||
return projects
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.project_endpoints.list_projects(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
ROUTER SETTINGS MANAGEMENT
|
||||
|
||||
Endpoints for accessing router configuration and metadata
|
||||
|
||||
GET /router/settings - Get router configuration including available routing strategies
|
||||
GET /router/fields - Get router settings field definitions without values (for UI rendering)
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, get_args
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.router import Router
|
||||
from litellm.types.management_endpoints import (
|
||||
ROUTER_SETTINGS_FIELDS,
|
||||
ROUTING_STRATEGY_DESCRIPTIONS,
|
||||
RouterSettingsField,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RouterSettingsResponse(BaseModel):
|
||||
fields: List[RouterSettingsField] = Field(
|
||||
description="List of all configurable router settings with metadata"
|
||||
)
|
||||
current_values: Dict[str, Any] = Field(
|
||||
description="Current values of router settings"
|
||||
)
|
||||
routing_strategy_descriptions: Dict[str, str] = Field(
|
||||
description="Descriptions for each routing strategy option"
|
||||
)
|
||||
|
||||
|
||||
class RouterFieldsResponse(BaseModel):
|
||||
fields: List[RouterSettingsField] = Field(
|
||||
description="List of all configurable router settings with metadata (without field values)"
|
||||
)
|
||||
routing_strategy_descriptions: Dict[str, str] = Field(
|
||||
description="Descriptions for each routing strategy option"
|
||||
)
|
||||
|
||||
|
||||
def _get_routing_strategies_from_router_class() -> List[str]:
|
||||
"""
|
||||
Dynamically extract routing strategies from the Router class __init__ method.
|
||||
"""
|
||||
# Get the __init__ signature
|
||||
sig = inspect.signature(Router.__init__)
|
||||
|
||||
# Get the routing_strategy parameter
|
||||
routing_strategy_param = sig.parameters.get("routing_strategy")
|
||||
|
||||
if routing_strategy_param and routing_strategy_param.annotation:
|
||||
# Extract Literal values using get_args
|
||||
literal_values = get_args(routing_strategy_param.annotation)
|
||||
if literal_values:
|
||||
return list(literal_values)
|
||||
|
||||
raise ValueError("Unable to extract routing strategies from Router class")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/router/settings",
|
||||
tags=["Router Settings"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=RouterSettingsResponse,
|
||||
)
|
||||
async def get_router_settings(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get router configuration and available settings.
|
||||
|
||||
Returns:
|
||||
- fields: List of all configurable router settings with their metadata (type, description, default, options)
|
||||
The routing_strategy field includes available options extracted from the Router class
|
||||
- current_values: Current values of router settings from config
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_router, proxy_config
|
||||
|
||||
try:
|
||||
# Get available routing strategies dynamically from Router class
|
||||
available_routing_strategies = _get_routing_strategies_from_router_class()
|
||||
|
||||
# Get router settings fields from types file
|
||||
router_fields = [
|
||||
field.model_copy(deep=True) for field in ROUTER_SETTINGS_FIELDS
|
||||
]
|
||||
|
||||
# Populate routing_strategy field with available options and descriptions
|
||||
for field in router_fields:
|
||||
if field.field_name == "routing_strategy":
|
||||
field.options = available_routing_strategies
|
||||
break
|
||||
|
||||
# Try to get router settings from config
|
||||
config = await proxy_config.get_config()
|
||||
router_settings_from_config = config.get("router_settings", {})
|
||||
|
||||
# Get current values from llm_router if initialized
|
||||
current_values = {}
|
||||
if llm_router is not None:
|
||||
# Check all field names from the fields list
|
||||
for field in router_fields:
|
||||
if hasattr(llm_router, field.field_name):
|
||||
value = getattr(llm_router, field.field_name)
|
||||
current_values[field.field_name] = value
|
||||
|
||||
# Merge with config values (config takes precedence)
|
||||
current_values.update(router_settings_from_config)
|
||||
|
||||
# Update field values with current values
|
||||
for field in router_fields:
|
||||
if field.field_name in current_values:
|
||||
field.field_value = current_values[field.field_name]
|
||||
|
||||
return RouterSettingsResponse(
|
||||
fields=router_fields,
|
||||
current_values=current_values,
|
||||
routing_strategy_descriptions=ROUTING_STRATEGY_DESCRIPTIONS,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error fetching router settings: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/router/fields",
|
||||
tags=["Router Settings"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=RouterFieldsResponse,
|
||||
)
|
||||
async def get_router_fields(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get router settings field definitions without values.
|
||||
|
||||
Returns only the field metadata (type, description, default, options) without
|
||||
populating field_value. This is useful for UI components that need to know
|
||||
what fields to render, but will get the actual values from a different endpoint.
|
||||
|
||||
Returns:
|
||||
- fields: List of all configurable router settings with their metadata (type, description, default, options)
|
||||
The routing_strategy field includes available options extracted from the Router class
|
||||
Note: field_value will be None for all fields
|
||||
- routing_strategy_descriptions: Descriptions for each routing strategy option
|
||||
"""
|
||||
try:
|
||||
# Get available routing strategies dynamically from Router class
|
||||
available_routing_strategies = _get_routing_strategies_from_router_class()
|
||||
|
||||
# Get router settings fields from types file
|
||||
router_fields = [
|
||||
field.model_copy(deep=True) for field in ROUTER_SETTINGS_FIELDS
|
||||
]
|
||||
|
||||
# Populate routing_strategy field with available options
|
||||
for field in router_fields:
|
||||
if field.field_name == "routing_strategy":
|
||||
field.options = available_routing_strategies
|
||||
break
|
||||
|
||||
# Ensure field_value is None for all fields (don't populate values)
|
||||
for field in router_fields:
|
||||
field.field_value = None
|
||||
|
||||
return RouterFieldsResponse(
|
||||
fields=router_fields,
|
||||
routing_strategy_descriptions=ROUTING_STRATEGY_DESCRIPTIONS,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error fetching router fields: {str(e)}")
|
||||
raise
|
||||
@@ -0,0 +1,118 @@
|
||||
# SCIM v2 Integration for LiteLLM Proxy
|
||||
|
||||
This module provides SCIM v2 (System for Cross-domain Identity Management) endpoints for LiteLLM Proxy, allowing identity providers to manage users and teams (groups) within the LiteLLM ecosystem.
|
||||
|
||||
## Overview
|
||||
|
||||
SCIM is an open standard designed to simplify user management across different systems. This implementation allows compatible identity providers (like Okta, Azure AD, OneLogin, etc.) to automatically provision and deprovision users and groups in LiteLLM Proxy.
|
||||
|
||||
## Endpoints
|
||||
|
||||
The SCIM v2 API follows the standard specification with the following base URL:
|
||||
|
||||
```
|
||||
/scim/v2
|
||||
```
|
||||
|
||||
### User Management
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/Users` | GET | List all users with pagination support |
|
||||
| `/Users/{user_id}` | GET | Get a specific user by ID |
|
||||
| `/Users` | POST | Create a new user |
|
||||
| `/Users/{user_id}` | PUT | Update an existing user |
|
||||
| `/Users/{user_id}` | DELETE | Delete a user |
|
||||
|
||||
### Group Management
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/Groups` | GET | List all groups with pagination support |
|
||||
| `/Groups/{group_id}` | GET | Get a specific group by ID |
|
||||
| `/Groups` | POST | Create a new group |
|
||||
| `/Groups/{group_id}` | PUT | Update an existing group |
|
||||
| `/Groups/{group_id}` | DELETE | Delete a group |
|
||||
|
||||
## SCIM Schema
|
||||
|
||||
This implementation follows the standard SCIM v2 schema with the following mappings:
|
||||
|
||||
### Users
|
||||
|
||||
- SCIM User ID → LiteLLM `user_id`
|
||||
- SCIM User Email → LiteLLM `user_email`
|
||||
- SCIM User Group Memberships → LiteLLM User-Team relationships
|
||||
|
||||
### Groups
|
||||
|
||||
- SCIM Group ID → LiteLLM `team_id`
|
||||
- SCIM Group Display Name → LiteLLM `team_alias`
|
||||
- SCIM Group Members → LiteLLM Team members list
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable SCIM in your identity provider, use the full URL to the SCIM endpoint:
|
||||
|
||||
```
|
||||
https://your-litellm-proxy-url/scim/v2
|
||||
```
|
||||
|
||||
Most identity providers will require authentication. You should use a valid LiteLLM API key with administrative privileges.
|
||||
|
||||
## Features
|
||||
|
||||
- Full CRUD operations for users and groups
|
||||
- Pagination support
|
||||
- Basic filtering support
|
||||
- Automatic synchronization of user-team relationships
|
||||
- Proper status codes and error handling per SCIM specification
|
||||
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Listing Users
|
||||
|
||||
```
|
||||
GET /scim/v2/Users?startIndex=1&count=10
|
||||
```
|
||||
|
||||
### Creating a User
|
||||
|
||||
```json
|
||||
POST /scim/v2/Users
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"userName": "john.doe@example.com",
|
||||
"active": true,
|
||||
"emails": [
|
||||
{
|
||||
"value": "john.doe@example.com",
|
||||
"primary": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a User to Groups
|
||||
|
||||
```json
|
||||
PUT /scim/v2/Users/{user_id}
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"userName": "john.doe@example.com",
|
||||
"active": true,
|
||||
"emails": [
|
||||
{
|
||||
"value": "john.doe@example.com",
|
||||
"primary": true
|
||||
}
|
||||
],
|
||||
"groups": [
|
||||
{
|
||||
"value": "team-123",
|
||||
"display": "Engineering Team"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,177 @@
|
||||
from typing import List, Union
|
||||
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
Member,
|
||||
NewUserResponse,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.scim_v2 import *
|
||||
|
||||
|
||||
class ScimTransformations:
|
||||
DEFAULT_SCIM_NAME = "Unknown User"
|
||||
DEFAULT_SCIM_FAMILY_NAME = "Unknown Family Name"
|
||||
DEFAULT_SCIM_DISPLAY_NAME = "Unknown Display Name"
|
||||
DEFAULT_SCIM_MEMBER_VALUE = "Unknown Member Value"
|
||||
|
||||
@staticmethod
|
||||
async def transform_litellm_user_to_scim_user(
|
||||
user: Union[LiteLLM_UserTable, NewUserResponse],
|
||||
) -> SCIMUser:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "No database connected"}
|
||||
)
|
||||
|
||||
# Get user's teams/groups
|
||||
groups = []
|
||||
for team_id in user.teams or []:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team:
|
||||
team_alias = getattr(team, "team_alias", team.team_id)
|
||||
groups.append(SCIMUserGroup(value=team.team_id, display=team_alias))
|
||||
|
||||
user_created_at = user.created_at.isoformat() if user.created_at else None
|
||||
user_updated_at = user.updated_at.isoformat() if user.updated_at else None
|
||||
|
||||
emails = []
|
||||
# Only add email if it's a valid email address (contains @)
|
||||
# user_email can be a UUID when users are created without an email
|
||||
if user.user_email and "@" in user.user_email:
|
||||
emails.append(SCIMUserEmail(value=user.user_email, primary=True))
|
||||
|
||||
return SCIMUser(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id=user.user_id,
|
||||
userName=ScimTransformations._get_scim_user_name(user),
|
||||
displayName=ScimTransformations._get_scim_user_name(user),
|
||||
name=SCIMUserName(
|
||||
familyName=ScimTransformations._get_scim_family_name(user),
|
||||
givenName=ScimTransformations._get_scim_given_name(user),
|
||||
),
|
||||
emails=emails,
|
||||
groups=groups,
|
||||
active=True,
|
||||
meta={
|
||||
"resourceType": "User",
|
||||
"created": user_created_at,
|
||||
"lastModified": user_updated_at,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_user_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
|
||||
"""
|
||||
SCIM requires a display name with length > 0
|
||||
|
||||
We use the same userName and displayName for SCIM users
|
||||
"""
|
||||
if user.user_email and len(user.user_email) > 0:
|
||||
return user.user_email
|
||||
return ScimTransformations.DEFAULT_SCIM_DISPLAY_NAME
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_family_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
|
||||
"""
|
||||
SCIM requires a family name with length > 0
|
||||
"""
|
||||
metadata = user.metadata or {}
|
||||
if "scim_metadata" in metadata:
|
||||
scim_metadata: LiteLLM_UserScimMetadata = LiteLLM_UserScimMetadata(
|
||||
**metadata["scim_metadata"]
|
||||
)
|
||||
if scim_metadata.familyName and len(scim_metadata.familyName) > 0:
|
||||
return scim_metadata.familyName
|
||||
|
||||
if user.user_alias and len(user.user_alias) > 0:
|
||||
return user.user_alias
|
||||
return ScimTransformations.DEFAULT_SCIM_FAMILY_NAME
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_given_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
|
||||
"""
|
||||
SCIM requires a given name with length > 0
|
||||
"""
|
||||
metadata = user.metadata or {}
|
||||
if "scim_metadata" in metadata:
|
||||
scim_metadata: LiteLLM_UserScimMetadata = LiteLLM_UserScimMetadata(
|
||||
**metadata["scim_metadata"]
|
||||
)
|
||||
if scim_metadata.givenName and len(scim_metadata.givenName) > 0:
|
||||
return scim_metadata.givenName
|
||||
|
||||
if user.user_alias and len(user.user_alias) > 0:
|
||||
return user.user_alias or ScimTransformations.DEFAULT_SCIM_NAME
|
||||
return ScimTransformations.DEFAULT_SCIM_NAME
|
||||
|
||||
@staticmethod
|
||||
async def transform_litellm_team_to_scim_group(
|
||||
team: Union[LiteLLM_TeamTable, dict],
|
||||
) -> SCIMGroup:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "No database connected"}
|
||||
)
|
||||
|
||||
if isinstance(team, dict):
|
||||
team = LiteLLM_TeamTable(**team)
|
||||
|
||||
# Get team members with proper display names
|
||||
scim_members: List[SCIMMember] = []
|
||||
for member in team.members_with_roles or []:
|
||||
if isinstance(member, dict):
|
||||
member = Member(**member)
|
||||
|
||||
scim_members.append(
|
||||
SCIMMember(
|
||||
value=ScimTransformations._get_scim_member_value(member),
|
||||
display=ScimTransformations._get_scim_member_display(member),
|
||||
)
|
||||
)
|
||||
|
||||
team_alias = getattr(team, "team_alias", team.team_id)
|
||||
team_created_at = team.created_at.isoformat() if team.created_at else None
|
||||
team_updated_at = team.updated_at.isoformat() if team.updated_at else None
|
||||
|
||||
return SCIMGroup(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
id=team.team_id,
|
||||
displayName=team_alias,
|
||||
members=scim_members,
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"created": team_created_at,
|
||||
"lastModified": team_updated_at,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_member_value(member: Member) -> str:
|
||||
"""
|
||||
Get the SCIM member value. Use user_email if available, otherwise use user_id.
|
||||
SCIM member value should be the unique identifier for the user.
|
||||
"""
|
||||
if hasattr(member, "user_email") and member.user_email:
|
||||
return member.user_email
|
||||
elif hasattr(member, "user_id"):
|
||||
return member.user_id or ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
|
||||
return ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_member_display(member: Member) -> str:
|
||||
"""
|
||||
Get the SCIM member display. Use user_email if available, otherwise use user_id.
|
||||
SCIM member display should be the display name for the user.
|
||||
"""
|
||||
if hasattr(member, "user_email") and member.user_email:
|
||||
return member.user_email
|
||||
elif hasattr(member, "user_id"):
|
||||
return member.user_id or ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
|
||||
return ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
SSO (Single Sign-On) related modules for LiteLLM Proxy.
|
||||
|
||||
This package contains custom SSO implementations and utilities.
|
||||
"""
|
||||
|
||||
from litellm.proxy.management_endpoints.sso.custom_microsoft_sso import (
|
||||
CustomMicrosoftSSO,
|
||||
)
|
||||
|
||||
__all__ = ["CustomMicrosoftSSO"]
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Custom Microsoft SSO class that allows overriding default Microsoft endpoints.
|
||||
|
||||
This module provides a subclass of fastapi_sso's MicrosoftSSO that allows
|
||||
custom authorization, token, and userinfo endpoints to be specified via environment
|
||||
variables.
|
||||
|
||||
Environment Variables:
|
||||
- MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL
|
||||
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
|
||||
- MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL
|
||||
|
||||
If these are not set, the default Microsoft endpoints are used.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pydantic
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
class CustomMicrosoftSSO(MicrosoftSSO):
|
||||
"""
|
||||
Microsoft SSO subclass that allows overriding default endpoints via environment variables.
|
||||
|
||||
Supports:
|
||||
- MICROSOFT_AUTHORIZATION_ENDPOINT
|
||||
- MICROSOFT_TOKEN_ENDPOINT
|
||||
- MICROSOFT_USERINFO_ENDPOINT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = None,
|
||||
allow_insecure_http: bool = False,
|
||||
scope: Optional[List[str]] = None,
|
||||
tenant: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
allow_insecure_http=allow_insecure_http,
|
||||
scope=scope,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
async def get_discovery_document(self) -> DiscoveryDocument:
|
||||
"""
|
||||
Override to support custom endpoints via environment variables.
|
||||
Falls back to default Microsoft endpoints if not set.
|
||||
"""
|
||||
custom_authorization_endpoint = os.getenv(
|
||||
"MICROSOFT_AUTHORIZATION_ENDPOINT", None
|
||||
)
|
||||
custom_token_endpoint = os.getenv("MICROSOFT_TOKEN_ENDPOINT", None)
|
||||
custom_userinfo_endpoint = os.getenv("MICROSOFT_USERINFO_ENDPOINT", None)
|
||||
|
||||
# Use custom endpoints if set, otherwise use defaults
|
||||
authorization_endpoint = (
|
||||
custom_authorization_endpoint
|
||||
or f"https://login.microsoftonline.com/{self.tenant}/oauth2/v2.0/authorize"
|
||||
)
|
||||
token_endpoint = (
|
||||
custom_token_endpoint
|
||||
or f"https://login.microsoftonline.com/{self.tenant}/oauth2/v2.0/token"
|
||||
)
|
||||
userinfo_endpoint = (
|
||||
custom_userinfo_endpoint or f"https://graph.microsoft.com/{self.version}/me"
|
||||
)
|
||||
|
||||
if (
|
||||
custom_authorization_endpoint
|
||||
or custom_token_endpoint
|
||||
or custom_userinfo_endpoint
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Using custom Microsoft SSO endpoints - "
|
||||
f"authorization: {authorization_endpoint}, "
|
||||
f"token: {token_endpoint}, "
|
||||
f"userinfo: {userinfo_endpoint}"
|
||||
)
|
||||
|
||||
return DiscoveryDocument(
|
||||
authorization_endpoint=authorization_endpoint,
|
||||
token_endpoint=token_endpoint,
|
||||
userinfo_endpoint=userinfo_endpoint,
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
|
||||
def check_is_admin_only_access(ui_access_mode: Union[str, Dict]) -> bool:
|
||||
"""Checks ui access mode is admin_only"""
|
||||
if isinstance(ui_access_mode, str):
|
||||
return ui_access_mode == "admin_only"
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def has_admin_ui_access(user_role: str) -> bool:
|
||||
"""
|
||||
Check if the user has admin access to the UI.
|
||||
|
||||
Returns:
|
||||
bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise.
|
||||
"""
|
||||
|
||||
if (
|
||||
user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -0,0 +1,574 @@
|
||||
"""
|
||||
TAG MANAGEMENT
|
||||
|
||||
All /tag management endpoints
|
||||
|
||||
/tag/new
|
||||
/tag/info
|
||||
/tag/update
|
||||
/tag/delete
|
||||
/tag/list
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
get_daily_activity,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import handle_budget_for_entity
|
||||
from litellm.types.tag_management import (
|
||||
TagConfig,
|
||||
TagDeleteRequest,
|
||||
TagInfoRequest,
|
||||
TagNewRequest,
|
||||
TagUpdateRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.router import Deployment
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
|
||||
"""Helper function to get model names from model IDs"""
|
||||
try:
|
||||
models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
where={"model_id": {"in": model_ids}}
|
||||
)
|
||||
return {model.model_id: model.model_name for model in models}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error getting model names: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_deployments_by_model(
|
||||
model: str, llm_router: "Router"
|
||||
) -> List["Deployment"]:
|
||||
"""
|
||||
Get all deployments by model
|
||||
"""
|
||||
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
||||
|
||||
# Check if model id
|
||||
deployment = llm_router.get_deployment(model_id=model)
|
||||
if deployment is not None:
|
||||
return [deployment]
|
||||
|
||||
# Check if model name
|
||||
deployments = llm_router.get_model_list(model_name=model)
|
||||
if deployments is None:
|
||||
return []
|
||||
return [
|
||||
Deployment(
|
||||
model_name=deployment["model_name"],
|
||||
litellm_params=LiteLLM_Params(**deployment["litellm_params"]), # type: ignore
|
||||
model_info=ModelInfo(**deployment.get("model_info") or {}),
|
||||
)
|
||||
for deployment in deployments
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/new",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_tag(
|
||||
tag: TagNewRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag
|
||||
- description: Optional[str] - Description of what this tag represents
|
||||
- models: List[str] - List of either 'model_id' or 'model_name' allowed for this tag
|
||||
- budget_id: Optional[str] - The id for a budget (tpm/rpm/max budget) for the tag
|
||||
|
||||
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
|
||||
- max_budget: Optional[float] - Max budget for tag
|
||||
- tpm_limit: Optional[int] - Max tpm limit for tag
|
||||
- rpm_limit: Optional[int] - Max rpm limit for tag
|
||||
- max_parallel_requests: Optional[int] - Max parallel requests for tag
|
||||
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
|
||||
- model_max_budget: Optional[dict] - Max budget for a specific model
|
||||
- budget_duration: Optional[str] - Frequency of resetting tag budget
|
||||
"""
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.no_llm_router.value
|
||||
)
|
||||
try:
|
||||
# Check if tag already exists
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
where={"tag_name": tag.name}
|
||||
)
|
||||
if existing_tag is not None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Tag {tag.name} already exists"
|
||||
)
|
||||
|
||||
# Handle budget creation/assignment using common helper
|
||||
budget_id = await handle_budget_for_entity(
|
||||
data=tag,
|
||||
existing_budget_id=None,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
)
|
||||
|
||||
# Get model names for model_info
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
|
||||
# Create new tag in database
|
||||
new_tag_record = await prisma_client.db.litellm_tagtable.create(
|
||||
data={
|
||||
"tag_name": tag.name,
|
||||
"description": tag.description,
|
||||
"models": tag.models or [],
|
||||
"model_info": json.dumps(model_info),
|
||||
"spend": 0.0,
|
||||
"budget_id": budget_id,
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Update models with new tag
|
||||
if tag.models:
|
||||
tasks = []
|
||||
for model in tag.models:
|
||||
deployments = await get_deployments_by_model(model, llm_router)
|
||||
tasks.extend(
|
||||
[
|
||||
_add_tag_to_deployment(
|
||||
deployment=deployment,
|
||||
tag=tag.name,
|
||||
)
|
||||
for deployment in deployments
|
||||
]
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Build response
|
||||
tag_config = TagConfig(
|
||||
name=new_tag_record.tag_name,
|
||||
description=new_tag_record.description,
|
||||
models=new_tag_record.models,
|
||||
model_info=model_info,
|
||||
created_at=new_tag_record.created_at.isoformat(),
|
||||
updated_at=new_tag_record.updated_at.isoformat(),
|
||||
created_by=new_tag_record.created_by,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} created successfully",
|
||||
"tag": tag_config,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating tag: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _add_tag_to_deployment(deployment: "Deployment", tag: str):
|
||||
"""Helper function to add tag to deployment"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get current model from database to preserve encrypted fields
|
||||
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": deployment.model_info.id}
|
||||
)
|
||||
|
||||
if db_model is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Model {deployment.model_info.id} not found in database",
|
||||
)
|
||||
|
||||
# Prisma returns litellm_params as dict (already parsed from JSON)
|
||||
existing_params = db_model.litellm_params
|
||||
if isinstance(existing_params, str):
|
||||
# If it's a string, parse it
|
||||
existing_params = json.loads(existing_params)
|
||||
elif not isinstance(existing_params, dict):
|
||||
raise Exception(f"Unexpected litellm_params type: {type(existing_params)}")
|
||||
|
||||
# Add tag to tags array (preserve encryption of other fields)
|
||||
if "tags" not in existing_params:
|
||||
existing_params["tags"] = []
|
||||
if tag not in existing_params["tags"]:
|
||||
existing_params["tags"].append(tag)
|
||||
|
||||
# Update database with modified params (keeps encrypted fields encrypted)
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": deployment.model_info.id},
|
||||
data={"litellm_params": json.dumps(existing_params)},
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding tag to deployment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/update",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_tag(
|
||||
tag: TagUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag to update
|
||||
- description: Optional[str] - Updated description
|
||||
- models: List[str] - Updated list of allowed LLM models
|
||||
- budget_id: Optional[str] - The id for a budget to associate with the tag
|
||||
|
||||
### BUDGET UPDATE PARAMS ###
|
||||
- max_budget: Optional[float] - Max budget for tag
|
||||
- tpm_limit: Optional[int] - Max tpm limit for tag
|
||||
- rpm_limit: Optional[int] - Max rpm limit for tag
|
||||
- max_parallel_requests: Optional[int] - Max parallel requests for tag
|
||||
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
|
||||
- model_max_budget: Optional[dict] - Max budget for a specific model
|
||||
- budget_duration: Optional[str] - Frequency of resetting tag budget
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Check if tag exists
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
where={"tag_name": tag.name}
|
||||
)
|
||||
if existing_tag is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found")
|
||||
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Handle budget updates using common helper
|
||||
budget_id = await handle_budget_for_entity(
|
||||
data=tag,
|
||||
existing_budget_id=existing_tag.budget_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
)
|
||||
|
||||
# Get model names for model_info
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
|
||||
# Prepare update data
|
||||
update_data = {
|
||||
"description": tag.description,
|
||||
"models": tag.models or [],
|
||||
"model_info": json.dumps(model_info),
|
||||
}
|
||||
|
||||
# Add budget_id if it changed
|
||||
if budget_id != existing_tag.budget_id:
|
||||
update_data["budget_id"] = budget_id
|
||||
|
||||
# Update tag in database
|
||||
updated_tag_record = await prisma_client.db.litellm_tagtable.update(
|
||||
where={"tag_name": tag.name},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
# Build response
|
||||
tag_config = TagConfig(
|
||||
name=updated_tag_record.tag_name,
|
||||
description=updated_tag_record.description,
|
||||
models=updated_tag_record.models,
|
||||
model_info=model_info,
|
||||
created_at=updated_tag_record.created_at.isoformat(),
|
||||
updated_at=updated_tag_record.updated_at.isoformat(),
|
||||
created_by=updated_tag_record.created_by,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} updated successfully",
|
||||
"tag": tag_config,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating tag: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/info",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def info_tag(
|
||||
data: TagInfoRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get information about specific tags.
|
||||
|
||||
Parameters:
|
||||
- names: List[str] - List of tag names to get information for
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Query tags from database with budget info
|
||||
tag_records = await prisma_client.db.litellm_tagtable.find_many(
|
||||
where={"tag_name": {"in": data.names}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
# Check if any requested tags don't exist
|
||||
found_tag_names = {tag.tag_name for tag in tag_records}
|
||||
missing_tags = [name for name in data.names if name not in found_tag_names]
|
||||
if missing_tags:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tags not found: {missing_tags}"
|
||||
)
|
||||
|
||||
# Build response
|
||||
requested_tags = {}
|
||||
for tag_record in tag_records:
|
||||
# Parse model_info from JSON
|
||||
model_info = {}
|
||||
if tag_record.model_info:
|
||||
if isinstance(tag_record.model_info, str):
|
||||
model_info = json.loads(tag_record.model_info)
|
||||
else:
|
||||
model_info = tag_record.model_info
|
||||
|
||||
tag_dict = {
|
||||
"name": tag_record.tag_name,
|
||||
"description": tag_record.description,
|
||||
"models": tag_record.models,
|
||||
"model_info": model_info,
|
||||
"created_at": tag_record.created_at.isoformat(),
|
||||
"updated_at": tag_record.updated_at.isoformat(),
|
||||
"created_by": tag_record.created_by,
|
||||
}
|
||||
|
||||
# Add budget info if available
|
||||
if (
|
||||
hasattr(tag_record, "litellm_budget_table")
|
||||
and tag_record.litellm_budget_table
|
||||
):
|
||||
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
|
||||
|
||||
requested_tags[tag_record.tag_name] = tag_dict
|
||||
|
||||
return requested_tags
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/list",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def list_tags(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all available tags with their budget information.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
## QUERY STORED TAGS ##
|
||||
tag_records = await prisma_client.db.litellm_tagtable.find_many(
|
||||
include={"litellm_budget_table": True}
|
||||
)
|
||||
|
||||
stored_tag_names = set()
|
||||
list_of_tags = []
|
||||
for tag_record in tag_records:
|
||||
stored_tag_names.add(tag_record.tag_name)
|
||||
# Parse model_info from JSON
|
||||
model_info = {}
|
||||
if tag_record.model_info:
|
||||
if isinstance(tag_record.model_info, str):
|
||||
model_info = json.loads(tag_record.model_info)
|
||||
else:
|
||||
model_info = tag_record.model_info
|
||||
|
||||
tag_dict = {
|
||||
"name": tag_record.tag_name,
|
||||
"description": tag_record.description,
|
||||
"models": tag_record.models,
|
||||
"model_info": model_info,
|
||||
"created_at": tag_record.created_at.isoformat(),
|
||||
"updated_at": tag_record.updated_at.isoformat(),
|
||||
"created_by": tag_record.created_by,
|
||||
}
|
||||
|
||||
# Add budget info if available
|
||||
if (
|
||||
hasattr(tag_record, "litellm_budget_table")
|
||||
and tag_record.litellm_budget_table
|
||||
):
|
||||
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
|
||||
|
||||
list_of_tags.append(tag_dict)
|
||||
|
||||
## QUERY DYNAMIC TAGS ##
|
||||
# Use group_by instead of find_many(distinct=["tag"]).
|
||||
# Prisma's distinct fetches all columns for all rows and deduplicates
|
||||
# in application code, which is extremely slow on large tables.
|
||||
# See: https://www.prisma.io/docs/orm/prisma-client/queries/aggregation-grouping-summarizing#distinct-under-the-hood
|
||||
dynamic_tag_rows = await prisma_client.db.litellm_dailytagspend.group_by(
|
||||
by=["tag"],
|
||||
where={"tag": {"not": None}},
|
||||
# The old find_many(distinct=...) returned arbitrary timestamps from
|
||||
# whichever row Prisma happened to pick. MIN/MAX give more meaningful
|
||||
# values: earliest appearance and most recent activity.
|
||||
_min={"created_at": True},
|
||||
_max={"updated_at": True},
|
||||
)
|
||||
|
||||
dynamic_tag_config = [
|
||||
{
|
||||
"name": row["tag"],
|
||||
"description": "This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.",
|
||||
"models": None,
|
||||
"created_at": row["_min"]["created_at"].isoformat(),
|
||||
"updated_at": row["_max"]["updated_at"].isoformat(),
|
||||
}
|
||||
for row in dynamic_tag_rows
|
||||
if row["tag"] not in stored_tag_names
|
||||
]
|
||||
|
||||
return list_of_tags + dynamic_tag_config
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/delete",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_tag(
|
||||
data: TagDeleteRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete a tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag to delete
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Check if tag exists
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
where={"tag_name": data.name}
|
||||
)
|
||||
if existing_tag is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
|
||||
|
||||
# Delete tag from database
|
||||
await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name})
|
||||
|
||||
return {"message": f"Tag {data.name} deleted successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/daily/activity",
|
||||
response_model=SpendAnalyticsPaginatedResponse,
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_tag_daily_activity(
|
||||
tags: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
):
|
||||
"""
|
||||
Get daily activity for specific tags or all tags.
|
||||
|
||||
Args:
|
||||
tags (Optional[str]): Comma-separated list of tags to filter by. If not provided, returns data for all tags.
|
||||
start_date (Optional[str]): Start date for the activity period (YYYY-MM-DD).
|
||||
end_date (Optional[str]): End date for the activity period (YYYY-MM-DD).
|
||||
model (Optional[str]): Filter by model name.
|
||||
api_key (Optional[str]): Filter by API key.
|
||||
page (int): Page number for pagination.
|
||||
page_size (int): Number of items per page.
|
||||
|
||||
Returns:
|
||||
SpendAnalyticsPaginatedResponse: Paginated response containing daily activity data.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
# Convert comma-separated tags string to list if provided
|
||||
tag_list = tags.split(",") if tags else None
|
||||
|
||||
return await get_daily_activity(
|
||||
prisma_client=prisma_client,
|
||||
table_name="litellm_dailytagspend",
|
||||
entity_id_field="tag",
|
||||
entity_id=tag_list,
|
||||
entity_metadata_field=None,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
# metadata_metrics_func=None because litellm_dailytagspend rows are
|
||||
# pre-aggregated per (date, tag, model, …) and have no request_id.
|
||||
# Deduplication across tags is therefore not possible at this level —
|
||||
# a request tagged with N tags contributes its spend to N separate rows,
|
||||
# so passing compute_tag_metadata_totals would double-count spend when
|
||||
# multiple tags are present. The panel is primarily used to inspect
|
||||
# individual tags, making this trade-off acceptable.
|
||||
metadata_metrics_func=None,
|
||||
)
|
||||
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Endpoints to control callbacks per team
|
||||
|
||||
Use this when each team should control its own callbacks
|
||||
"""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
AddTeamCallback,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
TeamCallbackMetadata,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/{team_id:path}/callback",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def add_team_callbacks(
|
||||
data: AddTeamCallback,
|
||||
http_request: Request,
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Add a success/failure callback to a team
|
||||
|
||||
Use this if if you want different teams to have different success/failure callbacks
|
||||
|
||||
Parameters:
|
||||
- callback_name (Literal["langfuse", "langsmith", "gcs"], required): The name of the callback to add
|
||||
- callback_type (Literal["success", "failure", "success_and_failure"], required): The type of callback to add. One of:
|
||||
- "success": Callback for successful LLM calls
|
||||
- "failure": Callback for failed LLM calls
|
||||
- "success_and_failure": Callback for both successful and failed LLM calls
|
||||
- callback_vars (StandardCallbackDynamicParams, required): A dictionary of variables to pass to the callback
|
||||
- langfuse_public_key: The public key for the Langfuse callback
|
||||
- langfuse_secret_key: The secret key for the Langfuse callback
|
||||
- langfuse_secret: The secret for the Langfuse callback
|
||||
- langfuse_host: The host for the Langfuse callback
|
||||
- gcs_bucket_name: The name of the GCS bucket
|
||||
- gcs_path_service_account: The path to the GCS service account
|
||||
- langsmith_api_key: The API key for the Langsmith callback
|
||||
- langsmith_project: The project for the Langsmith callback
|
||||
- langsmith_base_url: The base URL for the Langsmith callback
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X POST 'http:/localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "success",
|
||||
"callback_vars": {"langfuse_public_key": "pk-lf-xxxx1", "langfuse_secret_key": "sk-xxxxx"}
|
||||
|
||||
}'
|
||||
```
|
||||
|
||||
This means for the team where team_id = dbe2f686-a686-4896-864a-4c3924458709, all LLM calls will be logged to langfuse using the public key pk-lf-xxxx1 and the secret key sk-xxxxx
|
||||
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Check if team_id exists already
|
||||
_existing_team = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if _existing_team is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Team id = {team_id} does not exist. Please use a different team id."
|
||||
},
|
||||
)
|
||||
|
||||
# store team callback settings in metadata
|
||||
team_metadata = _existing_team.metadata
|
||||
team_callback_settings: List[dict] = team_metadata.get(
|
||||
"logging"
|
||||
) # will be dict of type AddTeamCallback
|
||||
if team_callback_settings is None or not isinstance(
|
||||
team_callback_settings, list
|
||||
):
|
||||
team_callback_settings = []
|
||||
|
||||
## check if it already exists, for the same callback event
|
||||
for callback in team_callback_settings:
|
||||
if (
|
||||
callback.get("callback_name") == data.callback_name
|
||||
and callback.get("callback_type") == data.callback_type
|
||||
):
|
||||
raise ProxyException(
|
||||
message=f"callback_name = {data.callback_name} already exists in team_callback_settings, for team_id = {team_id} and event = {data.callback_type}",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="callback_name",
|
||||
)
|
||||
|
||||
team_callback_settings.append(data.model_dump())
|
||||
|
||||
team_metadata["logging"] = team_callback_settings
|
||||
team_metadata_json = json.dumps(team_metadata) # update team_metadata
|
||||
|
||||
new_team_row = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"data": new_team_row,
|
||||
}
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except ProxyException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.add_team_callbacks(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/{team_id}/disable_logging",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def disable_team_logging(
|
||||
http_request: Request,
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Disable all logging callbacks for a team
|
||||
|
||||
Parameters:
|
||||
- team_id (str, required): The unique identifier for the team
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X POST 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/disable_logging' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Check if team exists
|
||||
_existing_team = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if _existing_team is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team id = {team_id} does not exist."},
|
||||
)
|
||||
|
||||
# Update team metadata to disable logging
|
||||
team_metadata = _existing_team.metadata
|
||||
team_callback_settings = team_metadata.get("callback_settings", {})
|
||||
team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings)
|
||||
|
||||
# Reset callbacks
|
||||
team_callback_settings_obj.success_callback = []
|
||||
team_callback_settings_obj.failure_callback = []
|
||||
|
||||
# Update metadata
|
||||
team_metadata["callback_settings"] = team_callback_settings_obj.model_dump()
|
||||
team_metadata_json = json.dumps(team_metadata)
|
||||
|
||||
# Update team in database
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
|
||||
)
|
||||
|
||||
if updated_team is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Team id = {team_id} does not exist. Error updating team logging"
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Logging disabled for team {team_id}",
|
||||
"data": {
|
||||
"team_id": updated_team.team_id,
|
||||
"success_callbacks": [],
|
||||
"failure_callbacks": [],
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"litellm.proxy.proxy_server.disable_team_logging(): Exception occurred - {str(e)}"
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/{team_id:path}/callback",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def get_team_callbacks(
|
||||
http_request: Request,
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get the success/failure callbacks and variables for a team
|
||||
|
||||
Parameters:
|
||||
- team_id (str, required): The unique identifier for the team
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
This will return the callback settings for the team with id dbe2f686-a686-4896-864a-4c3924458709
|
||||
|
||||
Returns {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"team_id": team_id,
|
||||
"success_callbacks": team_callback_settings_obj.success_callback,
|
||||
"failure_callbacks": team_callback_settings_obj.failure_callback,
|
||||
"callback_vars": team_callback_settings_obj.callback_vars,
|
||||
},
|
||||
}
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Check if team_id exists
|
||||
_existing_team = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if _existing_team is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team id = {team_id} does not exist."},
|
||||
)
|
||||
|
||||
# Retrieve team callback settings from metadata
|
||||
team_metadata = _existing_team.metadata
|
||||
team_callback_settings = team_metadata.get("callback_settings", {})
|
||||
|
||||
# Convert to TeamCallbackMetadata object for consistent structure
|
||||
team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"team_id": team_id,
|
||||
"success_callbacks": team_callback_settings_obj.success_callback,
|
||||
"failure_callbacks": team_callback_settings_obj.failure_callback,
|
||||
"callback_vars": team_callback_settings_obj.callback_vars,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.get_team_callbacks(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,605 @@
|
||||
"""
|
||||
TOOL POLICY MANAGEMENT
|
||||
|
||||
All /tool management endpoints
|
||||
|
||||
GET /v1/tool/list - List all discovered tools and their policies
|
||||
GET /v1/tool/policy/options - List available input/output policy options with descriptions
|
||||
GET /v1/tool/{tool_name} - Get a single tool's details
|
||||
POST /v1/tool/policy - Update the input_policy / output_policy for a tool
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.tool_management import (
|
||||
LiteLLM_ToolTableRow,
|
||||
ToolDetailResponse,
|
||||
ToolInputPolicy,
|
||||
ToolListResponse,
|
||||
ToolPolicyOption,
|
||||
ToolPolicyOptionsResponse,
|
||||
ToolPolicyUpdateRequest,
|
||||
ToolPolicyUpdateResponse,
|
||||
ToolUsageLogEntry,
|
||||
ToolUsageLogsResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TOOL_POLICY_OPTIONS = ToolPolicyOptionsResponse(
|
||||
input_policies=[
|
||||
ToolPolicyOption(
|
||||
value="untrusted",
|
||||
label="Untrusted",
|
||||
description="Tool accepts any input, including data from untrusted tool outputs. Default for newly discovered tools.",
|
||||
),
|
||||
ToolPolicyOption(
|
||||
value="trusted",
|
||||
label="Trusted",
|
||||
description="Tool requires trusted input. Blocked if the conversation contains output from any tool with output_policy=untrusted.",
|
||||
),
|
||||
ToolPolicyOption(
|
||||
value="blocked",
|
||||
label="Blocked",
|
||||
description="Tool is completely prohibited. Any attempt to call it is rejected.",
|
||||
),
|
||||
],
|
||||
output_policies=[
|
||||
ToolPolicyOption(
|
||||
value="untrusted",
|
||||
label="Untrusted",
|
||||
description="Tool output may contain unsafe content (prompt injection, risky code). Downstream tools with input_policy=trusted will be blocked.",
|
||||
),
|
||||
ToolPolicyOption(
|
||||
value="trusted",
|
||||
label="Trusted",
|
||||
description="Tool output is verified safe. Will not trigger trust-chain blocks on downstream tools.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/policy/options",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolPolicyOptionsResponse,
|
||||
)
|
||||
async def get_tool_policy_options(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Return the available input and output policy options with descriptions.
|
||||
Static data — no DB call.
|
||||
"""
|
||||
return TOOL_POLICY_OPTIONS
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/list",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolListResponse,
|
||||
)
|
||||
async def list_tools(
|
||||
input_policy: Optional[ToolInputPolicy] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all auto-discovered tools and their policies.
|
||||
|
||||
Parameters:
|
||||
- input_policy: Optional filter — one of "trusted", "untrusted", "blocked"
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import list_tools as db_list_tools
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
tools = await db_list_tools(
|
||||
prisma_client=prisma_client, input_policy=input_policy
|
||||
)
|
||||
return ToolListResponse(tools=tools, total=len(tools))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error listing tools: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/{tool_name:path}/detail",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolDetailResponse,
|
||||
)
|
||||
async def get_tool_detail(
|
||||
tool_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get a single tool with its policy overrides (for UI detail view).
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
|
||||
from litellm.proxy.db.tool_registry_writer import list_overrides_for_tool
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
|
||||
if tool is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
overrides = await list_overrides_for_tool(
|
||||
prisma_client=prisma_client, tool_name=tool_name
|
||||
)
|
||||
return ToolDetailResponse(tool=tool, overrides=overrides)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error getting tool detail: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def _input_snippet_for_tool_log(sl: Any, max_len: int = 200) -> Optional[str]:
|
||||
"""Short snippet from messages or proxy_server_request for tool usage log row."""
|
||||
if sl is None:
|
||||
return None
|
||||
messages = getattr(sl, "messages", None)
|
||||
if messages is not None:
|
||||
s = _snippet_str(messages, max_len)
|
||||
if s:
|
||||
return s
|
||||
psr = getattr(sl, "proxy_server_request", None)
|
||||
if not psr:
|
||||
return None
|
||||
if isinstance(psr, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
psr = json.loads(psr)
|
||||
except Exception:
|
||||
return _snippet_str(psr, max_len)
|
||||
if isinstance(psr, dict):
|
||||
msgs = psr.get("messages")
|
||||
if msgs is None and isinstance(psr.get("body"), dict):
|
||||
msgs = psr["body"].get("messages")
|
||||
s = _snippet_str(msgs, max_len)
|
||||
if s:
|
||||
return s
|
||||
return _snippet_str(psr, max_len)
|
||||
|
||||
|
||||
def _snippet_str(text: Any, max_len: int = 200) -> Optional[str]:
|
||||
if text is None:
|
||||
return None
|
||||
if isinstance(text, str):
|
||||
s = text
|
||||
elif isinstance(text, list):
|
||||
parts = []
|
||||
for item in text:
|
||||
if isinstance(item, dict) and "content" in item:
|
||||
c = item["content"]
|
||||
parts.append(c if isinstance(c, str) else str(c))
|
||||
else:
|
||||
parts.append(str(item))
|
||||
s = " ".join(parts)
|
||||
else:
|
||||
s = str(text)
|
||||
if not s or s == "{}":
|
||||
return None
|
||||
return (s[:max_len] + "...") if len(s) > max_len else s
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/{tool_name:path}/logs",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolUsageLogsResponse,
|
||||
)
|
||||
async def get_tool_usage_logs(
|
||||
tool_name: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
start_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
|
||||
end_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Return paginated spend logs for requests that used this tool (from SpendLogToolIndex).
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
where: dict = {"tool_name": tool_name}
|
||||
if start_date or end_date:
|
||||
start_time_filter: Optional[datetime] = None
|
||||
end_time_filter: Optional[datetime] = None
|
||||
if start_date:
|
||||
try:
|
||||
start_time_filter = datetime.strptime(
|
||||
start_date + "T00:00:00", "%Y-%m-%dT%H:%M:%S"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
if end_date:
|
||||
try:
|
||||
end_time_filter = datetime.strptime(
|
||||
end_date + "T23:59:59", "%Y-%m-%dT%H:%M:%S"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
if start_time_filter is not None or end_time_filter is not None:
|
||||
where["start_time"] = {}
|
||||
if start_time_filter is not None:
|
||||
where["start_time"]["gte"] = start_time_filter
|
||||
if end_time_filter is not None:
|
||||
where["start_time"]["lte"] = end_time_filter
|
||||
|
||||
total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where)
|
||||
index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many(
|
||||
where=where,
|
||||
order={"start_time": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
request_ids = [r.request_id for r in index_rows]
|
||||
if not request_ids:
|
||||
return ToolUsageLogsResponse(
|
||||
logs=[], total=total, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
|
||||
where={"request_id": {"in": request_ids}}
|
||||
)
|
||||
log_by_id = {s.request_id: s for s in spend_logs}
|
||||
|
||||
logs_out: List[ToolUsageLogEntry] = []
|
||||
for r in index_rows:
|
||||
sl = log_by_id.get(r.request_id)
|
||||
if not sl:
|
||||
continue
|
||||
ts = (
|
||||
sl.startTime.isoformat()
|
||||
if hasattr(sl.startTime, "isoformat")
|
||||
else str(sl.startTime)
|
||||
)
|
||||
logs_out.append(
|
||||
ToolUsageLogEntry(
|
||||
id=sl.request_id,
|
||||
timestamp=ts,
|
||||
model=getattr(sl, "model", None) or None,
|
||||
spend=getattr(sl, "spend", None),
|
||||
total_tokens=getattr(sl, "total_tokens", None),
|
||||
input_snippet=_input_snippet_for_tool_log(sl),
|
||||
)
|
||||
)
|
||||
|
||||
return ToolUsageLogsResponse(
|
||||
logs=logs_out, total=total, page=page, page_size=page_size
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error getting tool usage logs: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/{tool_name:path}",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_ToolTableRow,
|
||||
)
|
||||
async def get_tool(
|
||||
tool_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get details for a single tool.
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
|
||||
if tool is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
return tool
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error getting tool: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _resolve_key_hash_to_object_permission_id(
|
||||
prisma_client: "PrismaClient",
|
||||
key_hash: str,
|
||||
) -> Optional[str]:
|
||||
"""Resolve key (hash or raw) to object_permission_id; create permission if key has none."""
|
||||
from litellm.proxy.proxy_server import hash_token
|
||||
|
||||
hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash)
|
||||
if not hashed:
|
||||
return None
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed}
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
if op_id:
|
||||
return op_id
|
||||
new_id = str(uuid.uuid4())
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data={"object_permission_id": new_id, "blocked_tools": []}
|
||||
)
|
||||
updated_count = await prisma_client.db.litellm_verificationtoken.update_many(
|
||||
where={"token": hashed, "object_permission_id": None},
|
||||
data={"object_permission_id": new_id},
|
||||
)
|
||||
if updated_count == 0:
|
||||
await prisma_client.db.litellm_objectpermissiontable.delete(
|
||||
where={"object_permission_id": new_id}
|
||||
)
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed}
|
||||
)
|
||||
return getattr(row, "object_permission_id", None) if row else None
|
||||
return new_id
|
||||
|
||||
|
||||
async def _resolve_team_id_to_object_permission_id(
|
||||
prisma_client: "PrismaClient",
|
||||
team_id: str,
|
||||
) -> Optional[str]:
|
||||
"""Resolve team_id to object_permission_id; create permission if team has none."""
|
||||
if not team_id or not team_id.strip():
|
||||
return None
|
||||
team_id_clean = team_id.strip()
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id_clean},
|
||||
select={"object_permission_id": True},
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
if op_id:
|
||||
return op_id
|
||||
new_id = str(uuid.uuid4())
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data={"object_permission_id": new_id, "blocked_tools": []}
|
||||
)
|
||||
updated_count = await prisma_client.db.litellm_teamtable.update_many(
|
||||
where={"team_id": team_id_clean, "object_permission_id": None},
|
||||
data={"object_permission_id": new_id},
|
||||
)
|
||||
if updated_count == 0:
|
||||
await prisma_client.db.litellm_objectpermissiontable.delete(
|
||||
where={"object_permission_id": new_id}
|
||||
)
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id_clean},
|
||||
select={"object_permission_id": True},
|
||||
)
|
||||
return getattr(row, "object_permission_id", None) if row else None
|
||||
return new_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/tool/policy",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolPolicyUpdateResponse,
|
||||
)
|
||||
async def update_tool_policy(
|
||||
data: ToolPolicyUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Set the input_policy and/or output_policy for a tool (global), or block for a specific team/key (override).
|
||||
|
||||
Parameters:
|
||||
- tool_name: str - The tool to update
|
||||
- input_policy: optional - "trusted" | "untrusted" | "blocked"
|
||||
- output_policy: optional - "trusted" | "untrusted"
|
||||
- team_id: optional - if set, create/update override for this team only
|
||||
- key_hash: optional - if set, create/update override for this key only
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import (
|
||||
add_tool_to_object_permission_blocked,
|
||||
get_tool_policy_registry,
|
||||
remove_tool_from_object_permission_blocked,
|
||||
)
|
||||
from litellm.proxy.db.tool_registry_writer import (
|
||||
update_tool_policy as db_update_tool_policy,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
if data.team_id is not None or data.key_hash is not None:
|
||||
if data.team_id is not None and data.key_hash is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provide either team_id or key_hash, not both",
|
||||
)
|
||||
if data.key_hash is not None:
|
||||
op_id = await _resolve_key_hash_to_object_permission_id(
|
||||
prisma_client, data.key_hash
|
||||
)
|
||||
else:
|
||||
op_id = await _resolve_team_id_to_object_permission_id(
|
||||
prisma_client, data.team_id or ""
|
||||
)
|
||||
if op_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Key or team not found for the given identifier",
|
||||
)
|
||||
is_blocking = data.input_policy == "blocked"
|
||||
if is_blocking:
|
||||
ok = await add_tool_to_object_permission_blocked(
|
||||
prisma_client=prisma_client,
|
||||
object_permission_id=op_id,
|
||||
tool_name=data.tool_name,
|
||||
)
|
||||
else:
|
||||
ok = await remove_tool_from_object_permission_blocked(
|
||||
prisma_client=prisma_client,
|
||||
object_permission_id=op_id,
|
||||
tool_name=data.tool_name,
|
||||
)
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update policy override for tool '{data.tool_name}'",
|
||||
)
|
||||
registry = get_tool_policy_registry()
|
||||
if registry.is_initialized():
|
||||
await registry.sync_tool_policy_from_db(prisma_client)
|
||||
return ToolPolicyUpdateResponse(
|
||||
tool_name=data.tool_name,
|
||||
input_policy=data.input_policy,
|
||||
output_policy=data.output_policy,
|
||||
updated=True,
|
||||
team_id=data.team_id,
|
||||
key_hash=data.key_hash,
|
||||
)
|
||||
|
||||
if data.input_policy is None and data.output_policy is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one of input_policy or output_policy must be provided",
|
||||
)
|
||||
|
||||
updated = await db_update_tool_policy(
|
||||
prisma_client=prisma_client,
|
||||
tool_name=data.tool_name,
|
||||
updated_by=user_api_key_dict.user_id,
|
||||
input_policy=data.input_policy,
|
||||
output_policy=data.output_policy,
|
||||
)
|
||||
if updated is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update policy for tool '{data.tool_name}'",
|
||||
)
|
||||
registry = get_tool_policy_registry()
|
||||
if registry.is_initialized():
|
||||
await registry.sync_tool_policy_from_db(prisma_client)
|
||||
return ToolPolicyUpdateResponse(
|
||||
tool_name=updated.tool_name,
|
||||
input_policy=updated.input_policy,
|
||||
output_policy=updated.output_policy,
|
||||
updated=True,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error updating tool policy: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/v1/tool/{tool_name:path}/overrides",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_tool_policy_override(
|
||||
tool_name: str,
|
||||
team_id: Optional[str] = Query(
|
||||
None, description="Team ID of the override to remove"
|
||||
),
|
||||
key_hash: Optional[str] = Query(
|
||||
None, description="Key hash of the override to remove"
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Remove a policy override for a tool. Specify the override by team_id or key_hash
|
||||
(exactly one required).
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import (
|
||||
get_tool_policy_registry,
|
||||
remove_tool_from_object_permission_blocked,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
if team_id is None and key_hash is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one of team_id or key_hash is required to identify the override",
|
||||
)
|
||||
if team_id is not None and key_hash is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provide either team_id or key_hash, not both",
|
||||
)
|
||||
try:
|
||||
if key_hash is not None:
|
||||
op_id = await _resolve_key_hash_to_object_permission_id(
|
||||
prisma_client, key_hash
|
||||
)
|
||||
else:
|
||||
op_id = await _resolve_team_id_to_object_permission_id(
|
||||
prisma_client, team_id or ""
|
||||
)
|
||||
if op_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Key or team not found for the given identifier",
|
||||
)
|
||||
deleted = await remove_tool_from_object_permission_blocked(
|
||||
prisma_client=prisma_client,
|
||||
object_permission_id=op_id,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No override found for tool '{tool_name}' with the given scope",
|
||||
)
|
||||
registry = get_tool_policy_registry()
|
||||
if registry.is_initialized():
|
||||
await registry.sync_tool_policy_from_db(prisma_client)
|
||||
return {"deleted": True, "tool_name": tool_name}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error deleting tool policy override: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Types for the management endpoints
|
||||
|
||||
Might include fastapi/proxy requirements.txt related imports
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
|
||||
def is_valid_litellm_user_role(role_str: str) -> bool:
|
||||
"""
|
||||
Check if a string is a valid LitellmUserRoles enum value (case-insensitive).
|
||||
|
||||
Args:
|
||||
role_str: String to validate (e.g., "proxy_admin", "PROXY_ADMIN", "internal_user")
|
||||
|
||||
Returns:
|
||||
True if the string matches a valid LitellmUserRoles value, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Use _value2member_map_ for O(1) lookup, case-insensitive
|
||||
return role_str.lower() in LitellmUserRoles._value2member_map_
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_litellm_user_role(role_str) -> Optional[LitellmUserRoles]:
|
||||
"""
|
||||
Convert a string (or list of strings) to a LitellmUserRoles enum if valid (case-insensitive).
|
||||
|
||||
Handles list inputs since some SSO providers (e.g., Keycloak) return roles
|
||||
as arrays like ["proxy_admin"] instead of plain strings.
|
||||
|
||||
Args:
|
||||
role_str: String or list to convert (e.g., "proxy_admin", ["proxy_admin"])
|
||||
|
||||
Returns:
|
||||
LitellmUserRoles enum if valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
if isinstance(role_str, list):
|
||||
if len(role_str) == 0:
|
||||
return None
|
||||
role_str = role_str[0]
|
||||
# Use _value2member_map_ for O(1) lookup, case-insensitive
|
||||
result = LitellmUserRoles._value2member_map_.get(role_str.lower())
|
||||
return cast(Optional[LitellmUserRoles], result)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class CustomOpenID(OpenID):
|
||||
team_ids: List[str]
|
||||
user_role: Optional[LitellmUserRoles] = None
|
||||
extra_fields: Optional[Dict[str, Any]] = None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Usage endpoints package.
|
||||
|
||||
Re-exports the router from endpoints module.
|
||||
"""
|
||||
|
||||
from litellm.proxy.management_endpoints.usage_endpoints.endpoints import ( # noqa: F401
|
||||
router,
|
||||
)
|
||||
@@ -0,0 +1,578 @@
|
||||
"""
|
||||
AI Usage Chat - uses LLM tool calling to answer questions about
|
||||
usage/spend data by querying the aggregated daily activity endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import date
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import DEFAULT_COMPETITOR_DISCOVERY_MODEL
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
USAGE_AI_TEMPERATURE = 0.2
|
||||
|
||||
TABLE_DAILY_USER_SPEND = "litellm_dailyuserspend"
|
||||
TABLE_DAILY_TEAM_SPEND = "litellm_dailyteamspend"
|
||||
TABLE_DAILY_TAG_SPEND = "litellm_dailytagspend"
|
||||
|
||||
ENTITY_FIELD_USER = "user_id"
|
||||
ENTITY_FIELD_TEAM = "team_id"
|
||||
ENTITY_FIELD_TAG = "tag"
|
||||
|
||||
PAGINATED_PAGE_SIZE = 200
|
||||
MAX_CHAT_MESSAGES = 20
|
||||
TOP_N_MODELS = 15
|
||||
TOP_N_PROVIDERS = 10
|
||||
TOP_N_KEYS = 10
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SSEStatusEvent(TypedDict):
|
||||
type: Literal["status"]
|
||||
message: str
|
||||
|
||||
|
||||
class SSEToolCallEvent(TypedDict, total=False):
|
||||
type: Literal["tool_call"]
|
||||
tool_name: str
|
||||
tool_label: str
|
||||
arguments: Dict[str, str]
|
||||
status: Literal["running", "complete", "error"]
|
||||
error: str
|
||||
|
||||
|
||||
class SSEChunkEvent(TypedDict):
|
||||
type: Literal["chunk"]
|
||||
content: str
|
||||
|
||||
|
||||
class SSEDoneEvent(TypedDict):
|
||||
type: Literal["done"]
|
||||
|
||||
|
||||
class SSEErrorEvent(TypedDict):
|
||||
type: Literal["error"]
|
||||
message: str
|
||||
|
||||
|
||||
SSEEvent = (
|
||||
SSEStatusEvent | SSEToolCallEvent | SSEChunkEvent | SSEDoneEvent | SSEErrorEvent
|
||||
)
|
||||
|
||||
|
||||
class ToolHandler(TypedDict):
|
||||
fetch: Callable[..., Any]
|
||||
summarise: Callable[[Dict[str, Any]], str]
|
||||
label: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions (OpenAI function-calling schema)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DATE_PARAMS = {
|
||||
"start_date": {"type": "string", "description": "Start date in YYYY-MM-DD format"},
|
||||
"end_date": {"type": "string", "description": "End date in YYYY-MM-DD format"},
|
||||
}
|
||||
|
||||
_TOOL_USAGE = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_usage_data",
|
||||
"description": (
|
||||
"Fetch aggregated global usage/spend data. Returns daily spend, "
|
||||
"token counts, request counts, and breakdowns by model, provider, "
|
||||
"and API key. Use for overall spend, top models, top providers."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
**_DATE_PARAMS,
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"description": "Optional user ID filter. Omit for global view.",
|
||||
},
|
||||
},
|
||||
"required": ["start_date", "end_date"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_TOOL_TEAM = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_team_usage_data",
|
||||
"description": (
|
||||
"Fetch usage/spend data broken down by team. Use for questions "
|
||||
"like 'which team spends the most' or 'show me team X usage'."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
**_DATE_PARAMS,
|
||||
"team_ids": {
|
||||
"type": "string",
|
||||
"description": "Optional comma-separated team IDs. Omit for all teams.",
|
||||
},
|
||||
},
|
||||
"required": ["start_date", "end_date"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_TOOL_TAG = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_tag_usage_data",
|
||||
"description": (
|
||||
"Fetch usage/spend data broken down by tag. Tags are labels "
|
||||
"attached to requests (features, environments, credentials)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
**_DATE_PARAMS,
|
||||
"tags": {
|
||||
"type": "string",
|
||||
"description": "Optional comma-separated tag names. Omit for all tags.",
|
||||
},
|
||||
},
|
||||
"required": ["start_date", "end_date"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
TOOLS_BASE = [_TOOL_USAGE]
|
||||
TOOLS_ADMIN = [_TOOL_USAGE, _TOOL_TEAM, _TOOL_TAG]
|
||||
|
||||
|
||||
def get_tools_for_role(is_admin: bool) -> List[Dict[str, Any]]:
|
||||
"""Return the tool list appropriate for the user's role."""
|
||||
return TOOLS_ADMIN if is_admin else TOOLS_BASE
|
||||
|
||||
|
||||
_SYSTEM_PROMPT_BASE = (
|
||||
"You are an AI assistant embedded in the LiteLLM Usage dashboard. "
|
||||
"You help users understand their LLM API spend and usage data.\n\n"
|
||||
"ALWAYS call the appropriate tool(s) first to fetch data before answering. "
|
||||
"You may call multiple tools if the question spans different dimensions.\n\n"
|
||||
"Guidelines:\n"
|
||||
"- Be concise and specific. Use exact numbers from the data.\n"
|
||||
"- Format costs as dollar amounts (e.g. $12.34).\n"
|
||||
"- When comparing entities, show a ranked list.\n"
|
||||
"- If data is empty or no results found, say so clearly.\n"
|
||||
"- Do not hallucinate data — only use what the tools return.\n"
|
||||
"- Today's date will be provided below. Use it to interpret relative dates "
|
||||
"like 'this week', 'this month', 'last 7 days', etc."
|
||||
)
|
||||
|
||||
_TOOL_DESCRIPTIONS_ADMIN = (
|
||||
"You have access to these tools:\n"
|
||||
"- `get_usage_data`: Global/user-level usage (spend, models, providers, API keys)\n"
|
||||
"- `get_team_usage_data`: Team-level usage breakdown\n"
|
||||
"- `get_tag_usage_data`: Tag-level usage breakdown\n\n"
|
||||
)
|
||||
|
||||
_TOOL_DESCRIPTIONS_BASE = (
|
||||
"You have access to this tool:\n"
|
||||
"- `get_usage_data`: Your usage data (spend, models, providers, API keys)\n\n"
|
||||
)
|
||||
|
||||
|
||||
def _build_system_prompt(is_admin: bool) -> str:
|
||||
"""Build role-appropriate system prompt with today's date."""
|
||||
tool_desc = _TOOL_DESCRIPTIONS_ADMIN if is_admin else _TOOL_DESCRIPTIONS_BASE
|
||||
return (
|
||||
f"{_SYSTEM_PROMPT_BASE}\n\n{tool_desc}"
|
||||
f"Today's date: {date.today().isoformat()}"
|
||||
)
|
||||
|
||||
|
||||
# keep a public reference for test assertions
|
||||
SYSTEM_PROMPT = _SYSTEM_PROMPT_BASE
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data fetchers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_csv_ids(raw: Optional[str]) -> Optional[List[str]]:
|
||||
if not raw:
|
||||
return None
|
||||
return [t.strip() for t in raw.split(",") if t.strip()]
|
||||
|
||||
|
||||
async def _query_activity(
|
||||
table_name: str,
|
||||
entity_id_field: str,
|
||||
entity_id: Optional[Any],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
*,
|
||||
use_aggregated: bool = False,
|
||||
) -> SpendAnalyticsPaginatedResponse:
|
||||
"""Shared helper that calls the daily activity query layer."""
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import (
|
||||
get_daily_activity,
|
||||
get_daily_activity_aggregated,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if use_aggregated:
|
||||
return await get_daily_activity_aggregated(
|
||||
prisma_client=prisma_client,
|
||||
table_name=table_name,
|
||||
entity_id_field=entity_id_field,
|
||||
entity_id=entity_id,
|
||||
entity_metadata_field=None,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=None,
|
||||
api_key=None,
|
||||
)
|
||||
return await get_daily_activity(
|
||||
prisma_client=prisma_client,
|
||||
table_name=table_name,
|
||||
entity_id_field=entity_id_field,
|
||||
entity_id=entity_id,
|
||||
entity_metadata_field=None,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=None,
|
||||
api_key=None,
|
||||
page=1,
|
||||
page_size=PAGINATED_PAGE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_usage_data(
|
||||
start_date: str, end_date: str, user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
resp = await _query_activity(
|
||||
TABLE_DAILY_USER_SPEND,
|
||||
ENTITY_FIELD_USER,
|
||||
user_id,
|
||||
start_date,
|
||||
end_date,
|
||||
use_aggregated=True,
|
||||
)
|
||||
return resp.model_dump(mode="json")
|
||||
|
||||
|
||||
async def _fetch_team_usage_data(
|
||||
start_date: str, end_date: str, team_ids: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
resp = await _query_activity(
|
||||
TABLE_DAILY_TEAM_SPEND,
|
||||
ENTITY_FIELD_TEAM,
|
||||
_parse_csv_ids(team_ids),
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
return resp.model_dump(mode="json")
|
||||
|
||||
|
||||
async def _fetch_tag_usage_data(
|
||||
start_date: str, end_date: str, tags: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
resp = await _query_activity(
|
||||
TABLE_DAILY_TAG_SPEND,
|
||||
ENTITY_FIELD_TAG,
|
||||
_parse_csv_ids(tags),
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
return resp.model_dump(mode="json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Summarisers — convert raw JSON to concise text the LLM can reason over
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _accumulate_breakdown(
|
||||
results: List[Dict[str, Any]], dimension: str, fields: List[str]
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""Aggregate a single breakdown dimension across days."""
|
||||
totals: Dict[str, Dict[str, float]] = {}
|
||||
for day in results:
|
||||
for key, entry in day.get("breakdown", {}).get(dimension, {}).items():
|
||||
if key not in totals:
|
||||
totals[key] = {f: 0.0 for f in fields}
|
||||
m = entry.get("metrics", {})
|
||||
for f in fields:
|
||||
totals[key][f] += m.get(f, 0)
|
||||
return totals
|
||||
|
||||
|
||||
def _ranked_lines(
|
||||
totals: Dict[str, Dict[str, float]],
|
||||
fmt: Callable[[str, Dict[str, float]], str],
|
||||
limit: int,
|
||||
) -> List[str]:
|
||||
"""Sort by spend descending, format each entry, and truncate."""
|
||||
return [
|
||||
fmt(name, vals)
|
||||
for name, vals in sorted(totals.items(), key=lambda x: -x[1].get("spend", 0))[
|
||||
:limit
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def _summarise_usage_data(data: Dict[str, Any]) -> str:
|
||||
meta = data.get("metadata", {})
|
||||
results = data.get("results", [])
|
||||
|
||||
header = (
|
||||
f"Total Spend: ${meta.get('total_spend', 0):.4f}\n"
|
||||
f"Total Requests: {meta.get('total_api_requests', 0)}\n"
|
||||
f"Successful: {meta.get('total_successful_requests', 0)} | "
|
||||
f"Failed: {meta.get('total_failed_requests', 0)}\n"
|
||||
f"Total Tokens: {meta.get('total_tokens', 0)}"
|
||||
)
|
||||
|
||||
models = _accumulate_breakdown(
|
||||
results, "models", ["spend", "api_requests", "total_tokens"]
|
||||
)
|
||||
providers = _accumulate_breakdown(results, "providers", ["spend", "api_requests"])
|
||||
|
||||
model_lines = _ranked_lines(
|
||||
models,
|
||||
lambda n, d: f" - {n}: ${d['spend']:.4f} ({int(d['api_requests'])} reqs, {int(d['total_tokens'])} tokens)",
|
||||
TOP_N_MODELS,
|
||||
)
|
||||
provider_lines = _ranked_lines(
|
||||
providers,
|
||||
lambda n, d: f" - {n}: ${d['spend']:.4f} ({int(d['api_requests'])} reqs)",
|
||||
TOP_N_PROVIDERS,
|
||||
)
|
||||
|
||||
sections = [header, ""]
|
||||
sections += ["Top Models by Spend:"] + (model_lines or [" (no data)"]) + [""]
|
||||
sections += ["Top Providers by Spend:"] + (provider_lines or [" (no data)"])
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _summarise_entity_data(data: Dict[str, Any], entity_label: str) -> str:
|
||||
"""Summarise team/tag entity usage data."""
|
||||
results = data.get("results", [])
|
||||
if not results:
|
||||
return f"No {entity_label} usage data found for the given date range."
|
||||
|
||||
totals: Dict[str, Dict[str, Any]] = {}
|
||||
for day in results:
|
||||
for eid, entry in day.get("breakdown", {}).get("entities", {}).items():
|
||||
if eid not in totals:
|
||||
alias = entry.get("metadata", {}).get("alias", eid)
|
||||
totals[eid] = {"alias": alias, "spend": 0.0, "requests": 0, "tokens": 0}
|
||||
m = entry.get("metrics", {})
|
||||
totals[eid]["spend"] += m.get("spend", 0)
|
||||
totals[eid]["requests"] += m.get("api_requests", 0)
|
||||
totals[eid]["tokens"] += m.get("total_tokens", 0)
|
||||
|
||||
lines = [f"{entity_label} Usage ({len(totals)} {entity_label.lower()}s):", ""]
|
||||
for eid, d in sorted(totals.items(), key=lambda x: -x[1]["spend"]):
|
||||
label = d["alias"] if d["alias"] != eid else eid
|
||||
lines.append(
|
||||
f"- {label} (ID: {eid}): ${d['spend']:.4f} | "
|
||||
f"{int(d['requests'])} reqs | {int(d['tokens'])} tokens"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool dispatch registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TOOL_HANDLERS: Dict[str, ToolHandler] = {
|
||||
"get_usage_data": ToolHandler(
|
||||
fetch=_fetch_usage_data,
|
||||
summarise=_summarise_usage_data,
|
||||
label="global usage data",
|
||||
),
|
||||
"get_team_usage_data": ToolHandler(
|
||||
fetch=_fetch_team_usage_data,
|
||||
summarise=lambda data: _summarise_entity_data(data, "Team"),
|
||||
label="team usage data",
|
||||
),
|
||||
"get_tag_usage_data": ToolHandler(
|
||||
fetch=_fetch_tag_usage_data,
|
||||
summarise=lambda data: _summarise_entity_data(data, "Tag"),
|
||||
label="tag usage data",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sse(event: SSEEvent) -> str:
|
||||
return f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
|
||||
def _resolve_fetch_kwargs(
|
||||
fn_name: str,
|
||||
fn_args: Dict[str, str],
|
||||
user_id: Optional[str],
|
||||
is_admin: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build keyword arguments for a tool's fetch function."""
|
||||
start_date = fn_args.get("start_date", "")
|
||||
end_date = fn_args.get("end_date", "")
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("Missing required start_date or end_date from tool arguments")
|
||||
kwargs: Dict[str, Any] = {"start_date": start_date, "end_date": end_date}
|
||||
if fn_name == "get_usage_data":
|
||||
if not is_admin:
|
||||
kwargs["user_id"] = user_id
|
||||
elif fn_args.get("user_id"):
|
||||
kwargs["user_id"] = fn_args["user_id"]
|
||||
elif fn_name == "get_team_usage_data" and fn_args.get("team_ids"):
|
||||
kwargs["team_ids"] = fn_args["team_ids"]
|
||||
elif fn_name == "get_tag_usage_data" and fn_args.get("tags"):
|
||||
kwargs["tags"] = fn_args["tags"]
|
||||
return kwargs
|
||||
|
||||
|
||||
async def _execute_tool_call(
|
||||
handler: ToolHandler,
|
||||
fn_name: str,
|
||||
fn_args: Dict[str, str],
|
||||
user_id: Optional[str],
|
||||
is_admin: bool,
|
||||
) -> str:
|
||||
"""Run a single tool and return the summarised result text."""
|
||||
kwargs = _resolve_fetch_kwargs(fn_name, fn_args, user_id, is_admin)
|
||||
raw_data = await handler["fetch"](**kwargs)
|
||||
return handler["summarise"](raw_data)
|
||||
|
||||
|
||||
async def _process_tool_call(
|
||||
tc: Any,
|
||||
chat_messages: List[Dict[str, Any]],
|
||||
user_id: Optional[str],
|
||||
is_admin: bool,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Execute a single tool call, yielding SSE events for status."""
|
||||
fn_name = tc.function.name
|
||||
fn_args = json.loads(tc.function.arguments)
|
||||
|
||||
allowed_names = {t["function"]["name"] for t in get_tools_for_role(is_admin)}
|
||||
handler = TOOL_HANDLERS.get(fn_name)
|
||||
|
||||
if fn_name not in allowed_names or not handler:
|
||||
chat_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": f"Tool not available: {fn_name}",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
tool_event_base = {
|
||||
"type": "tool_call",
|
||||
"tool_name": fn_name,
|
||||
"tool_label": handler["label"],
|
||||
"arguments": fn_args,
|
||||
}
|
||||
yield _sse(cast(SSEToolCallEvent, {**tool_event_base, "status": "running"}))
|
||||
|
||||
try:
|
||||
tool_result = await _execute_tool_call(
|
||||
handler, fn_name, fn_args, user_id, is_admin
|
||||
)
|
||||
yield _sse(cast(SSEToolCallEvent, {**tool_event_base, "status": "complete"}))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("Tool %s failed: %s", fn_name, e)
|
||||
tool_result = f"Error fetching {handler['label']}. Please try again."
|
||||
yield _sse(cast(SSEToolCallEvent, {**tool_event_base, "status": "error"}))
|
||||
|
||||
chat_messages.append(
|
||||
{"role": "tool", "tool_call_id": tc.id, "content": tool_result}
|
||||
)
|
||||
|
||||
|
||||
async def _stream_final_response(
|
||||
model: str, chat_messages: List[Dict[str, Any]]
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream the final LLM response after tool results are appended."""
|
||||
yield _sse({"type": "status", "message": "Analyzing results..."})
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=chat_messages,
|
||||
stream=True,
|
||||
temperature=USAGE_AI_TEMPERATURE,
|
||||
)
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
yield _sse({"type": "chunk", "content": delta})
|
||||
|
||||
|
||||
async def stream_usage_ai_chat(
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream SSE events: status → tool_call → chunk → done."""
|
||||
resolved_model = (model or "").strip() or DEFAULT_COMPETITOR_DISCOVERY_MODEL
|
||||
truncated = (
|
||||
messages[-MAX_CHAT_MESSAGES:] if len(messages) > MAX_CHAT_MESSAGES else messages
|
||||
)
|
||||
chat_messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": _build_system_prompt(is_admin)},
|
||||
*truncated,
|
||||
]
|
||||
|
||||
try:
|
||||
yield _sse({"type": "status", "message": "Thinking..."})
|
||||
tools = get_tools_for_role(is_admin)
|
||||
response = await litellm.acompletion(
|
||||
model=resolved_model,
|
||||
messages=chat_messages,
|
||||
tools=tools,
|
||||
temperature=USAGE_AI_TEMPERATURE,
|
||||
)
|
||||
choice = response.choices[0] # type: ignore
|
||||
|
||||
if not choice.message.tool_calls:
|
||||
if choice.message.content:
|
||||
yield _sse({"type": "chunk", "content": choice.message.content})
|
||||
yield _sse({"type": "done"})
|
||||
return
|
||||
|
||||
chat_messages.append(choice.message.model_dump())
|
||||
for tc in choice.message.tool_calls:
|
||||
async for event in _process_tool_call(tc, chat_messages, user_id, is_admin):
|
||||
yield event
|
||||
async for event in _stream_final_response(resolved_model, chat_messages):
|
||||
yield event
|
||||
yield _sse({"type": "done"})
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("AI usage chat failed: %s", e)
|
||||
yield _sse(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "An internal error occurred. Please try again.",
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
USAGE AI CHAT ENDPOINTS
|
||||
|
||||
/usage/ai/chat - Stream AI chat responses about usage data
|
||||
"""
|
||||
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class UsageAIChatRequest(BaseModel):
|
||||
messages: List[ChatMessage] = Field(
|
||||
..., description="Chat messages (user/assistant history)"
|
||||
)
|
||||
model: Optional[str] = Field(default=None, description="Model to use for AI chat")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/ai/chat",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def usage_ai_chat(
|
||||
data: UsageAIChatRequest,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
AI chat about usage data. Streams SSE events with the AI response.
|
||||
The AI agent has access to tools that query aggregated daily activity data.
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_user_has_admin_view,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.usage_endpoints.ai_usage_chat import (
|
||||
stream_usage_ai_chat,
|
||||
)
|
||||
|
||||
is_admin = _user_has_admin_view(user_api_key_dict)
|
||||
user_id = user_api_key_dict.user_id
|
||||
messages = [{"role": m.role, "content": m.content} for m in data.messages]
|
||||
|
||||
return StreamingResponse(
|
||||
stream_usage_ai_chat(
|
||||
messages=messages,
|
||||
model=data.model,
|
||||
user_id=user_id,
|
||||
is_admin=is_admin,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
@@ -0,0 +1,775 @@
|
||||
"""
|
||||
User Agent Analytics Endpoints
|
||||
|
||||
This module provides optimized endpoints for tracking user agent activity metrics including:
|
||||
- Daily Active Users (DAU) by tags for configurable number of days
|
||||
- Weekly Active Users (WAU) by tags for configurable number of weeks
|
||||
- Monthly Active Users (MAU) by tags for configurable number of months
|
||||
- Summary analytics by tags
|
||||
|
||||
These endpoints use optimized single SQL queries with joins to efficiently calculate
|
||||
user metrics from tag activity data and return time series for dashboard visualization.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
# Constants for analytics periods
|
||||
MAX_DAYS = 7 # Number of days to show in DAU analytics
|
||||
MAX_WEEKS = 7 # Number of weeks to show in WAU analytics
|
||||
MAX_MONTHS = 7 # Number of months to show in MAU analytics
|
||||
MAX_TAGS = 250 # Maximum number of distinct tags to return
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class TagActiveUsersResponse(BaseModel):
|
||||
"""Response for tag active users metrics"""
|
||||
|
||||
tag: str
|
||||
active_users: int
|
||||
date: str # The specific date or period identifier
|
||||
period_start: Optional[
|
||||
str
|
||||
] = None # For WAU/MAU, this will be the start of the period
|
||||
period_end: Optional[str] = None # For WAU/MAU, this will be the end of the period
|
||||
|
||||
|
||||
class ActiveUsersAnalyticsResponse(BaseModel):
|
||||
"""Response for active users analytics"""
|
||||
|
||||
results: List[TagActiveUsersResponse]
|
||||
|
||||
|
||||
class TagSummaryMetrics(BaseModel):
|
||||
"""Summary metrics for a tag"""
|
||||
|
||||
tag: str
|
||||
unique_users: int
|
||||
total_requests: int
|
||||
successful_requests: int
|
||||
failed_requests: int
|
||||
total_tokens: int
|
||||
total_spend: float
|
||||
|
||||
|
||||
class TagSummaryResponse(BaseModel):
|
||||
"""Response for tag summary analytics"""
|
||||
|
||||
results: List[TagSummaryMetrics]
|
||||
|
||||
|
||||
class DistinctTagResponse(BaseModel):
|
||||
"""Response for distinct user agent tags"""
|
||||
|
||||
tag: str
|
||||
|
||||
|
||||
class DistinctTagsResponse(BaseModel):
|
||||
"""Response for all distinct user agent tags"""
|
||||
|
||||
results: List[DistinctTagResponse]
|
||||
|
||||
|
||||
class PerUserMetrics(BaseModel):
|
||||
"""Metrics for individual user"""
|
||||
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
total_requests: int = 0
|
||||
total_tokens: int = 0
|
||||
spend: float = 0.0
|
||||
|
||||
|
||||
class PerUserAnalyticsResponse(BaseModel):
|
||||
"""Response for per-user analytics"""
|
||||
|
||||
results: List[PerUserMetrics]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/distinct",
|
||||
response_model=DistinctTagsResponse,
|
||||
tags=["tag management", "user agent analytics"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_distinct_user_agent_tags(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get all distinct user agent tags up to a maximum of {MAX_TAGS} tags.
|
||||
|
||||
This endpoint returns all unique user agent tags found in the database,
|
||||
sorted by frequency of usage.
|
||||
|
||||
Returns:
|
||||
DistinctTagsResponse: List of distinct user agent tags
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
dts.tag,
|
||||
COUNT(*) as usage_count
|
||||
FROM "LiteLLM_DailyTagSpend" dts
|
||||
WHERE dts.tag LIKE 'User-Agent:%' OR dts.tag NOT LIKE '%:%'
|
||||
GROUP BY dts.tag
|
||||
ORDER BY usage_count DESC
|
||||
LIMIT {MAX_TAGS}
|
||||
"""
|
||||
|
||||
db_response = await prisma_client.db.query_raw(sql_query)
|
||||
|
||||
results = [DistinctTagResponse(tag=row["tag"]) for row in db_response]
|
||||
|
||||
return DistinctTagsResponse(results=results)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch distinct user agent tags: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/dau",
|
||||
response_model=ActiveUsersAnalyticsResponse,
|
||||
tags=["tag management", "user agent analytics"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_daily_active_users(
|
||||
tag_filter: Optional[str] = Query(
|
||||
default=None,
|
||||
description="Filter by specific tag (optional)",
|
||||
),
|
||||
tag_filters: Optional[List[str]] = Query(
|
||||
default=None,
|
||||
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get Daily Active Users (DAU) by tags for the last {MAX_DAYS} days ending on UTC today + 1 day.
|
||||
|
||||
This endpoint efficiently calculates unique users per tag for each of the last {MAX_DAYS} days
|
||||
using a single optimized SQL query, perfect for dashboard time series visualization.
|
||||
|
||||
Args:
|
||||
tag_filter: Optional filter to specific tag (legacy)
|
||||
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
|
||||
|
||||
Returns:
|
||||
ActiveUsersAnalyticsResponse: DAU data by tag for each of the last {MAX_DAYS} days
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Calculate end_date as UTC today + 1 day
|
||||
from datetime import timezone
|
||||
|
||||
end_dt = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) + timedelta(days=1)
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Calculate date range (last MAX_DAYS days)
|
||||
start_dt = end_dt - timedelta(days=MAX_DAYS)
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Build SQL query with optional tag filter(s)
|
||||
where_clause = (
|
||||
"WHERE dts.date >= $1 AND dts.date <= $2 AND vt.user_id IS NOT NULL"
|
||||
)
|
||||
params = [start_date, end_date]
|
||||
|
||||
# Handle multiple tag filters (takes precedence over single tag filter)
|
||||
if tag_filters and len(tag_filters) > 0:
|
||||
tag_conditions = []
|
||||
for i, tag in enumerate(tag_filters):
|
||||
param_index = len(params) + 1
|
||||
tag_conditions.append(f"dts.tag = ${param_index}")
|
||||
params.append(tag)
|
||||
where_clause += f" AND ({' OR '.join(tag_conditions)})"
|
||||
elif tag_filter:
|
||||
where_clause += " AND dts.tag ILIKE $3"
|
||||
params.append(f"%{tag_filter}%")
|
||||
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
dts.tag,
|
||||
dts.date,
|
||||
COUNT(DISTINCT vt.user_id) as active_users
|
||||
FROM "LiteLLM_DailyTagSpend" dts
|
||||
INNER JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
|
||||
{where_clause}
|
||||
GROUP BY dts.tag, dts.date
|
||||
ORDER BY dts.date DESC, active_users DESC
|
||||
"""
|
||||
|
||||
db_response = await prisma_client.db.query_raw(sql_query, *params)
|
||||
|
||||
results = [
|
||||
TagActiveUsersResponse(
|
||||
tag=row["tag"], active_users=row["active_users"], date=row["date"]
|
||||
)
|
||||
for row in db_response
|
||||
]
|
||||
|
||||
return ActiveUsersAnalyticsResponse(results=results)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch DAU analytics: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/wau",
|
||||
response_model=ActiveUsersAnalyticsResponse,
|
||||
tags=["tag management", "user agent analytics"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_weekly_active_users(
|
||||
tag_filter: Optional[str] = Query(
|
||||
default=None,
|
||||
description="Filter by specific tag (optional)",
|
||||
),
|
||||
tag_filters: Optional[List[str]] = Query(
|
||||
default=None,
|
||||
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get Weekly Active Users (WAU) by tags for the last {MAX_WEEKS} weeks ending on UTC today + 1 day.
|
||||
|
||||
Shows week-by-week breakdown:
|
||||
- Week 1 (Jan 1): Earliest week (7 weeks ago)
|
||||
- Week 2 (Jan 8): Next week (6 weeks ago)
|
||||
- Week 3 (Jan 15): Next week (5 weeks ago)
|
||||
- ... and so on for {MAX_WEEKS} weeks total
|
||||
- Week 7: Most recent week ending on UTC today + 1 day
|
||||
|
||||
Args:
|
||||
tag_filter: Optional filter to specific tag (legacy)
|
||||
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
|
||||
|
||||
Returns:
|
||||
ActiveUsersAnalyticsResponse: WAU data by tag for each of the last {MAX_WEEKS} weeks with descriptive week labels (e.g., "Week 1 (Jan 1)")
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Calculate end_date as UTC today + 1 day
|
||||
from datetime import timezone
|
||||
|
||||
end_dt = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) + timedelta(days=1)
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Calculate date range for all weeks (49 days total)
|
||||
# Start from 48 days before end_date to cover exactly MAX_WEEKS complete weeks
|
||||
start_dt = end_dt - timedelta(
|
||||
days=(MAX_WEEKS * 7 - 1)
|
||||
) # MAX_WEEKS weeks * 7 days - 1
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Build SQL query with optional tag filter(s)
|
||||
where_clause = (
|
||||
"WHERE dts.date >= $1 AND dts.date <= $2 AND vt.user_id IS NOT NULL"
|
||||
)
|
||||
params = [start_date, end_date]
|
||||
|
||||
# Handle multiple tag filters (takes precedence over single tag filter)
|
||||
if tag_filters and len(tag_filters) > 0:
|
||||
tag_conditions = []
|
||||
for i, tag in enumerate(tag_filters):
|
||||
param_index = len(params) + 1
|
||||
tag_conditions.append(f"dts.tag = ${param_index}")
|
||||
params.append(tag)
|
||||
where_clause += f" AND ({' OR '.join(tag_conditions)})"
|
||||
elif tag_filter:
|
||||
where_clause += " AND dts.tag ILIKE $3"
|
||||
params.append(f"%{tag_filter}%")
|
||||
|
||||
# Use window function to group by weeks with clear week numbering
|
||||
sql_query = f"""
|
||||
WITH weekly_data AS (
|
||||
SELECT
|
||||
dts.tag,
|
||||
dts.date,
|
||||
vt.user_id,
|
||||
-- Calculate week number (0 = Week 1 most recent, 1 = Week 2, etc.)
|
||||
FLOOR((DATE '{end_date}' - dts.date::date) / 7) as week_offset
|
||||
FROM "LiteLLM_DailyTagSpend" dts
|
||||
INNER JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
|
||||
{where_clause}
|
||||
)
|
||||
SELECT
|
||||
tag,
|
||||
COUNT(DISTINCT user_id) as active_users,
|
||||
-- Week identifier with month and day (Week 1 (earliest), Week 2, etc.)
|
||||
'Week ' || ({MAX_WEEKS} - week_offset)::text || ' (' ||
|
||||
TO_CHAR(DATE '{end_date}' - (week_offset * 7 || ' days')::interval - '6 days'::interval, 'Mon DD') || ')' as date,
|
||||
-- Calculate week start and end dates for each week
|
||||
(DATE '{end_date}' - (week_offset * 7 || ' days')::interval - '6 days'::interval)::text as period_start,
|
||||
(DATE '{end_date}' - (week_offset * 7 || ' days')::interval)::text as period_end,
|
||||
week_offset
|
||||
FROM weekly_data
|
||||
WHERE week_offset < {MAX_WEEKS}
|
||||
GROUP BY tag, week_offset
|
||||
ORDER BY week_offset DESC, active_users DESC
|
||||
"""
|
||||
|
||||
db_response = await prisma_client.db.query_raw(sql_query, *params)
|
||||
|
||||
results = [
|
||||
TagActiveUsersResponse(
|
||||
tag=row["tag"],
|
||||
active_users=row["active_users"],
|
||||
date=row[
|
||||
"date"
|
||||
], # This will be "Week 1 (Jan 15)", "Week 2 (Jan 8)", etc.
|
||||
period_start=row["period_start"],
|
||||
period_end=row["period_end"],
|
||||
)
|
||||
for row in db_response
|
||||
]
|
||||
|
||||
return ActiveUsersAnalyticsResponse(results=results)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch WAU analytics: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/mau",
|
||||
response_model=ActiveUsersAnalyticsResponse,
|
||||
tags=["tag management", "user agent analytics"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_monthly_active_users(
|
||||
tag_filter: Optional[str] = Query(
|
||||
default=None,
|
||||
description="Filter by specific tag (optional)",
|
||||
),
|
||||
tag_filters: Optional[List[str]] = Query(
|
||||
default=None,
|
||||
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get Monthly Active Users (MAU) by tags for the last {MAX_MONTHS} months ending on UTC today + 1 day.
|
||||
|
||||
Shows month-by-month breakdown:
|
||||
- Month 1 (Nov): Earliest month (7 months ago, 30-day period)
|
||||
- Month 2 (Dec): Next month (6 months ago)
|
||||
- Month 3 (Jan): Next month (5 months ago)
|
||||
- ... and so on for {MAX_MONTHS} months total
|
||||
- Month 7: Most recent month ending on UTC today + 1 day
|
||||
|
||||
Args:
|
||||
tag_filter: Optional filter to specific tag (legacy)
|
||||
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
|
||||
|
||||
Returns:
|
||||
ActiveUsersAnalyticsResponse: MAU data by tag for each of the last {MAX_MONTHS} months with descriptive month labels (e.g., "Month 1 (Nov)")
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Calculate end_date as UTC today + 1 day
|
||||
from datetime import timezone
|
||||
|
||||
end_dt = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) + timedelta(days=1)
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Calculate date range for all months (210 days total)
|
||||
# Start from 209 days before end_date to cover exactly MAX_MONTHS complete months
|
||||
start_dt = end_dt - timedelta(
|
||||
days=(MAX_MONTHS * 30 - 1)
|
||||
) # MAX_MONTHS months * 30 days - 1
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Build SQL query with optional tag filter(s)
|
||||
where_clause = (
|
||||
"WHERE dts.date >= $1 AND dts.date <= $2 AND vt.user_id IS NOT NULL"
|
||||
)
|
||||
params = [start_date, end_date]
|
||||
|
||||
# Handle multiple tag filters (takes precedence over single tag filter)
|
||||
if tag_filters and len(tag_filters) > 0:
|
||||
tag_conditions = []
|
||||
for i, tag in enumerate(tag_filters):
|
||||
param_index = len(params) + 1
|
||||
tag_conditions.append(f"dts.tag = ${param_index}")
|
||||
params.append(tag)
|
||||
where_clause += f" AND ({' OR '.join(tag_conditions)})"
|
||||
elif tag_filter:
|
||||
where_clause += " AND dts.tag ILIKE $3"
|
||||
params.append(f"%{tag_filter}%")
|
||||
|
||||
# Use window function to group by months (30-day periods) with clear month numbering
|
||||
sql_query = f"""
|
||||
WITH monthly_data AS (
|
||||
SELECT
|
||||
dts.tag,
|
||||
dts.date,
|
||||
vt.user_id,
|
||||
-- Calculate month number (0 = Month 1 most recent, 1 = Month 2, etc.)
|
||||
FLOOR((DATE '{end_date}' - dts.date::date) / 30) as month_offset
|
||||
FROM "LiteLLM_DailyTagSpend" dts
|
||||
INNER JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
|
||||
{where_clause}
|
||||
)
|
||||
SELECT
|
||||
tag,
|
||||
COUNT(DISTINCT user_id) as active_users,
|
||||
-- Month identifier with month name (Month 1 (earliest), Month 2, etc.)
|
||||
'Month ' || ({MAX_MONTHS} - month_offset)::text || ' (' ||
|
||||
TO_CHAR(DATE '{end_date}' - (month_offset * 30 || ' days')::interval - '29 days'::interval, 'Mon') || ')' as date,
|
||||
-- Calculate month start and end dates for each month
|
||||
(DATE '{end_date}' - (month_offset * 30 || ' days')::interval - '29 days'::interval)::text as period_start,
|
||||
(DATE '{end_date}' - (month_offset * 30 || ' days')::interval)::text as period_end,
|
||||
month_offset
|
||||
FROM monthly_data
|
||||
WHERE month_offset < {MAX_MONTHS}
|
||||
GROUP BY tag, month_offset
|
||||
ORDER BY month_offset DESC, active_users DESC
|
||||
"""
|
||||
|
||||
db_response = await prisma_client.db.query_raw(sql_query, *params)
|
||||
|
||||
results = [
|
||||
TagActiveUsersResponse(
|
||||
tag=row["tag"],
|
||||
active_users=row["active_users"],
|
||||
date=row["date"], # This will be "Month 1 (Jan)", "Month 2 (Dec)", etc.
|
||||
period_start=row["period_start"],
|
||||
period_end=row["period_end"],
|
||||
)
|
||||
for row in db_response
|
||||
]
|
||||
|
||||
return ActiveUsersAnalyticsResponse(results=results)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch MAU analytics: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/summary",
|
||||
response_model=TagSummaryResponse,
|
||||
tags=["tag management", "user agent analytics"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_tag_summary(
|
||||
start_date: str = Query(description="Start date in YYYY-MM-DD format"),
|
||||
end_date: str = Query(description="End date in YYYY-MM-DD format"),
|
||||
tag_filter: Optional[str] = Query(
|
||||
default=None,
|
||||
description="Filter by specific tag (optional)",
|
||||
),
|
||||
tag_filters: Optional[List[str]] = Query(
|
||||
default=None,
|
||||
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get summary analytics for tags including unique users, requests, tokens, and spend.
|
||||
|
||||
Args:
|
||||
start_date: Start date for the analytics period (YYYY-MM-DD)
|
||||
end_date: End date for the analytics period (YYYY-MM-DD)
|
||||
tag_filter: Optional filter to specific tag (legacy)
|
||||
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
|
||||
|
||||
Returns:
|
||||
TagSummaryResponse: Summary analytics data by tag
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate date format
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
# Build SQL query with optional tag filter(s)
|
||||
where_clause = "WHERE dts.date >= $1 AND dts.date <= $2"
|
||||
params = [start_date, end_date]
|
||||
|
||||
# Handle multiple tag filters (takes precedence over single tag filter)
|
||||
if tag_filters and len(tag_filters) > 0:
|
||||
tag_conditions = []
|
||||
for i, tag in enumerate(tag_filters):
|
||||
param_index = len(params) + 1
|
||||
tag_conditions.append(f"dts.tag = ${param_index}")
|
||||
params.append(tag)
|
||||
where_clause += f" AND ({' OR '.join(tag_conditions)})"
|
||||
elif tag_filter:
|
||||
where_clause += " AND dts.tag ILIKE $3"
|
||||
params.append(f"%{tag_filter}%")
|
||||
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
dts.tag,
|
||||
COUNT(DISTINCT vt.user_id) as unique_users,
|
||||
SUM(dts.api_requests) as total_requests,
|
||||
SUM(dts.successful_requests) as successful_requests,
|
||||
SUM(dts.failed_requests) as failed_requests,
|
||||
SUM(dts.prompt_tokens + dts.completion_tokens) as total_tokens,
|
||||
SUM(dts.spend) as total_spend
|
||||
FROM "LiteLLM_DailyTagSpend" dts
|
||||
LEFT JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
|
||||
{where_clause}
|
||||
GROUP BY dts.tag
|
||||
ORDER BY total_requests DESC
|
||||
"""
|
||||
|
||||
db_response = await prisma_client.db.query_raw(sql_query, *params)
|
||||
|
||||
results = [
|
||||
TagSummaryMetrics(
|
||||
tag=row["tag"],
|
||||
unique_users=row["unique_users"] or 0,
|
||||
total_requests=int(row["total_requests"] or 0),
|
||||
successful_requests=int(row["successful_requests"] or 0),
|
||||
failed_requests=int(row["failed_requests"] or 0),
|
||||
total_tokens=int(row["total_tokens"] or 0),
|
||||
total_spend=float(row["total_spend"] or 0.0),
|
||||
)
|
||||
for row in db_response
|
||||
]
|
||||
|
||||
return TagSummaryResponse(results=results)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid date format. Use YYYY-MM-DD: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch tag summary analytics: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/user-agent/per-user-analytics",
|
||||
response_model=PerUserAnalyticsResponse,
|
||||
tags=["tag management", "user agent analytics"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_per_user_analytics(
|
||||
tag_filter: Optional[str] = Query(
|
||||
default=None,
|
||||
description="Filter by specific tag (optional)",
|
||||
),
|
||||
tag_filters: Optional[List[str]] = Query(
|
||||
default=None,
|
||||
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
|
||||
),
|
||||
page: int = Query(default=1, description="Page number for pagination", ge=1),
|
||||
page_size: int = Query(default=50, description="Items per page", ge=1, le=1000),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get per-user analytics including successful requests, tokens, and spend by individual users.
|
||||
|
||||
This endpoint provides usage metrics broken down by individual users based on their
|
||||
tag activity during the last 30 days ending on UTC today + 1 day.
|
||||
|
||||
Args:
|
||||
tag_filter: Optional filter to specific tag (legacy)
|
||||
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
PerUserAnalyticsResponse: Analytics data broken down by individual users for the last 30 days
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Calculate end_date as UTC today + 1 day
|
||||
from datetime import timezone
|
||||
|
||||
end_dt = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) + timedelta(days=1)
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Calculate date range (last 30 days)
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Build where clause with date range
|
||||
where_clause: Dict[str, Any] = {"date": {"gte": start_date, "lte": end_date}}
|
||||
|
||||
# Add tag filtering if provided
|
||||
if tag_filters and len(tag_filters) > 0:
|
||||
where_clause["tag"] = {"in": tag_filters}
|
||||
elif tag_filter:
|
||||
where_clause["tag"] = {"contains": tag_filter}
|
||||
|
||||
# Get all tag records in the date range with optional tag filtering
|
||||
tag_records = await prisma_client.db.litellm_dailytagspend.find_many(
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
# Get unique api_keys
|
||||
api_keys = set(record.api_key for record in tag_records if record.api_key)
|
||||
|
||||
if not api_keys:
|
||||
return PerUserAnalyticsResponse(
|
||||
results=[],
|
||||
total_count=0,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=0,
|
||||
)
|
||||
|
||||
# Lookup user_id for each api_key
|
||||
api_key_records = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": list(api_keys)}}
|
||||
)
|
||||
|
||||
# Create mapping from api_key to user_id
|
||||
api_key_to_user_id = {
|
||||
record.token: record.user_id for record in api_key_records if record.user_id
|
||||
}
|
||||
|
||||
# Get user emails for the user_ids
|
||||
user_ids = list(set(api_key_to_user_id.values()))
|
||||
user_records = await prisma_client.db.litellm_usertable.find_many(
|
||||
where={"user_id": {"in": user_ids}}
|
||||
)
|
||||
|
||||
# Create mapping from user_id to user_email
|
||||
user_id_to_email = {
|
||||
record.user_id: record.user_email for record in user_records
|
||||
}
|
||||
|
||||
# Aggregate metrics by user
|
||||
user_metrics: Dict[str, PerUserMetrics] = {}
|
||||
|
||||
for record in tag_records:
|
||||
if record.api_key in api_key_to_user_id:
|
||||
user_id = api_key_to_user_id[record.api_key]
|
||||
tag = record.tag # Use the full tag as user_agent
|
||||
|
||||
if user_id not in user_metrics:
|
||||
user_metrics[user_id] = PerUserMetrics(
|
||||
user_id=user_id,
|
||||
user_email=user_id_to_email.get(user_id),
|
||||
user_agent=tag,
|
||||
)
|
||||
else:
|
||||
# If tag is different, keep the first one or prioritize certain ones
|
||||
if tag and not user_metrics[user_id].user_agent:
|
||||
user_metrics[user_id].user_agent = tag
|
||||
|
||||
# Aggregate metrics
|
||||
user_metrics[user_id].successful_requests += (
|
||||
record.successful_requests or 0
|
||||
)
|
||||
user_metrics[user_id].failed_requests += record.failed_requests or 0
|
||||
user_metrics[user_id].total_requests += record.api_requests or 0
|
||||
# Calculate total_tokens from prompt_tokens + completion_tokens
|
||||
prompt_tokens = record.prompt_tokens or 0
|
||||
completion_tokens = record.completion_tokens or 0
|
||||
user_metrics[user_id].total_tokens += int(
|
||||
prompt_tokens + completion_tokens
|
||||
)
|
||||
user_metrics[user_id].spend += record.spend or 0.0
|
||||
|
||||
# Convert to list and sort by successful requests (descending)
|
||||
results = sorted(
|
||||
list(user_metrics.values()),
|
||||
key=lambda x: x.successful_requests,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
total_count = len(results)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
paginated_results = results[start_idx:end_idx]
|
||||
|
||||
return PerUserAnalyticsResponse(
|
||||
results=paginated_results,
|
||||
total_count=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch per-user analytics: {str(e)}",
|
||||
)
|
||||
Reference in New Issue
Block a user