3646 lines
146 KiB
Python
3646 lines
146 KiB
Python
|
|
"""
|
||
|
|
Has all /sso/* routes
|
||
|
|
|
||
|
|
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
|
||
|
|
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
|
||
|
|
|
||
|
|
/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback
|
||
|
|
/sso/debug/callback - returns the OpenID object returned by the SSO provider
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import base64
|
||
|
|
import hashlib
|
||
|
|
import inspect
|
||
|
|
import os
|
||
|
|
import secrets
|
||
|
|
from copy import deepcopy
|
||
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
import jwt
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||
|
|
from fastapi.responses import RedirectResponse
|
||
|
|
|
||
|
|
import litellm
|
||
|
|
from litellm._logging import verbose_proxy_logger
|
||
|
|
from litellm._uuid import uuid
|
||
|
|
from litellm.caching import DualCache
|
||
|
|
from litellm.constants import (
|
||
|
|
LITELLM_UI_SESSION_DURATION,
|
||
|
|
MAX_SPENDLOG_ROWS_TO_QUERY,
|
||
|
|
MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE,
|
||
|
|
MICROSOFT_USER_EMAIL_ATTRIBUTE,
|
||
|
|
MICROSOFT_USER_FIRST_NAME_ATTRIBUTE,
|
||
|
|
MICROSOFT_USER_ID_ATTRIBUTE,
|
||
|
|
MICROSOFT_USER_LAST_NAME_ATTRIBUTE,
|
||
|
|
)
|
||
|
|
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
|
||
|
|
from litellm.llms.custom_httpx.http_handler import (
|
||
|
|
AsyncHTTPHandler,
|
||
|
|
get_async_httpx_client,
|
||
|
|
httpxSpecialProvider,
|
||
|
|
)
|
||
|
|
from litellm.proxy._types import (
|
||
|
|
CommonProxyErrors,
|
||
|
|
LiteLLM_UserTable,
|
||
|
|
LitellmUserRoles,
|
||
|
|
Member,
|
||
|
|
NewTeamRequest,
|
||
|
|
NewUserRequest,
|
||
|
|
NewUserResponse,
|
||
|
|
ProxyErrorTypes,
|
||
|
|
ProxyException,
|
||
|
|
SSOUserDefinedValues,
|
||
|
|
TeamMemberAddRequest,
|
||
|
|
UserAPIKeyAuth,
|
||
|
|
)
|
||
|
|
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken, get_user_object
|
||
|
|
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
|
||
|
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||
|
|
from litellm.proxy.common_utils.admin_ui_utils import (
|
||
|
|
admin_ui_disabled,
|
||
|
|
show_missing_vars_in_env,
|
||
|
|
)
|
||
|
|
from litellm.proxy.common_utils.html_forms.jwt_display_template import (
|
||
|
|
jwt_display_template,
|
||
|
|
)
|
||
|
|
from litellm.proxy.common_utils.html_forms.ui_login import html_form
|
||
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||
|
|
from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO
|
||
|
|
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||
|
|
check_is_admin_only_access,
|
||
|
|
has_admin_ui_access,
|
||
|
|
)
|
||
|
|
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
|
||
|
|
from litellm.proxy.management_endpoints.types import (
|
||
|
|
CustomOpenID,
|
||
|
|
get_litellm_user_role,
|
||
|
|
is_valid_litellm_user_role,
|
||
|
|
)
|
||
|
|
from litellm.proxy.utils import (
|
||
|
|
PrismaClient,
|
||
|
|
ProxyLogging,
|
||
|
|
get_custom_url,
|
||
|
|
get_server_root_path,
|
||
|
|
)
|
||
|
|
from litellm.secret_managers.main import get_secret_bool, str_to_bool
|
||
|
|
from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401
|
||
|
|
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||
|
|
DefaultTeamSSOParams,
|
||
|
|
MicrosoftGraphAPIUserGroupDirectoryObject,
|
||
|
|
MicrosoftGraphAPIUserGroupResponse,
|
||
|
|
MicrosoftServicePrincipalTeam,
|
||
|
|
RoleMappings,
|
||
|
|
TeamMappings,
|
||
|
|
)
|
||
|
|
from litellm.types.proxy.ui_sso import ParsedOpenIDResult
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from fastapi_sso.sso.base import OpenID
|
||
|
|
else:
|
||
|
|
from typing import Any as OpenID
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
# OAuth bearer credential fields that must not appear in SSO debug responses
|
||
|
|
# (received_response is included in restricted-group error messages).
|
||
|
|
# Metadata fields (token_type, expires_in, scope) are intentionally kept so
|
||
|
|
# response convertors see the same fields in the PKCE path as in the non-PKCE path.
|
||
|
|
_OAUTH_TOKEN_FIELDS = frozenset({"access_token", "id_token", "refresh_token"})
|
||
|
|
|
||
|
|
|
||
|
|
def normalize_email(email: Optional[str]) -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Normalize email address to lowercase for consistent storage and comparison.
|
||
|
|
|
||
|
|
Email addresses should be treated as case-insensitive for SSO purposes,
|
||
|
|
even though RFC 5321 technically allows case-sensitive local parts.
|
||
|
|
This prevents issues where SSO providers return emails with different casing
|
||
|
|
than what's stored in the database.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
email: Email address to normalize, can be None
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Lowercased email address, or None if input is None
|
||
|
|
"""
|
||
|
|
if email is None:
|
||
|
|
return None
|
||
|
|
return email.lower() if isinstance(email, str) else email
|
||
|
|
|
||
|
|
|
||
|
|
def determine_role_from_groups(
|
||
|
|
user_groups: List[str],
|
||
|
|
role_mappings: "RoleMappings",
|
||
|
|
) -> Optional[LitellmUserRoles]:
|
||
|
|
"""
|
||
|
|
Determine the highest privilege role for a user based on their groups.
|
||
|
|
|
||
|
|
Role hierarchy (highest to lowest):
|
||
|
|
- proxy_admin
|
||
|
|
- proxy_admin_viewer
|
||
|
|
- internal_user
|
||
|
|
- internal_user_viewer
|
||
|
|
|
||
|
|
Args:
|
||
|
|
user_groups: List of group names from the SSO token
|
||
|
|
role_mappings: RoleMappings configuration object
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The highest privilege role found, or default_role if no matches, or None
|
||
|
|
"""
|
||
|
|
if not role_mappings.roles:
|
||
|
|
# No role mappings configured, return default_role
|
||
|
|
return role_mappings.default_role
|
||
|
|
|
||
|
|
# Role hierarchy (highest to lowest)
|
||
|
|
role_hierarchy = [
|
||
|
|
LitellmUserRoles.PROXY_ADMIN,
|
||
|
|
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||
|
|
LitellmUserRoles.INTERNAL_USER,
|
||
|
|
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||
|
|
]
|
||
|
|
|
||
|
|
# Convert user_groups to a set for efficient lookup
|
||
|
|
user_groups_set = set(user_groups) if isinstance(user_groups, list) else set()
|
||
|
|
|
||
|
|
# Find the highest privilege role the user belongs to
|
||
|
|
for role in role_hierarchy:
|
||
|
|
if role in role_mappings.roles:
|
||
|
|
role_groups = role_mappings.roles[role]
|
||
|
|
if isinstance(role_groups, list) and user_groups_set.intersection(
|
||
|
|
set(role_groups)
|
||
|
|
):
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"User groups {user_groups} matched role '{role.value}' via groups: {role_groups}"
|
||
|
|
)
|
||
|
|
return role
|
||
|
|
|
||
|
|
# No matching groups found, return default_role
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"User groups {user_groups} did not match any role mappings, using default_role: {role_mappings.default_role}"
|
||
|
|
)
|
||
|
|
return role_mappings.default_role
|
||
|
|
|
||
|
|
|
||
|
|
def process_sso_jwt_access_token(
|
||
|
|
access_token_str: Optional[str],
|
||
|
|
sso_jwt_handler: Optional[JWTHandler],
|
||
|
|
result: Union[OpenID, dict, None],
|
||
|
|
role_mappings: Optional["RoleMappings"] = None,
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Process SSO JWT access token and extract team IDs and user role if available.
|
||
|
|
|
||
|
|
This function decodes the JWT access token and extracts team IDs and user
|
||
|
|
role, then sets them on the result object. Role extraction from the access
|
||
|
|
token is needed because some SSO providers (e.g., Keycloak) do not include
|
||
|
|
role claims in the UserInfo endpoint response.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
access_token_str: The JWT access token string
|
||
|
|
sso_jwt_handler: SSO-specific JWT handler for team ID extraction
|
||
|
|
result: The SSO result object to update with team IDs and role
|
||
|
|
role_mappings: Optional role mappings configuration for group-based role determination
|
||
|
|
"""
|
||
|
|
if access_token_str and result:
|
||
|
|
import jwt
|
||
|
|
|
||
|
|
try:
|
||
|
|
access_token_payload = jwt.decode(
|
||
|
|
access_token_str, options={"verify_signature": False}
|
||
|
|
)
|
||
|
|
except jwt.exceptions.DecodeError:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction"
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
# Extract team IDs from access token if sso_jwt_handler is available
|
||
|
|
if sso_jwt_handler:
|
||
|
|
if isinstance(result, dict):
|
||
|
|
result_team_ids: Optional[List[str]] = result.get("team_ids", [])
|
||
|
|
if not result_team_ids:
|
||
|
|
team_ids = sso_jwt_handler.get_team_ids_from_jwt(
|
||
|
|
access_token_payload
|
||
|
|
)
|
||
|
|
result["team_ids"] = team_ids
|
||
|
|
else:
|
||
|
|
result_team_ids = getattr(result, "team_ids", []) if result else []
|
||
|
|
if not result_team_ids:
|
||
|
|
team_ids = sso_jwt_handler.get_team_ids_from_jwt(
|
||
|
|
access_token_payload
|
||
|
|
)
|
||
|
|
setattr(result, "team_ids", team_ids)
|
||
|
|
|
||
|
|
# Extract user role from access token if not already set from UserInfo
|
||
|
|
existing_role = (
|
||
|
|
result.get("user_role")
|
||
|
|
if isinstance(result, dict)
|
||
|
|
else getattr(result, "user_role", None)
|
||
|
|
)
|
||
|
|
if existing_role is None:
|
||
|
|
user_role: Optional[LitellmUserRoles] = None
|
||
|
|
|
||
|
|
# Try role_mappings first (group-based role determination)
|
||
|
|
if role_mappings is not None and role_mappings.roles:
|
||
|
|
group_claim = role_mappings.group_claim
|
||
|
|
user_groups_raw: Any = get_nested_value(
|
||
|
|
access_token_payload, group_claim
|
||
|
|
)
|
||
|
|
|
||
|
|
user_groups: List[str] = []
|
||
|
|
if isinstance(user_groups_raw, list):
|
||
|
|
user_groups = [str(g) for g in user_groups_raw]
|
||
|
|
elif isinstance(user_groups_raw, str):
|
||
|
|
user_groups = [
|
||
|
|
g.strip() for g in user_groups_raw.split(",") if g.strip()
|
||
|
|
]
|
||
|
|
elif user_groups_raw is not None:
|
||
|
|
user_groups = [str(user_groups_raw)]
|
||
|
|
|
||
|
|
if user_groups:
|
||
|
|
user_role = determine_role_from_groups(user_groups, role_mappings)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Determined role '{user_role}' from access token groups '{user_groups}' using role_mappings"
|
||
|
|
)
|
||
|
|
elif role_mappings.default_role:
|
||
|
|
user_role = role_mappings.default_role
|
||
|
|
|
||
|
|
# Fallback: try GENERIC_USER_ROLE_ATTRIBUTE on the access token payload
|
||
|
|
if user_role is None:
|
||
|
|
generic_user_role_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||
|
|
)
|
||
|
|
user_role_from_token = get_nested_value(
|
||
|
|
access_token_payload, generic_user_role_attribute_name
|
||
|
|
)
|
||
|
|
if user_role_from_token is not None:
|
||
|
|
user_role = get_litellm_user_role(user_role_from_token)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Extracted role '{user_role}' from access token field '{generic_user_role_attribute_name}'"
|
||
|
|
)
|
||
|
|
|
||
|
|
if user_role is not None:
|
||
|
|
if isinstance(result, dict):
|
||
|
|
result["user_role"] = user_role
|
||
|
|
else:
|
||
|
|
setattr(result, "user_role", user_role)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Set user_role='{user_role}' from JWT access token"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
|
||
|
|
async def google_login(
|
||
|
|
request: Request,
|
||
|
|
source: Optional[str] = None,
|
||
|
|
key: Optional[str] = None,
|
||
|
|
existing_key: Optional[str] = None,
|
||
|
|
): # noqa: PLR0915
|
||
|
|
"""
|
||
|
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||
|
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||
|
|
Example:
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import (
|
||
|
|
premium_user,
|
||
|
|
prisma_client,
|
||
|
|
user_custom_ui_sso_sign_in_handler,
|
||
|
|
)
|
||
|
|
|
||
|
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||
|
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||
|
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||
|
|
|
||
|
|
####### Check if UI is disabled #######
|
||
|
|
_disable_ui_flag = os.getenv("DISABLE_ADMIN_UI")
|
||
|
|
if _disable_ui_flag is not None:
|
||
|
|
is_disabled = str_to_bool(value=_disable_ui_flag)
|
||
|
|
if is_disabled:
|
||
|
|
return admin_ui_disabled()
|
||
|
|
|
||
|
|
####### Check if user is a Enterprise / Premium User #######
|
||
|
|
if (
|
||
|
|
microsoft_client_id is not None
|
||
|
|
or google_client_id is not None
|
||
|
|
or generic_client_id is not None
|
||
|
|
):
|
||
|
|
if premium_user is not True:
|
||
|
|
# Check if under 'free SSO user' limit
|
||
|
|
if prisma_client is not None:
|
||
|
|
total_users = await prisma_client.db.litellm_usertable.count()
|
||
|
|
if total_users and total_users > 5:
|
||
|
|
raise ProxyException(
|
||
|
|
message="You must be a LiteLLM Enterprise user to use SSO for more than 5 users. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="premium_user",
|
||
|
|
code=status.HTTP_403_FORBIDDEN,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise ProxyException(
|
||
|
|
message=CommonProxyErrors.db_not_connected_error.value,
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="premium_user",
|
||
|
|
code=status.HTTP_403_FORBIDDEN,
|
||
|
|
)
|
||
|
|
|
||
|
|
####### Detect DB + MASTER KEY in .env #######
|
||
|
|
missing_env_vars = show_missing_vars_in_env()
|
||
|
|
if missing_env_vars is not None:
|
||
|
|
return missing_env_vars
|
||
|
|
ui_username = os.getenv("UI_USERNAME")
|
||
|
|
|
||
|
|
# get url from request - always use regular callback, but set state for CLI
|
||
|
|
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||
|
|
request=request,
|
||
|
|
sso_callback_route="sso/callback",
|
||
|
|
existing_key=existing_key,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Store CLI key in state for OAuth flow
|
||
|
|
cli_state: Optional[str] = SSOAuthenticationHandler._get_cli_state(
|
||
|
|
source=source,
|
||
|
|
key=key,
|
||
|
|
existing_key=existing_key,
|
||
|
|
)
|
||
|
|
|
||
|
|
# check if user defined a custom auth sso sign in handler, if yes, use it
|
||
|
|
if user_custom_ui_sso_sign_in_handler is not None:
|
||
|
|
try:
|
||
|
|
from litellm_enterprise.proxy.auth.custom_sso_handler import ( # type: ignore[import-untyped]
|
||
|
|
EnterpriseCustomSSOHandler,
|
||
|
|
)
|
||
|
|
|
||
|
|
return await EnterpriseCustomSSOHandler.handle_custom_ui_sso_sign_in(
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
except ImportError:
|
||
|
|
raise ValueError(
|
||
|
|
"Enterprise features are not available. Custom UI SSO sign-in requires LiteLLM Enterprise."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check if we should use SSO handler
|
||
|
|
if (
|
||
|
|
SSOAuthenticationHandler.should_use_sso_handler(
|
||
|
|
microsoft_client_id=microsoft_client_id,
|
||
|
|
google_client_id=google_client_id,
|
||
|
|
generic_client_id=generic_client_id,
|
||
|
|
)
|
||
|
|
is True
|
||
|
|
):
|
||
|
|
verbose_proxy_logger.info(f"Redirecting to SSO login for {redirect_url}")
|
||
|
|
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
microsoft_client_id=microsoft_client_id,
|
||
|
|
google_client_id=google_client_id,
|
||
|
|
generic_client_id=generic_client_id,
|
||
|
|
state=cli_state,
|
||
|
|
)
|
||
|
|
elif ui_username is not None:
|
||
|
|
# No Google, Microsoft SSO
|
||
|
|
# Use UI Credentials set in .env
|
||
|
|
from fastapi.responses import HTMLResponse
|
||
|
|
|
||
|
|
return HTMLResponse(content=html_form, status_code=200)
|
||
|
|
else:
|
||
|
|
from fastapi.responses import HTMLResponse
|
||
|
|
|
||
|
|
return HTMLResponse(content=html_form, status_code=200)
|
||
|
|
|
||
|
|
|
||
|
|
def generic_response_convertor(
|
||
|
|
response,
|
||
|
|
jwt_handler: JWTHandler,
|
||
|
|
sso_jwt_handler: Optional[JWTHandler] = None,
|
||
|
|
role_mappings: Optional["RoleMappings"] = None,
|
||
|
|
team_mappings: Optional["TeamMappings"] = None,
|
||
|
|
) -> CustomOpenID:
|
||
|
|
generic_user_id_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
||
|
|
)
|
||
|
|
generic_user_display_name_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
||
|
|
)
|
||
|
|
generic_user_email_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||
|
|
)
|
||
|
|
|
||
|
|
generic_user_first_name_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
||
|
|
)
|
||
|
|
generic_user_last_name_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
||
|
|
)
|
||
|
|
|
||
|
|
generic_provider_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
|
||
|
|
)
|
||
|
|
|
||
|
|
generic_user_role_attribute_name = os.getenv("GENERIC_USER_ROLE_ATTRIBUTE", "role")
|
||
|
|
|
||
|
|
generic_user_extra_attributes = os.getenv("GENERIC_USER_EXTRA_ATTRIBUTES", None)
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
|
||
|
|
)
|
||
|
|
|
||
|
|
all_teams = []
|
||
|
|
if sso_jwt_handler is not None:
|
||
|
|
team_ids = sso_jwt_handler.get_team_ids_from_jwt(cast(dict, response))
|
||
|
|
all_teams.extend(team_ids)
|
||
|
|
|
||
|
|
if team_mappings is not None and team_mappings.team_ids_jwt_field is not None:
|
||
|
|
team_ids_from_db_mapping: Optional[List[str]] = get_nested_value(
|
||
|
|
data=cast(dict, response),
|
||
|
|
key_path=team_mappings.team_ids_jwt_field,
|
||
|
|
default=[],
|
||
|
|
)
|
||
|
|
if team_ids_from_db_mapping:
|
||
|
|
all_teams.extend(team_ids_from_db_mapping)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Loaded team_ids from DB team_mappings.team_ids_jwt_field='{team_mappings.team_ids_jwt_field}': {team_ids_from_db_mapping}"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
team_ids = jwt_handler.get_team_ids_from_jwt(cast(dict, response))
|
||
|
|
all_teams.extend(team_ids)
|
||
|
|
|
||
|
|
# Determine user role based on role_mappings if available
|
||
|
|
# Only apply role_mappings for GENERIC SSO provider
|
||
|
|
user_role: Optional[LitellmUserRoles] = None
|
||
|
|
|
||
|
|
if role_mappings is not None and role_mappings.provider.lower() in [
|
||
|
|
"generic",
|
||
|
|
"okta",
|
||
|
|
]:
|
||
|
|
# Use role_mappings to determine role from groups
|
||
|
|
group_claim = role_mappings.group_claim
|
||
|
|
user_groups_raw: Any = get_nested_value(response, group_claim)
|
||
|
|
|
||
|
|
# Handle different formats: could be a list, string (comma-separated), or single value
|
||
|
|
user_groups: List[str] = []
|
||
|
|
if isinstance(user_groups_raw, list):
|
||
|
|
user_groups = [str(g) for g in user_groups_raw]
|
||
|
|
elif isinstance(user_groups_raw, str):
|
||
|
|
# Handle comma-separated string
|
||
|
|
user_groups = [g.strip() for g in user_groups_raw.split(",") if g.strip()]
|
||
|
|
elif user_groups_raw is not None:
|
||
|
|
# Single value
|
||
|
|
user_groups = [str(user_groups_raw)]
|
||
|
|
|
||
|
|
if user_groups:
|
||
|
|
user_role = determine_role_from_groups(user_groups, role_mappings)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Determined role '{user_role.value if user_role else None}' from groups '{user_groups}' using role_mappings"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# No groups found, use default_role
|
||
|
|
user_role = role_mappings.default_role
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"No groups found in '{group_claim}', using default_role: {role_mappings.default_role}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Fallback to existing logic if role_mappings not used
|
||
|
|
if user_role is None:
|
||
|
|
user_role_from_sso = get_nested_value(
|
||
|
|
response, generic_user_role_attribute_name
|
||
|
|
)
|
||
|
|
if user_role_from_sso is not None:
|
||
|
|
role = get_litellm_user_role(user_role_from_sso)
|
||
|
|
if role is not None:
|
||
|
|
user_role = role
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Found valid LitellmUserRoles '{role.value}' from SSO attribute '{generic_user_role_attribute_name}'"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Build extra_fields dict from GENERIC_USER_EXTRA_ATTRIBUTES if specified
|
||
|
|
extra_fields: Optional[Dict[str, Any]] = None
|
||
|
|
if generic_user_extra_attributes:
|
||
|
|
extra_fields = {}
|
||
|
|
for attr_name in generic_user_extra_attributes.split(","):
|
||
|
|
attr_name = attr_name.strip()
|
||
|
|
extra_fields[attr_name] = get_nested_value(response, attr_name)
|
||
|
|
|
||
|
|
return CustomOpenID(
|
||
|
|
id=get_nested_value(response, generic_user_id_attribute_name),
|
||
|
|
display_name=get_nested_value(
|
||
|
|
response, generic_user_display_name_attribute_name
|
||
|
|
),
|
||
|
|
email=normalize_email(
|
||
|
|
get_nested_value(response, generic_user_email_attribute_name)
|
||
|
|
),
|
||
|
|
first_name=get_nested_value(response, generic_user_first_name_attribute_name),
|
||
|
|
last_name=get_nested_value(response, generic_user_last_name_attribute_name),
|
||
|
|
provider=get_nested_value(response, generic_provider_attribute_name),
|
||
|
|
team_ids=all_teams,
|
||
|
|
user_role=user_role,
|
||
|
|
extra_fields=extra_fields,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _setup_generic_sso_env_vars(
|
||
|
|
generic_client_id: str, redirect_url: str
|
||
|
|
) -> Tuple[str, List[str], str, str, str, bool]:
|
||
|
|
"""Setup and validate Generic SSO environment variables."""
|
||
|
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||
|
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||
|
|
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
|
||
|
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||
|
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||
|
|
generic_include_client_id = (
|
||
|
|
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Validate required environment variables
|
||
|
|
if generic_client_secret is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_CLIENT_SECRET",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
if generic_authorization_endpoint is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
if generic_token_endpoint is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_TOKEN_ENDPOINT",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
if generic_userinfo_endpoint is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_USERINFO_ENDPOINT",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||
|
|
)
|
||
|
|
|
||
|
|
return (
|
||
|
|
generic_client_secret,
|
||
|
|
generic_scope,
|
||
|
|
generic_authorization_endpoint,
|
||
|
|
generic_token_endpoint,
|
||
|
|
generic_userinfo_endpoint,
|
||
|
|
generic_include_client_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def _setup_team_mappings() -> Optional["TeamMappings"]:
|
||
|
|
"""Setup team mappings from SSO database settings."""
|
||
|
|
team_mappings: Optional["TeamMappings"] = None
|
||
|
|
try:
|
||
|
|
from litellm.proxy.utils import get_prisma_client_or_throw
|
||
|
|
|
||
|
|
prisma_client = get_prisma_client_or_throw(
|
||
|
|
"Prisma client is None, connect a database to your proxy"
|
||
|
|
)
|
||
|
|
|
||
|
|
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
|
||
|
|
where={"id": "sso_config"}
|
||
|
|
)
|
||
|
|
|
||
|
|
if sso_db_record and sso_db_record.sso_settings:
|
||
|
|
sso_settings_dict = dict(sso_db_record.sso_settings)
|
||
|
|
team_mappings_data = sso_settings_dict.get("team_mappings")
|
||
|
|
|
||
|
|
if team_mappings_data:
|
||
|
|
from litellm.types.proxy.management_endpoints.ui_sso import TeamMappings
|
||
|
|
|
||
|
|
if isinstance(team_mappings_data, dict):
|
||
|
|
team_mappings = TeamMappings(**team_mappings_data)
|
||
|
|
elif isinstance(team_mappings_data, TeamMappings):
|
||
|
|
team_mappings = team_mappings_data
|
||
|
|
|
||
|
|
if team_mappings and team_mappings.team_ids_jwt_field:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Loaded team_mappings with team_ids_jwt_field: '{team_mappings.team_ids_jwt_field}'"
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Could not load team_mappings from database: {e}. Continuing with config-based team mapping."
|
||
|
|
)
|
||
|
|
|
||
|
|
return team_mappings
|
||
|
|
|
||
|
|
|
||
|
|
async def _setup_role_mappings() -> Optional["RoleMappings"]:
|
||
|
|
"""Setup role mappings from SSO database settings."""
|
||
|
|
role_mappings: Optional["RoleMappings"] = None
|
||
|
|
try:
|
||
|
|
from litellm.proxy.utils import get_prisma_client_or_throw
|
||
|
|
|
||
|
|
prisma_client = get_prisma_client_or_throw(
|
||
|
|
"Prisma client is None, connect a database to your proxy"
|
||
|
|
)
|
||
|
|
|
||
|
|
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
|
||
|
|
where={"id": "sso_config"}
|
||
|
|
)
|
||
|
|
|
||
|
|
if sso_db_record and sso_db_record.sso_settings:
|
||
|
|
sso_settings_dict = dict(sso_db_record.sso_settings)
|
||
|
|
role_mappings_data = sso_settings_dict.get("role_mappings")
|
||
|
|
|
||
|
|
if role_mappings_data:
|
||
|
|
from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings
|
||
|
|
|
||
|
|
if isinstance(role_mappings_data, dict):
|
||
|
|
role_mappings = RoleMappings(**role_mappings_data)
|
||
|
|
elif isinstance(role_mappings_data, RoleMappings):
|
||
|
|
role_mappings = role_mappings_data
|
||
|
|
|
||
|
|
if role_mappings:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Loaded role_mappings for provider '{role_mappings.provider}'"
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Could not load role_mappings from database: {e}. Continuing with existing role logic."
|
||
|
|
)
|
||
|
|
|
||
|
|
generic_role_mappings = os.getenv("GENERIC_ROLE_MAPPINGS_ROLES", None)
|
||
|
|
generic_role_mappings_group_claim = os.getenv(
|
||
|
|
"GENERIC_ROLE_MAPPINGS_GROUP_CLAIM", None
|
||
|
|
)
|
||
|
|
generic_role_mappoings_default_role = os.getenv(
|
||
|
|
"GENERIC_ROLE_MAPPINGS_DEFAULT_ROLE", None
|
||
|
|
)
|
||
|
|
if generic_role_mappings is not None:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"Found role_mappings for generic provider in environment variables"
|
||
|
|
)
|
||
|
|
import ast
|
||
|
|
|
||
|
|
try:
|
||
|
|
generic_user_role_mappings_data: Dict[
|
||
|
|
LitellmUserRoles, List[str]
|
||
|
|
] = ast.literal_eval(generic_role_mappings)
|
||
|
|
if isinstance(generic_user_role_mappings_data, dict):
|
||
|
|
from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings
|
||
|
|
|
||
|
|
role_mappings_data = {
|
||
|
|
"provider": "generic",
|
||
|
|
"group_claim": generic_role_mappings_group_claim,
|
||
|
|
"default_role": generic_role_mappoings_default_role,
|
||
|
|
"roles": generic_user_role_mappings_data,
|
||
|
|
}
|
||
|
|
|
||
|
|
role_mappings = RoleMappings(**role_mappings_data)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Loaded role_mappings from environments for provider '{role_mappings.provider}'."
|
||
|
|
)
|
||
|
|
return role_mappings
|
||
|
|
except TypeError as e:
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic."
|
||
|
|
)
|
||
|
|
return role_mappings
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_generic_sso_headers() -> dict:
|
||
|
|
"""Parse comma-separated GENERIC_SSO_HEADERS env var into a dict."""
|
||
|
|
raw = os.getenv("GENERIC_SSO_HEADERS", None)
|
||
|
|
if raw is None:
|
||
|
|
return {}
|
||
|
|
result: Dict[str, str] = {}
|
||
|
|
for header in raw.split(","):
|
||
|
|
header = header.strip()
|
||
|
|
if header:
|
||
|
|
key, value = header.split("=")
|
||
|
|
result[key] = value
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def _handle_generic_sso_error(
|
||
|
|
e: Exception,
|
||
|
|
generic_authorization_endpoint: Optional[str],
|
||
|
|
generic_token_endpoint: Optional[str],
|
||
|
|
additional_headers: dict,
|
||
|
|
) -> None:
|
||
|
|
"""Handle errors from generic SSO verify_and_process. Always re-raises."""
|
||
|
|
error_message = str(e)
|
||
|
|
|
||
|
|
# Surface a helpful PKCE misconfiguration hint only when:
|
||
|
|
# 1. The error mentions PKCE/code verifier, AND
|
||
|
|
# 2. PKCE is not currently configured (GENERIC_CLIENT_USE_PKCE != true)
|
||
|
|
pkce_configured = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true"
|
||
|
|
if not pkce_configured and (
|
||
|
|
"PKCE" in error_message or "code verifier" in error_message.lower()
|
||
|
|
):
|
||
|
|
is_okta = (
|
||
|
|
generic_authorization_endpoint
|
||
|
|
and "okta" in generic_authorization_endpoint.lower()
|
||
|
|
) or (generic_token_endpoint and "okta" in generic_token_endpoint.lower())
|
||
|
|
provider_name = "Okta" if is_okta else "Your OAuth provider"
|
||
|
|
|
||
|
|
detailed_message = (
|
||
|
|
f"SSO authentication failed: {provider_name} requires PKCE (Proof Key for Code Exchange) "
|
||
|
|
f"but it's not enabled in your LiteLLM configuration.\n\n"
|
||
|
|
f"SOLUTION: Add this environment variable and restart your proxy:\n"
|
||
|
|
f" GENERIC_CLIENT_USE_PKCE=true\n\n"
|
||
|
|
)
|
||
|
|
if is_okta:
|
||
|
|
detailed_message += (
|
||
|
|
"For AWS ECS: Add the environment variable to your task definition.\n"
|
||
|
|
"For Docker: Add -e GENERIC_CLIENT_USE_PKCE=true to your docker run command.\n"
|
||
|
|
"For .env file: Add GENERIC_CLIENT_USE_PKCE=true to your .env file.\n\n"
|
||
|
|
)
|
||
|
|
detailed_message += f"Original error: {error_message}"
|
||
|
|
|
||
|
|
raise ProxyException(
|
||
|
|
message=detailed_message,
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_CLIENT_USE_PKCE",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
|
||
|
|
if isinstance(e, ProxyException):
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"SSO authentication failed: %s. Passed in headers: %s",
|
||
|
|
e,
|
||
|
|
additional_headers,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
"Error verifying and processing generic SSO: %s. Passed in headers: %s",
|
||
|
|
e,
|
||
|
|
additional_headers,
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
|
||
|
|
async def get_generic_sso_response(
|
||
|
|
request: Request,
|
||
|
|
jwt_handler: JWTHandler,
|
||
|
|
sso_jwt_handler: Optional[
|
||
|
|
JWTHandler
|
||
|
|
], # sso specific jwt handler - used for restricted sso group access control
|
||
|
|
generic_client_id: str,
|
||
|
|
redirect_url: str,
|
||
|
|
) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response
|
||
|
|
# make generic sso provider
|
||
|
|
from fastapi_sso.sso.base import DiscoveryDocument
|
||
|
|
from fastapi_sso.sso.generic import create_provider
|
||
|
|
|
||
|
|
received_response: Optional[dict] = None
|
||
|
|
|
||
|
|
# Setup environment variables
|
||
|
|
(
|
||
|
|
generic_client_secret,
|
||
|
|
generic_scope,
|
||
|
|
generic_authorization_endpoint,
|
||
|
|
generic_token_endpoint,
|
||
|
|
generic_userinfo_endpoint,
|
||
|
|
generic_include_client_id,
|
||
|
|
) = _setup_generic_sso_env_vars(generic_client_id, redirect_url)
|
||
|
|
|
||
|
|
discovery = DiscoveryDocument(
|
||
|
|
authorization_endpoint=generic_authorization_endpoint,
|
||
|
|
token_endpoint=generic_token_endpoint,
|
||
|
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||
|
|
)
|
||
|
|
|
||
|
|
role_mappings = await _setup_role_mappings()
|
||
|
|
team_mappings = await _setup_team_mappings()
|
||
|
|
|
||
|
|
def response_convertor(response, client):
|
||
|
|
nonlocal received_response # return for user debugging
|
||
|
|
received_response = response
|
||
|
|
return generic_response_convertor(
|
||
|
|
response=response,
|
||
|
|
jwt_handler=jwt_handler,
|
||
|
|
sso_jwt_handler=sso_jwt_handler,
|
||
|
|
role_mappings=role_mappings,
|
||
|
|
team_mappings=team_mappings,
|
||
|
|
)
|
||
|
|
|
||
|
|
SSOProvider = create_provider(
|
||
|
|
name="oidc",
|
||
|
|
discovery_document=discovery,
|
||
|
|
response_convertor=response_convertor,
|
||
|
|
)
|
||
|
|
generic_sso = SSOProvider(
|
||
|
|
client_id=generic_client_id,
|
||
|
|
client_secret=generic_client_secret,
|
||
|
|
redirect_uri=redirect_url,
|
||
|
|
allow_insecure_http=True,
|
||
|
|
scope=generic_scope,
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
||
|
|
additional_generic_sso_headers_dict = _parse_generic_sso_headers()
|
||
|
|
|
||
|
|
code_verifier: Optional[str] = None # assigned inside try; initialized for type tracking
|
||
|
|
|
||
|
|
try:
|
||
|
|
token_exchange_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters(
|
||
|
|
request=request,
|
||
|
|
generic_include_client_id=generic_include_client_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Extract code_verifier (and the cache key for deferred deletion) before calling fastapi-sso
|
||
|
|
code_verifier = token_exchange_params.pop("code_verifier", None)
|
||
|
|
pkce_cache_key = token_exchange_params.pop("_pkce_cache_key", None)
|
||
|
|
|
||
|
|
# Get authorization code from query params (only used in the PKCE path below;
|
||
|
|
# the non-PKCE path delegates to verify_and_process which handles OAuth error
|
||
|
|
# callbacks — user-denied, CSRF mismatch — internally).
|
||
|
|
authorization_code = request.query_params.get("code")
|
||
|
|
|
||
|
|
if code_verifier:
|
||
|
|
if not authorization_code:
|
||
|
|
raise ProxyException(
|
||
|
|
message="Missing authorization code in callback",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="code",
|
||
|
|
code=status.HTTP_400_BAD_REQUEST,
|
||
|
|
)
|
||
|
|
if not generic_client_id:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_CLIENT_ID must be set when PKCE is enabled",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_CLIENT_ID",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
if not generic_token_endpoint:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_TOKEN_ENDPOINT must be set when PKCE is enabled",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_TOKEN_ENDPOINT",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
# All guards above raise, so authorization_code is a non-empty str here.
|
||
|
|
# Use an explicit type guard rather than assert (assert is a no-op with -O).
|
||
|
|
if not isinstance(authorization_code, str):
|
||
|
|
raise ProxyException(
|
||
|
|
message="Missing authorization code in callback",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="code",
|
||
|
|
code=status.HTTP_400_BAD_REQUEST,
|
||
|
|
)
|
||
|
|
combined_response = await SSOAuthenticationHandler._pkce_token_exchange(
|
||
|
|
authorization_code=authorization_code,
|
||
|
|
code_verifier=code_verifier,
|
||
|
|
client_id=generic_client_id,
|
||
|
|
client_secret=generic_client_secret,
|
||
|
|
token_endpoint=generic_token_endpoint,
|
||
|
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||
|
|
include_client_id=generic_include_client_id,
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
additional_headers=additional_generic_sso_headers_dict,
|
||
|
|
)
|
||
|
|
# Pass the full response so custom response_convertor implementations
|
||
|
|
# can access all fields (including id_token for claim extraction).
|
||
|
|
result = response_convertor(combined_response, generic_sso)
|
||
|
|
# Strip bearer credentials from combined_response before storing in
|
||
|
|
# received_response. received_response may appear in restricted-group
|
||
|
|
# error messages — bearer tokens (access_token, id_token, refresh_token)
|
||
|
|
# must not be exposed to callers.
|
||
|
|
# Assign directly rather than relying on nonlocal mutation so that Pyright
|
||
|
|
# can track that received_response is non-None from this point on.
|
||
|
|
received_response = {
|
||
|
|
k: v for k, v in combined_response.items() if k not in _OAUTH_TOKEN_FIELDS
|
||
|
|
}
|
||
|
|
# In the PKCE path verify_and_process is skipped, so generic_sso.access_token
|
||
|
|
# is never set. Read the token directly from the exchange response instead so
|
||
|
|
# process_sso_jwt_access_token can extract JWT-embedded roles/teams.
|
||
|
|
access_token_str: Optional[str] = combined_response.get("access_token")
|
||
|
|
else:
|
||
|
|
result = await generic_sso.verify_and_process(
|
||
|
|
request,
|
||
|
|
params=token_exchange_params,
|
||
|
|
headers=additional_generic_sso_headers_dict,
|
||
|
|
)
|
||
|
|
access_token_str = generic_sso.access_token
|
||
|
|
|
||
|
|
process_sso_jwt_access_token(
|
||
|
|
access_token_str, sso_jwt_handler, result, role_mappings=role_mappings
|
||
|
|
)
|
||
|
|
# Delete the single-use PKCE verifier only after all downstream processing
|
||
|
|
# (response_convertor and process_sso_jwt_access_token) has completed
|
||
|
|
# successfully. Deleting earlier would consume the verifier on a transient
|
||
|
|
# failure, forcing the user to restart the entire OAuth flow from scratch.
|
||
|
|
if pkce_cache_key:
|
||
|
|
await SSOAuthenticationHandler._delete_pkce_verifier(pkce_cache_key)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
_handle_generic_sso_error(
|
||
|
|
e,
|
||
|
|
generic_authorization_endpoint,
|
||
|
|
generic_token_endpoint,
|
||
|
|
additional_generic_sso_headers_dict,
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug("generic result: %s", result)
|
||
|
|
return result or {}, received_response
|
||
|
|
|
||
|
|
|
||
|
|
async def create_team_member_add_task(team_id, user_info):
|
||
|
|
"""Create a task for adding a member to a team."""
|
||
|
|
try:
|
||
|
|
member = Member(user_id=user_info.user_id, role="user")
|
||
|
|
team_member_add_request = TeamMemberAddRequest(
|
||
|
|
member=member,
|
||
|
|
team_id=team_id,
|
||
|
|
)
|
||
|
|
return await team_member_add(
|
||
|
|
data=team_member_add_request,
|
||
|
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def add_missing_team_member(
|
||
|
|
user_info: Union[NewUserResponse, LiteLLM_UserTable], sso_teams: List[str]
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
- Get missing teams (diff b/w user_info.team_ids and sso_teams)
|
||
|
|
- Add missing user to missing teams
|
||
|
|
"""
|
||
|
|
# Handle None as empty list for new users
|
||
|
|
user_teams = user_info.teams if user_info.teams is not None else []
|
||
|
|
missing_teams = set(sso_teams) - set(user_teams)
|
||
|
|
missing_teams_list = list(missing_teams)
|
||
|
|
tasks = []
|
||
|
|
tasks = [
|
||
|
|
create_team_member_add_task(team_id, user_info)
|
||
|
|
for team_id in missing_teams_list
|
||
|
|
]
|
||
|
|
|
||
|
|
try:
|
||
|
|
await asyncio.gather(*tasks)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_disabled_non_admin_personal_key_creation():
|
||
|
|
key_generation_settings = litellm.key_generation_settings
|
||
|
|
if key_generation_settings is None:
|
||
|
|
return False
|
||
|
|
personal_key_generation = (
|
||
|
|
key_generation_settings.get("personal_key_generation") or {}
|
||
|
|
)
|
||
|
|
allowed_user_roles = personal_key_generation.get("allowed_user_roles") or []
|
||
|
|
return bool("proxy_admin" in allowed_user_roles)
|
||
|
|
|
||
|
|
|
||
|
|
async def get_existing_user_info_from_db(
|
||
|
|
user_id: Optional[str],
|
||
|
|
user_email: Optional[str],
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
user_api_key_cache: DualCache,
|
||
|
|
proxy_logging_obj: ProxyLogging,
|
||
|
|
) -> Optional[LiteLLM_UserTable]:
|
||
|
|
try:
|
||
|
|
user_info = await get_user_object(
|
||
|
|
user_id=user_id,
|
||
|
|
user_email=user_email,
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
user_id_upsert=False,
|
||
|
|
parent_otel_span=None,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
sso_user_id=user_id,
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.debug(f"Error getting user object: {e}")
|
||
|
|
user_info = None
|
||
|
|
|
||
|
|
return user_info
|
||
|
|
|
||
|
|
|
||
|
|
async def get_user_info_from_db(
|
||
|
|
result: Union[CustomOpenID, OpenID, dict],
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
user_api_key_cache: DualCache,
|
||
|
|
proxy_logging_obj: ProxyLogging,
|
||
|
|
user_email: Optional[str],
|
||
|
|
user_defined_values: Optional[SSOUserDefinedValues],
|
||
|
|
alternate_user_id: Optional[str] = None,
|
||
|
|
) -> Optional[Union[LiteLLM_UserTable, NewUserResponse]]:
|
||
|
|
try:
|
||
|
|
potential_user_ids = []
|
||
|
|
if alternate_user_id is not None:
|
||
|
|
potential_user_ids.append(alternate_user_id)
|
||
|
|
if not isinstance(result, dict):
|
||
|
|
_id = getattr(result, "id", None)
|
||
|
|
if _id is not None and isinstance(_id, str):
|
||
|
|
potential_user_ids.append(_id)
|
||
|
|
else:
|
||
|
|
_id = result.get("id", None)
|
||
|
|
if _id is not None and isinstance(_id, str):
|
||
|
|
potential_user_ids.append(_id)
|
||
|
|
|
||
|
|
user_email = normalize_email(
|
||
|
|
getattr(result, "email", None)
|
||
|
|
if not isinstance(result, dict)
|
||
|
|
else result.get("email", None)
|
||
|
|
)
|
||
|
|
|
||
|
|
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]] = None
|
||
|
|
|
||
|
|
for user_id in potential_user_ids:
|
||
|
|
user_info = await get_existing_user_info_from_db(
|
||
|
|
user_id=user_id,
|
||
|
|
user_email=user_email,
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
)
|
||
|
|
if user_info is not None:
|
||
|
|
break
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Upsert SSO User to LiteLLM DB
|
||
|
|
user_info = await SSOAuthenticationHandler.upsert_sso_user(
|
||
|
|
result=result,
|
||
|
|
user_info=user_info,
|
||
|
|
user_email=user_email,
|
||
|
|
user_defined_values=user_defined_values,
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
)
|
||
|
|
|
||
|
|
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
|
||
|
|
result=result,
|
||
|
|
user_info=user_info,
|
||
|
|
)
|
||
|
|
|
||
|
|
return user_info
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _should_use_role_from_sso_response(sso_role: Optional[str]) -> bool:
|
||
|
|
"""returns true if SSO upsert should use the 'role' defined on the SSO response"""
|
||
|
|
if sso_role is None:
|
||
|
|
return False
|
||
|
|
|
||
|
|
if not is_valid_litellm_user_role(sso_role):
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"SSO role '{sso_role}' is not a valid LiteLLM user role. "
|
||
|
|
"Ignoring role from SSO response. See LitellmUserRoles enum for valid roles."
|
||
|
|
)
|
||
|
|
return False
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def _build_sso_user_update_data(
|
||
|
|
result: Optional[Union["CustomOpenID", OpenID, dict]],
|
||
|
|
user_email: Optional[str],
|
||
|
|
user_id: Optional[str],
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Build the update data dictionary for SSO user upsert.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
result: The SSO response containing user information
|
||
|
|
user_email: The user's email from SSO
|
||
|
|
user_id: The user's ID for logging purposes
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: Update data containing user_email and optionally user_role if valid
|
||
|
|
"""
|
||
|
|
update_data: dict = {"user_email": normalize_email(user_email)}
|
||
|
|
|
||
|
|
# Get SSO role from result and include if valid
|
||
|
|
sso_role = getattr(result, "user_role", None)
|
||
|
|
if sso_role is not None:
|
||
|
|
# Convert enum to string if needed
|
||
|
|
sso_role_str = (
|
||
|
|
sso_role.value if isinstance(sso_role, LitellmUserRoles) else sso_role
|
||
|
|
)
|
||
|
|
|
||
|
|
# Only include if it's a valid LiteLLM role
|
||
|
|
if _should_use_role_from_sso_response(sso_role_str):
|
||
|
|
update_data["user_role"] = sso_role_str
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"Updating user {user_id} role from SSO: {sso_role_str}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return update_data
|
||
|
|
|
||
|
|
|
||
|
|
def apply_user_info_values_to_sso_user_defined_values(
|
||
|
|
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]],
|
||
|
|
user_defined_values: Optional[SSOUserDefinedValues],
|
||
|
|
) -> Optional[SSOUserDefinedValues]:
|
||
|
|
if user_defined_values is None:
|
||
|
|
return None
|
||
|
|
if user_info is not None and user_info.user_id is not None:
|
||
|
|
user_defined_values["user_id"] = user_info.user_id
|
||
|
|
|
||
|
|
# SSO role takes precedence - only use DB role if SSO didn't provide one
|
||
|
|
# This ensures SSO is the authoritative source for user roles
|
||
|
|
sso_role = user_defined_values.get("user_role")
|
||
|
|
db_role = user_info.user_role if user_info else None
|
||
|
|
|
||
|
|
if _should_use_role_from_sso_response(sso_role):
|
||
|
|
# SSO provided a valid role, keep it and log that we're using it
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"Using SSO role: {sso_role} (DB role was: {db_role})"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# SSO didn't provide a valid role, fall back to DB role or default
|
||
|
|
if user_info is None or user_info.user_role is None:
|
||
|
|
user_defined_values[
|
||
|
|
"user_role"
|
||
|
|
] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"No SSO or DB role found, using default: INTERNAL_USER_VIEW_ONLY"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
user_defined_values["user_role"] = user_info.user_role
|
||
|
|
verbose_proxy_logger.debug(f"Using DB role: {user_info.user_role}")
|
||
|
|
|
||
|
|
# Preserve the user's existing models from the database
|
||
|
|
if user_info is not None and hasattr(user_info, "models") and user_info.models:
|
||
|
|
user_defined_values["models"] = user_info.models
|
||
|
|
|
||
|
|
return user_defined_values
|
||
|
|
|
||
|
|
|
||
|
|
async def check_and_update_if_proxy_admin_id(
|
||
|
|
user_role: str, user_id: str, prisma_client: Optional[PrismaClient]
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
- Check if user role in DB is admin
|
||
|
|
- If not, update user role in DB to admin role
|
||
|
|
"""
|
||
|
|
proxy_admin_id = os.getenv("PROXY_ADMIN_ID")
|
||
|
|
if proxy_admin_id is not None and proxy_admin_id == user_id:
|
||
|
|
if user_role and user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||
|
|
return user_role
|
||
|
|
|
||
|
|
if prisma_client:
|
||
|
|
await prisma_client.db.litellm_usertable.update(
|
||
|
|
where={"user_id": user_id},
|
||
|
|
data={"user_role": LitellmUserRoles.PROXY_ADMIN.value},
|
||
|
|
)
|
||
|
|
|
||
|
|
user_role = LitellmUserRoles.PROXY_ADMIN.value
|
||
|
|
|
||
|
|
return user_role
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||
|
|
async def auth_callback(request: Request, state: Optional[str] = None): # noqa: PLR0915
|
||
|
|
"""Verify login"""
|
||
|
|
verbose_proxy_logger.info(f"Starting SSO callback with state: {state}")
|
||
|
|
|
||
|
|
# Check if this is a CLI login (state starts with our CLI prefix)
|
||
|
|
from litellm.constants import LITELLM_CLI_SESSION_TOKEN_PREFIX
|
||
|
|
from litellm.proxy._types import LiteLLM_JWTAuth
|
||
|
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||
|
|
from litellm.proxy.proxy_server import (
|
||
|
|
general_settings,
|
||
|
|
jwt_handler,
|
||
|
|
master_key,
|
||
|
|
prisma_client,
|
||
|
|
user_api_key_cache,
|
||
|
|
)
|
||
|
|
|
||
|
|
if prisma_client is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||
|
|
)
|
||
|
|
|
||
|
|
sso_jwt_handler: Optional[JWTHandler] = None
|
||
|
|
ui_access_mode = general_settings.get("ui_access_mode", None)
|
||
|
|
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
|
||
|
|
sso_jwt_handler = JWTHandler()
|
||
|
|
sso_jwt_handler.update_environment(
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
litellm_jwtauth=LiteLLM_JWTAuth(
|
||
|
|
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
|
||
|
|
"sso_group_jwt_field", None
|
||
|
|
),
|
||
|
|
),
|
||
|
|
leeway=0,
|
||
|
|
)
|
||
|
|
|
||
|
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||
|
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||
|
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||
|
|
received_response: Optional[dict] = None
|
||
|
|
# get url from request
|
||
|
|
if master_key is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="master_key",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||
|
|
request=request, sso_callback_route="sso/callback"
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(f"Redirecting to {redirect_url}")
|
||
|
|
result = None
|
||
|
|
if google_client_id is not None:
|
||
|
|
result = await GoogleSSOHandler.get_google_callback_response(
|
||
|
|
request=request,
|
||
|
|
google_client_id=google_client_id,
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
)
|
||
|
|
elif microsoft_client_id is not None:
|
||
|
|
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||
|
|
request=request,
|
||
|
|
microsoft_client_id=microsoft_client_id,
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
)
|
||
|
|
|
||
|
|
elif generic_client_id is not None:
|
||
|
|
result, received_response = await get_generic_sso_response(
|
||
|
|
request=request,
|
||
|
|
jwt_handler=jwt_handler,
|
||
|
|
generic_client_id=generic_client_id,
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
sso_jwt_handler=sso_jwt_handler,
|
||
|
|
)
|
||
|
|
|
||
|
|
if result is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=401,
|
||
|
|
detail="Result not returned by SSO provider.",
|
||
|
|
)
|
||
|
|
|
||
|
|
if state and state.startswith(f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:"):
|
||
|
|
# Extract the key ID and existing_key from the state
|
||
|
|
# State format: {PREFIX}:{key}:{existing_key} or {PREFIX}:{key}
|
||
|
|
state_parts = state.split(":", 2) # Split into max 3 parts
|
||
|
|
key_id = state_parts[1] if len(state_parts) > 1 else None
|
||
|
|
existing_key = state_parts[2] if len(state_parts) > 2 else None
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"CLI SSO callback detected for key: {key_id}, existing_key: {existing_key}"
|
||
|
|
)
|
||
|
|
return await cli_sso_callback(
|
||
|
|
request=request, key=key_id, existing_key=existing_key, result=result
|
||
|
|
)
|
||
|
|
|
||
|
|
return await SSOAuthenticationHandler.get_redirect_response_from_openid(
|
||
|
|
result=result,
|
||
|
|
request=request,
|
||
|
|
received_response=received_response,
|
||
|
|
generic_client_id=generic_client_id,
|
||
|
|
ui_access_mode=ui_access_mode,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def cli_sso_callback(
|
||
|
|
request: Request,
|
||
|
|
key: Optional[str] = None,
|
||
|
|
existing_key: Optional[str] = None,
|
||
|
|
result: Optional[Union[OpenID, dict]] = None,
|
||
|
|
):
|
||
|
|
"""CLI SSO callback - stores session info for JWT generation on polling"""
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"CLI SSO callback for key: {key}, existing_key: {existing_key}"
|
||
|
|
)
|
||
|
|
|
||
|
|
from litellm.proxy.proxy_server import (
|
||
|
|
prisma_client,
|
||
|
|
proxy_logging_obj,
|
||
|
|
user_api_key_cache,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not key or not key.startswith("sk-"):
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=400,
|
||
|
|
detail="Invalid key parameter. Must be a valid key ID starting with 'sk-'",
|
||
|
|
)
|
||
|
|
|
||
|
|
if prisma_client is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||
|
|
)
|
||
|
|
|
||
|
|
if result is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500,
|
||
|
|
detail="SSO authentication failed - no result returned from provider",
|
||
|
|
)
|
||
|
|
|
||
|
|
# After None check, cast to non-None type for type checker
|
||
|
|
result_non_none: Union[OpenID, dict] = cast(Union[OpenID, dict], result)
|
||
|
|
|
||
|
|
parsed_openid_result = SSOAuthenticationHandler._get_user_email_and_id_from_result(
|
||
|
|
result=result_non_none
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Get full user info from DB
|
||
|
|
user_info = await get_user_info_from_db(
|
||
|
|
result=result_non_none,
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
user_email=parsed_openid_result.get("user_email"),
|
||
|
|
user_defined_values=None,
|
||
|
|
alternate_user_id=parsed_openid_result.get("user_id"),
|
||
|
|
)
|
||
|
|
|
||
|
|
if user_info is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500, detail="Failed to retrieve user information from SSO"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Store session info in cache (10 min TTL)
|
||
|
|
from litellm.constants import CLI_SSO_SESSION_CACHE_KEY_PREFIX
|
||
|
|
|
||
|
|
# Get all teams from user_info - CLI will let user select which one
|
||
|
|
teams: List[str] = []
|
||
|
|
if hasattr(user_info, "teams") and user_info.teams:
|
||
|
|
teams = user_info.teams if isinstance(user_info.teams, list) else []
|
||
|
|
|
||
|
|
# Also fetch team aliases for a better CLI UX. We keep the original
|
||
|
|
# "teams" list of IDs for backwards compatibility and add an
|
||
|
|
# optional "team_details" field containing objects with both
|
||
|
|
# team_id and team_alias.
|
||
|
|
team_details: List[Dict[str, Any]] = []
|
||
|
|
try:
|
||
|
|
if teams:
|
||
|
|
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
|
||
|
|
where={"team_id": {"in": teams}}
|
||
|
|
)
|
||
|
|
for team_row in prisma_teams:
|
||
|
|
team_dict = team_row.model_dump()
|
||
|
|
team_details.append(
|
||
|
|
{
|
||
|
|
"team_id": team_dict.get("team_id"),
|
||
|
|
"team_alias": team_dict.get("team_alias"),
|
||
|
|
}
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
# If anything goes wrong here, fall back gracefully without
|
||
|
|
# impacting the SSO flow.
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
f"Error fetching team details for CLI SSO session: {e}"
|
||
|
|
)
|
||
|
|
|
||
|
|
session_data = {
|
||
|
|
"user_id": user_info.user_id,
|
||
|
|
"user_role": user_info.user_role,
|
||
|
|
"models": user_info.models if hasattr(user_info, "models") else [],
|
||
|
|
"user_email": parsed_openid_result.get("user_email"),
|
||
|
|
"teams": teams,
|
||
|
|
# Optional rich metadata for clients that want nicer display
|
||
|
|
"team_details": team_details,
|
||
|
|
}
|
||
|
|
|
||
|
|
cache_key = f"{CLI_SSO_SESSION_CACHE_KEY_PREFIX}:{key}"
|
||
|
|
user_api_key_cache.set_cache(key=cache_key, value=session_data, ttl=600)
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Return success page
|
||
|
|
from fastapi.responses import HTMLResponse
|
||
|
|
|
||
|
|
from litellm.proxy.common_utils.html_forms.cli_sso_success import (
|
||
|
|
render_cli_sso_success_page,
|
||
|
|
)
|
||
|
|
|
||
|
|
html_content = render_cli_sso_success_page()
|
||
|
|
return HTMLResponse(content=html_content, status_code=200)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Error with CLI SSO callback: {e}")
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500, detail=f"Failed to process CLI SSO: {str(e)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/sso/cli/poll/{key_id}", tags=["experimental"], include_in_schema=False)
|
||
|
|
async def cli_poll_key(key_id: str, team_id: Optional[str] = None):
|
||
|
|
"""
|
||
|
|
CLI polling endpoint - retrieves session from cache and generates JWT.
|
||
|
|
|
||
|
|
Flow:
|
||
|
|
1. First poll (no team_id): Returns teams list without generating JWT
|
||
|
|
2. Second poll (with team_id): Generates JWT with selected team and deletes session
|
||
|
|
|
||
|
|
Args:
|
||
|
|
key_id: The session key ID
|
||
|
|
team_id: Optional team ID to assign to the JWT. If provided, must be one of user's teams.
|
||
|
|
"""
|
||
|
|
from litellm.constants import CLI_SSO_SESSION_CACHE_KEY_PREFIX
|
||
|
|
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken
|
||
|
|
from litellm.proxy.proxy_server import user_api_key_cache
|
||
|
|
|
||
|
|
if not key_id.startswith("sk-"):
|
||
|
|
raise HTTPException(status_code=400, detail="Invalid key ID format")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Look up session in cache
|
||
|
|
cache_key = f"{CLI_SSO_SESSION_CACHE_KEY_PREFIX}:{key_id}"
|
||
|
|
session_data = user_api_key_cache.get_cache(key=cache_key)
|
||
|
|
|
||
|
|
if session_data:
|
||
|
|
user_teams = session_data.get("teams", [])
|
||
|
|
user_team_details = session_data.get("team_details")
|
||
|
|
user_id = session_data["user_id"]
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"CLI poll: user={user_id}, team_id={team_id}, user_teams={user_teams}, num_teams={len(user_teams)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# If no team_id provided and user has teams, return teams list for selection
|
||
|
|
# Don't generate JWT yet - let CLI select a team first. For newer
|
||
|
|
# clients we return rich team details (id + alias); older clients
|
||
|
|
# can continue to rely on the simple "teams" list.
|
||
|
|
if team_id is None and len(user_teams) > 1:
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"Returning teams list for user {user_id} to select from: {user_teams}"
|
||
|
|
)
|
||
|
|
# Best-effort construction of team_details if it wasn't
|
||
|
|
# already cached for some reason.
|
||
|
|
team_details_response: Optional[List[Dict[str, Any]]] = None
|
||
|
|
if isinstance(user_team_details, list) and user_team_details:
|
||
|
|
team_details_response = user_team_details
|
||
|
|
elif user_teams:
|
||
|
|
team_details_response = [
|
||
|
|
{"team_id": t, "team_alias": None} for t in user_teams
|
||
|
|
]
|
||
|
|
return {
|
||
|
|
"status": "ready",
|
||
|
|
"user_id": user_id,
|
||
|
|
"teams": user_teams,
|
||
|
|
"team_details": team_details_response,
|
||
|
|
"requires_team_selection": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Validate team_id if provided
|
||
|
|
if team_id is not None:
|
||
|
|
if team_id not in user_teams:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=403,
|
||
|
|
detail=f"User does not belong to team: {team_id}. Available teams: {user_teams}",
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# If no team_id provided and user has 0 or 1 team, use first team (or None)
|
||
|
|
team_id = user_teams[0] if len(user_teams) > 0 else None
|
||
|
|
|
||
|
|
# Create user object for JWT generation
|
||
|
|
user_info = LiteLLM_UserTable(
|
||
|
|
user_id=user_id,
|
||
|
|
user_role=session_data["user_role"],
|
||
|
|
models=session_data.get("models", []),
|
||
|
|
max_budget=litellm.max_ui_session_budget,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Generate CLI JWT on-demand (expiration configurable via LITELLM_CLI_JWT_EXPIRATION_HOURS)
|
||
|
|
# Pass selected team_id to ensure JWT has correct team
|
||
|
|
jwt_token = ExperimentalUIJWTToken.get_cli_jwt_auth_token(
|
||
|
|
user_info=user_info, team_id=team_id
|
||
|
|
)
|
||
|
|
|
||
|
|
# Delete cache entry (single-use)
|
||
|
|
user_api_key_cache.delete_cache(key=cache_key)
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"CLI JWT generated for user: {user_id}, team: {team_id}"
|
||
|
|
)
|
||
|
|
return {
|
||
|
|
"status": "ready",
|
||
|
|
"key": jwt_token,
|
||
|
|
"user_id": user_id,
|
||
|
|
"team_id": team_id,
|
||
|
|
"teams": user_teams,
|
||
|
|
# Echo back any team details we have so clients can
|
||
|
|
# present nicer information if needed.
|
||
|
|
"team_details": user_team_details,
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
return {"status": "pending"}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Error polling for CLI JWT: {e}")
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=500, detail=f"Error checking session status: {str(e)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def insert_sso_user(
|
||
|
|
result_openid: Optional[Union[OpenID, dict]],
|
||
|
|
user_defined_values: Optional[SSOUserDefinedValues] = None,
|
||
|
|
) -> NewUserResponse:
|
||
|
|
"""
|
||
|
|
Helper function to create a New User in LiteLLM DB after a successful SSO login
|
||
|
|
|
||
|
|
Args:
|
||
|
|
result_openid (OpenID): User information in OpenID format if the login was successful.
|
||
|
|
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple[str, str]: User ID and User Role
|
||
|
|
"""
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Inserting SSO user into DB. User values: {user_defined_values}"
|
||
|
|
)
|
||
|
|
if result_openid is None:
|
||
|
|
raise ValueError("result_openid is None")
|
||
|
|
if isinstance(result_openid, dict):
|
||
|
|
result_openid = OpenID(**result_openid)
|
||
|
|
|
||
|
|
if user_defined_values is None:
|
||
|
|
raise ValueError("user_defined_values is None")
|
||
|
|
|
||
|
|
# Apply default_internal_user_params
|
||
|
|
if litellm.default_internal_user_params:
|
||
|
|
# Preserve the SSO-extracted role if it's a valid LiteLLM role,
|
||
|
|
# regardless of how it was determined (role_mappings, Microsoft app_roles,
|
||
|
|
# GENERIC_USER_ROLE_ATTRIBUTE, custom SSO handler, etc.)
|
||
|
|
sso_role = user_defined_values.get("user_role")
|
||
|
|
if _should_use_role_from_sso_response(sso_role):
|
||
|
|
# Preserve the SSO-extracted role, but apply other defaults
|
||
|
|
preserved_role = sso_role
|
||
|
|
user_defined_values.update(litellm.default_internal_user_params) # type: ignore
|
||
|
|
user_defined_values["user_role"] = preserved_role # Restore preserved role
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Preserved SSO-extracted role '{preserved_role}'"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# SSO didn't provide a valid role, apply all defaults including role
|
||
|
|
user_defined_values.update(litellm.default_internal_user_params) # type: ignore
|
||
|
|
|
||
|
|
# Set budget for internal users
|
||
|
|
if user_defined_values.get("user_role") == LitellmUserRoles.INTERNAL_USER.value:
|
||
|
|
if user_defined_values.get("max_budget") is None:
|
||
|
|
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||
|
|
if user_defined_values.get("budget_duration") is None:
|
||
|
|
user_defined_values[
|
||
|
|
"budget_duration"
|
||
|
|
] = litellm.internal_user_budget_duration
|
||
|
|
|
||
|
|
if user_defined_values["user_role"] is None:
|
||
|
|
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||
|
|
|
||
|
|
new_user_request = NewUserRequest(
|
||
|
|
user_id=user_defined_values["user_id"],
|
||
|
|
user_email=normalize_email(user_defined_values["user_email"]),
|
||
|
|
user_role=user_defined_values["user_role"], # type: ignore
|
||
|
|
max_budget=user_defined_values["max_budget"],
|
||
|
|
budget_duration=user_defined_values["budget_duration"],
|
||
|
|
sso_user_id=user_defined_values["user_id"],
|
||
|
|
auto_create_key=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
if result_openid and hasattr(result_openid, "provider"):
|
||
|
|
new_user_request.metadata = {
|
||
|
|
"auth_provider": getattr(result_openid, "provider")
|
||
|
|
}
|
||
|
|
|
||
|
|
response = await new_user(
|
||
|
|
data=new_user_request,
|
||
|
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||
|
|
)
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
|
||
|
|
@router.get(
|
||
|
|
"/sso/get/ui_settings",
|
||
|
|
tags=["experimental"],
|
||
|
|
include_in_schema=False,
|
||
|
|
dependencies=[Depends(user_api_key_auth)],
|
||
|
|
)
|
||
|
|
async def get_ui_settings(request: Request):
|
||
|
|
from litellm.proxy.proxy_server import general_settings, proxy_state
|
||
|
|
|
||
|
|
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
||
|
|
_logout_url = os.getenv("PROXY_LOGOUT_URL", None)
|
||
|
|
_api_doc_base_url = os.getenv("LITELLM_UI_API_DOC_BASE_URL", None)
|
||
|
|
_is_sso_enabled = _has_user_setup_sso()
|
||
|
|
disable_expensive_db_queries = (
|
||
|
|
proxy_state.get_proxy_state_variable("spend_logs_row_count")
|
||
|
|
> MAX_SPENDLOG_ROWS_TO_QUERY
|
||
|
|
)
|
||
|
|
default_team_disabled = general_settings.get("default_team_disabled", False)
|
||
|
|
if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ:
|
||
|
|
if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true":
|
||
|
|
default_team_disabled = True
|
||
|
|
|
||
|
|
return {
|
||
|
|
"PROXY_BASE_URL": _proxy_base_url,
|
||
|
|
"PROXY_LOGOUT_URL": _logout_url,
|
||
|
|
"LITELLM_UI_API_DOC_BASE_URL": _api_doc_base_url,
|
||
|
|
"DEFAULT_TEAM_DISABLED": default_team_disabled,
|
||
|
|
"SSO_ENABLED": _is_sso_enabled,
|
||
|
|
"NUM_SPEND_LOGS_ROWS": proxy_state.get_proxy_state_variable(
|
||
|
|
"spend_logs_row_count"
|
||
|
|
),
|
||
|
|
"DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get(
|
||
|
|
"/sso/readiness",
|
||
|
|
tags=["experimental"],
|
||
|
|
dependencies=[Depends(user_api_key_auth)],
|
||
|
|
)
|
||
|
|
async def sso_readiness():
|
||
|
|
"""
|
||
|
|
Health endpoint for checking SSO readiness.
|
||
|
|
Checks if the configured SSO provider has all required environment variables set in memory.
|
||
|
|
"""
|
||
|
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||
|
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||
|
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||
|
|
|
||
|
|
# Determine which SSO provider is configured
|
||
|
|
configured_provider = None
|
||
|
|
if google_client_id is not None:
|
||
|
|
configured_provider = "google"
|
||
|
|
elif microsoft_client_id is not None:
|
||
|
|
configured_provider = "microsoft"
|
||
|
|
elif generic_client_id is not None:
|
||
|
|
configured_provider = "generic"
|
||
|
|
|
||
|
|
# If no SSO is configured, return healthy (SSO is optional)
|
||
|
|
if configured_provider is None:
|
||
|
|
return {
|
||
|
|
"status": "healthy",
|
||
|
|
"sso_configured": False,
|
||
|
|
"message": "No SSO provider configured",
|
||
|
|
}
|
||
|
|
|
||
|
|
# Check required environment variables for the configured provider
|
||
|
|
missing_vars = []
|
||
|
|
|
||
|
|
if configured_provider == "google":
|
||
|
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||
|
|
if google_client_secret is None:
|
||
|
|
missing_vars.append("GOOGLE_CLIENT_SECRET")
|
||
|
|
|
||
|
|
elif configured_provider == "microsoft":
|
||
|
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||
|
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||
|
|
if microsoft_client_secret is None:
|
||
|
|
missing_vars.append("MICROSOFT_CLIENT_SECRET")
|
||
|
|
if microsoft_tenant is None:
|
||
|
|
missing_vars.append("MICROSOFT_TENANT")
|
||
|
|
|
||
|
|
elif configured_provider == "generic":
|
||
|
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||
|
|
generic_authorization_endpoint = os.getenv(
|
||
|
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||
|
|
)
|
||
|
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||
|
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||
|
|
if generic_client_secret is None:
|
||
|
|
missing_vars.append("GENERIC_CLIENT_SECRET")
|
||
|
|
if generic_authorization_endpoint is None:
|
||
|
|
missing_vars.append("GENERIC_AUTHORIZATION_ENDPOINT")
|
||
|
|
if generic_token_endpoint is None:
|
||
|
|
missing_vars.append("GENERIC_TOKEN_ENDPOINT")
|
||
|
|
if generic_userinfo_endpoint is None:
|
||
|
|
missing_vars.append("GENERIC_USERINFO_ENDPOINT")
|
||
|
|
|
||
|
|
# If all required variables are present, return healthy
|
||
|
|
if len(missing_vars) == 0:
|
||
|
|
return {
|
||
|
|
"status": "healthy",
|
||
|
|
"sso_configured": True,
|
||
|
|
"provider": configured_provider,
|
||
|
|
"message": f"{configured_provider.capitalize()} SSO is properly configured",
|
||
|
|
}
|
||
|
|
|
||
|
|
# If some variables are missing, return unhealthy
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=503,
|
||
|
|
detail={
|
||
|
|
"status": "unhealthy",
|
||
|
|
"sso_configured": True,
|
||
|
|
"provider": configured_provider,
|
||
|
|
"missing_environment_variables": missing_vars,
|
||
|
|
"message": f"{configured_provider.capitalize()} SSO is configured but missing required environment variables: {', '.join(missing_vars)}",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class SSOAuthenticationHandler:
|
||
|
|
"""
|
||
|
|
Handler for SSO Authentication across all SSO providers
|
||
|
|
"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_sso_login_redirect(
|
||
|
|
redirect_url: str,
|
||
|
|
google_client_id: Optional[str] = None,
|
||
|
|
microsoft_client_id: Optional[str] = None,
|
||
|
|
generic_client_id: Optional[str] = None,
|
||
|
|
state: Optional[str] = None,
|
||
|
|
) -> Optional[RedirectResponse]:
|
||
|
|
"""
|
||
|
|
Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url`
|
||
|
|
|
||
|
|
Args:
|
||
|
|
redirect_url (str): The URL to redirect the user to after login
|
||
|
|
google_client_id (Optional[str], optional): The Google Client ID. Defaults to None.
|
||
|
|
microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None.
|
||
|
|
generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
RedirectResponse: The redirect response from the SSO provider.
|
||
|
|
"""
|
||
|
|
# Google SSO Auth
|
||
|
|
if google_client_id is not None:
|
||
|
|
from fastapi_sso.sso.google import GoogleSSO
|
||
|
|
|
||
|
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||
|
|
if google_client_secret is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GOOGLE_CLIENT_SECRET",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
google_sso = GoogleSSO(
|
||
|
|
client_id=google_client_id,
|
||
|
|
client_secret=google_client_secret,
|
||
|
|
redirect_uri=redirect_url,
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
||
|
|
)
|
||
|
|
with google_sso:
|
||
|
|
return await google_sso.get_login_redirect(state=state)
|
||
|
|
# Microsoft SSO Auth
|
||
|
|
elif microsoft_client_id is not None:
|
||
|
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||
|
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||
|
|
if microsoft_client_secret is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="MICROSOFT_CLIENT_SECRET",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
microsoft_sso = CustomMicrosoftSSO(
|
||
|
|
client_id=microsoft_client_id,
|
||
|
|
client_secret=microsoft_client_secret,
|
||
|
|
tenant=microsoft_tenant,
|
||
|
|
redirect_uri=redirect_url,
|
||
|
|
allow_insecure_http=True,
|
||
|
|
)
|
||
|
|
with microsoft_sso:
|
||
|
|
return await microsoft_sso.get_login_redirect(state=state)
|
||
|
|
elif generic_client_id is not None:
|
||
|
|
from fastapi_sso.sso.base import DiscoveryDocument
|
||
|
|
from fastapi_sso.sso.generic import create_provider
|
||
|
|
|
||
|
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||
|
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(
|
||
|
|
" "
|
||
|
|
)
|
||
|
|
generic_authorization_endpoint = os.getenv(
|
||
|
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||
|
|
)
|
||
|
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||
|
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||
|
|
if generic_client_secret is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_CLIENT_SECRET",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
if generic_authorization_endpoint is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
if generic_token_endpoint is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_TOKEN_ENDPOINT",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
if generic_userinfo_endpoint is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GENERIC_USERINFO_ENDPOINT",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||
|
|
)
|
||
|
|
discovery = DiscoveryDocument(
|
||
|
|
authorization_endpoint=generic_authorization_endpoint,
|
||
|
|
token_endpoint=generic_token_endpoint,
|
||
|
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||
|
|
)
|
||
|
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||
|
|
generic_sso = SSOProvider(
|
||
|
|
client_id=generic_client_id,
|
||
|
|
client_secret=generic_client_secret,
|
||
|
|
redirect_uri=redirect_url,
|
||
|
|
allow_insecure_http=True,
|
||
|
|
scope=generic_scope,
|
||
|
|
)
|
||
|
|
return await SSOAuthenticationHandler.get_generic_sso_redirect_response(
|
||
|
|
generic_sso=generic_sso,
|
||
|
|
state=state,
|
||
|
|
generic_authorization_endpoint=generic_authorization_endpoint,
|
||
|
|
)
|
||
|
|
raise ValueError(
|
||
|
|
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_generic_sso_redirect_response(
|
||
|
|
generic_sso: Any,
|
||
|
|
state: Optional[str] = None,
|
||
|
|
generic_authorization_endpoint: Optional[str] = None,
|
||
|
|
) -> Optional[RedirectResponse]:
|
||
|
|
"""
|
||
|
|
Get the redirect response for Generic SSO
|
||
|
|
"""
|
||
|
|
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||
|
|
|
||
|
|
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache
|
||
|
|
|
||
|
|
with generic_sso:
|
||
|
|
# TODO: state should be a random string and added to the user session with cookie
|
||
|
|
# or a cryptographicly signed state that we can verify stateless
|
||
|
|
# For simplification we are using a static state, this is not perfect but some
|
||
|
|
# SSO providers do not allow stateless verification
|
||
|
|
(
|
||
|
|
redirect_params,
|
||
|
|
code_verifier,
|
||
|
|
) = SSOAuthenticationHandler._get_generic_sso_redirect_params(
|
||
|
|
state=state,
|
||
|
|
generic_authorization_endpoint=generic_authorization_endpoint,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Separate PKCE params from state params (fastapi-sso doesn't accept code_challenge)
|
||
|
|
pkce_params = {}
|
||
|
|
state_only_params = {}
|
||
|
|
for key, value in redirect_params.items():
|
||
|
|
if key in ("code_challenge", "code_challenge_method"):
|
||
|
|
pkce_params[key] = value
|
||
|
|
else:
|
||
|
|
state_only_params[key] = value
|
||
|
|
|
||
|
|
# Get the redirect response from fastapi-sso with only state param
|
||
|
|
redirect_response = await generic_sso.get_login_redirect(**state_only_params) # type: ignore
|
||
|
|
|
||
|
|
# If PKCE is enabled, add PKCE parameters to the redirect URL
|
||
|
|
if code_verifier and "state" in redirect_params:
|
||
|
|
# Store code_verifier in cache (10 min TTL). Wrap in dict for proper
|
||
|
|
# JSON serialization in Redis. Use Redis when available so callbacks
|
||
|
|
# landing on another pod can retrieve it (multi-pod SSO).
|
||
|
|
cache_key = f"pkce_verifier:{redirect_params['state']}"
|
||
|
|
if redis_usage_cache is not None:
|
||
|
|
await redis_usage_cache.async_set_cache(
|
||
|
|
key=cache_key,
|
||
|
|
value={"code_verifier": code_verifier},
|
||
|
|
ttl=600,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
await user_api_key_cache.async_set_cache(
|
||
|
|
key=cache_key,
|
||
|
|
value={"code_verifier": code_verifier},
|
||
|
|
ttl=600,
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"PKCE code_verifier stored in cache (TTL: 600s)"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add PKCE parameters to the authorization URL
|
||
|
|
if pkce_params:
|
||
|
|
parsed_url = urlparse(str(redirect_response.headers["location"]))
|
||
|
|
query_params = parse_qs(parsed_url.query)
|
||
|
|
|
||
|
|
# Add PKCE parameters
|
||
|
|
for key, value in pkce_params.items():
|
||
|
|
query_params[key] = [value]
|
||
|
|
|
||
|
|
# Reconstruct the URL with PKCE parameters
|
||
|
|
new_query = urlencode(query_params, doseq=True)
|
||
|
|
new_url = urlunparse(
|
||
|
|
(
|
||
|
|
parsed_url.scheme,
|
||
|
|
parsed_url.netloc,
|
||
|
|
parsed_url.path,
|
||
|
|
parsed_url.params,
|
||
|
|
new_query,
|
||
|
|
parsed_url.fragment,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Update the redirect response
|
||
|
|
redirect_response.headers["location"] = new_url
|
||
|
|
return redirect_response
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _get_generic_sso_redirect_params(
|
||
|
|
state: Optional[str] = None,
|
||
|
|
generic_authorization_endpoint: Optional[str] = None,
|
||
|
|
) -> Tuple[dict, Optional[str]]:
|
||
|
|
"""
|
||
|
|
Get redirect parameters for Generic SSO with proper state priority handling.
|
||
|
|
Optionally generates PKCE parameters if GENERIC_CLIENT_USE_PKCE is enabled.
|
||
|
|
|
||
|
|
Priority order:
|
||
|
|
1. CLI state (if provided)
|
||
|
|
2. GENERIC_CLIENT_STATE environment variable
|
||
|
|
3. Generated UUID (required by Okta and most OAuth providers)
|
||
|
|
|
||
|
|
|
||
|
|
Args:
|
||
|
|
state: Optional state parameter (e.g., CLI state)
|
||
|
|
generic_authorization_endpoint: Authorization endpoint URL
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple[dict, Optional[str]]:
|
||
|
|
- Redirect parameters for SSO login (may include PKCE params)
|
||
|
|
- code_verifier (if PKCE is enabled, None otherwise)
|
||
|
|
"""
|
||
|
|
redirect_params = {}
|
||
|
|
code_verifier: Optional[str] = None
|
||
|
|
|
||
|
|
if state:
|
||
|
|
# CLI state takes priority
|
||
|
|
# the litellm proxy cli sends the "state" parameter to the proxy server for auth. We should maintain the state parameter for the cli if it is provided
|
||
|
|
redirect_params["state"] = state
|
||
|
|
else:
|
||
|
|
generic_client_state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||
|
|
if generic_client_state:
|
||
|
|
redirect_params["state"] = generic_client_state
|
||
|
|
else:
|
||
|
|
redirect_params["state"] = uuid.uuid4().hex
|
||
|
|
|
||
|
|
# Handle PKCE (Proof Key for Code Exchange) if enabled
|
||
|
|
# Set GENERIC_CLIENT_USE_PKCE=true to enable PKCE for enhanced OAuth security
|
||
|
|
use_pkce = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true"
|
||
|
|
|
||
|
|
if use_pkce:
|
||
|
|
(
|
||
|
|
code_verifier,
|
||
|
|
code_challenge,
|
||
|
|
) = SSOAuthenticationHandler.generate_pkce_params()
|
||
|
|
redirect_params["code_challenge"] = code_challenge
|
||
|
|
redirect_params["code_challenge_method"] = "S256"
|
||
|
|
verbose_proxy_logger.debug("PKCE enabled for authorization request")
|
||
|
|
|
||
|
|
return redirect_params, code_verifier
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def should_use_sso_handler(
|
||
|
|
google_client_id: Optional[str] = None,
|
||
|
|
microsoft_client_id: Optional[str] = None,
|
||
|
|
generic_client_id: Optional[str] = None,
|
||
|
|
) -> bool:
|
||
|
|
if (
|
||
|
|
google_client_id is not None
|
||
|
|
or microsoft_client_id is not None
|
||
|
|
or generic_client_id is not None
|
||
|
|
):
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_redirect_url_for_sso(
|
||
|
|
request: Request,
|
||
|
|
sso_callback_route: str,
|
||
|
|
existing_key: Optional[str] = None,
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
Get the redirect URL for SSO
|
||
|
|
|
||
|
|
Note: existing_key is not added to the URL to avoid changing the callback URL.
|
||
|
|
It should be passed via the state parameter instead.
|
||
|
|
"""
|
||
|
|
from litellm.proxy.utils import get_custom_url
|
||
|
|
|
||
|
|
redirect_url = get_custom_url(request_base_url=str(request.base_url))
|
||
|
|
if redirect_url.endswith("/"):
|
||
|
|
redirect_url += sso_callback_route
|
||
|
|
else:
|
||
|
|
redirect_url += "/" + sso_callback_route
|
||
|
|
|
||
|
|
return redirect_url
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def upsert_sso_user(
|
||
|
|
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||
|
|
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
|
||
|
|
user_email: Optional[str],
|
||
|
|
user_defined_values: Optional[SSOUserDefinedValues],
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Connects the SSO Users to the User Table in LiteLLM DB
|
||
|
|
|
||
|
|
- If user on LiteLLM DB, update the user_email and user_role (if SSO provides valid role) with the SSO values
|
||
|
|
- If user not on LiteLLM DB, insert the user into LiteLLM DB
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
if user_info is not None:
|
||
|
|
user_id = user_info.user_id
|
||
|
|
update_data = _build_sso_user_update_data(
|
||
|
|
result=result,
|
||
|
|
user_email=user_email,
|
||
|
|
user_id=user_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
await prisma_client.db.litellm_usertable.update_many(
|
||
|
|
where={"user_id": user_id}, data=update_data
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"user not in DB, inserting user into LiteLLM DB"
|
||
|
|
)
|
||
|
|
# user not in DB, insert User into LiteLLM DB
|
||
|
|
user_info = await insert_sso_user(
|
||
|
|
result_openid=result,
|
||
|
|
user_defined_values=user_defined_values,
|
||
|
|
)
|
||
|
|
return user_info
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
f"Error upserting SSO user into LiteLLM DB: {e}"
|
||
|
|
)
|
||
|
|
return user_info
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def add_user_to_teams_from_sso_response(
|
||
|
|
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||
|
|
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Adds the user as a team member to the teams specified in the SSO responses `team_ids` field
|
||
|
|
|
||
|
|
|
||
|
|
The `team_ids` field is populated by litellm after processing the SSO response
|
||
|
|
"""
|
||
|
|
if user_info is None:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"User not found in LiteLLM DB, skipping team member addition"
|
||
|
|
)
|
||
|
|
return
|
||
|
|
sso_teams = getattr(result, "team_ids", [])
|
||
|
|
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def verify_user_in_restricted_sso_group(
|
||
|
|
general_settings: Dict,
|
||
|
|
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||
|
|
received_response: Optional[dict],
|
||
|
|
) -> Literal[True]:
|
||
|
|
"""
|
||
|
|
when ui_access_mode.type == "restricted_sso_group":
|
||
|
|
|
||
|
|
- result.team_ids should contain the restricted_sso_group
|
||
|
|
- if not, raise a ProxyException
|
||
|
|
- if so, return True
|
||
|
|
- if result.team_ids is None, return False
|
||
|
|
- if result.team_ids is an empty list, return False
|
||
|
|
- if result.team_ids is a list, return True if the restricted_sso_group is in the list, otherwise return False
|
||
|
|
"""
|
||
|
|
|
||
|
|
ui_access_mode = cast(
|
||
|
|
Optional[Union[Dict, str]], general_settings.get("ui_access_mode")
|
||
|
|
)
|
||
|
|
|
||
|
|
if ui_access_mode is None:
|
||
|
|
return True
|
||
|
|
if isinstance(ui_access_mode, str):
|
||
|
|
return True
|
||
|
|
team_ids = getattr(result, "team_ids", [])
|
||
|
|
|
||
|
|
if ui_access_mode.get("type") == "restricted_sso_group":
|
||
|
|
restricted_sso_group = ui_access_mode.get("restricted_sso_group")
|
||
|
|
if restricted_sso_group not in team_ids:
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"User is not in the restricted SSO group: {restricted_sso_group}. User groups: {team_ids}. Received SSO response: {received_response}",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="restricted_sso_group",
|
||
|
|
code=status.HTTP_403_FORBIDDEN,
|
||
|
|
)
|
||
|
|
return True
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def create_litellm_team_from_sso_group(
|
||
|
|
litellm_team_id: str,
|
||
|
|
litellm_team_name: Optional[str] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Creates a Litellm Team from a SSO Group ID
|
||
|
|
|
||
|
|
Your SSO provider might have groups that should be created on LiteLLM
|
||
|
|
|
||
|
|
Use this helper to create a Litellm Team from a SSO Group ID
|
||
|
|
|
||
|
|
Args:
|
||
|
|
litellm_team_id (str): The ID of the Litellm Team
|
||
|
|
litellm_team_name (Optional[str]): The name of the Litellm Team
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import prisma_client
|
||
|
|
|
||
|
|
if prisma_client is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="Prisma client not found. Set it in the proxy_server.py file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="prisma_client",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
team_obj = await prisma_client.db.litellm_teamtable.find_first(
|
||
|
|
where={"team_id": litellm_team_id}
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(f"Team object: {team_obj}")
|
||
|
|
|
||
|
|
# only create a new team if it doesn't exist
|
||
|
|
if team_obj:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
team_request: NewTeamRequest = NewTeamRequest(
|
||
|
|
team_id=litellm_team_id,
|
||
|
|
team_alias=litellm_team_name,
|
||
|
|
)
|
||
|
|
if litellm.default_team_params:
|
||
|
|
team_request = SSOAuthenticationHandler._cast_and_deepcopy_litellm_default_team_params(
|
||
|
|
default_team_params=litellm.default_team_params,
|
||
|
|
litellm_team_id=litellm_team_id,
|
||
|
|
litellm_team_name=litellm_team_name,
|
||
|
|
team_request=team_request,
|
||
|
|
)
|
||
|
|
|
||
|
|
await new_team(
|
||
|
|
data=team_request,
|
||
|
|
# params used for Audit Logging
|
||
|
|
http_request=Request(scope={"type": "http", "method": "POST"}),
|
||
|
|
user_api_key_dict=UserAPIKeyAuth(
|
||
|
|
token="",
|
||
|
|
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
|
||
|
|
),
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _cast_and_deepcopy_litellm_default_team_params(
|
||
|
|
default_team_params: Union[DefaultTeamSSOParams, Dict],
|
||
|
|
team_request: NewTeamRequest,
|
||
|
|
litellm_team_id: str,
|
||
|
|
litellm_team_name: Optional[str] = None,
|
||
|
|
) -> NewTeamRequest:
|
||
|
|
"""
|
||
|
|
Casts and deepcopies the litellm.default_team_params to a NewTeamRequest object
|
||
|
|
|
||
|
|
- Ensures we create a new DefaultTeamSSOParams object
|
||
|
|
- Handle the case where litellm.default_team_params is a dict or a DefaultTeamSSOParams object
|
||
|
|
- Adds the litellm_team_id and litellm_team_name to the DefaultTeamSSOParams object
|
||
|
|
"""
|
||
|
|
if isinstance(default_team_params, dict):
|
||
|
|
_team_request = deepcopy(default_team_params)
|
||
|
|
_team_request["team_id"] = litellm_team_id
|
||
|
|
_team_request["team_alias"] = litellm_team_name
|
||
|
|
team_request = NewTeamRequest(**_team_request)
|
||
|
|
elif isinstance(litellm.default_team_params, DefaultTeamSSOParams):
|
||
|
|
_default_team_params = deepcopy(litellm.default_team_params)
|
||
|
|
_new_team_request = team_request.model_dump()
|
||
|
|
_new_team_request.update(_default_team_params)
|
||
|
|
team_request = NewTeamRequest(**_new_team_request)
|
||
|
|
return team_request
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _get_cli_state(
|
||
|
|
source: Optional[str], key: Optional[str], existing_key: Optional[str] = None
|
||
|
|
) -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Checks the request 'source' if a cli state token was passed in
|
||
|
|
|
||
|
|
This is used to authenticate through the CLI login flow.
|
||
|
|
|
||
|
|
The state parameter format is: {PREFIX}:{key}:{existing_key}
|
||
|
|
- If existing_key is provided, it's included in the state
|
||
|
|
- The state parameter is used to pass data through the OAuth flow without changing the callback URL
|
||
|
|
"""
|
||
|
|
from litellm.constants import (
|
||
|
|
LITELLM_CLI_SESSION_TOKEN_PREFIX,
|
||
|
|
LITELLM_CLI_SOURCE_IDENTIFIER,
|
||
|
|
)
|
||
|
|
|
||
|
|
if source == LITELLM_CLI_SOURCE_IDENTIFIER and key:
|
||
|
|
if existing_key:
|
||
|
|
return f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}:{existing_key}"
|
||
|
|
else:
|
||
|
|
return f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}"
|
||
|
|
else:
|
||
|
|
return None
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _get_user_email_and_id_from_result(
|
||
|
|
result: Optional[Union[OpenID, dict]],
|
||
|
|
generic_client_id: Optional[str] = None,
|
||
|
|
) -> ParsedOpenIDResult:
|
||
|
|
"""
|
||
|
|
Gets the user email and id from the OpenID result after validating the email domain
|
||
|
|
"""
|
||
|
|
user_email: Optional[str] = normalize_email(getattr(result, "email", None))
|
||
|
|
user_id: Optional[str] = (
|
||
|
|
getattr(result, "id", None) if result is not None else None
|
||
|
|
)
|
||
|
|
user_role: Optional[str] = None
|
||
|
|
|
||
|
|
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
|
||
|
|
email_domain = user_email.split("@")[1]
|
||
|
|
allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore
|
||
|
|
if email_domain not in allowed_domains:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=401,
|
||
|
|
detail={
|
||
|
|
"message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format(
|
||
|
|
email_domain, allowed_domains
|
||
|
|
)
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Extract user_role from result (works for all SSO providers)
|
||
|
|
if result is not None:
|
||
|
|
_user_role = getattr(result, "user_role", None)
|
||
|
|
if _user_role is not None:
|
||
|
|
# Convert enum to string if needed
|
||
|
|
user_role = (
|
||
|
|
_user_role.value
|
||
|
|
if isinstance(_user_role, LitellmUserRoles)
|
||
|
|
else _user_role
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Extracted user_role from SSO result: {user_role}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# generic client id - override with custom attribute name if specified
|
||
|
|
if generic_client_id is not None and result is not None:
|
||
|
|
generic_user_role_attribute_name = os.getenv(
|
||
|
|
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||
|
|
)
|
||
|
|
user_id = getattr(result, "id", None)
|
||
|
|
user_email = normalize_email(getattr(result, "email", None))
|
||
|
|
if user_role is None:
|
||
|
|
_role_from_attr = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
||
|
|
if _role_from_attr is not None:
|
||
|
|
# Convert enum to string if needed
|
||
|
|
user_role = (
|
||
|
|
_role_from_attr.value
|
||
|
|
if isinstance(_role_from_attr, LitellmUserRoles)
|
||
|
|
else _role_from_attr
|
||
|
|
)
|
||
|
|
|
||
|
|
if user_id is None and result is not None:
|
||
|
|
_first_name = getattr(result, "first_name", "") or ""
|
||
|
|
_last_name = getattr(result, "last_name", "") or ""
|
||
|
|
user_id = _first_name + _last_name
|
||
|
|
|
||
|
|
if user_email is not None and (user_id is None or len(user_id) == 0):
|
||
|
|
user_id = user_email
|
||
|
|
|
||
|
|
return ParsedOpenIDResult(
|
||
|
|
user_email=user_email,
|
||
|
|
user_id=user_id,
|
||
|
|
user_role=user_role,
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_redirect_response_from_openid( # noqa: PLR0915
|
||
|
|
result: Union[OpenID, dict, CustomOpenID],
|
||
|
|
request: Request,
|
||
|
|
received_response: Optional[dict] = None,
|
||
|
|
generic_client_id: Optional[str] = None,
|
||
|
|
ui_access_mode: Optional[Dict] = None,
|
||
|
|
) -> RedirectResponse:
|
||
|
|
import jwt
|
||
|
|
|
||
|
|
from litellm.proxy.proxy_server import (
|
||
|
|
general_settings,
|
||
|
|
generate_key_helper_fn,
|
||
|
|
master_key,
|
||
|
|
premium_user,
|
||
|
|
proxy_logging_obj,
|
||
|
|
user_api_key_cache,
|
||
|
|
user_custom_sso,
|
||
|
|
)
|
||
|
|
from litellm.proxy.utils import get_prisma_client_or_throw
|
||
|
|
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
|
||
|
|
|
||
|
|
prisma_client = get_prisma_client_or_throw(
|
||
|
|
"Prisma client is None, connect a database to your proxy"
|
||
|
|
)
|
||
|
|
|
||
|
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||
|
|
parsed_openid_result = (
|
||
|
|
SSOAuthenticationHandler._get_user_email_and_id_from_result(
|
||
|
|
result=result, generic_client_id=generic_client_id
|
||
|
|
)
|
||
|
|
)
|
||
|
|
user_email = parsed_openid_result.get("user_email")
|
||
|
|
user_id = parsed_openid_result.get("user_id")
|
||
|
|
user_role = parsed_openid_result.get("user_role")
|
||
|
|
verbose_proxy_logger.info(f"SSO callback result: {result}")
|
||
|
|
|
||
|
|
user_info = None
|
||
|
|
user_id_models: List = []
|
||
|
|
max_internal_user_budget = litellm.max_internal_user_budget
|
||
|
|
internal_user_budget_duration = litellm.internal_user_budget_duration
|
||
|
|
|
||
|
|
# User might not be already created on first generation of key
|
||
|
|
# But if it is, we want their models preferences
|
||
|
|
default_ui_key_values: Dict[str, Any] = {
|
||
|
|
"duration": LITELLM_UI_SESSION_DURATION,
|
||
|
|
"key_max_budget": litellm.max_ui_session_budget,
|
||
|
|
"aliases": {},
|
||
|
|
"config": {},
|
||
|
|
"spend": 0,
|
||
|
|
"team_id": "litellm-dashboard",
|
||
|
|
}
|
||
|
|
user_defined_values: Optional[SSOUserDefinedValues] = None
|
||
|
|
|
||
|
|
if user_custom_sso is not None:
|
||
|
|
if inspect.iscoroutinefunction(user_custom_sso):
|
||
|
|
user_defined_values = await user_custom_sso(result) # type: ignore
|
||
|
|
else:
|
||
|
|
raise ValueError("user_custom_sso must be a coroutine function")
|
||
|
|
elif user_id is not None:
|
||
|
|
user_defined_values = SSOUserDefinedValues(
|
||
|
|
models=user_id_models,
|
||
|
|
user_id=user_id,
|
||
|
|
user_email=user_email,
|
||
|
|
max_budget=max_internal_user_budget,
|
||
|
|
user_role=user_role,
|
||
|
|
budget_duration=internal_user_budget_duration,
|
||
|
|
)
|
||
|
|
|
||
|
|
# (IF SET) Verify user is in restricted SSO group
|
||
|
|
SSOAuthenticationHandler.verify_user_in_restricted_sso_group(
|
||
|
|
general_settings=general_settings,
|
||
|
|
result=result,
|
||
|
|
received_response=received_response,
|
||
|
|
)
|
||
|
|
|
||
|
|
user_info = await get_user_info_from_db(
|
||
|
|
result=result,
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
user_email=user_email,
|
||
|
|
user_defined_values=user_defined_values,
|
||
|
|
alternate_user_id=user_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
user_defined_values = apply_user_info_values_to_sso_user_defined_values(
|
||
|
|
user_info=user_info, user_defined_values=user_defined_values
|
||
|
|
)
|
||
|
|
|
||
|
|
if user_defined_values is None:
|
||
|
|
raise Exception(
|
||
|
|
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
f"user_defined_values for creating ui key: {user_defined_values}"
|
||
|
|
)
|
||
|
|
|
||
|
|
default_ui_key_values.update(user_defined_values)
|
||
|
|
default_ui_key_values["request_type"] = "key"
|
||
|
|
response = await generate_key_helper_fn(
|
||
|
|
**default_ui_key_values, # type: ignore
|
||
|
|
table_name="key",
|
||
|
|
)
|
||
|
|
|
||
|
|
key = response["token"] # type: ignore
|
||
|
|
user_id = response["user_id"] # type: ignore
|
||
|
|
|
||
|
|
user_role = (
|
||
|
|
user_defined_values["user_role"]
|
||
|
|
or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||
|
|
)
|
||
|
|
if user_id and isinstance(user_id, str):
|
||
|
|
user_role = await check_and_update_if_proxy_admin_id(
|
||
|
|
user_role=user_role, user_id=user_id, prisma_client=prisma_client
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
||
|
|
)
|
||
|
|
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
||
|
|
is_admin_only_access = check_is_admin_only_access(ui_access_mode or {})
|
||
|
|
if is_admin_only_access:
|
||
|
|
has_access = has_admin_ui_access(user_role or "")
|
||
|
|
if not has_access:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=401,
|
||
|
|
detail={
|
||
|
|
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
disabled_non_admin_personal_key_creation = (
|
||
|
|
get_disabled_non_admin_personal_key_creation()
|
||
|
|
)
|
||
|
|
litellm_dashboard_ui = get_custom_url(
|
||
|
|
request_base_url=str(request.base_url), route="ui/"
|
||
|
|
)
|
||
|
|
|
||
|
|
if get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
|
||
|
|
_user_info: Optional[LiteLLM_UserTable] = None
|
||
|
|
if (
|
||
|
|
user_defined_values is not None
|
||
|
|
and user_defined_values["user_id"] is not None
|
||
|
|
):
|
||
|
|
_user_info = LiteLLM_UserTable(
|
||
|
|
user_id=user_defined_values["user_id"],
|
||
|
|
user_role=user_defined_values["user_role"] or user_role,
|
||
|
|
models=[],
|
||
|
|
max_budget=litellm.max_ui_session_budget,
|
||
|
|
)
|
||
|
|
if _user_info is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=401,
|
||
|
|
detail={
|
||
|
|
"error": "User Information is required for experimental UI login"
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
||
|
|
_user_info
|
||
|
|
)
|
||
|
|
|
||
|
|
returned_ui_token_object = ReturnedUITokenObject(
|
||
|
|
user_id=cast(str, user_id),
|
||
|
|
key=key,
|
||
|
|
user_email=user_email,
|
||
|
|
user_role=user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value,
|
||
|
|
login_method="sso",
|
||
|
|
premium_user=premium_user,
|
||
|
|
auth_header_name=general_settings.get(
|
||
|
|
"litellm_key_header_name", "Authorization"
|
||
|
|
),
|
||
|
|
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||
|
|
server_root_path=get_server_root_path(),
|
||
|
|
)
|
||
|
|
|
||
|
|
jwt_token = jwt.encode(
|
||
|
|
cast(dict, returned_ui_token_object),
|
||
|
|
master_key or "",
|
||
|
|
algorithm="HS256",
|
||
|
|
)
|
||
|
|
if user_id is not None and isinstance(user_id, str):
|
||
|
|
litellm_dashboard_ui += "?login=success"
|
||
|
|
verbose_proxy_logger.info(f"Redirecting to {litellm_dashboard_ui}")
|
||
|
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||
|
|
redirect_response.set_cookie(key="token", value=jwt_token)
|
||
|
|
return redirect_response
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def prepare_token_exchange_parameters(
|
||
|
|
request: Request,
|
||
|
|
generic_include_client_id: bool,
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Prepare token exchange parameters for Generic SSO.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Request object
|
||
|
|
generic_include_client_id: Generic OAuth Client ID
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: Token exchange parameters
|
||
|
|
"""
|
||
|
|
# Prepare token exchange parameters (may add code_verifier: str later)
|
||
|
|
token_params: Dict[str, Any] = {"include_client_id": generic_include_client_id}
|
||
|
|
|
||
|
|
# Retrieve PKCE code_verifier if PKCE was used in authorization.
|
||
|
|
# Gate on GENERIC_CLIENT_USE_PKCE to avoid an unnecessary Redis round-trip
|
||
|
|
# on every non-PKCE SSO callback.
|
||
|
|
query_params = dict(request.query_params)
|
||
|
|
state = query_params.get("state")
|
||
|
|
|
||
|
|
use_pkce = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true"
|
||
|
|
|
||
|
|
if use_pkce and not state:
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"PKCE is enabled (GENERIC_CLIENT_USE_PKCE=true) but no 'state' parameter "
|
||
|
|
"was found in the callback. The PKCE verifier cannot be retrieved without "
|
||
|
|
"a state value — the token exchange will proceed without code_verifier, "
|
||
|
|
"which the provider may reject. Ensure your OAuth provider returns 'state' "
|
||
|
|
"in the callback redirect."
|
||
|
|
)
|
||
|
|
|
||
|
|
if state and use_pkce:
|
||
|
|
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache
|
||
|
|
|
||
|
|
cache_key = f"pkce_verifier:{state}"
|
||
|
|
if redis_usage_cache is not None:
|
||
|
|
cached_data = await redis_usage_cache.async_get_cache(key=cache_key)
|
||
|
|
else:
|
||
|
|
cached_data = await user_api_key_cache.async_get_cache(key=cache_key)
|
||
|
|
|
||
|
|
code_verifier = None
|
||
|
|
# Track why code_verifier is absent for accurate strict-mode diagnostics.
|
||
|
|
_empty_value_in_dict = False # dict format correct but value is empty/null
|
||
|
|
|
||
|
|
if cached_data:
|
||
|
|
# Extract code_verifier from dict (stored as dict for JSON serialization)
|
||
|
|
if isinstance(cached_data, dict) and "code_verifier" in cached_data:
|
||
|
|
code_verifier = cached_data["code_verifier"]
|
||
|
|
if not code_verifier:
|
||
|
|
# Dict format is correct but value is empty or null. This is
|
||
|
|
# a distinct case from an unrecognized format — the entry exists
|
||
|
|
# but was stored with an empty/null verifier (data integrity issue).
|
||
|
|
_empty_value_in_dict = True
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"PKCE verifier dict for state '%s' has an empty/null code_verifier "
|
||
|
|
"value — may indicate a storage bug. Treating as a cache miss.",
|
||
|
|
state,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
verbose_proxy_logger.debug("PKCE code_verifier retrieved from cache")
|
||
|
|
elif isinstance(cached_data, str):
|
||
|
|
# Handle legacy format (plain string) for backward compatibility
|
||
|
|
code_verifier = cached_data
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"Retrieved code_verifier in legacy plain-string format. "
|
||
|
|
"Future storage will use dict format."
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Defer the detailed ERROR log to the strict-mode branch below
|
||
|
|
# (which includes state and a diagnostic message). Log at DEBUG
|
||
|
|
# here to avoid duplicate ERROR entries in the same request.
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"Unexpected PKCE verifier cache format (type=%s); skipping.",
|
||
|
|
type(cached_data).__name__,
|
||
|
|
)
|
||
|
|
|
||
|
|
if code_verifier:
|
||
|
|
# Add code_verifier to token exchange parameters.
|
||
|
|
token_params["code_verifier"] = code_verifier
|
||
|
|
# Return the cache key so the caller can delete it *after* a
|
||
|
|
# successful token exchange (avoids losing the verifier on retry
|
||
|
|
# if the exchange fails partway through).
|
||
|
|
token_params["_pkce_cache_key"] = cache_key
|
||
|
|
else:
|
||
|
|
await SSOAuthenticationHandler._handle_missing_pkce_verifier(
|
||
|
|
state=state,
|
||
|
|
cache_key=cache_key,
|
||
|
|
cached_data=cached_data,
|
||
|
|
empty_value_in_dict=_empty_value_in_dict,
|
||
|
|
redis_usage_cache=redis_usage_cache,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
)
|
||
|
|
return token_params
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def _handle_missing_pkce_verifier(
|
||
|
|
state: Optional[str],
|
||
|
|
cache_key: str,
|
||
|
|
cached_data: object,
|
||
|
|
empty_value_in_dict: bool,
|
||
|
|
redis_usage_cache: object,
|
||
|
|
user_api_key_cache: object,
|
||
|
|
) -> None:
|
||
|
|
"""Handle the case where PKCE verifier could not be extracted from cache.
|
||
|
|
|
||
|
|
In strict mode (PKCE_STRICT_CACHE_MISS=true) raises ProxyException.
|
||
|
|
Otherwise logs a warning and returns (token exchange proceeds without verifier).
|
||
|
|
"""
|
||
|
|
active_cache = redis_usage_cache if redis_usage_cache is not None else user_api_key_cache
|
||
|
|
strict_cache_miss = (
|
||
|
|
os.getenv("PKCE_STRICT_CACHE_MISS", "false").lower() == "true"
|
||
|
|
)
|
||
|
|
if strict_cache_miss:
|
||
|
|
if empty_value_in_dict:
|
||
|
|
await SSOAuthenticationHandler._delete_pkce_verifier(cache_key)
|
||
|
|
raise ProxyException(
|
||
|
|
message=(
|
||
|
|
f"PKCE verifier for state '{state}' was found in cache but "
|
||
|
|
f"has an empty or null code_verifier value — possible storage bug."
|
||
|
|
),
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="PKCE_CACHE_MISS",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
elif cached_data is not None:
|
||
|
|
await SSOAuthenticationHandler._delete_pkce_verifier(cache_key)
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"PKCE verifier for state '%s' has an unrecognized format (type=%s); "
|
||
|
|
"treating as a cache miss. Investigate the cached value — it may be "
|
||
|
|
"a corrupt or stale entry.",
|
||
|
|
state,
|
||
|
|
type(cached_data).__name__,
|
||
|
|
)
|
||
|
|
raise ProxyException(
|
||
|
|
message=(
|
||
|
|
f"PKCE verifier for state '{state}' has an unrecognized format "
|
||
|
|
f"(type={type(cached_data).__name__}). The cached entry may be corrupt."
|
||
|
|
),
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="PKCE_CACHE_MISS",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
if redis_usage_cache is not None:
|
||
|
|
cause = (
|
||
|
|
"The authorization and callback were likely handled by different "
|
||
|
|
"instances — the verifier was stored on one pod but not found on another."
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
cause = (
|
||
|
|
"The verifier may have expired (TTL), been lost on a pod restart, "
|
||
|
|
"or the PKCE authorization step was never completed. "
|
||
|
|
"Configure Redis so all proxy instances share the PKCE verifier."
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"PKCE is enabled but no verifier found in cache for state '%s'. "
|
||
|
|
"%s Cache type: %s.",
|
||
|
|
state,
|
||
|
|
cause,
|
||
|
|
type(active_cache).__name__,
|
||
|
|
)
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"PKCE verifier not found in cache for state '{state}'. {cause}",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="PKCE_CACHE_MISS",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
if cached_data is not None:
|
||
|
|
await SSOAuthenticationHandler._delete_pkce_verifier(cache_key)
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"PKCE is enabled but verifier not found in cache for state '%s' "
|
||
|
|
"(cache type: %s, raw data present: %s). "
|
||
|
|
"Continuing without code_verifier — set PKCE_STRICT_CACHE_MISS=true to fail fast instead.",
|
||
|
|
state,
|
||
|
|
type(active_cache).__name__,
|
||
|
|
cached_data is not None,
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def _delete_pkce_verifier(cache_key: str) -> None:
|
||
|
|
"""Delete a single-use PKCE verifier from cache after a successful exchange.
|
||
|
|
|
||
|
|
Failure is non-fatal: a leftover verifier is a minor security concern
|
||
|
|
(unused key in cache) but not worth aborting an otherwise-successful login.
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache
|
||
|
|
|
||
|
|
try:
|
||
|
|
if redis_usage_cache is not None:
|
||
|
|
await redis_usage_cache.async_delete_cache(key=cache_key)
|
||
|
|
else:
|
||
|
|
await user_api_key_cache.async_delete_cache(key=cache_key)
|
||
|
|
except Exception as exc:
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"PKCE: failed to delete verifier cache key '%s' (best-effort cleanup): %s",
|
||
|
|
cache_key,
|
||
|
|
exc,
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def generate_pkce_params() -> Tuple[str, str]:
|
||
|
|
"""
|
||
|
|
Generate PKCE (Proof Key for Code Exchange) parameters for OAuth 2.0.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple[str, str]: (code_verifier, code_challenge)
|
||
|
|
- code_verifier: Random 43-128 character string (we use 43 for efficiency)
|
||
|
|
- code_challenge: Base64-URL-encoded SHA256 hash of the code_verifier
|
||
|
|
|
||
|
|
Reference: https://datatracker.ietf.org/doc/html/rfc7636
|
||
|
|
"""
|
||
|
|
# Generate a cryptographically random code_verifier (43 characters)
|
||
|
|
# Using 32 random bytes which becomes 43 characters when base64-url-encoded
|
||
|
|
code_verifier = (
|
||
|
|
base64.urlsafe_b64encode(secrets.token_bytes(32))
|
||
|
|
.decode("utf-8")
|
||
|
|
.rstrip("=")
|
||
|
|
)
|
||
|
|
|
||
|
|
# Generate code_challenge using S256 method (SHA256)
|
||
|
|
code_challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||
|
|
code_challenge = (
|
||
|
|
base64.urlsafe_b64encode(code_challenge_bytes).decode("utf-8").rstrip("=")
|
||
|
|
)
|
||
|
|
|
||
|
|
return code_verifier, code_challenge
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _validate_token_response(response: "httpx.Response") -> dict:
|
||
|
|
"""
|
||
|
|
Parse and validate the token endpoint response.
|
||
|
|
|
||
|
|
Ensures the response is valid JSON, a dict, and contains a non-null
|
||
|
|
access_token string. Raises ProxyException on any validation failure.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
token_response_raw = response.json()
|
||
|
|
except Exception as json_err:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"Failed to parse token response as JSON: %s. Body: %s",
|
||
|
|
json_err,
|
||
|
|
response.text[:500],
|
||
|
|
)
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"Token endpoint returned invalid JSON: {json_err}",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="token_exchange",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not isinstance(token_response_raw, dict):
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"Token endpoint returned non-dict JSON (type=%s). Body: %s",
|
||
|
|
type(token_response_raw).__name__,
|
||
|
|
response.text[:500],
|
||
|
|
)
|
||
|
|
raise ProxyException(
|
||
|
|
message=(
|
||
|
|
f"Token endpoint returned unexpected response format "
|
||
|
|
f"(expected JSON object, got {type(token_response_raw).__name__})"
|
||
|
|
),
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="token_exchange",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
token_response: dict = token_response_raw
|
||
|
|
|
||
|
|
access_token_val = token_response.get("access_token")
|
||
|
|
if not isinstance(access_token_val, str) or not access_token_val:
|
||
|
|
error = token_response.get("error")
|
||
|
|
error_desc = token_response.get("error_description", "")
|
||
|
|
if error:
|
||
|
|
detail = f"{error} - {error_desc}" if error_desc else error
|
||
|
|
else:
|
||
|
|
detail = (
|
||
|
|
"token endpoint returned HTTP 200 but no access_token "
|
||
|
|
f"(response keys: {sorted(token_response.keys())})"
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"Token response missing or null access_token. detail=%s", detail
|
||
|
|
)
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"Token exchange failed: {detail}",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="token_exchange",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
|
||
|
|
return token_response
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def _pkce_token_exchange(
|
||
|
|
authorization_code: str,
|
||
|
|
code_verifier: str,
|
||
|
|
client_id: str,
|
||
|
|
client_secret: Optional[str],
|
||
|
|
token_endpoint: str,
|
||
|
|
userinfo_endpoint: Optional[str],
|
||
|
|
include_client_id: bool,
|
||
|
|
redirect_url: Optional[str],
|
||
|
|
additional_headers: Dict[str, str],
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Performs a direct OAuth token exchange including the PKCE code_verifier.
|
||
|
|
|
||
|
|
fastapi-sso does not forward code_verifier, so when PKCE is enabled we
|
||
|
|
bypass it and call the token endpoint ourselves, then fetch user info.
|
||
|
|
|
||
|
|
Returns a combined dict of the token response and user info, suitable
|
||
|
|
for passing to a response_convertor.
|
||
|
|
"""
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"PKCE: performing direct token exchange (code_verifier length=%d)",
|
||
|
|
len(code_verifier),
|
||
|
|
)
|
||
|
|
|
||
|
|
token_data: Dict[str, str] = {
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"code": authorization_code,
|
||
|
|
"code_verifier": code_verifier,
|
||
|
|
}
|
||
|
|
# Only include redirect_uri when set — omitting it avoids sending the
|
||
|
|
# literal string "None" to the provider if the env var is missing.
|
||
|
|
if redirect_url:
|
||
|
|
token_data["redirect_uri"] = redirect_url
|
||
|
|
|
||
|
|
request_headers = {
|
||
|
|
**additional_headers,
|
||
|
|
"Content-Type": "application/x-www-form-urlencoded", # must not be overridden
|
||
|
|
"Accept": "application/json",
|
||
|
|
}
|
||
|
|
|
||
|
|
if not include_client_id:
|
||
|
|
# Use Basic Auth only when a secret is available; public PKCE clients omit it.
|
||
|
|
if client_secret:
|
||
|
|
credentials = base64.b64encode(
|
||
|
|
f"{client_id}:{client_secret}".encode()
|
||
|
|
).decode()
|
||
|
|
request_headers["Authorization"] = f"Basic {credentials}"
|
||
|
|
else:
|
||
|
|
token_data["client_id"] = client_id
|
||
|
|
else:
|
||
|
|
token_data["client_id"] = client_id
|
||
|
|
if client_secret:
|
||
|
|
token_data["client_secret"] = client_secret
|
||
|
|
|
||
|
|
http_client = get_async_httpx_client(
|
||
|
|
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
response = await http_client.post(
|
||
|
|
url=token_endpoint,
|
||
|
|
data=token_data,
|
||
|
|
headers=request_headers,
|
||
|
|
timeout=30.0,
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
# Catch network-level errors (SSL, DNS, TCP, timeout, etc.) and
|
||
|
|
# wrap them as a clean ProxyException rather than leaking raw
|
||
|
|
# httpx or OS exceptions to callers.
|
||
|
|
verbose_proxy_logger.error("PKCE token endpoint unreachable: %s", exc)
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"Token endpoint request failed: {exc}",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="token_exchange",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
) from exc
|
||
|
|
if response.status_code != 200:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"PKCE token exchange failed. status=%s body=%s",
|
||
|
|
response.status_code,
|
||
|
|
response.text[:500],
|
||
|
|
)
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"Token exchange failed: {response.status_code} - {response.text[:500]}",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="token_exchange",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
|
||
|
|
token_response = SSOAuthenticationHandler._validate_token_response(response)
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"PKCE token exchange successful. id_token_present=%s",
|
||
|
|
bool(token_response.get("id_token")),
|
||
|
|
)
|
||
|
|
# Bearer credentials (access_token, id_token, refresh_token) are always sourced
|
||
|
|
# from token_response — not from userinfo — in the merge step below.
|
||
|
|
userinfo = await SSOAuthenticationHandler._get_pkce_userinfo(
|
||
|
|
access_token=token_response["access_token"],
|
||
|
|
id_token=token_response.get("id_token"),
|
||
|
|
userinfo_endpoint=userinfo_endpoint,
|
||
|
|
additional_headers=additional_headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Merge: userinfo takes precedence for identity claims (sub, email, name, …) per
|
||
|
|
# the OpenID Connect spec (userinfo is the authoritative source for identity).
|
||
|
|
# Bearer credentials (access_token, id_token, refresh_token) from the token endpoint
|
||
|
|
# take precedence over same-named fields in userinfo — non-standard providers sometimes
|
||
|
|
# include token fields in userinfo, which must not shadow the real bearer token.
|
||
|
|
# If a bearer field is absent from the token response, any userinfo-provided value
|
||
|
|
# is preserved as a fallback (useful for non-standard providers that omit id_token
|
||
|
|
# from the token response but include it in userinfo).
|
||
|
|
#
|
||
|
|
# Three-way merge semantics for each bearer-credential field:
|
||
|
|
# 1. token_response has a non-null value → use it (token endpoint is authoritative)
|
||
|
|
# 2. token_response explicitly sent null → remove the key so callers get a clean
|
||
|
|
# absence signal; the null from the token endpoint overrides userinfo too
|
||
|
|
# 3. field absent from token_response → leave whatever userinfo provided as-is
|
||
|
|
# (e.g. userinfo-provided id_token from a non-standard provider)
|
||
|
|
merged = {**token_response, **userinfo}
|
||
|
|
for field in _OAUTH_TOKEN_FIELDS:
|
||
|
|
if token_response.get(field) is not None:
|
||
|
|
# Case 1: non-null in token_response — restore authoritative value.
|
||
|
|
merged[field] = token_response[field]
|
||
|
|
elif field in token_response:
|
||
|
|
# Case 2: key exists but value is explicitly null — remove from merged.
|
||
|
|
merged.pop(field, None)
|
||
|
|
# Case 3: field absent from token_response — leave userinfo value as-is.
|
||
|
|
return merged
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def _get_pkce_userinfo(
|
||
|
|
access_token: str,
|
||
|
|
id_token: Optional[str],
|
||
|
|
userinfo_endpoint: Optional[str],
|
||
|
|
additional_headers: Dict[str, str],
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Fetches user info from the userinfo endpoint.
|
||
|
|
Falls back to decoding the id_token if the endpoint is unavailable.
|
||
|
|
"""
|
||
|
|
# None = request not yet attempted, failed, or returned empty/null (treated as failure
|
||
|
|
# so the id_token fallback can be attempted instead of returning a session with no claims).
|
||
|
|
userinfo: Optional[dict] = None
|
||
|
|
|
||
|
|
if userinfo_endpoint:
|
||
|
|
try:
|
||
|
|
client = get_async_httpx_client(
|
||
|
|
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
||
|
|
)
|
||
|
|
resp = await client.get(
|
||
|
|
url=userinfo_endpoint,
|
||
|
|
headers={
|
||
|
|
**additional_headers,
|
||
|
|
"Authorization": f"Bearer {access_token}", # must not be overridden
|
||
|
|
},
|
||
|
|
)
|
||
|
|
if resp.status_code == 200:
|
||
|
|
try:
|
||
|
|
userinfo_raw = resp.json()
|
||
|
|
if not userinfo_raw:
|
||
|
|
# JSON null (None) or empty dict ({}) — no identity claims.
|
||
|
|
# Treat as failure so id_token fallback can be attempted.
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"Userinfo endpoint returned an empty or null response "
|
||
|
|
"(type=%s); treating as failure and attempting id_token fallback. "
|
||
|
|
"Check your provider's userinfo endpoint configuration.",
|
||
|
|
type(userinfo_raw).__name__,
|
||
|
|
)
|
||
|
|
userinfo = None
|
||
|
|
else:
|
||
|
|
userinfo = userinfo_raw
|
||
|
|
except Exception as json_err:
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"Userinfo endpoint returned non-JSON response (status 200): %s",
|
||
|
|
json_err,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"Userinfo endpoint returned %s (body: %s), falling back to id_token",
|
||
|
|
resp.status_code,
|
||
|
|
resp.text[:500],
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"Userinfo endpoint error: %s, falling back to id_token", e
|
||
|
|
)
|
||
|
|
|
||
|
|
# Only fall back to id_token when the userinfo request failed (None).
|
||
|
|
# Empty dict ({}) and JSON null are both treated as failure (set to None above) since
|
||
|
|
# they contain no identity claims — id_token fallback is attempted in that case too.
|
||
|
|
# Explicitly check for a non-empty string to avoid attempting JWT decode on
|
||
|
|
# a blank or non-string id_token field from a misbehaving provider.
|
||
|
|
if userinfo is None and isinstance(id_token, str) and id_token:
|
||
|
|
try:
|
||
|
|
userinfo = jwt.decode(id_token, options={"verify_signature": False})
|
||
|
|
if not userinfo:
|
||
|
|
# jwt.decode returned an empty dict (payload-free JWT or provider bug).
|
||
|
|
# Treat this the same as a missing userinfo — the session would have no
|
||
|
|
# identity claims, which is equivalent to a broken session.
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"id_token decoded to an empty payload — treating as failure."
|
||
|
|
)
|
||
|
|
userinfo = None
|
||
|
|
except Exception as decode_err:
|
||
|
|
verbose_proxy_logger.error("Failed to decode id_token: %s", decode_err)
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"Failed to decode id_token JWT: {decode_err}",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="userinfo",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
|
||
|
|
if userinfo is None:
|
||
|
|
id_token_attempted = isinstance(id_token, str) and bool(id_token)
|
||
|
|
if userinfo_endpoint:
|
||
|
|
if id_token_attempted:
|
||
|
|
detail = (
|
||
|
|
"userinfo endpoint failed and id_token was present but "
|
||
|
|
"decoded to an empty payload — no identity claims available"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
detail = "userinfo endpoint failed and no id_token was present in the token response"
|
||
|
|
else:
|
||
|
|
if id_token_attempted:
|
||
|
|
detail = (
|
||
|
|
"no userinfo endpoint is configured (GENERIC_USERINFO_ENDPOINT) "
|
||
|
|
"and id_token decoded to an empty payload — no identity claims available"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
detail = "no userinfo endpoint is configured (GENERIC_USERINFO_ENDPOINT) and no id_token was present"
|
||
|
|
raise ProxyException(
|
||
|
|
message=f"SSO user info unavailable: {detail}.",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="userinfo",
|
||
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
)
|
||
|
|
|
||
|
|
return userinfo
|
||
|
|
|
||
|
|
|
||
|
|
class MicrosoftSSOHandler:
|
||
|
|
"""
|
||
|
|
Handles Microsoft SSO callback response and returns a CustomOpenID object
|
||
|
|
"""
|
||
|
|
|
||
|
|
graph_api_base_url = "https://graph.microsoft.com/v1.0"
|
||
|
|
graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf"
|
||
|
|
|
||
|
|
"""
|
||
|
|
Constants
|
||
|
|
"""
|
||
|
|
MAX_GRAPH_API_PAGES = 200
|
||
|
|
|
||
|
|
# used for debugging to show the user groups litellm found from Graph API
|
||
|
|
GRAPH_API_RESPONSE_KEY = "graph_api_user_groups"
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_microsoft_callback_response(
|
||
|
|
request: Request,
|
||
|
|
microsoft_client_id: str,
|
||
|
|
redirect_url: str,
|
||
|
|
return_raw_sso_response: bool = False,
|
||
|
|
) -> Union[CustomOpenID, OpenID, dict]:
|
||
|
|
"""
|
||
|
|
Get the Microsoft SSO callback response
|
||
|
|
|
||
|
|
Args:
|
||
|
|
return_raw_sso_response: If True, return the raw SSO response
|
||
|
|
"""
|
||
|
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
||
|
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
||
|
|
if microsoft_client_secret is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="MICROSOFT_CLIENT_SECRET",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
if microsoft_tenant is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="MICROSOFT_TENANT not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="MICROSOFT_TENANT",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
microsoft_sso = CustomMicrosoftSSO(
|
||
|
|
client_id=microsoft_client_id,
|
||
|
|
client_secret=microsoft_client_secret,
|
||
|
|
tenant=microsoft_tenant,
|
||
|
|
redirect_uri=redirect_url,
|
||
|
|
allow_insecure_http=True,
|
||
|
|
)
|
||
|
|
original_msft_result = (
|
||
|
|
await microsoft_sso.verify_and_process(
|
||
|
|
request=request,
|
||
|
|
convert_response=False, # type: ignore
|
||
|
|
)
|
||
|
|
or {}
|
||
|
|
)
|
||
|
|
|
||
|
|
user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
|
||
|
|
access_token=microsoft_sso.access_token
|
||
|
|
)
|
||
|
|
|
||
|
|
# Extract app roles from the id_token JWT
|
||
|
|
app_roles = MicrosoftSSOHandler.get_app_roles_from_id_token(
|
||
|
|
id_token=microsoft_sso.id_token
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(f"Extracted app roles from id_token: {app_roles}")
|
||
|
|
|
||
|
|
# Combine groups and app roles
|
||
|
|
user_role: Optional[LitellmUserRoles] = None
|
||
|
|
if app_roles:
|
||
|
|
# Check if any app role is a valid LitellmUserRoles
|
||
|
|
for role_str in app_roles:
|
||
|
|
role = get_litellm_user_role(role_str)
|
||
|
|
if role is not None:
|
||
|
|
user_role = role
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Found valid LitellmUserRoles '{role.value}' in app_roles"
|
||
|
|
)
|
||
|
|
break
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Combined team_ids (groups + app roles): {user_team_ids}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||
|
|
if return_raw_sso_response:
|
||
|
|
original_msft_result[
|
||
|
|
MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY
|
||
|
|
] = user_team_ids
|
||
|
|
original_msft_result["app_roles"] = app_roles
|
||
|
|
return original_msft_result or {}
|
||
|
|
|
||
|
|
result = MicrosoftSSOHandler.openid_from_response(
|
||
|
|
response=original_msft_result,
|
||
|
|
team_ids=user_team_ids,
|
||
|
|
user_role=user_role,
|
||
|
|
)
|
||
|
|
return result
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def openid_from_response(
|
||
|
|
response: Optional[dict],
|
||
|
|
team_ids: List[str],
|
||
|
|
user_role: Optional[LitellmUserRoles],
|
||
|
|
) -> CustomOpenID:
|
||
|
|
response = response or {}
|
||
|
|
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
|
||
|
|
openid_response = CustomOpenID(
|
||
|
|
email=normalize_email(
|
||
|
|
response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail")
|
||
|
|
),
|
||
|
|
display_name=response.get(MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE),
|
||
|
|
provider="microsoft",
|
||
|
|
id=response.get(MICROSOFT_USER_ID_ATTRIBUTE),
|
||
|
|
first_name=response.get(MICROSOFT_USER_FIRST_NAME_ATTRIBUTE),
|
||
|
|
last_name=response.get(MICROSOFT_USER_LAST_NAME_ATTRIBUTE),
|
||
|
|
team_ids=team_ids,
|
||
|
|
user_role=user_role,
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
|
||
|
|
return openid_response
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_app_roles_from_id_token(id_token: Optional[str]) -> List[str]:
|
||
|
|
"""
|
||
|
|
Extract app roles from the Microsoft Entra ID (Azure AD) id_token JWT.
|
||
|
|
|
||
|
|
App roles are assigned in the Azure AD Enterprise Application and appear
|
||
|
|
in the 'app_roles' claim of the id_token.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
id_token (Optional[str]): The JWT id_token from Microsoft SSO
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[str]: List of app role names assigned to the user
|
||
|
|
"""
|
||
|
|
if not id_token:
|
||
|
|
verbose_proxy_logger.debug("No id_token provided for app role extraction")
|
||
|
|
return []
|
||
|
|
|
||
|
|
try:
|
||
|
|
import jwt
|
||
|
|
|
||
|
|
# Decode the JWT without signature verification
|
||
|
|
# (signature is already verified by fastapi_sso)
|
||
|
|
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
|
||
|
|
|
||
|
|
# Extract app_roles claim from the token
|
||
|
|
## check for both 'roles' and 'app_roles' claims
|
||
|
|
roles = decoded_token.get("app_roles", []) or decoded_token.get("roles", [])
|
||
|
|
|
||
|
|
if roles and isinstance(roles, list):
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Found {len(roles)} app role(s) in id_token: {roles}"
|
||
|
|
)
|
||
|
|
return roles
|
||
|
|
else:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"No app roles found in id_token or roles claim is not a list"
|
||
|
|
)
|
||
|
|
return []
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Error extracting app roles from id_token: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_user_groups_from_graph_api(
|
||
|
|
access_token: Optional[str] = None,
|
||
|
|
) -> List[str]:
|
||
|
|
"""
|
||
|
|
Returns a list of `team_ids` the user belongs to from the Microsoft Graph API
|
||
|
|
|
||
|
|
Args:
|
||
|
|
access_token (Optional[str]): Microsoft Graph API access token
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[str]: List of group IDs the user belongs to
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
async_client = get_async_httpx_client(
|
||
|
|
llm_provider=httpxSpecialProvider.SSO_HANDLER
|
||
|
|
)
|
||
|
|
|
||
|
|
# Handle MSFT Enterprise Application Groups
|
||
|
|
service_principal_id = os.getenv("MICROSOFT_SERVICE_PRINCIPAL_ID", None)
|
||
|
|
service_principal_group_ids: Optional[List[str]] = []
|
||
|
|
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
|
||
|
|
if service_principal_id:
|
||
|
|
(
|
||
|
|
service_principal_group_ids,
|
||
|
|
service_principal_teams,
|
||
|
|
) = await MicrosoftSSOHandler.get_group_ids_from_service_principal(
|
||
|
|
service_principal_id=service_principal_id,
|
||
|
|
async_client=async_client,
|
||
|
|
access_token=access_token,
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Service principal group IDs: {service_principal_group_ids}"
|
||
|
|
)
|
||
|
|
if len(service_principal_group_ids) > 0:
|
||
|
|
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
|
||
|
|
service_principal_teams=service_principal_teams,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Fetch user membership from Microsoft Graph API
|
||
|
|
all_group_ids = []
|
||
|
|
next_link: Optional[
|
||
|
|
str
|
||
|
|
] = MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||
|
|
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
||
|
|
page_count = 0
|
||
|
|
|
||
|
|
while (
|
||
|
|
next_link is not None
|
||
|
|
and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
|
||
|
|
):
|
||
|
|
group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups(
|
||
|
|
url=next_link, headers=auth_headers, async_client=async_client
|
||
|
|
)
|
||
|
|
all_group_ids.extend(group_ids)
|
||
|
|
page_count += 1
|
||
|
|
|
||
|
|
if (
|
||
|
|
next_link is not None
|
||
|
|
and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
|
||
|
|
):
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
|
||
|
|
)
|
||
|
|
|
||
|
|
# If service_principal_group_ids is not empty, only return group_ids that are in both all_group_ids and service_principal_group_ids
|
||
|
|
if service_principal_group_ids and len(service_principal_group_ids) > 0:
|
||
|
|
all_group_ids = [
|
||
|
|
group_id
|
||
|
|
for group_id in all_group_ids
|
||
|
|
if group_id in service_principal_group_ids
|
||
|
|
]
|
||
|
|
|
||
|
|
return all_group_ids
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
f"Error getting user groups from Microsoft Graph API: {e}"
|
||
|
|
)
|
||
|
|
return []
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def fetch_and_parse_groups(
|
||
|
|
url: str, headers: dict, async_client: AsyncHTTPHandler
|
||
|
|
) -> Tuple[List[str], Optional[str]]:
|
||
|
|
"""Helper function to fetch and parse group data from a URL"""
|
||
|
|
response = await async_client.get(url, headers=headers)
|
||
|
|
response_json = response.json()
|
||
|
|
response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict(
|
||
|
|
response=response_json
|
||
|
|
)
|
||
|
|
group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(
|
||
|
|
response=response_typed
|
||
|
|
)
|
||
|
|
return group_ids, response_typed.get("odata_nextLink")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _get_group_ids_from_graph_api_response(
|
||
|
|
response: MicrosoftGraphAPIUserGroupResponse,
|
||
|
|
) -> List[str]:
|
||
|
|
group_ids = []
|
||
|
|
for _object in response.get("value", []) or []:
|
||
|
|
_group_id = _object.get("id")
|
||
|
|
if _group_id is not None:
|
||
|
|
group_ids.append(_group_id)
|
||
|
|
return group_ids
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def _cast_graph_api_response_dict(
|
||
|
|
response: dict,
|
||
|
|
) -> MicrosoftGraphAPIUserGroupResponse:
|
||
|
|
directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = []
|
||
|
|
for _object in response.get("value", []):
|
||
|
|
directory_objects.append(
|
||
|
|
MicrosoftGraphAPIUserGroupDirectoryObject(
|
||
|
|
odata_type=_object.get("@odata.type"),
|
||
|
|
id=_object.get("id"),
|
||
|
|
deletedDateTime=_object.get("deletedDateTime"),
|
||
|
|
description=_object.get("description"),
|
||
|
|
displayName=_object.get("displayName"),
|
||
|
|
roleTemplateId=_object.get("roleTemplateId"),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return MicrosoftGraphAPIUserGroupResponse(
|
||
|
|
odata_context=response.get("@odata.context"),
|
||
|
|
odata_nextLink=response.get("@odata.nextLink"),
|
||
|
|
value=directory_objects,
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_group_ids_from_service_principal(
|
||
|
|
service_principal_id: str,
|
||
|
|
async_client: AsyncHTTPHandler,
|
||
|
|
access_token: Optional[str] = None,
|
||
|
|
) -> Tuple[List[str], List[MicrosoftServicePrincipalTeam]]:
|
||
|
|
"""
|
||
|
|
Gets the groups belonging to the Service Principal Application
|
||
|
|
|
||
|
|
Service Principal Id is an `Enterprise Application` in Azure AD
|
||
|
|
|
||
|
|
Users use Enterprise Applications to manage Groups and Users on Microsoft Entra ID
|
||
|
|
"""
|
||
|
|
base_url = "https://graph.microsoft.com/v1.0"
|
||
|
|
# Endpoint to get app role assignments for the given service principal
|
||
|
|
endpoint = f"/servicePrincipals/{service_principal_id}/appRoleAssignedTo"
|
||
|
|
url = base_url + endpoint
|
||
|
|
|
||
|
|
headers = {
|
||
|
|
"Authorization": f"Bearer {access_token}",
|
||
|
|
"Content-Type": "application/json",
|
||
|
|
}
|
||
|
|
|
||
|
|
response = await async_client.get(url, headers=headers)
|
||
|
|
response_json = response.json()
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Response from service principal app role assigned to: {response_json}"
|
||
|
|
)
|
||
|
|
group_ids: List[str] = []
|
||
|
|
service_principal_teams: List[MicrosoftServicePrincipalTeam] = []
|
||
|
|
|
||
|
|
for _object in response_json.get("value", []):
|
||
|
|
if _object.get("principalType") == "Group":
|
||
|
|
# Append the group ID to the list
|
||
|
|
group_ids.append(_object.get("principalId"))
|
||
|
|
# Append the service principal team to the list
|
||
|
|
service_principal_teams.append(
|
||
|
|
MicrosoftServicePrincipalTeam(
|
||
|
|
principalDisplayName=_object.get("principalDisplayName"),
|
||
|
|
principalId=_object.get("principalId"),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
return group_ids, service_principal_teams
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def create_litellm_teams_from_service_principal_team_ids(
|
||
|
|
service_principal_teams: List[MicrosoftServicePrincipalTeam],
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Creates Litellm Teams from the Service Principal Group IDs
|
||
|
|
|
||
|
|
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
|
||
|
|
"""
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
|
||
|
|
)
|
||
|
|
for service_principal_team in service_principal_teams:
|
||
|
|
litellm_team_id: Optional[str] = service_principal_team.get("principalId")
|
||
|
|
litellm_team_name: Optional[str] = service_principal_team.get(
|
||
|
|
"principalDisplayName"
|
||
|
|
)
|
||
|
|
if not litellm_team_id:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Skipping team creation for {litellm_team_name} because it has no principalId"
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
|
||
|
|
await SSOAuthenticationHandler.create_litellm_team_from_sso_group(
|
||
|
|
litellm_team_id=litellm_team_id,
|
||
|
|
litellm_team_name=litellm_team_name,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class GoogleSSOHandler:
|
||
|
|
"""
|
||
|
|
Handles Google SSO callback response and returns a CustomOpenID object
|
||
|
|
"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_google_callback_response(
|
||
|
|
request: Request,
|
||
|
|
google_client_id: str,
|
||
|
|
redirect_url: str,
|
||
|
|
return_raw_sso_response: bool = False,
|
||
|
|
) -> Union[OpenID, dict]:
|
||
|
|
"""
|
||
|
|
Get the Google SSO callback response
|
||
|
|
|
||
|
|
Args:
|
||
|
|
return_raw_sso_response: If True, return the raw SSO response
|
||
|
|
"""
|
||
|
|
from fastapi_sso.sso.google import GoogleSSO
|
||
|
|
|
||
|
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
||
|
|
if google_client_secret is None:
|
||
|
|
raise ProxyException(
|
||
|
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="GOOGLE_CLIENT_SECRET",
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
google_sso = GoogleSSO(
|
||
|
|
client_id=google_client_id,
|
||
|
|
redirect_uri=redirect_url,
|
||
|
|
client_secret=google_client_secret,
|
||
|
|
)
|
||
|
|
|
||
|
|
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||
|
|
if return_raw_sso_response:
|
||
|
|
return (
|
||
|
|
await google_sso.verify_and_process(
|
||
|
|
request=request,
|
||
|
|
convert_response=False, # type: ignore
|
||
|
|
)
|
||
|
|
or {}
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await google_sso.verify_and_process(request)
|
||
|
|
return result or {}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False)
|
||
|
|
async def debug_sso_login(request: Request):
|
||
|
|
"""
|
||
|
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
||
|
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||
|
|
Example:
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import premium_user
|
||
|
|
|
||
|
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||
|
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||
|
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||
|
|
|
||
|
|
####### Check if user is a Enterprise / Premium User #######
|
||
|
|
if (
|
||
|
|
microsoft_client_id is not None
|
||
|
|
or google_client_id is not None
|
||
|
|
or generic_client_id is not None
|
||
|
|
):
|
||
|
|
if premium_user is not True:
|
||
|
|
raise ProxyException(
|
||
|
|
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||
|
|
type=ProxyErrorTypes.auth_error,
|
||
|
|
param="premium_user",
|
||
|
|
code=status.HTTP_403_FORBIDDEN,
|
||
|
|
)
|
||
|
|
|
||
|
|
# get url from request
|
||
|
|
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
|
||
|
|
request=request,
|
||
|
|
sso_callback_route="sso/debug/callback",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check if we should use SSO handler
|
||
|
|
if (
|
||
|
|
SSOAuthenticationHandler.should_use_sso_handler(
|
||
|
|
microsoft_client_id=microsoft_client_id,
|
||
|
|
google_client_id=google_client_id,
|
||
|
|
generic_client_id=generic_client_id,
|
||
|
|
)
|
||
|
|
is True
|
||
|
|
):
|
||
|
|
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
microsoft_client_id=microsoft_client_id,
|
||
|
|
google_client_id=google_client_id,
|
||
|
|
generic_client_id=generic_client_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False)
|
||
|
|
async def debug_sso_callback(request: Request):
|
||
|
|
"""
|
||
|
|
Returns the OpenID object returned by the SSO provider
|
||
|
|
"""
|
||
|
|
import json
|
||
|
|
|
||
|
|
from fastapi.responses import HTMLResponse
|
||
|
|
|
||
|
|
from litellm.proxy._types import LiteLLM_JWTAuth
|
||
|
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||
|
|
from litellm.proxy.proxy_server import (
|
||
|
|
general_settings,
|
||
|
|
jwt_handler,
|
||
|
|
prisma_client,
|
||
|
|
user_api_key_cache,
|
||
|
|
)
|
||
|
|
|
||
|
|
sso_jwt_handler: Optional[JWTHandler] = None
|
||
|
|
ui_access_mode = general_settings.get("ui_access_mode", None)
|
||
|
|
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
|
||
|
|
sso_jwt_handler = JWTHandler()
|
||
|
|
sso_jwt_handler.update_environment(
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
litellm_jwtauth=LiteLLM_JWTAuth(
|
||
|
|
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
|
||
|
|
"sso_group_jwt_field", None
|
||
|
|
),
|
||
|
|
),
|
||
|
|
leeway=0,
|
||
|
|
)
|
||
|
|
|
||
|
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||
|
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||
|
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||
|
|
|
||
|
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||
|
|
if redirect_url.endswith("/"):
|
||
|
|
redirect_url += "sso/debug/callback"
|
||
|
|
else:
|
||
|
|
redirect_url += "/sso/debug/callback"
|
||
|
|
|
||
|
|
result = None
|
||
|
|
if google_client_id is not None:
|
||
|
|
result = await GoogleSSOHandler.get_google_callback_response(
|
||
|
|
request=request,
|
||
|
|
google_client_id=google_client_id,
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
return_raw_sso_response=True,
|
||
|
|
)
|
||
|
|
elif microsoft_client_id is not None:
|
||
|
|
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
|
||
|
|
request=request,
|
||
|
|
microsoft_client_id=microsoft_client_id,
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
return_raw_sso_response=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
elif generic_client_id is not None:
|
||
|
|
result, _ = await get_generic_sso_response(
|
||
|
|
request=request,
|
||
|
|
jwt_handler=jwt_handler,
|
||
|
|
generic_client_id=generic_client_id,
|
||
|
|
redirect_url=redirect_url,
|
||
|
|
sso_jwt_handler=sso_jwt_handler,
|
||
|
|
)
|
||
|
|
|
||
|
|
# If result is None, return a basic error message
|
||
|
|
if result is None:
|
||
|
|
return HTMLResponse(
|
||
|
|
content="<h1>SSO Authentication Failed</h1><p>No data was returned from the SSO provider.</p>",
|
||
|
|
status_code=400,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Convert the OpenID object to a dictionary
|
||
|
|
if hasattr(result, "__dict__"):
|
||
|
|
result_dict = result.__dict__
|
||
|
|
else:
|
||
|
|
result_dict = dict(result)
|
||
|
|
|
||
|
|
# Filter out any None values and convert to JSON serializable format
|
||
|
|
filtered_result = {}
|
||
|
|
for key, value in result_dict.items():
|
||
|
|
if value is not None and not key.startswith("_"):
|
||
|
|
if isinstance(value, (str, int, float, bool)) or value is None:
|
||
|
|
filtered_result[key] = value
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
# Try to convert to string or another JSON serializable format
|
||
|
|
filtered_result[key] = str(value)
|
||
|
|
except Exception as e:
|
||
|
|
filtered_result[key] = f"Complex value (not displayable): {str(e)}"
|
||
|
|
|
||
|
|
# Replace the placeholder in the template with the actual data
|
||
|
|
html_content = jwt_display_template.replace(
|
||
|
|
"const userData = SSO_DATA;",
|
||
|
|
f"const userData = {json.dumps(filtered_result, indent=2)};",
|
||
|
|
)
|
||
|
|
|
||
|
|
return HTMLResponse(content=html_content)
|